d9c14e98afa7b56ec6ec4dd8f59c479a14948dcc
[zonemaker.git] / zonemaker / zone.py
1 import re, datetime
2 import typing
3
4
5 second = 1
6 minute = 60*second
7 hour = 60*minute
8 day = 24*hour
9 week = 7*day
10
11 REGEX_label = r'[a-zA-Z90-9]([a-zA-Z90-9-]{0,61}[a-zA-Z90-9])?' # max. 63 characters; must not start or end with hyphen
12 REGEX_ipv4  = r'^\d{1,3}(\.\d{1,3}){3}$'
13 REGEX_ipv6  = r'^[a-fA-F0-9]{1,4}(:[a-fA-F0-9]{1,4}){7}$'
14
15 def check_label(label: str) -> str:
16     pattern = r'^{0}$'.format(REGEX_label)
17     if re.match(pattern, label):
18         return label
19     raise Exception(label+" is not a valid label")
20
21 def check_hostname(name: str) -> str:
22     # check hostname for validity
23     pattern = r'^{0}(\.{0})*\.?$'.format(REGEX_label)
24     if re.match(pattern, name):
25         return name
26     raise Exception(name+" is not a valid hostname")
27
28 def check_hex(data: str) -> str:
29     if re.match('^[a-fA-F0-9]+$', data):
30         return data
31     raise Exception(data+" is not valid hex data")
32
33 def check_ipv4(address: str) -> str:
34     if re.match(REGEX_ipv4, address):
35         return address
36     raise Exception(address+" is not a valid IPv4 address")
37
38 def check_ipv6(address: str) -> str:
39     if re.match(REGEX_ipv6, address):
40         return address
41     raise Exception(address+" is not a valid IPv6 address")
42
43 def time(time: int) -> str:
44     if time == 0:
45         return "0"
46     elif time % week == 0:
47         return str(time//week)+"w"
48     elif time % day == 0:
49         return str(time//day)+"d"
50     elif time % hour == 0:
51         return str(time//hour)+"h"
52     elif time % minute == 0:
53         return str(time//minute)+"m"
54     else:
55         return str(time)
56
57 def column_widths(datas: 'Sequence', widths: 'Sequence[int]'):
58     assert len(datas) == len(widths)+1, "There must be as one more data points as widths"
59     result = ""
60     width_sum = 0
61     for data, width in zip(datas, widths): # will *not* cover the last point
62         result += str(data)+" " # add data point, and a minimal space
63         width_sum += width
64         if len(result) < width_sum: # add padding
65             result += (width_sum - len(result))*" "
66     # last data point
67     return result+str(datas[-1])
68
69
70 ## Enums
71 class Protocol:
72     TCP = 'tcp'
73     UDP = 'udp'
74
75 class Algorithm:
76     RSA_SHA256 = 8
77
78 class Digest:
79     SHA1 = 1
80     SHA256 = 2
81
82
83 ## Record types
84 class A:
85     def __init__(self, address: str) -> None:
86         self._address = check_ipv4(address)
87     
88     def generate_rr(self, owner: str, zone: 'Zone') -> 'Any':
89         return zone.RR(owner, 'A', self._address)
90
91
92 class AAAA:
93     def __init__(self, address: str) -> None:
94         self._address = check_ipv6(address)
95     
96     def generate_rr(self, owner: str, zone: 'Zone') -> 'Any':
97         return zone.RR(owner, 'AAAA', self._address)
98
99
100 class MX:
101     def __init__(self, name: str, prio: int = 10) -> None:
102         self._priority = int(prio)
103         self._name = check_hostname(name)
104     
105     def generate_rr(self, owner: str, zone: 'Zone') -> 'Any':
106         return zone.RR(owner, 'MX', '{0} {1}'.format(self._priority, zone.abs_hostname(self._name)))
107
108
109 class SRV:
110     def __init__(self, protocol: str, service: str, name: str, port: int, prio: int, weight: int) -> None:
111         self._service = check_label(service)
112         self._protocol = check_label(protocol)
113         self._priority = int(prio)
114         self._weight = int(weight)
115         self._port = int(port)
116         self._name = check_hostname(name)
117     
118     def generate_rr(self, owner: str, zone: 'Zone') -> 'Any':
119         return zone.RR('_{0}._{1}.{2}'.format(self._service, self._protocol, owner), 'SRV',
120                        '{0} {1} {2} {3}'.format(self._priority, self._weight, self._port, zone.abs_hostname(self._name)))
121
122
123 class TLSA:
124     class Usage:
125         CA = 0 # certificate must pass the usual CA check, with the CA specified by the TLSA record
126         EndEntity_PlusCAs = 1 # the certificate must match the TLSA record *and* pass the usual CA check
127         TrustAnchor = 2 # the certificate must pass a check with the TLSA record giving the (only) trust anchor
128         EndEntity = 3 # the certificate must match the TLSA record
129
130     class Selector:
131         Full = 0
132         SubjectPublicKeyInfo = 1
133     
134     class MatchingType:
135         Exact = 0
136         SHA256 = 1
137         SHA512 = 2
138     
139     def __init__(self, protocol: str, port: int, usage: int, selector: int, matching_type: int, data: str) -> None:
140         self._port = int(port)
141         self._protocol = str(protocol)
142         self._usage = int(usage)
143         self._selector = int(selector)
144         self._matching_type = int(matching_type)
145         self._data = check_hex(data)
146     
147     def generate_rr(self, owner: str, zone: 'Zone') -> 'Any':
148         return zone.RR('_{0}._{1}.{2}'.format(self._port, self._protocol, owner), 'TLSA', '{0} {1} {2} {3}'.format(self._usage, self._selector, self._matching_type, self._data))
149
150
151 class CNAME:
152     def __init__(self, name: str) -> None:
153         self._name = check_hostname(name)
154     
155     def generate_rr(self, owner: str, zone: 'Zone') -> 'Any':
156         return zone.RR(owner, 'CNAME', zone.abs_hostname(self._name))
157
158
159 class NS:
160     def __init__(self, name: str) -> None:
161         self._name = check_hostname(name)
162     
163     def generate_rr(self, owner: str, zone: 'Zone') -> 'Any':
164         return zone.RR(owner, 'NS', zone.abs_hostname(self._name))
165
166
167 class DS:
168     def __init__(self, tag: int, alg: int, digest: int, key: str) -> None:
169         self._tag = int(tag)
170         self._key = check_hex(key)
171         self._alg = int(alg)
172         self._digest = int(digest)
173     
174     def generate_rr(self, owner: str, zone: 'Zone') -> 'Any':
175         return zone.RR(owner, 'DS', '{0} {1} {2} {3}'.format(self._tag, self._alg, self._digest, self._key))
176
177 ## Higher-level classes
178 class Name:
179     def __init__(self, *records: 'List[Any]') -> None:
180         self._records = records
181     
182     def generate_rrs(self, owner: str, zone: 'Zone') -> 'Iterator':
183         for record in self._records:
184             # this could still be a list
185             if isinstance(record, list):
186                 for subrecord in record:
187                     yield subrecord.generate_rr(owner, zone)
188             else:
189                 yield record.generate_rr(owner, zone)
190
191
192 def CName(name: str) -> Name:
193     return Name(CNAME(name))
194
195
196 def Delegation(name: str) -> Name:
197     return Name(NS(name))
198
199
200 def SecureDelegation(name: str, tag: int, alg: int, digest: int, key: str) -> Name:
201     return Name(NS(name), DS(tag, alg, digest, key))
202
203
204 class Zone:
205     def __init__(self, name: str, serialfile: str, mail: str, NS: 'List[str]',
206                  secondary_refresh: int, secondary_retry: int, secondary_expire: int,
207                  NX_TTL: int = None, A_TTL: int = None, other_TTL: int = None,
208                  domains: 'Dict[str, Any]' = {}) -> None:
209         self._serialfile = serialfile
210         
211         if not name.endswith('.'): raise Exception("Expected an absolute hostname")
212         self._name = check_hostname(name)
213         if not mail.endswith('.'): raise Exception("Mail must be absolute, end with a dot")
214         atpos = mail.find('@')
215         if atpos < 0 or atpos > mail.find('.'): raise Exception("Mail must contain an @ before the first dot")
216         self._mail = check_hostname(mail.replace('@', '.', 1))
217         self._NS = list(map(check_hostname, NS))
218         
219         self._refresh = secondary_refresh
220         self._retry = secondary_retry
221         self._expire = secondary_expire
222         
223         if other_TTL is None: raise Exception("Must give other_TTL")
224         self._NX_TTL = NX_TTL
225         self._A_TTL = self._AAAA_TTL = A_TTL
226         self._other_TTL = other_TTL
227         
228         self._domains = domains
229     
230     def RR(self, owner: str, recordType: str, data: str) -> str:
231         '''generate given RR, in textual representation'''
232         assert re.match(r'^[A-Z]+$', recordType), "got invalid record type"
233         # figure out TTL
234         attrname = "_"+recordType+"_TTL"
235         TTL = None # type: int
236         if hasattr(self, attrname):
237             TTL = getattr(self, attrname)
238         if TTL is None:
239             TTL = self._other_TTL
240         # be done
241         return column_widths((self.abs_hostname(owner), time(TTL), recordType, data), (32, 8, 8))
242     
243     def abs_hostname(self, name):
244         if name == '.' or name == '@':
245             return self._name
246         if name.endswith('.'):
247             return name
248         return name+"."+self._name
249     
250     def inc_serial(self) -> int:
251         # get serial
252         cur_serial = 0
253         try:
254             with open(self._serialfile) as f:
255                 cur_serial = int(f.read())
256         except FileNotFoundError:
257             pass
258         # increment serial
259         cur_serial += 1
260         # save serial
261         with open(self._serialfile, 'w') as f:
262             f.write(str(cur_serial))
263         # be done
264         return cur_serial
265     
266     def generate_rrs(self) -> 'Iterator':
267         # SOA record
268         serial = self.inc_serial()
269         yield self.RR(self._name, 'SOA',
270                       ('{NS} {mail} {serial} {refresh} {retry} {expire} {NX_TTL}'+
271                       ' ; primns mail serial refresh retry expire NX_TTL').format(
272                           NS=self.abs_hostname(self._NS[0]), mail=self._mail, serial=serial,
273                           refresh=time(self._refresh), retry=time(self._retry), expire=time(self._expire),
274                           NX_TTL=time(self._NX_TTL))
275                       )
276         # NS records
277         for name in self._NS:
278             yield NS(name).generate_rr(self._name, self)
279         # all the rest
280         for name in sorted(self._domains.keys(), key=lambda s: list(reversed(s.split('.')))):
281             for rr in self._domains[name].generate_rrs(name, self):
282                 yield rr
283     
284     def write(self) -> None:
285         print(";; {0} zone file, generated by zonemaker <https://www.ralfj.de/projects/zonemaker> on {1}".format(self._name, datetime.datetime.now()))
286         for rr in self.generate_rrs():
287             print(rr)