refactor how we deal with TTLs
authorRalf Jung <post@ralfj.de>
Tue, 11 Nov 2014 18:39:35 +0000 (19:39 +0100)
committerRalf Jung <post@ralfj.de>
Tue, 11 Nov 2014 18:39:35 +0000 (19:39 +0100)
db.example.com.py
zonemaker/zone.py

index 3187832ac92c0a5c413f1ebf928668940815b8c6..ad8f957db3bc1e69d763ce0a2cff9777303440c7 100644 (file)
@@ -13,11 +13,18 @@ mail = [MX('mx', 10)] # this is first server name, then priority (as in plain DN
 def HTTPS(key):
     return TLSA(Protocol.TCP, 443, TLSA.Usage.EndEntity, TLSA.Selector.Full, TLSA.MatchingType.SHA256, key)
 
+# setup TTLs by record type
+TTLs = {
+    '':     1*day,  # special value: default TTL
+    'NX':   1*hour, # special value: TTL for NXDOMAIN replies
+    'A':    1*hour, # for the rest, just use the type of the resource records
+    'AAAA': 1*hour,
+}
+
 # Now to the actual zone: the header part should be fairly self-explanatory.
-__zone__ = Zone('example.com.', serialfile = 'db.example.com.srl', mail = 'root@example.com.',
-    NS = ['ns', 'ns.example.org.'],
+__zone__ = Zone('example.com.', serialfile = 'db.example.com.srl',
+    mail = 'root@example.com.', NS = ['ns', 'ns.example.org.'], TTLs = TTLs,
     secondary_refresh = 6*hour, secondary_retry = 1*hour, secondary_expire = 7*day,
-    NX_TTL = 1*hour, A_TTL = 1*hour, other_TTL = 1*day,
     # Here come the actual domains. Each takes records as argument, either individually or as lists.
     domains = {
         '.':            Name(one, mail), # this will all all records from the list "one" and the list "mail" to this name
index add2bde8c4df333b6f17baa70318d53435df096a..68a17ad02aae72614b50362859eca9edbd0d6a1c 100644 (file)
@@ -207,43 +207,34 @@ def SecureDelegation(name: str, tag: int, alg: int, digest: int, key: str) -> Na
 
 
 class Zone:
-    def __init__(self, name: str, serialfile: str, mail: str, NS: 'List[str]',
+    def __init__(self, name: str, serialfile: str, mail: str, NS: 'List[str]', TTLs: 'Dict[str, int]',
                  secondary_refresh: int, secondary_retry: int, secondary_expire: int,
-                 NX_TTL: int = None, A_TTL: int = None, other_TTL: int = None,
-                 domains: 'Dict[str, Any]' = {}) -> None:
-        self._serialfile = serialfile
-        
+                 domains: 'Dict[str, Any]') -> None:
         if not name.endswith('.'): raise Exception("Expected an absolute hostname")
         self._name = check_hostname(name)
+        self._serialfile = serialfile
+        
         if not mail.endswith('.'): raise Exception("Mail must be absolute, end with a dot")
         atpos = mail.find('@')
         if atpos < 0 or atpos > mail.find('.'): raise Exception("Mail must contain an @ before the first dot")
         self._mail = check_hostname(mail.replace('@', '.', 1))
         self._NS = list(map(check_hostname, NS))
+        if '' not in TTLs: raise Exception("Must give a default TTL with empty key")
+        self._TTLs = TTLs
         
         self._refresh = secondary_refresh
         self._retry = secondary_retry
         self._expire = secondary_expire
         
-        if other_TTL is None: raise Exception("Must give other_TTL")
-        self._NX_TTL = NX_TTL
-        self._A_TTL = self._AAAA_TTL = A_TTL
-        self._other_TTL = other_TTL
-        
         self._domains = domains
     
+    def getTTL(self, recordType: str) -> str:
+        return self._TTLs.get(recordType, self._TTLs[''])
+    
     def RR(self, owner: str, recordType: str, data: str) -> str:
         '''generate given RR, in textual representation'''
         assert re.match(r'^[A-Z]+$', recordType), "got invalid record type"
-        # figure out TTL
-        attrname = "_"+recordType+"_TTL"
-        TTL = None # type: int
-        if hasattr(self, attrname):
-            TTL = getattr(self, attrname)
-        if TTL is None:
-            TTL = self._other_TTL
-        # be done
-        return column_widths((self.abs_hostname(owner), time(TTL), recordType, data), (32, 8, 8))
+        return column_widths((self.abs_hostname(owner), time(self.getTTL(recordType)), recordType, data), (32, 8, 8))
     
     def abs_hostname(self, name):
         if name == '.' or name == '@':
@@ -276,7 +267,7 @@ class Zone:
                       ' ; primns mail serial refresh retry expire NX_TTL').format(
                           NS=self.abs_hostname(self._NS[0]), mail=self._mail, serial=serial,
                           refresh=time(self._refresh), retry=time(self._retry), expire=time(self._expire),
-                          NX_TTL=time(self._NX_TTL))
+                          NX_TTL=time(self.getTTL('NX')))
                       )
         # NS records
         for name in self._NS: