support splitting TXT records into bind-sized chunks
[zonemaker.git] / zone.py
diff --git a/zone.py b/zone.py
index 36dae78df4102c18e987e56059cbcd5910d11df7..a74ee1858361619774a8f12ba8748b279ed18d64 100644 (file)
--- a/zone.py
+++ b/zone.py
@@ -33,7 +33,7 @@ week = 7*day
 
 REGEX_label = r'[a-zA-Z90-9]([a-zA-Z90-9-]{0,61}[a-zA-Z90-9])?' # max. 63 characters; must not start or end with hyphen
 REGEX_ipv4  = r'^\d{1,3}(\.\d{1,3}){3}$'
 
 REGEX_label = r'[a-zA-Z90-9]([a-zA-Z90-9-]{0,61}[a-zA-Z90-9])?' # max. 63 characters; must not start or end with hyphen
 REGEX_ipv4  = r'^\d{1,3}(\.\d{1,3}){3}$'
-REGEX_ipv6  = r'^[a-fA-F0-9]{1,4}(:[a-fA-F0-9]{1,4}){7}$'
+REGEX_ipv6  = r'^[a-fA-F0-9]{1,4}(::?[a-fA-F0-9]{1,4}){1,7}$'
 
 def check_label(label: str) -> str:
     label = str(label)
 
 def check_label(label: str) -> str:
     label = str(label)
@@ -90,7 +90,7 @@ def time(time: int) -> str:
         return str(time)
 
 def column_widths(datas: 'Sequence', widths: 'Sequence[int]'):
         return str(time)
 
 def column_widths(datas: 'Sequence', widths: 'Sequence[int]'):
-    assert len(datas) == len(widths)+1, "There must be as one more data points as widths"
+    assert len(datas) == len(widths)+1, "There must be one more data points as there are widths"
     result = ""
     width_sum = 0
     for data, width in zip(datas, widths): # will *not* cover the last point
     result = ""
     width_sum = 0
     for data, width in zip(datas, widths): # will *not* cover the last point
@@ -101,6 +101,20 @@ def column_widths(datas: 'Sequence', widths: 'Sequence[int]'):
     # last data point
     return result+str(datas[-1])
 
     # last data point
     return result+str(datas[-1])
 
+def concatenate(root, path):
+    if path == '' or root == '':
+        raise Exception("Empty domain name is not valid")
+    if path == '@':
+        return root
+    if root == '@' or path.endswith('.'):
+        return path
+    return path+"."+root
+
+def escape_TXT(text):
+    for c in ('\\', '\"'):
+        text = text.replace(c, '\\'+c)
+    return text
+
 
 ## Enums
 class Protocol:
 
 ## Enums
 class Protocol:
@@ -115,21 +129,47 @@ class Digest:
     SHA256 = 2
 
 
     SHA256 = 2
 
 
+## Resource records
+class RR:
+    def __init__(self, path, recordType, data):
+        '''<path> can be relative or absolute.'''
+        assert re.match(r'^[A-Z]+$', recordType), "got invalid record type"
+        self.path = path
+        self.recordType = recordType
+        self.data = data
+        self.TTL = None
+    
+    def mapPath(self, f):
+        '''Run the path through f. Returns self, for nicer chaining.'''
+        self.path = f(self.path)
+        return self
+    
+    def relativize(self, root):
+        return self.mapPath(lambda path: concatenate(root, path))
+    
+    def mapTTL(self, f):
+        '''Run the current TTL and the recordType through f.'''
+        self.TTL = f(self.TTL, self.recordType)
+        return self
+    
+    def __str__(self):
+        return column_widths((self.path, time(self.TTL), self.recordType, self.data), (8*3, 8, 8))
+
 ## Record types
 class A:
     def __init__(self, address: str) -> None:
         self._address = check_ipv4(address)
     
 ## Record types
 class A:
     def __init__(self, address: str) -> None:
         self._address = check_ipv4(address)
     
-    def generate_rr(self, owner: str, zone: 'Zone') -> 'Any':
-        return zone.RR(owner, 'A', self._address)
+    def generate_rr(self):
+        return RR('@', 'A', self._address)
 
 
 class AAAA:
     def __init__(self, address: str) -> None:
         self._address = check_ipv6(address)
     
 
 
 class AAAA:
     def __init__(self, address: str) -> None:
         self._address = check_ipv6(address)
     
-    def generate_rr(self, owner: str, zone: 'Zone') -> 'Any':
-        return zone.RR(owner, 'AAAA', self._address)
+    def generate_rr(self):
+        return RR('@', 'AAAA', self._address)
 
 
 class MX:
 
 
 class MX:
@@ -137,8 +177,8 @@ class MX:
         self._priority = int(prio)
         self._name = check_hostname(name)
     
         self._priority = int(prio)
         self._name = check_hostname(name)
     
-    def generate_rr(self, owner: str, zone: 'Zone') -> 'Any':
-        return zone.RR(owner, 'MX', '{0} {1}'.format(self._priority, zone.abs_hostname(self._name)))
+    def generate_rr(self):
+        return RR('@', 'MX', '{0} {1}'.format(self._priority, self._name))
 
 
 class TXT:
 
 
 class TXT:
@@ -147,13 +187,17 @@ class TXT:
         for c in ('\n', '\r', '\t'):
             if c in text:
                 raise Exception("TXT record {0} contains invalid character")
         for c in ('\n', '\r', '\t'):
             if c in text:
                 raise Exception("TXT record {0} contains invalid character")
-        # escape text
-        for c in ('\\', '\"'):
-            text = text.replace(c, '\\'+c)
         self._text = text
     
         self._text = text
     
-    def generate_rr(self, owner:str, zone: 'Zone') -> 'Any':
-        return zone.RR(owner, 'TXT', '"{0}"'.format(self._text))
+    def generate_rr(self):
+        text = escape_TXT(self._text)
+        # split into chunks of max. 255 characters; be careful not to split right after a backslash
+        chunks = re.findall(r'.{0,254}[^\\]', text)
+        assert sum(len(c) for c in chunks) == len (text)
+        chunksep = '"\n' + ' '*20 + '"'
+        chunked = '( "' + chunksep.join(chunks) + '" )'
+        # generate the chunks
+        return RR('@', 'TXT', chunked)
 
 
 class DKIM(TXT): # helper class to treat DKIM more antively
 
 
 class DKIM(TXT): # helper class to treat DKIM more antively
@@ -170,8 +214,8 @@ class DKIM(TXT): # helper class to treat DKIM more antively
         key = check_base64(key)
         super().__init__("v={0}; k={1}; p={2}".format(version, alg, key))
     
         key = check_base64(key)
         super().__init__("v={0}; k={1}; p={2}".format(version, alg, key))
     
-    def generate_rr(self, owner, zone):
-        return super().generate_rr('{0}._domainkey.{1}'.format(self._selector, owner), zone)
+    def generate_rr(self):
+        return super().generate_rr().relativize('{}._domainkey'.format(self._selector))
 
 
 class SRV:
 
 
 class SRV:
@@ -183,9 +227,9 @@ class SRV:
         self._port = int(port)
         self._name = check_hostname(name)
     
         self._port = int(port)
         self._name = check_hostname(name)
     
-    def generate_rr(self, owner: str, zone: 'Zone') -> 'Any':
-        return zone.RR('_{0}._{1}.{2}'.format(self._service, self._protocol, owner), 'SRV',
-                       '{0} {1} {2} {3}'.format(self._priority, self._weight, self._port, zone.abs_hostname(self._name)))
+    def generate_rr(self):
+        return RR('_{}._{}'.format(self._service, self._protocol), 'SRV',
+                       '{} {} {} {}'.format(self._priority, self._weight, self._port, self._name))
 
 
 class TLSA:
 
 
 class TLSA:
@@ -212,24 +256,24 @@ class TLSA:
         self._matching_type = int(matching_type)
         self._data = check_hex(data)
     
         self._matching_type = int(matching_type)
         self._data = check_hex(data)
     
-    def generate_rr(self, owner: str, zone: 'Zone') -> 'Any':
-        return zone.RR('_{0}._{1}.{2}'.format(self._port, self._protocol, owner), 'TLSA', '{0} {1} {2} {3}'.format(self._usage, self._selector, self._matching_type, self._data))
+    def generate_rr(self):
+        return RR('_{}._{}'.format(self._port, self._protocol), 'TLSA', '{} {} {} {}'.format(self._usage, self._selector, self._matching_type, self._data))
 
 
 class CNAME:
     def __init__(self, name: str) -> None:
         self._name = check_hostname(name)
     
 
 
 class CNAME:
     def __init__(self, name: str) -> None:
         self._name = check_hostname(name)
     
-    def generate_rr(self, owner: str, zone: 'Zone') -> 'Any':
-        return zone.RR(owner, 'CNAME', zone.abs_hostname(self._name))
+    def generate_rr(self):
+        return RR('@', 'CNAME', self._name)
 
 
 class NS:
     def __init__(self, name: str) -> None:
         self._name = check_hostname(name)
     
 
 
 class NS:
     def __init__(self, name: str) -> None:
         self._name = check_hostname(name)
     
-    def generate_rr(self, owner: str, zone: 'Zone') -> 'Any':
-        return zone.RR(owner, 'NS', zone.abs_hostname(self._name))
+    def generate_rr(self):
+        return RR('@', 'NS', self._name)
 
 
 class DS:
 
 
 class DS:
@@ -239,34 +283,34 @@ class DS:
         self._alg = int(alg)
         self._digest = int(digest)
     
         self._alg = int(alg)
         self._digest = int(digest)
     
-    def generate_rr(self, owner: str, zone: 'Zone') -> 'Any':
-        return zone.RR(owner, 'DS', '{0} {1} {2} {3}'.format(self._tag, self._alg, self._digest, self._key))
+    def generate_rr(self):
+        return RR('@', 'DS', '{} {} {} {}'.format(self._tag, self._alg, self._digest, self._key))
 
 ## Higher-level classes
 class Name:
     def __init__(self, *records: 'List[Any]') -> None:
         self._records = records
     
 
 ## Higher-level classes
 class Name:
     def __init__(self, *records: 'List[Any]') -> None:
         self._records = records
     
-    def generate_rrs(self, owner: str, zone: 'Zone') -> 'Iterator':
+    def generate_rrs(self):
         for record in self._records:
             # this could still be a list
             if isinstance(record, list):
                 for subrecord in record:
         for record in self._records:
             # this could still be a list
             if isinstance(record, list):
                 for subrecord in record:
-                    yield subrecord.generate_rr(owner, zone)
+                    yield subrecord.generate_rr()
             else:
             else:
-                yield record.generate_rr(owner, zone)
+                yield record.generate_rr()
 
 
 def CName(name: str) -> Name:
     return Name(CNAME(name))
 
 
 
 
 def CName(name: str) -> Name:
     return Name(CNAME(name))
 
 
-def Delegation(name: str) -> Name:
-    return Name(NS(name))
+def Delegation(*names) -> Name:
+    return Name(list(map(NS, names)))
 
 
 
 
-def SecureDelegation(name: str, tag: int, alg: int, digest: int, key: str) -> Name:
-    return Name(NS(name), DS(tag, alg, digest, key))
+def SecureDelegation(tag: int, alg: int, digest: int, key: str, *names) -> Name:
+    return Name(DS(tag, alg, digest, key), list(map(NS, names)))
 
 
 class Zone:
 
 
 class Zone:
@@ -291,22 +335,10 @@ class Zone:
         
         self._domains = domains
     
         
         self._domains = domains
     
-    def getTTL(self, recordType: str) -> str:
-        return self._TTLs.get(recordType, self._TTLs[''])
-    
-    def RR(self, owner: str, recordType: str, data: str) -> str:
-        '''generate given RR, in textual representation'''
-        assert re.match(r'^[A-Z]+$', recordType), "got invalid record type"
-        return column_widths((self.abs_hostname(owner), time(self.getTTL(recordType)), recordType, data), (32, 8, 8))
-    
-    def abs_hostname(self, name):
-        if name == '':
-            raise Exception("Empty domain name is not valid")
-        if name == '.' or name == '@':
-            return self._name
-        if name.endswith('.'):
-            return name
-        return name+"."+self._name
+    def getTTL(self, TTL: int, recordType: str) -> int:
+        if TTL is not None: return TTL
+        # TTL is None, so get a global default
+        return int(self._TTLs.get(recordType, self._TTLs['']))
     
     def inc_serial(self) -> int:
         # get serial
     
     def inc_serial(self) -> int:
         # get serial
@@ -324,25 +356,39 @@ class Zone:
         # be done
         return cur_serial
     
         # be done
         return cur_serial
     
+    @staticmethod
+    def generate_rrs_from_dict(root, domains):
+        for name in sorted(domains.keys(), key=lambda s: s.split('.')):
+            if name.endswith('.'):
+                raise Exception("You are trying to add a record outside of your zone. This is not supported. Use '@' for the zone root.")
+            domain = domains[name]
+            name = concatenate(root, name)
+            if isinstance(domain, dict):
+                for rr in Zone.generate_rrs_from_dict(name, domain):
+                    yield rr
+            else:
+                for rr in domain.generate_rrs():
+                    yield rr.relativize(name)
+    
     def generate_rrs(self) -> 'Iterator':
         # SOA record
         serial = self.inc_serial()
     def generate_rrs(self) -> 'Iterator':
         # SOA record
         serial = self.inc_serial()
-        yield self.RR(self._name, 'SOA',
+        yield RR('@', 'SOA',
                       ('{NS} {mail} {serial} {refresh} {retry} {expire} {NX_TTL}'+
                       ' ; primns mail serial refresh retry expire NX_TTL').format(
                       ('{NS} {mail} {serial} {refresh} {retry} {expire} {NX_TTL}'+
                       ' ; primns mail serial refresh retry expire NX_TTL').format(
-                          NS=self.abs_hostname(self._NS[0]), mail=self._mail, serial=serial,
+                          NS=self._NS[0], mail=self._mail, serial=serial,
                           refresh=time(self._refresh), retry=time(self._retry), expire=time(self._expire),
                           refresh=time(self._refresh), retry=time(self._retry), expire=time(self._expire),
-                          NX_TTL=time(self.getTTL('NX')))
+                          NX_TTL=time(self.getTTL(None, 'NX')))
                       )
         # NS records
         for name in self._NS:
                       )
         # NS records
         for name in self._NS:
-            yield NS(name).generate_rr(self._name, self)
+            yield NS(name).generate_rr()
         # all the rest
         # all the rest
-        for name in sorted(self._domains.keys(), key=lambda s: list(reversed(s.split('.')))):
-            for rr in self._domains[name].generate_rrs(name, self):
-                yield rr
+        for rr in Zone.generate_rrs_from_dict('@', self._domains):
+            yield rr
     
     def write(self) -> None:
     
     def write(self) -> None:
-        print(";; {0} zone file, generated by zonemaker <https://www.ralfj.de/projects/zonemaker> on {1}".format(self._name, datetime.datetime.now()))
-        for rr in self.generate_rrs():
+        print(";; {} zone file, generated by zonemaker <https://www.ralfj.de/projects/zonemaker> on {}".format(self._name, datetime.datetime.now()))
+        print("$ORIGIN {}".format(self._name))
+        for rr in map(lambda rr: rr.mapTTL(self.getTTL), self.generate_rrs()):
             print(rr)
             print(rr)