fix adding DKIM in the root of a zone
[zonemaker.git] / zone.py
diff --git a/zone.py b/zone.py
index b543d31c3c6d0456491519e2c21296f88b953271..c311a214c6843ea1baa5834a0add9d636ef1c91b 100644 (file)
--- a/zone.py
+++ b/zone.py
 # ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
 # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
 # SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
-# 
-# The views and conclusions contained in the software and documentation are those
-# of the authors and should not be interpreted as representing official policies, 
-# either expressed or implied, of the FreeBSD Project.
 
 import re, datetime
 #from typing import *
@@ -60,6 +56,13 @@ def check_hex(data: str) -> str:
         return data
     raise Exception(data+" is not valid hex data")
 
+def check_base64(data: str) -> str:
+    data = str(data)
+    if re.match('^[a-zA-Z0-9+/=]+$', data):
+        return data
+    raise Exception(data+" is not valid hex data")
+
+
 def check_ipv4(address: str) -> str:
     address = str(address)
     if re.match(REGEX_ipv4, address):
@@ -138,6 +141,39 @@ class MX:
         return zone.RR(owner, 'MX', '{0} {1}'.format(self._priority, zone.abs_hostname(self._name)))
 
 
+class TXT:
+    def __init__(self, text: str) -> None:
+        # test for bad characters
+        for c in ('\n', '\r', '\t'):
+            if c in text:
+                raise Exception("TXT record {0} contains invalid character")
+        # escape text
+        for c in ('\\', '\"'):
+            text = text.replace(c, '\\'+c)
+        self._text = text
+    
+    def generate_rr(self, owner:str, zone: 'Zone') -> 'Any':
+        return zone.RR(owner, 'TXT', '"{0}"'.format(self._text))
+
+
+class DKIM(TXT): # helper class to treat DKIM more antively
+    class Version:
+        DKIM1 = "DKIM1"
+    
+    class Algorithm:
+        RSA = "rsa"
+    
+    def __init__(self, selector, version, alg, key):
+        self._selector = check_label(selector)
+        version = check_label(version)
+        alg = check_label(alg)
+        key = check_base64(key)
+        super().__init__("v={0}; k={1}; p={2}".format(version, alg, key))
+    
+    def generate_rr(self, owner, zone):
+        return super().generate_rr('{0}._domainkey.{1}'.format(self._selector, zone.abs_hostname(owner)), zone)
+
+
 class SRV:
     def __init__(self, protocol: str, service: str, name: str, port: int, prio: int, weight: int) -> None:
         self._service = check_label(service)
@@ -264,6 +300,8 @@ class Zone:
         return column_widths((self.abs_hostname(owner), time(self.getTTL(recordType)), recordType, data), (32, 8, 8))
     
     def abs_hostname(self, name):
+        if name == '':
+            raise Exception("Empty domain name is not valid")
         if name == '.' or name == '@':
             return self._name
         if name.endswith('.'):