REGEX_label = r'[a-zA-Z90-9]([a-zA-Z90-9-]{0,61}[a-zA-Z90-9])?' # max. 63 characters; must not start or end with hyphen
REGEX_ipv4 = r'^\d{1,3}(\.\d{1,3}){3}$'
-REGEX_ipv6 = r'^[a-fA-F0-9]{1,4}(:[a-fA-F0-9]{1,4}){7}$'
+REGEX_ipv6 = r'^[a-fA-F0-9]{1,4}(::?[a-fA-F0-9]{1,4}){1,7}$'
def check_label(label: str) -> str:
label = str(label)
return data
raise Exception(data+" is not valid hex data")
+def check_base64(data: str) -> str:
+ data = str(data)
+ if re.match('^[a-zA-Z0-9+/=]+$', data):
+ return data
+ raise Exception(data+" is not valid hex data")
+
+
def check_ipv4(address: str) -> str:
address = str(address)
if re.match(REGEX_ipv4, address):
return str(time)
def column_widths(datas: 'Sequence', widths: 'Sequence[int]'):
- assert len(datas) == len(widths)+1, "There must be as one more data points as widths"
+ assert len(datas) == len(widths)+1, "There must be one more data points as there are widths"
result = ""
width_sum = 0
for data, width in zip(datas, widths): # will *not* cover the last point
# last data point
return result+str(datas[-1])
+def concatenate(root, path):
+ if path == '' or root == '':
+ raise Exception("Empty domain name is not valid")
+ if path == '@':
+ return root
+ if root == '@' or path.endswith('.'):
+ return path
+ return path+"."+root
+
+def escape_TXT(text):
+ for c in ('\\', '\"'):
+ text = text.replace(c, '\\'+c)
+ return text
+
## Enums
class Protocol:
SHA256 = 2
+## Resource records
+class RR:
+ def __init__(self, path, recordType, data):
+ '''<path> can be relative or absolute.'''
+ assert re.match(r'^[A-Z]+$', recordType), "got invalid record type"
+ self.path = path
+ self.recordType = recordType
+ self.data = data
+ self.TTL = None
+
+ def mapPath(self, f):
+ '''Run the path through f. Returns self, for nicer chaining.'''
+ self.path = f(self.path)
+ return self
+
+ def relativize(self, root):
+ return self.mapPath(lambda path: concatenate(root, path))
+
+ def mapTTL(self, f):
+ '''Run the current TTL and the recordType through f.'''
+ self.TTL = f(self.TTL, self.recordType)
+ return self
+
+ def __str__(self):
+ return column_widths((self.path, time(self.TTL), self.recordType, self.data), (8*3, 8, 8))
+
## Record types
class A:
def __init__(self, address: str) -> None:
self._address = check_ipv4(address)
- def generate_rr(self, owner: str, zone: 'Zone') -> 'Any':
- return zone.RR(owner, 'A', self._address)
+ def generate_rr(self):
+ return RR('@', 'A', self._address)
class AAAA:
def __init__(self, address: str) -> None:
self._address = check_ipv6(address)
- def generate_rr(self, owner: str, zone: 'Zone') -> 'Any':
- return zone.RR(owner, 'AAAA', self._address)
+ def generate_rr(self):
+ return RR('@', 'AAAA', self._address)
class MX:
self._priority = int(prio)
self._name = check_hostname(name)
- def generate_rr(self, owner: str, zone: 'Zone') -> 'Any':
- return zone.RR(owner, 'MX', '{0} {1}'.format(self._priority, zone.abs_hostname(self._name)))
+ def generate_rr(self):
+ return RR('@', 'MX', '{0} {1}'.format(self._priority, self._name))
+
+
+class TXT:
+ def __init__(self, text: str) -> None:
+ # test for bad characters
+ for c in ('\n', '\r', '\t'):
+ if c in text:
+ raise Exception("TXT record {0} contains invalid character")
+ self._text = text
+
+ def generate_rr(self):
+ text = escape_TXT(self._text)
+ # split into chunks of max. 255 characters; be careful not to split right after a backslash
+ chunks = re.findall(r'.{0,254}[^\\]', text)
+ assert sum(len(c) for c in chunks) == len (text)
+ chunksep = '"\n' + ' '*20 + '"'
+ chunked = '( "' + chunksep.join(chunks) + '" )'
+ # generate the chunks
+ return RR('@', 'TXT', chunked)
+
+
+class DKIM(TXT): # helper class to treat DKIM more antively
+ class Version:
+ DKIM1 = "DKIM1"
+
+ class Algorithm:
+ RSA = "rsa"
+
+ def __init__(self, selector, version, alg, key):
+ self._selector = check_label(selector)
+ version = check_label(version)
+ alg = check_label(alg)
+ key = check_base64(key)
+ super().__init__("v={0}; k={1}; p={2}".format(version, alg, key))
+
+ def generate_rr(self):
+ return super().generate_rr().relativize('{}._domainkey'.format(self._selector))
class SRV:
self._port = int(port)
self._name = check_hostname(name)
- def generate_rr(self, owner: str, zone: 'Zone') -> 'Any':
- return zone.RR('_{0}._{1}.{2}'.format(self._service, self._protocol, owner), 'SRV',
- '{0} {1} {2} {3}'.format(self._priority, self._weight, self._port, zone.abs_hostname(self._name)))
+ def generate_rr(self):
+ return RR('_{}._{}'.format(self._service, self._protocol), 'SRV',
+ '{} {} {} {}'.format(self._priority, self._weight, self._port, self._name))
class TLSA:
self._matching_type = int(matching_type)
self._data = check_hex(data)
- def generate_rr(self, owner: str, zone: 'Zone') -> 'Any':
- return zone.RR('_{0}._{1}.{2}'.format(self._port, self._protocol, owner), 'TLSA', '{0} {1} {2} {3}'.format(self._usage, self._selector, self._matching_type, self._data))
+ def generate_rr(self):
+ return RR('_{}._{}'.format(self._port, self._protocol), 'TLSA', '{} {} {} {}'.format(self._usage, self._selector, self._matching_type, self._data))
class CNAME:
def __init__(self, name: str) -> None:
self._name = check_hostname(name)
- def generate_rr(self, owner: str, zone: 'Zone') -> 'Any':
- return zone.RR(owner, 'CNAME', zone.abs_hostname(self._name))
+ def generate_rr(self):
+ return RR('@', 'CNAME', self._name)
class NS:
def __init__(self, name: str) -> None:
self._name = check_hostname(name)
- def generate_rr(self, owner: str, zone: 'Zone') -> 'Any':
- return zone.RR(owner, 'NS', zone.abs_hostname(self._name))
+ def generate_rr(self):
+ return RR('@', 'NS', self._name)
class DS:
self._alg = int(alg)
self._digest = int(digest)
- def generate_rr(self, owner: str, zone: 'Zone') -> 'Any':
- return zone.RR(owner, 'DS', '{0} {1} {2} {3}'.format(self._tag, self._alg, self._digest, self._key))
+ def generate_rr(self):
+ return RR('@', 'DS', '{} {} {} {}'.format(self._tag, self._alg, self._digest, self._key))
## Higher-level classes
class Name:
def __init__(self, *records: 'List[Any]') -> None:
self._records = records
- def generate_rrs(self, owner: str, zone: 'Zone') -> 'Iterator':
+ def generate_rrs(self):
for record in self._records:
# this could still be a list
if isinstance(record, list):
for subrecord in record:
- yield subrecord.generate_rr(owner, zone)
+ yield subrecord.generate_rr()
else:
- yield record.generate_rr(owner, zone)
+ yield record.generate_rr()
def CName(name: str) -> Name:
return Name(CNAME(name))
-def Delegation(name: str) -> Name:
- return Name(NS(name))
+def Delegation(*names) -> Name:
+ return Name(list(map(NS, names)))
-def SecureDelegation(name: str, tag: int, alg: int, digest: int, key: str) -> Name:
- return Name(NS(name), DS(tag, alg, digest, key))
+def SecureDelegation(tag: int, alg: int, digest: int, key: str, *names) -> Name:
+ return Name(DS(tag, alg, digest, key), list(map(NS, names)))
class Zone:
self._domains = domains
- def getTTL(self, recordType: str) -> str:
- return self._TTLs.get(recordType, self._TTLs[''])
-
- def RR(self, owner: str, recordType: str, data: str) -> str:
- '''generate given RR, in textual representation'''
- assert re.match(r'^[A-Z]+$', recordType), "got invalid record type"
- return column_widths((self.abs_hostname(owner), time(self.getTTL(recordType)), recordType, data), (32, 8, 8))
-
- def abs_hostname(self, name):
- if name == '':
- raise Exception("Empty domain name is not valid")
- if name == '.' or name == '@':
- return self._name
- if name.endswith('.'):
- return name
- return name+"."+self._name
+ def getTTL(self, TTL: int, recordType: str) -> int:
+ if TTL is not None: return TTL
+ # TTL is None, so get a global default
+ return int(self._TTLs.get(recordType, self._TTLs['']))
def inc_serial(self) -> int:
# get serial
# be done
return cur_serial
+ @staticmethod
+ def generate_rrs_from_dict(root, domains):
+ for name in sorted(domains.keys(), key=lambda s: s.split('.')):
+ if name.endswith('.'):
+ raise Exception("You are trying to add a record outside of your zone. This is not supported. Use '@' for the zone root.")
+ domain = domains[name]
+ name = concatenate(root, name)
+ if isinstance(domain, dict):
+ for rr in Zone.generate_rrs_from_dict(name, domain):
+ yield rr
+ else:
+ for rr in domain.generate_rrs():
+ yield rr.relativize(name)
+
def generate_rrs(self) -> 'Iterator':
# SOA record
serial = self.inc_serial()
- yield self.RR(self._name, 'SOA',
+ yield RR('@', 'SOA',
('{NS} {mail} {serial} {refresh} {retry} {expire} {NX_TTL}'+
' ; primns mail serial refresh retry expire NX_TTL').format(
- NS=self.abs_hostname(self._NS[0]), mail=self._mail, serial=serial,
+ NS=self._NS[0], mail=self._mail, serial=serial,
refresh=time(self._refresh), retry=time(self._retry), expire=time(self._expire),
- NX_TTL=time(self.getTTL('NX')))
+ NX_TTL=time(self.getTTL(None, 'NX')))
)
# NS records
for name in self._NS:
- yield NS(name).generate_rr(self._name, self)
+ yield NS(name).generate_rr()
# all the rest
- for name in sorted(self._domains.keys(), key=lambda s: list(reversed(s.split('.')))):
- for rr in self._domains[name].generate_rrs(name, self):
- yield rr
+ for rr in Zone.generate_rrs_from_dict('@', self._domains):
+ yield rr
def write(self) -> None:
- print(";; {0} zone file, generated by zonemaker <https://www.ralfj.de/projects/zonemaker> on {1}".format(self._name, datetime.datetime.now()))
- for rr in self.generate_rrs():
+ print(";; {} zone file, generated by zonemaker <https://www.ralfj.de/projects/zonemaker> on {}".format(self._name, datetime.datetime.now()))
+ print("$ORIGIN {}".format(self._name))
+ for rr in map(lambda rr: rr.mapTTL(self.getTTL), self.generate_rrs()):
print(rr)