]> git.ralfj.de Git - zonemaker.git/blobdiff - zonemaker/zone.py
write to stdout only
[zonemaker.git] / zonemaker / zone.py
index 15add06c766036b1b62a9a89997ee3a79065e2ed..f9712e9e5e3aa9a7fa8a94df09a4720fd07998f9 100644 (file)
@@ -1,4 +1,4 @@
-import re
+import re, datetime
 from ipaddress import IPv4Address, IPv6Address
 from typing import List, Dict, Any, Iterator, Tuple, Sequence
 
 from ipaddress import IPv4Address, IPv6Address
 from typing import List, Dict, Any, Iterator, Tuple, Sequence
 
@@ -9,17 +9,26 @@ hour = 60*minute
 day = 24*hour
 week = 7*day
 
 day = 24*hour
 week = 7*day
 
-TCP = 'tcp'
-UDP = 'udp'
+REGEX_label = r'[a-zA-Z90-9]([a-zA-Z90-9-]{0,61}[a-zA-Z90-9])?' # max. 63 characters; must not start or end with hyphen
+
+def check_label(label: str) -> str:
+    pattern = r'^{0}$'.format(REGEX_label)
+    if re.match(pattern, label):
+        return label
+    raise Exception(label+" is not a valid label")
 
 def check_hostname(name: str) -> str:
     # check hostname for validity
 
 def check_hostname(name: str) -> str:
     # check hostname for validity
-    label = r'[a-zA-Z90-9]([a-zA-Z90-9-]{0,61}[a-zA-Z90-9])?' # must not start or end with hyphen
-    pattern = r'^{0}(\.{0})*\.?'.format(label)
+    pattern = r'^{0}(\.{0})*\.?$'.format(REGEX_label)
     if re.match(pattern, name):
         return name
     raise Exception(name+" is not a valid hostname")
 
     if re.match(pattern, name):
         return name
     raise Exception(name+" is not a valid hostname")
 
+def check_hex(data: str) -> str:
+    if re.match('^[a-fA-F0-9]+$', data):
+        return data
+    raise Exception(data+" is not valid hex data")
+
 def time(time: int) -> str:
     if time == 0:
         return "0"
 def time(time: int) -> str:
     if time == 0:
         return "0"
@@ -47,6 +56,19 @@ def column_widths(datas: Sequence, widths: Sequence[int]):
     return result+str(datas[-1])
 
 
     return result+str(datas[-1])
 
 
+## Enums
+class Protocol:
+    TCP = 'tcp'
+    UDP = 'udp'
+
+class Algorithm:
+    RSA_SHA256 = 8
+
+class Digest:
+    SHA1 = 1
+    SHA256 = 2
+
+
 ## Record types
 class A:
     def __init__(self, address: str) -> None:
 ## Record types
 class A:
     def __init__(self, address: str) -> None:
@@ -75,8 +97,8 @@ class MX:
 
 class SRV:
     def __init__(self, protocol: str, service: str, name: str, port: int, prio: int, weight: int) -> None:
 
 class SRV:
     def __init__(self, protocol: str, service: str, name: str, port: int, prio: int, weight: int) -> None:
-        self._service = str(service)
-        self._protocol = str(protocol)
+        self._service = check_label(service)
+        self._protocol = check_label(protocol)
         self._priority = int(prio)
         self._weight = int(weight)
         self._port = int(port)
         self._priority = int(prio)
         self._weight = int(weight)
         self._port = int(port)
@@ -88,14 +110,31 @@ class SRV:
 
 
 class TLSA:
 
 
 class TLSA:
-    def __init__(self, protocol: str, port: int, key: str) -> None:
-        # TODO: fix key stuff
+    class Usage:
+        CA = 0 # certificate must pass the usual CA check, with the CA specified by the TLSA record
+        EndEntity_PlusCAs = 1 # the certificate must match the TLSA record *and* pass the usual CA check
+        TrustAnchor = 2 # the certificate must pass a check with the TLSA record giving the (only) trust anchor
+        EndEntity = 3 # the certificate must match the TLSA record
+
+    class Selector:
+        Full = 0
+        SubjectPublicKeyInfo = 1
+    
+    class MatchingType:
+        Exact = 0
+        SHA256 = 1
+        SHA512 = 2
+    
+    def __init__(self, protocol: str, port: int, usage: int, selector: int, matching_type: int, data: str) -> None:
         self._port = int(port)
         self._protocol = str(protocol)
         self._port = int(port)
         self._protocol = str(protocol)
-        self._key = str(key)
+        self._usage = int(usage)
+        self._selector = int(selector)
+        self._matching_type = int(matching_type)
+        self._data = check_hex(data)
     
     def generate_rr(self, owner: str, zone: 'Zone') -> Any:
     
     def generate_rr(self, owner: str, zone: 'Zone') -> Any:
-        return zone.RR('_{0}._{1}.{2}'.format(self._port, self._protocol, owner), 'TLSA', self._key)
+        return zone.RR('_{0}._{1}.{2}'.format(self._port, self._protocol, owner), 'TLSA', '{0} {1} {2} {3}'.format(self._usage, self._selector, self._matching_type, self._data))
 
 
 class CNAME:
 
 
 class CNAME:
@@ -115,12 +154,14 @@ class NS:
 
 
 class DS:
 
 
 class DS:
-    def __init__(self, key: str) -> None:
-        # TODO: fix key stuff
-        self._key = str(key)
+    def __init__(self, tag: int, alg: int, digest: int, key: str) -> None:
+        self._tag = int(tag)
+        self._key = check_hex(key)
+        self._alg = int(alg)
+        self._digest = int(digest)
     
     def generate_rr(self, owner: str, zone: 'Zone') -> Any:
     
     def generate_rr(self, owner: str, zone: 'Zone') -> Any:
-        return zone.RR(owner, 'DS', self._key)
+        return zone.RR(owner, 'DS', '{0} {1} {2} {3}'.format(self._tag, self._alg, self._digest, self._key))
 
 ## Higher-level classes
 class Name:
 
 ## Higher-level classes
 class Name:
@@ -141,20 +182,20 @@ def CName(name: str) -> Name:
     return Name(CNAME(name))
 
 
     return Name(CNAME(name))
 
 
-def Delegation(name: str, key: str = None) -> Name:
-    records = [NS(name)]
-    if key is not None:
-        records.append(DS(key))
-    return Name(*records)
+def Delegation(name: str) -> Name:
+    return Name(NS(name))
+
+
+def SecureDelegation(name: str, tag: int, alg: int, digest: int, key: str) -> Name:
+    return Name(NS(name), DS(tag, alg, digest, key))
 
 
 class Zone:
 
 
 class Zone:
-    def __init__(self, name: str, serialfile: str, dbfile: str, mail: str, NS: List[str],
+    def __init__(self, name: str, serialfile: str, mail: str, NS: List[str],
                  secondary_refresh: int, secondary_retry: int, secondary_expire: int,
                  NX_TTL: int = None, A_TTL: int = None, other_TTL: int = None,
                  domains: Dict[str, Any] = {}) -> None:
         self._serialfile = serialfile
                  secondary_refresh: int, secondary_retry: int, secondary_expire: int,
                  NX_TTL: int = None, A_TTL: int = None, other_TTL: int = None,
                  domains: Dict[str, Any] = {}) -> None:
         self._serialfile = serialfile
-        self._dbfile = dbfile
         
         if not name.endswith('.'): raise Exception("Expected an absolute hostname")
         self._name = check_hostname(name)
         
         if not name.endswith('.'): raise Exception("Expected an absolute hostname")
         self._name = check_hostname(name)
@@ -215,8 +256,8 @@ class Zone:
         # SOA record
         serial = self.inc_serial()
         yield self.RR(self._name, 'SOA',
         # SOA record
         serial = self.inc_serial()
         yield self.RR(self._name, 'SOA',
-                      ('{NS} {mail} ({serial} {refresh} {retry} {expire} {NX_TTL}) ; '+
-                      '(serial refresh retry expire NX_TTL)').format(
+                      ('{NS} {mail} {serial} {refresh} {retry} {expire} {NX_TTL}'+
+                      ' ; primns mail serial refresh retry expire NX_TTL').format(
                           NS=self.abs_hostname(self._NS[0]), mail=self._mail, serial=serial,
                           refresh=time(self._refresh), retry=time(self._retry), expire=time(self._expire),
                           NX_TTL=time(self._NX_TTL))
                           NS=self.abs_hostname(self._NS[0]), mail=self._mail, serial=serial,
                           refresh=time(self._refresh), retry=time(self._retry), expire=time(self._expire),
                           NX_TTL=time(self._NX_TTL))
@@ -230,7 +271,6 @@ class Zone:
                 yield rr
     
     def write(self) -> None:
                 yield rr
     
     def write(self) -> None:
-        with open(self._dbfile, 'w') as f:
-            for rr in self.generate_rrs():
-                f.write(rr+"\n")
-                print(rr)
+        print(";; {0} zone file, generated by zonemaker <https://www.ralfj.de/projects/zonemaker> on {1}".format(self._name, datetime.datetime.now()))
+        for rr in self.generate_rrs():
+            print(rr)