- def __init__(self, name, mail, NS,
- secondary_refresh, secondary_retry, secondary_discard,
- NX_TTL = None, A_TTL = None, other_TTL = None,
- domains = []):
- assert other_TTL is not None
- self._NX_TTL = other_TTL if NX_TTL is None else NX_TTL
- self._A_TTL = other_TTL if A_TTL is None else A_TTL
- self._other_TTL = other_TTL
-
- def write(self, file):
- raise NotImplementedError()
+ def __init__(self, name: str, serialfile: str, mail: str, NS: 'List[str]', TTLs: 'Dict[str, int]',
+ secondary_refresh: int, secondary_retry: int, secondary_expire: int,
+ domains: 'Dict[str, Any]') -> None:
+ if not name.endswith('.'): raise Exception("Expected an absolute hostname")
+ self._name = check_hostname(name)
+ self._serialfile = serialfile
+
+ if not mail.endswith('.'): raise Exception("Mail must be absolute, end with a dot")
+ atpos = mail.find('@')
+ if atpos < 0 or atpos > mail.find('.'): raise Exception("Mail must contain an @ before the first dot")
+ self._mail = check_hostname(mail.replace('@', '.', 1))
+ self._NS = list(map(check_hostname, NS))
+ if '' not in TTLs: raise Exception("Must give a default TTL with empty key")
+ self._TTLs = TTLs
+
+ self._refresh = secondary_refresh
+ self._retry = secondary_retry
+ self._expire = secondary_expire
+
+ self._domains = domains
+
+ def getTTL(self, recordType: str) -> str:
+ return self._TTLs.get(recordType, self._TTLs[''])
+
+ def RR(self, owner: str, recordType: str, data: str) -> str:
+ '''generate given RR, in textual representation'''
+ assert re.match(r'^[A-Z]+$', recordType), "got invalid record type"
+ return column_widths((self.abs_hostname(owner), time(self.getTTL(recordType)), recordType, data), (32, 8, 8))
+
+ def abs_hostname(self, name):
+ if name == '.' or name == '@':
+ return self._name
+ if name.endswith('.'):
+ return name
+ return name+"."+self._name
+
+ def inc_serial(self) -> int:
+ # get serial
+ cur_serial = 0
+ try:
+ with open(self._serialfile) as f:
+ cur_serial = int(f.read())
+ except (OSError, IOError): # FileNotFoundError has been added in Python 3.3
+ pass
+ # increment serial
+ cur_serial += 1
+ # save serial
+ with open(self._serialfile, 'w') as f:
+ f.write(str(cur_serial))
+ # be done
+ return cur_serial
+
+ def generate_rrs(self) -> 'Iterator':
+ # SOA record
+ serial = self.inc_serial()
+ yield self.RR(self._name, 'SOA',
+ ('{NS} {mail} {serial} {refresh} {retry} {expire} {NX_TTL}'+
+ ' ; primns mail serial refresh retry expire NX_TTL').format(
+ NS=self.abs_hostname(self._NS[0]), mail=self._mail, serial=serial,
+ refresh=time(self._refresh), retry=time(self._retry), expire=time(self._expire),
+ NX_TTL=time(self.getTTL('NX')))
+ )
+ # NS records
+ for name in self._NS:
+ yield NS(name).generate_rr(self._name, self)
+ # all the rest
+ for name in sorted(self._domains.keys(), key=lambda s: list(reversed(s.split('.')))):
+ for rr in self._domains[name].generate_rrs(name, self):
+ yield rr
+
+ def write(self) -> None:
+ print(";; {0} zone file, generated by zonemaker <https://www.ralfj.de/projects/zonemaker> on {1}".format(self._name, datetime.datetime.now()))
+ for rr in self.generate_rrs():
+ print(rr)