1999ddc960e3f35ab1a6fcf6e37e8f7290f360fd
[zonemaker.git] / zonemaker / zone.py
1 import re
2 from ipaddress import IPv4Address, IPv6Address
3 from typing import List, Dict, Any, Iterator
4
5
6 second = 1
7 minute = 60*second
8 hour = 60*minute
9 day = 24*hour
10 week = 7*day
11
12 def check_hostname(name: str) -> str:
13     # check hostname for validity
14     label = r'[a-zA-Z90-9]([a-zA-Z90-9-]{0,61}[a-zA-Z90-9])?' # must not start or end with hyphen
15     pattern = r'^{0}(\.{0})*\.?'.format(label)
16     if re.match(pattern, name):
17         return name
18     raise Exception(name+" is not a valid hostname")
19
20 def time(time: int) -> str:
21     if time == 0:
22         return "0"
23     elif time % week == 0:
24         return str(time//week)+"w"
25     elif time % day == 0:
26         return str(time//day)+"d"
27     elif time % hour == 0:
28         return str(time//hour)+"h"
29     elif time % minute == 0:
30         return str(time//minute)+"m"
31     else:
32         return str(time)
33
34 class Address:
35     # mypy does not know about the ipaddress types, so leave this class unannotated for now
36     def __init__(self, IPv4 = None, IPv6 = None) -> None:
37         if IPv4 is None and IPv6 is None:
38             raise Exception("There has to be at least one valid address!")
39         self._IPv4 = None if IPv4 is None else IPv4Address(IPv4)
40         self._IPv6 = None if IPv6 is None else IPv6Address(IPv6)
41     
42     def IPv4(self):
43         return Address(IPv4 = self._IPv4)
44     
45     def IPv6(self):
46         return Address(IPv6 = self._IPv6)
47     
48     def generate_rrs(self, owner: str, zone: 'Zone') -> Iterator:
49         if self._IPv4 is not None:
50             yield zone.RR(owner, 'A', self._IPv4)
51         if self._IPv6 is not None:
52             yield zone.RR(owner, 'AAAA', self._IPv6)
53
54 class Name:
55     def __init__(self, address: Address = None, MX: List = None,
56                  TCP: Dict[int, Any] = None, UDP: Dict[int, Any] = None) -> None:
57         self._address = address
58     
59     def generate_rrs(self, owner: str, zone: 'Zone') -> Iterator:
60         if self._address is not None:
61             for rr in self._address.generate_rrs(owner, zone):
62                 yield rr
63         # TODO
64
65 class Service:
66     def __init__(self, SRV: str = None, TLSA: str=None) -> None:
67         self._SRV = None if SRV is None else check_hostname(SRV)
68         self._TLSA = TLSA
69
70 class CName:
71     def __init__(self, name: str) -> None:
72         self._name = check_hostname(name)
73     
74     def generate_rrs(self, owner: str, zone: 'Zone') -> Iterator:
75         yield zone.RR(owner, 'CNAME', zone.abs_hostname(self._name))
76
77 class Delegation():
78     def __init__(self, NS: str, DS: str = None) -> None:
79         self._NS = NS
80         self._DS = DS
81     
82     def generate_rrs(self, owner: str, zone: 'Zone') -> Iterator:
83         yield zone.RR(owner, 'NS', zone.abs_hostname(self._NS))
84         # TODO DS
85
86 class Zone:
87     def __init__(self, name: str, serialfile: str, dbfile: str, mail: str, NS: List[str],
88                  secondary_refresh: int, secondary_retry: int, secondary_expire: int,
89                  NX_TTL: int = None, A_TTL: int = None, other_TTL: int = None,
90                  domains: Dict[str, Any] = {}) -> None:
91         self._serialfile = serialfile
92         self._dbfile = dbfile
93         
94         if not name.endswith('.'): raise Exception("Expected an absolute hostname")
95         self._name = check_hostname(name)
96         if not mail.endswith('.'): raise Exception("Mail must be absolute, end with a dot")
97         atpos = mail.find('@')
98         if atpos < 0 or atpos > mail.find('.'): raise Exception("Mail must contain an @ before the first dot")
99         self._mail = check_hostname(mail.replace('@', '.', 1))
100         self._NS = list(map(check_hostname, NS))
101         
102         self._refresh = secondary_refresh
103         self._retry = secondary_retry
104         self._expire = secondary_expire
105         
106         if other_TTL is None: raise Exception("Must give other_TTL")
107         self._NX_TTL = NX_TTL
108         self._A_TTL = self._AAAA_TTL = A_TTL
109         self._other_TTL = other_TTL
110         
111         self._domains = domains
112     
113     def RR(self, owner: str, recordType: str, data: str) -> str:
114         '''generate given RR, in textual representation'''
115         assert re.match(r'^[A-Z]+$', recordType), "got invalid record type"
116         # figure out TTL
117         attrname = "_"+recordType+"_TTL"
118         TTL = None # type: int
119         if hasattr(self, attrname):
120             TTL = getattr(self, attrname)
121         if TTL is None:
122             TTL = self._other_TTL
123         # be done
124         return "{0}\t{1}\t{2}\t{3}".format(self.abs_hostname(owner), TTL, recordType, data)
125     
126     def abs_hostname(self, name):
127         if name.endswith('.'):
128             return name
129         return name+"."+self._name
130     
131     def inc_serial(self) -> int:
132         # get serial
133         cur_serial = 0
134         try:
135             with open(self._serialfile) as f:
136                 cur_serial = int(f.read())
137         except FileNotFoundError:
138             pass
139         # increment serial
140         cur_serial += 1
141         # save serial
142         with open(self._serialfile, 'w') as f:
143             f.write(str(cur_serial))
144         # be done
145         return cur_serial
146     
147     def generate_rrs(self) -> Iterator:
148         # SOA record
149         serial = self.inc_serial()
150         yield self.RR(self._name, 'SOA',
151                       ('{NS} {mail} ({serial} {refresh} {retry} {expire} {NX_TTL}) ; '+
152                       '(serial refresh retry expire NX_TTL)').format(
153                           NS=self.abs_hostname(self._NS[0]), mail=self._mail, serial=serial,
154                           refresh=time(self._refresh), retry=time(self._retry), expire=time(self._expire),
155                           NX_TTL=time(self._NX_TTL))
156                       )
157         # NS records
158         for ns in self._NS:
159             yield self.RR(self._name, 'NS', self.abs_hostname(ns))
160         
161         # all the rest
162         for name in sorted(self._domains.keys(), key=lambda s: list(reversed(s.split('.')))):
163             for rr in self._domains[name].generate_rrs(name, self):
164                 yield rr
165     
166     def write(self) -> None:
167         with open(self._dbfile, 'w') as f:
168             for rr in self.generate_rrs():
169                 f.write(rr+"\n")
170                 print(rr)