count serials; automagic TTLs
[zonemaker.git] / zonemaker / zone.py
1 import re
2 from ipaddress import IPv4Address, IPv6Address
3 from typing import List, Dict, Any, Iterator
4
5
6 second = 1
7 minute = 60*second
8 hour = 60*minute
9 day = 24*hour
10 week = 7*day
11
12 def check_hostname(name: str) -> str:
13     # check hostname for validity
14     label = r'[a-zA-Z90-9]([a-zA-Z90-9-]{0,61}[a-zA-Z90-9])?' # must not start or end with hyphen
15     pattern = r'^{0}(\.{0})*\.?'.format(label)
16     if re.match(pattern, name):
17         return name
18     raise Exception(name+" is not a valid hostname")
19
20 def time(time: int) -> str:
21     if time == 0:
22         return "0"
23     elif time % week == 0:
24         return str(time//week)+"w"
25     elif time % day == 0:
26         return str(time//day)+"d"
27     elif time % hour == 0:
28         return str(time//hour)+"h"
29     elif time % minute == 0:
30         return str(time//minute)+"m"
31     else:
32         return str(time)
33
34 class Address:
35     # mypy does not know about the ipaddress types, so leave this class unannotated for now
36     def __init__(self, IPv4 = None, IPv6 = None) -> None:
37         self._IPv4 = None if IPv4 is None else IPv4Address(IPv4)
38         self._IPv6 = None if IPv6 is None else IPv6Address(IPv6)
39     
40     def IPv4(self):
41         return Address(IPv4 = self._IPv4)
42     
43     def IPv6(self):
44         return Address(IPv6 = self._IPv6)
45
46 class Name:
47     def __init__(self, address: Address = None, MX: List = None,
48                  TCP: Dict[int, Any] = None, UDP: Dict[int, Any] = None) -> None:
49         self._address = address
50
51 class Service:
52     def __init__(self, SRV: str = None, TLSA: str=None) -> None:
53         self._SRV = None if SRV is None else check_hostname(SRV)
54         self._TLSA = TLSA
55
56 class CName:
57     def __init__(self, name: str) -> None:
58         self._name = check_hostname(name)
59
60 class Delegation():
61     def __init__(self, NS: str, DS: str = None) -> None:
62         pass
63
64 class Zone:
65     def __init__(self, name: str, serialfile: str, dbfile: str, mail: str, NS: List[str],
66                  secondary_refresh: int, secondary_retry: int, secondary_expire: int,
67                  NX_TTL: int = None, A_TTL: int = None, other_TTL: int = None,
68                  domains: Dict[str, Any] = {}) -> None:
69         self._serialfile = serialfile
70         self._dbfile = dbfile
71         
72         if not name.endswith('.'): raise Exception("Expected an absolute hostname")
73         self._name = check_hostname(name)
74         if not mail.endswith('.'): raise Exception("Mail must be absolute, end with a dot")
75         atpos = mail.find('@')
76         if atpos < 0 or atpos > mail.find('.'): raise Exception("Mail must contain an @ before the first dot")
77         self._mail = check_hostname(mail.replace('@', '.', 1))
78         self._NS = list(map(check_hostname, NS))
79         
80         self._refresh = secondary_refresh
81         self._retry = secondary_retry
82         self._expire = secondary_expire
83         
84         if other_TTL is None: raise Exception("Must give other_TTL")
85         self._NX_TTL = NX_TTL
86         self._A_TTL = A_TTL
87         self._other_TTL = other_TTL
88         
89         self._domains = domains
90     
91     def RR(self, owner: str, recordType: str, data: str) -> str:
92         '''generate given RR, in textual representation'''
93         assert re.match(r'^[A-Z]+$', recordType), "got invalid record type"
94         # figure out TTL
95         attrname = "_"+recordType+"_TTL"
96         TTL = None # type: int
97         if hasattr(self, attrname):
98             TTL = getattr(self, attrname)
99         if TTL is None:
100             TTL = self._other_TTL
101         # be done
102         return "{0}\t{1}\t{2}\t{3}".format(self.abs_hostname(owner), TTL, recordType, data)
103     
104     def abs_hostname(self, name):
105         if name.endswith('.'):
106             return name
107         return name+"."+self._name
108     
109     def inc_serial(self) -> int:
110         # get serial
111         cur_serial = 0
112         try:
113             with open(self._serialfile) as f:
114                 cur_serial = int(f.read())
115         except FileNotFoundError:
116             pass
117         # increment serial
118         cur_serial += 1
119         # save serial
120         with open(self._serialfile, 'w') as f:
121             f.write(str(cur_serial))
122         # be done
123         return cur_serial
124     
125     def generate_rrs(self) -> Iterator:
126         # SOA record
127         serial = self.inc_serial()
128         yield self.RR(self._name, 'SOA',
129                       ('{NS} {mail} ({serial} {refresh} {retry} {expire} {NX_TTL}) ; '+
130                       '(serial refresh retry expire NX_TTL)').format(
131                           NS=self.abs_hostname(self._NS[0]), mail=self._mail, serial=serial,
132                           refresh=time(self._refresh), retry=time(self._retry), expire=time(self._expire),
133                           NX_TTL=time(self._NX_TTL))
134                       )
135         # NS records
136         for ns in self._NS:
137             yield self.RR(self._name, 'NS', self.abs_hostname(ns))
138         
139         # all the rest
140         for name in sorted(self._domains.keys(), key=lambda s: list(reversed(s.split('.')))):
141             print(name)
142             #for rr in self._domains[name].generate_rrs(self):
143                 #yield rr
144     
145     def write(self) -> None:
146         for rr in self.generate_rrs():
147             print(rr)