check more data for sanity; properly support TLSA and DS records
[zonemaker.git] / zonemaker / zone.py
1 import re
2 from ipaddress import IPv4Address, IPv6Address
3 from typing import List, Dict, Any, Iterator, Tuple, Sequence
4
5
6 second = 1
7 minute = 60*second
8 hour = 60*minute
9 day = 24*hour
10 week = 7*day
11
12 REGEX_label = r'[a-zA-Z90-9]([a-zA-Z90-9-]{0,61}[a-zA-Z90-9])?' # max. 63 characters; must not start or end with hyphen
13
14 def check_label(label: str) -> str:
15     pattern = r'^{0}$'.format(REGEX_label)
16     if re.match(pattern, label):
17         return label
18     raise Exception(label+" is not a valid label")
19
20 def check_hostname(name: str) -> str:
21     # check hostname for validity
22     pattern = r'^{0}(\.{0})*\.?$'.format(REGEX_label)
23     if re.match(pattern, name):
24         return name
25     raise Exception(name+" is not a valid hostname")
26
27 def check_hex(data: str) -> str:
28     if re.match('^[a-fA-F0-9]+$', data):
29         return data
30     raise Exception(data+" is not valid hex data")
31
32 def time(time: int) -> str:
33     if time == 0:
34         return "0"
35     elif time % week == 0:
36         return str(time//week)+"w"
37     elif time % day == 0:
38         return str(time//day)+"d"
39     elif time % hour == 0:
40         return str(time//hour)+"h"
41     elif time % minute == 0:
42         return str(time//minute)+"m"
43     else:
44         return str(time)
45
46 def column_widths(datas: Sequence, widths: Sequence[int]):
47     assert len(datas) == len(widths)+1, "There must be as one more data points as widths"
48     result = ""
49     width_sum = 0
50     for data, width in zip(datas, widths): # will *not* cover the last point
51         result += str(data)+" " # add data point, and a minimal space
52         width_sum += width
53         if len(result) < width_sum: # add padding
54             result += (width_sum - len(result))*" "
55     # last data point
56     return result+str(datas[-1])
57
58
59 ## Enums
60 class Protocol:
61     TCP = 'tcp'
62     UDP = 'udp'
63
64 class Algorithm:
65     RSA_SHA256 = 8
66
67 class Digest:
68     SHA1 = 1
69     SHA256 = 2
70
71
72 ## Record types
73 class A:
74     def __init__(self, address: str) -> None:
75         self._address = IPv4Address(address)
76     
77     def generate_rr(self, owner: str, zone: 'Zone') -> Any:
78         return zone.RR(owner, 'A', self._address)
79
80
81 class AAAA:
82     def __init__(self, address: str) -> None:
83         self._address = IPv6Address(address)
84     
85     def generate_rr(self, owner: str, zone: 'Zone') -> Any:
86         return zone.RR(owner, 'AAAA', self._address)
87
88
89 class MX:
90     def __init__(self, name: str, prio: int = 10) -> None:
91         self._priority = int(prio)
92         self._name = check_hostname(name)
93     
94     def generate_rr(self, owner: str, zone: 'Zone') -> Any:
95         return zone.RR(owner, 'MX', '{0} {1}'.format(self._priority, zone.abs_hostname(self._name)))
96
97
98 class SRV:
99     def __init__(self, protocol: str, service: str, name: str, port: int, prio: int, weight: int) -> None:
100         self._service = check_label(service)
101         self._protocol = check_label(protocol)
102         self._priority = int(prio)
103         self._weight = int(weight)
104         self._port = int(port)
105         self._name = check_hostname(name)
106     
107     def generate_rr(self, owner: str, zone: 'Zone') -> Any:
108         return zone.RR('_{0}._{1}.{2}'.format(self._service, self._protocol, owner), 'SRV',
109                        '{0} {1} {2} {3}'.format(self._priority, self._weight, self._port, zone.abs_hostname(self._name)))
110
111
112 class TLSA:
113     class Usage:
114         CA = 0 # certificate must pass the usual CA check, with the CA specified by the TLSA record
115         EndEntity_PlusCAs = 1 # the certificate must match the TLSA record *and* pass the usual CA check
116         TrustAnchor = 2 # the certificate must pass a check with the TLSA record giving the (only) trust anchor
117         EndEntity = 3 # the certificate must match the TLSA record
118
119     class Selector:
120         Full = 0
121         SubjectPublicKeyInfo = 1
122     
123     class MatchingType:
124         Exact = 0
125         SHA256 = 1
126         SHA512 = 2
127     
128     def __init__(self, protocol: str, port: int, usage: int, selector: int, matching_type: int, data: str) -> None:
129         self._port = int(port)
130         self._protocol = str(protocol)
131         self._usage = int(usage)
132         self._selector = int(selector)
133         self._matching_type = int(matching_type)
134         self._data = check_hex(data)
135     
136     def generate_rr(self, owner: str, zone: 'Zone') -> Any:
137         return zone.RR('_{0}._{1}.{2}'.format(self._port, self._protocol, owner), 'TLSA', '{0} {1} {2} {3}'.format(self._usage, self._selector, self._matching_type, self._data))
138
139
140 class CNAME:
141     def __init__(self, name: str) -> None:
142         self._name = check_hostname(name)
143     
144     def generate_rr(self, owner: str, zone: 'Zone') -> Any:
145         return zone.RR(owner, 'CNAME', zone.abs_hostname(self._name))
146
147
148 class NS:
149     def __init__(self, name: str) -> None:
150         self._name = check_hostname(name)
151     
152     def generate_rr(self, owner: str, zone: 'Zone') -> Any:
153         return zone.RR(owner, 'NS', zone.abs_hostname(self._name))
154
155
156 class DS:
157     def __init__(self, tag: int, alg: int, digest: int, key: str) -> None:
158         self._tag = int(tag)
159         self._key = check_hex(key)
160         self._alg = int(alg)
161         self._digest = int(digest)
162     
163     def generate_rr(self, owner: str, zone: 'Zone') -> Any:
164         return zone.RR(owner, 'DS', '{0} {1} {2} {3}'.format(self._tag, self._alg, self._digest, self._key))
165
166 ## Higher-level classes
167 class Name:
168     def __init__(self, *records: List[Any]) -> None:
169         self._records = records
170     
171     def generate_rrs(self, owner: str, zone: 'Zone') -> Iterator:
172         for record in self._records:
173             # this could still be a list
174             if isinstance(record, list):
175                 for subrecord in record:
176                     yield subrecord.generate_rr(owner, zone)
177             else:
178                 yield record.generate_rr(owner, zone)
179
180
181 def CName(name: str) -> Name:
182     return Name(CNAME(name))
183
184
185 def Delegation(name: str) -> Name:
186     return Name(NS(name))
187
188
189 def SecureDelegation(name: str, tag: int, alg: int, digest: int, key: str) -> Name:
190     return Name(NS(name), DS(tag, alg, digest, key))
191
192
193 class Zone:
194     def __init__(self, name: str, serialfile: str, dbfile: str, mail: str, NS: List[str],
195                  secondary_refresh: int, secondary_retry: int, secondary_expire: int,
196                  NX_TTL: int = None, A_TTL: int = None, other_TTL: int = None,
197                  domains: Dict[str, Any] = {}) -> None:
198         self._serialfile = serialfile
199         self._dbfile = dbfile
200         
201         if not name.endswith('.'): raise Exception("Expected an absolute hostname")
202         self._name = check_hostname(name)
203         if not mail.endswith('.'): raise Exception("Mail must be absolute, end with a dot")
204         atpos = mail.find('@')
205         if atpos < 0 or atpos > mail.find('.'): raise Exception("Mail must contain an @ before the first dot")
206         self._mail = check_hostname(mail.replace('@', '.', 1))
207         self._NS = list(map(check_hostname, NS))
208         
209         self._refresh = secondary_refresh
210         self._retry = secondary_retry
211         self._expire = secondary_expire
212         
213         if other_TTL is None: raise Exception("Must give other_TTL")
214         self._NX_TTL = NX_TTL
215         self._A_TTL = self._AAAA_TTL = A_TTL
216         self._other_TTL = other_TTL
217         
218         self._domains = domains
219     
220     def RR(self, owner: str, recordType: str, data: str) -> str:
221         '''generate given RR, in textual representation'''
222         assert re.match(r'^[A-Z]+$', recordType), "got invalid record type"
223         # figure out TTL
224         attrname = "_"+recordType+"_TTL"
225         TTL = None # type: int
226         if hasattr(self, attrname):
227             TTL = getattr(self, attrname)
228         if TTL is None:
229             TTL = self._other_TTL
230         # be done
231         return column_widths((self.abs_hostname(owner), time(TTL), recordType, data), (32, 8, 8))
232     
233     def abs_hostname(self, name):
234         if name == '.' or name == '@':
235             return self._name
236         if name.endswith('.'):
237             return name
238         return name+"."+self._name
239     
240     def inc_serial(self) -> int:
241         # get serial
242         cur_serial = 0
243         try:
244             with open(self._serialfile) as f:
245                 cur_serial = int(f.read())
246         except FileNotFoundError:
247             pass
248         # increment serial
249         cur_serial += 1
250         # save serial
251         with open(self._serialfile, 'w') as f:
252             f.write(str(cur_serial))
253         # be done
254         return cur_serial
255     
256     def generate_rrs(self) -> Iterator:
257         # SOA record
258         serial = self.inc_serial()
259         yield self.RR(self._name, 'SOA',
260                       ('{NS} {mail} ({serial} {refresh} {retry} {expire} {NX_TTL}) ; '+
261                       '(serial refresh retry expire NX_TTL)').format(
262                           NS=self.abs_hostname(self._NS[0]), mail=self._mail, serial=serial,
263                           refresh=time(self._refresh), retry=time(self._retry), expire=time(self._expire),
264                           NX_TTL=time(self._NX_TTL))
265                       )
266         # NS records
267         for name in self._NS:
268             yield NS(name).generate_rr(self._name, self)
269         # all the rest
270         for name in sorted(self._domains.keys(), key=lambda s: list(reversed(s.split('.')))):
271             for rr in self._domains[name].generate_rrs(name, self):
272                 yield rr
273     
274     def write(self) -> None:
275         with open(self._dbfile, 'w') as f:
276             for rr in self.generate_rrs():
277                 f.write(rr+"\n")
278                 print(rr)