X-Git-Url: https://git.ralfj.de/zonemaker.git/blobdiff_plain/cb274ac18d094a7f862335464b29d8816f26eeef..53f825fef45e8d09bd04ab44fd4a5e9e6e0c7626:/zone.py diff --git a/zone.py b/zone.py index d7c23ff..a74ee18 100644 --- a/zone.py +++ b/zone.py @@ -33,7 +33,7 @@ week = 7*day REGEX_label = r'[a-zA-Z90-9]([a-zA-Z90-9-]{0,61}[a-zA-Z90-9])?' # max. 63 characters; must not start or end with hyphen REGEX_ipv4 = r'^\d{1,3}(\.\d{1,3}){3}$' -REGEX_ipv6 = r'^[a-fA-F0-9]{1,4}(:[a-fA-F0-9]{1,4}){7}$' +REGEX_ipv6 = r'^[a-fA-F0-9]{1,4}(::?[a-fA-F0-9]{1,4}){1,7}$' def check_label(label: str) -> str: label = str(label) @@ -56,6 +56,13 @@ def check_hex(data: str) -> str: return data raise Exception(data+" is not valid hex data") +def check_base64(data: str) -> str: + data = str(data) + if re.match('^[a-zA-Z0-9+/=]+$', data): + return data + raise Exception(data+" is not valid hex data") + + def check_ipv4(address: str) -> str: address = str(address) if re.match(REGEX_ipv4, address): @@ -83,7 +90,7 @@ def time(time: int) -> str: return str(time) def column_widths(datas: 'Sequence', widths: 'Sequence[int]'): - assert len(datas) == len(widths)+1, "There must be as one more data points as widths" + assert len(datas) == len(widths)+1, "There must be one more data points as there are widths" result = "" width_sum = 0 for data, width in zip(datas, widths): # will *not* cover the last point @@ -94,6 +101,20 @@ def column_widths(datas: 'Sequence', widths: 'Sequence[int]'): # last data point return result+str(datas[-1]) +def concatenate(root, path): + if path == '' or root == '': + raise Exception("Empty domain name is not valid") + if path == '@': + return root + if root == '@' or path.endswith('.'): + return path + return path+"."+root + +def escape_TXT(text): + for c in ('\\', '\"'): + text = text.replace(c, '\\'+c) + return text + ## Enums class Protocol: @@ -108,21 +129,47 @@ class Digest: SHA256 = 2 +## Resource records +class RR: + def __init__(self, path, recordType, data): + ''' can be relative or absolute.''' + assert re.match(r'^[A-Z]+$', recordType), "got invalid record type" + self.path = path + self.recordType = recordType + self.data = data + self.TTL = None + + def mapPath(self, f): + '''Run the path through f. Returns self, for nicer chaining.''' + self.path = f(self.path) + return self + + def relativize(self, root): + return self.mapPath(lambda path: concatenate(root, path)) + + def mapTTL(self, f): + '''Run the current TTL and the recordType through f.''' + self.TTL = f(self.TTL, self.recordType) + return self + + def __str__(self): + return column_widths((self.path, time(self.TTL), self.recordType, self.data), (8*3, 8, 8)) + ## Record types class A: def __init__(self, address: str) -> None: self._address = check_ipv4(address) - def generate_rr(self, owner: str, zone: 'Zone') -> 'Any': - return zone.RR(owner, 'A', self._address) + def generate_rr(self): + return RR('@', 'A', self._address) class AAAA: def __init__(self, address: str) -> None: self._address = check_ipv6(address) - def generate_rr(self, owner: str, zone: 'Zone') -> 'Any': - return zone.RR(owner, 'AAAA', self._address) + def generate_rr(self): + return RR('@', 'AAAA', self._address) class MX: @@ -130,8 +177,45 @@ class MX: self._priority = int(prio) self._name = check_hostname(name) - def generate_rr(self, owner: str, zone: 'Zone') -> 'Any': - return zone.RR(owner, 'MX', '{0} {1}'.format(self._priority, zone.abs_hostname(self._name))) + def generate_rr(self): + return RR('@', 'MX', '{0} {1}'.format(self._priority, self._name)) + + +class TXT: + def __init__(self, text: str) -> None: + # test for bad characters + for c in ('\n', '\r', '\t'): + if c in text: + raise Exception("TXT record {0} contains invalid character") + self._text = text + + def generate_rr(self): + text = escape_TXT(self._text) + # split into chunks of max. 255 characters; be careful not to split right after a backslash + chunks = re.findall(r'.{0,254}[^\\]', text) + assert sum(len(c) for c in chunks) == len (text) + chunksep = '"\n' + ' '*20 + '"' + chunked = '( "' + chunksep.join(chunks) + '" )' + # generate the chunks + return RR('@', 'TXT', chunked) + + +class DKIM(TXT): # helper class to treat DKIM more antively + class Version: + DKIM1 = "DKIM1" + + class Algorithm: + RSA = "rsa" + + def __init__(self, selector, version, alg, key): + self._selector = check_label(selector) + version = check_label(version) + alg = check_label(alg) + key = check_base64(key) + super().__init__("v={0}; k={1}; p={2}".format(version, alg, key)) + + def generate_rr(self): + return super().generate_rr().relativize('{}._domainkey'.format(self._selector)) class SRV: @@ -143,9 +227,9 @@ class SRV: self._port = int(port) self._name = check_hostname(name) - def generate_rr(self, owner: str, zone: 'Zone') -> 'Any': - return zone.RR('_{0}._{1}.{2}'.format(self._service, self._protocol, owner), 'SRV', - '{0} {1} {2} {3}'.format(self._priority, self._weight, self._port, zone.abs_hostname(self._name))) + def generate_rr(self): + return RR('_{}._{}'.format(self._service, self._protocol), 'SRV', + '{} {} {} {}'.format(self._priority, self._weight, self._port, self._name)) class TLSA: @@ -172,24 +256,24 @@ class TLSA: self._matching_type = int(matching_type) self._data = check_hex(data) - def generate_rr(self, owner: str, zone: 'Zone') -> 'Any': - return zone.RR('_{0}._{1}.{2}'.format(self._port, self._protocol, owner), 'TLSA', '{0} {1} {2} {3}'.format(self._usage, self._selector, self._matching_type, self._data)) + def generate_rr(self): + return RR('_{}._{}'.format(self._port, self._protocol), 'TLSA', '{} {} {} {}'.format(self._usage, self._selector, self._matching_type, self._data)) class CNAME: def __init__(self, name: str) -> None: self._name = check_hostname(name) - def generate_rr(self, owner: str, zone: 'Zone') -> 'Any': - return zone.RR(owner, 'CNAME', zone.abs_hostname(self._name)) + def generate_rr(self): + return RR('@', 'CNAME', self._name) class NS: def __init__(self, name: str) -> None: self._name = check_hostname(name) - def generate_rr(self, owner: str, zone: 'Zone') -> 'Any': - return zone.RR(owner, 'NS', zone.abs_hostname(self._name)) + def generate_rr(self): + return RR('@', 'NS', self._name) class DS: @@ -199,34 +283,34 @@ class DS: self._alg = int(alg) self._digest = int(digest) - def generate_rr(self, owner: str, zone: 'Zone') -> 'Any': - return zone.RR(owner, 'DS', '{0} {1} {2} {3}'.format(self._tag, self._alg, self._digest, self._key)) + def generate_rr(self): + return RR('@', 'DS', '{} {} {} {}'.format(self._tag, self._alg, self._digest, self._key)) ## Higher-level classes class Name: def __init__(self, *records: 'List[Any]') -> None: self._records = records - def generate_rrs(self, owner: str, zone: 'Zone') -> 'Iterator': + def generate_rrs(self): for record in self._records: # this could still be a list if isinstance(record, list): for subrecord in record: - yield subrecord.generate_rr(owner, zone) + yield subrecord.generate_rr() else: - yield record.generate_rr(owner, zone) + yield record.generate_rr() def CName(name: str) -> Name: return Name(CNAME(name)) -def Delegation(name: str) -> Name: - return Name(NS(name)) +def Delegation(*names) -> Name: + return Name(list(map(NS, names))) -def SecureDelegation(name: str, tag: int, alg: int, digest: int, key: str) -> Name: - return Name(NS(name), DS(tag, alg, digest, key)) +def SecureDelegation(tag: int, alg: int, digest: int, key: str, *names) -> Name: + return Name(DS(tag, alg, digest, key), list(map(NS, names))) class Zone: @@ -251,22 +335,10 @@ class Zone: self._domains = domains - def getTTL(self, recordType: str) -> str: - return self._TTLs.get(recordType, self._TTLs['']) - - def RR(self, owner: str, recordType: str, data: str) -> str: - '''generate given RR, in textual representation''' - assert re.match(r'^[A-Z]+$', recordType), "got invalid record type" - return column_widths((self.abs_hostname(owner), time(self.getTTL(recordType)), recordType, data), (32, 8, 8)) - - def abs_hostname(self, name): - if name == '': - raise Exception("Empty domain name is not valid") - if name == '.' or name == '@': - return self._name - if name.endswith('.'): - return name - return name+"."+self._name + def getTTL(self, TTL: int, recordType: str) -> int: + if TTL is not None: return TTL + # TTL is None, so get a global default + return int(self._TTLs.get(recordType, self._TTLs[''])) def inc_serial(self) -> int: # get serial @@ -284,25 +356,39 @@ class Zone: # be done return cur_serial + @staticmethod + def generate_rrs_from_dict(root, domains): + for name in sorted(domains.keys(), key=lambda s: s.split('.')): + if name.endswith('.'): + raise Exception("You are trying to add a record outside of your zone. This is not supported. Use '@' for the zone root.") + domain = domains[name] + name = concatenate(root, name) + if isinstance(domain, dict): + for rr in Zone.generate_rrs_from_dict(name, domain): + yield rr + else: + for rr in domain.generate_rrs(): + yield rr.relativize(name) + def generate_rrs(self) -> 'Iterator': # SOA record serial = self.inc_serial() - yield self.RR(self._name, 'SOA', + yield RR('@', 'SOA', ('{NS} {mail} {serial} {refresh} {retry} {expire} {NX_TTL}'+ ' ; primns mail serial refresh retry expire NX_TTL').format( - NS=self.abs_hostname(self._NS[0]), mail=self._mail, serial=serial, + NS=self._NS[0], mail=self._mail, serial=serial, refresh=time(self._refresh), retry=time(self._retry), expire=time(self._expire), - NX_TTL=time(self.getTTL('NX'))) + NX_TTL=time(self.getTTL(None, 'NX'))) ) # NS records for name in self._NS: - yield NS(name).generate_rr(self._name, self) + yield NS(name).generate_rr() # all the rest - for name in sorted(self._domains.keys(), key=lambda s: list(reversed(s.split('.')))): - for rr in self._domains[name].generate_rrs(name, self): - yield rr + for rr in Zone.generate_rrs_from_dict('@', self._domains): + yield rr def write(self) -> None: - print(";; {0} zone file, generated by zonemaker on {1}".format(self._name, datetime.datetime.now())) - for rr in self.generate_rrs(): + print(";; {} zone file, generated by zonemaker on {}".format(self._name, datetime.datetime.now())) + print("$ORIGIN {}".format(self._name)) + for rr in map(lambda rr: rr.mapTTL(self.getTTL), self.generate_rrs()): print(rr)