From cfd2f5fc1e06dc0fa0894d2b49405cb424d335ff Mon Sep 17 00:00:00 2001 From: Ralf Jung Date: Sat, 21 Mar 2015 13:28:12 +0100 Subject: [PATCH] more systematic (and sane) approach to handling relative paths --- db.example.com.py | 2 +- zone.py | 123 ++++++++++++++++++++++++++++------------------ 2 files changed, 75 insertions(+), 50 deletions(-) diff --git a/db.example.com.py b/db.example.com.py index dc34985..daf5327 100644 --- a/db.example.com.py +++ b/db.example.com.py @@ -27,7 +27,7 @@ __zone__ = Zone('example.com.', serialfile = 'db.example.com.srl', secondary_refresh = 6*hour, secondary_retry = 1*hour, secondary_expire = 7*day, # Here come the actual domains. Each takes records as argument, either individually or as lists. domains = { - '.': Name(one, mail), # this will all all records from the list "one" and the list "mail" to this name + '@': Name(one, mail, HTTPS('0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef')), # this will all all records from the list "one" and the list "mail" to this name 'ns': Name(one), 'ipv4.ns': Name(one4), # just a single record 'ipv6.ns': Name(one6), diff --git a/zone.py b/zone.py index c311a21..b0d055e 100644 --- a/zone.py +++ b/zone.py @@ -90,7 +90,7 @@ def time(time: int) -> str: return str(time) def column_widths(datas: 'Sequence', widths: 'Sequence[int]'): - assert len(datas) == len(widths)+1, "There must be as one more data points as widths" + assert len(datas) == len(widths)+1, "There must be one more data points as there are widths" result = "" width_sum = 0 for data, width in zip(datas, widths): # will *not* cover the last point @@ -115,21 +115,55 @@ class Digest: SHA256 = 2 +## Resource records +class RR: + def __init__(self, path, recordType, data): + ''' can be relative or absolute.''' + assert re.match(r'^[A-Z]+$', recordType), "got invalid record type" + self.path = path + self.recordType = recordType + self.data = data + self.TTL = None + + def mapPath(self, f): + '''Run the path through f. Returns self, for nicer chaining.''' + self.path = f(self.path) + return self + + def relativize(self, root): + def _relativize(path): + if path == '' or root == '': + raise Exception("Empty domain name is not valid") + if path == '@': + return root + if root == '@' or path.endswith('.'): + return path + return path+"."+root + return self.mapPath(_relativize) + + def mapTTL(self, f): + '''Run the current TTL and the recordType through f.''' + self.TTL = f(self.TTL, self.recordType) + return self + + def __str__(self): + return column_widths((self.path, time(self.TTL), self.recordType, self.data), (32, 8, 8)) + ## Record types class A: def __init__(self, address: str) -> None: self._address = check_ipv4(address) - def generate_rr(self, owner: str, zone: 'Zone') -> 'Any': - return zone.RR(owner, 'A', self._address) + def generate_rr(self): + return RR('@', 'A', self._address) class AAAA: def __init__(self, address: str) -> None: self._address = check_ipv6(address) - def generate_rr(self, owner: str, zone: 'Zone') -> 'Any': - return zone.RR(owner, 'AAAA', self._address) + def generate_rr(self): + return RR('@', 'AAAA', self._address) class MX: @@ -137,8 +171,8 @@ class MX: self._priority = int(prio) self._name = check_hostname(name) - def generate_rr(self, owner: str, zone: 'Zone') -> 'Any': - return zone.RR(owner, 'MX', '{0} {1}'.format(self._priority, zone.abs_hostname(self._name))) + def generate_rr(self): + return RR('@', 'MX', '{0} {1}'.format(self._priority, self._name)) class TXT: @@ -152,8 +186,8 @@ class TXT: text = text.replace(c, '\\'+c) self._text = text - def generate_rr(self, owner:str, zone: 'Zone') -> 'Any': - return zone.RR(owner, 'TXT', '"{0}"'.format(self._text)) + def generate_rr(self): + return RR('@', 'TXT', '"{0}"'.format(self._text)) class DKIM(TXT): # helper class to treat DKIM more antively @@ -170,8 +204,8 @@ class DKIM(TXT): # helper class to treat DKIM more antively key = check_base64(key) super().__init__("v={0}; k={1}; p={2}".format(version, alg, key)) - def generate_rr(self, owner, zone): - return super().generate_rr('{0}._domainkey.{1}'.format(self._selector, zone.abs_hostname(owner)), zone) + def generate_rr(self): + return super().generate_rr().relativize('{}._domainkey'.format(self._selector)) class SRV: @@ -183,9 +217,9 @@ class SRV: self._port = int(port) self._name = check_hostname(name) - def generate_rr(self, owner: str, zone: 'Zone') -> 'Any': - return zone.RR('_{0}._{1}.{2}'.format(self._service, self._protocol, owner), 'SRV', - '{0} {1} {2} {3}'.format(self._priority, self._weight, self._port, zone.abs_hostname(self._name))) + def generate_rr(self): + return RR('_{}._{}'.format(self._service, self._protocol), 'SRV', + '{} {} {} {}'.format(self._priority, self._weight, self._port, self._name)) class TLSA: @@ -212,24 +246,24 @@ class TLSA: self._matching_type = int(matching_type) self._data = check_hex(data) - def generate_rr(self, owner: str, zone: 'Zone') -> 'Any': - return zone.RR('_{0}._{1}.{2}'.format(self._port, self._protocol, owner), 'TLSA', '{0} {1} {2} {3}'.format(self._usage, self._selector, self._matching_type, self._data)) + def generate_rr(self): + return RR('_{}._{}'.format(self._port, self._protocol), 'TLSA', '{} {} {} {}'.format(self._usage, self._selector, self._matching_type, self._data)) class CNAME: def __init__(self, name: str) -> None: self._name = check_hostname(name) - def generate_rr(self, owner: str, zone: 'Zone') -> 'Any': - return zone.RR(owner, 'CNAME', zone.abs_hostname(self._name)) + def generate_rr(self): + return RR('@', 'CNAME', self._name) class NS: def __init__(self, name: str) -> None: self._name = check_hostname(name) - def generate_rr(self, owner: str, zone: 'Zone') -> 'Any': - return zone.RR(owner, 'NS', zone.abs_hostname(self._name)) + def generate_rr(self): + return RR('@', 'NS', self._name) class DS: @@ -239,22 +273,22 @@ class DS: self._alg = int(alg) self._digest = int(digest) - def generate_rr(self, owner: str, zone: 'Zone') -> 'Any': - return zone.RR(owner, 'DS', '{0} {1} {2} {3}'.format(self._tag, self._alg, self._digest, self._key)) + def generate_rr(self): + return RR('@', 'DS', '{} {} {} {}'.format(self._tag, self._alg, self._digest, self._key)) ## Higher-level classes class Name: def __init__(self, *records: 'List[Any]') -> None: self._records = records - def generate_rrs(self, owner: str, zone: 'Zone') -> 'Iterator': + def generate_rrs(self): for record in self._records: # this could still be a list if isinstance(record, list): for subrecord in record: - yield subrecord.generate_rr(owner, zone) + yield subrecord.generate_rr() else: - yield record.generate_rr(owner, zone) + yield record.generate_rr() def CName(name: str) -> Name: @@ -291,22 +325,10 @@ class Zone: self._domains = domains - def getTTL(self, recordType: str) -> str: - return self._TTLs.get(recordType, self._TTLs['']) - - def RR(self, owner: str, recordType: str, data: str) -> str: - '''generate given RR, in textual representation''' - assert re.match(r'^[A-Z]+$', recordType), "got invalid record type" - return column_widths((self.abs_hostname(owner), time(self.getTTL(recordType)), recordType, data), (32, 8, 8)) - - def abs_hostname(self, name): - if name == '': - raise Exception("Empty domain name is not valid") - if name == '.' or name == '@': - return self._name - if name.endswith('.'): - return name - return name+"."+self._name + def getTTL(self, TTL: int, recordType: str) -> int: + if TTL is not None: return TTL + # TTL is None, so get a global default + return int(self._TTLs.get(recordType, self._TTLs[''])) def inc_serial(self) -> int: # get serial @@ -327,22 +349,25 @@ class Zone: def generate_rrs(self) -> 'Iterator': # SOA record serial = self.inc_serial() - yield self.RR(self._name, 'SOA', + yield RR('@', 'SOA', ('{NS} {mail} {serial} {refresh} {retry} {expire} {NX_TTL}'+ ' ; primns mail serial refresh retry expire NX_TTL').format( - NS=self.abs_hostname(self._NS[0]), mail=self._mail, serial=serial, + NS=self._NS[0], mail=self._mail, serial=serial, refresh=time(self._refresh), retry=time(self._retry), expire=time(self._expire), - NX_TTL=time(self.getTTL('NX'))) + NX_TTL=time(self.getTTL(None, 'NX'))) ) # NS records for name in self._NS: - yield NS(name).generate_rr(self._name, self) + yield NS(name).generate_rr() # all the rest for name in sorted(self._domains.keys(), key=lambda s: list(reversed(s.split('.')))): - for rr in self._domains[name].generate_rrs(name, self): - yield rr + if name.endswith('.'): + raise Exception("You are trying to add a record outside of your zone. This is not supported. Use '@' for the zone root.") + for rr in self._domains[name].generate_rrs(): + yield rr.relativize(name) def write(self) -> None: - print(";; {0} zone file, generated by zonemaker on {1}".format(self._name, datetime.datetime.now())) - for rr in self.generate_rrs(): + print(";; {} zone file, generated by zonemaker on {}".format(self._name, datetime.datetime.now())) + print("$ORIGIN {}".format(self._name)) + for rr in map(lambda rr: rr.relativize(self._name).mapTTL(self.getTTL), self.generate_rrs()): print(rr) -- 2.30.2