make script work again
[tls-check.git] / tls-check
index 27651d53855b9a36b970c52434143898eb4a3126..0445084ffafe5194be7418681250682a0a1d7853 100755 (executable)
--- a/tls-check
+++ b/tls-check
@@ -1,9 +1,11 @@
-#!/usr/bin/python3
+#!/usr/bin/env python3
 import subprocess, sys, argparse, time, re
 from collections import OrderedDict, namedtuple
 from enum import Enum
 
 import subprocess, sys, argparse, time, re
 from collections import OrderedDict, namedtuple
 from enum import Enum
 
-# progress bar
+# progress bar and other console output fun
+STATE_WIDTH = 30
+
 def terminal_size():
     import fcntl, termios, struct
     try:
 def terminal_size():
     import fcntl, termios, struct
     try:
@@ -24,7 +26,6 @@ def compute_frac(fracs):
     return frac
 
 def print_progress(state, fracs):
     return frac
 
 def print_progress(state, fracs):
-    STATE_WIDTH = 30
     w, h = terminal_size()
     if w < STATE_WIDTH+10: return # not a (wide enough) terminal
     bar_width = w-STATE_WIDTH-3
     w, h = terminal_size()
     if w < STATE_WIDTH+10: return # not a (wide enough) terminal
     bar_width = w-STATE_WIDTH-3
@@ -35,18 +36,30 @@ def finish_progress():
     w, h = terminal_size()
     sys.stdout.write('\r'+(' '*w)+'\r')
     sys.stdout.flush()
     w, h = terminal_size()
     sys.stdout.write('\r'+(' '*w)+'\r')
     sys.stdout.flush()
+
+class ConsoleFormat:
+    BLACK, RED, GREEN, YELLOW, BLUE, MAGENTA, CYAN, WHITE = range(8)
+    RESET_SEQ = "\033[0m"
+    COLOR_SEQ = "\033[1;3%dm"
+    BOLD_SEQ  = "\033[1m"
     
     
+    @staticmethod
+    def color(text, color):
+        return (ConsoleFormat.COLOR_SEQ % color) + text + ConsoleFormat.RESET_SEQ
+
 
 # cipher check
 def list_ciphers(spec="ALL:COMPLEMENTOFALL"):
     ciphers = subprocess.check_output(["openssl", "ciphers", spec]).decode('UTF-8').strip()
     return ciphers.split(':')
 
 
 # cipher check
 def list_ciphers(spec="ALL:COMPLEMENTOFALL"):
     ciphers = subprocess.check_output(["openssl", "ciphers", spec]).decode('UTF-8').strip()
     return ciphers.split(':')
 
-def test_cipher(host, port, protocol, cipher = None, options=[]):
+def test_cipher(host, port, protocol, cipher = None, wait_time=0, options=[]):
+    # throttle
+    time.sleep(wait_time/1000)
     try:
         if cipher is not None:
             options = ["-cipher", cipher]+options
     try:
         if cipher is not None:
             options = ["-cipher", cipher]+options
-        subprocess.check_call(["openssl", "s_client", "-"+protocol, "-connect", host+":"+str(port)]+options,
+        subprocess.check_call(["openssl", "s_client", "-"+protocol, "-connect", host+":"+str(port), "-servername", host]+options,
                               stdin=subprocess.DEVNULL, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL)
     except subprocess.CalledProcessError:
         return False
                               stdin=subprocess.DEVNULL, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL)
     except subprocess.CalledProcessError:
         return False
@@ -54,15 +67,13 @@ def test_cipher(host, port, protocol, cipher = None, options=[]):
         return True
 
 def test_protocol(host, port, protocol, ciphers, base_frac, wait_time=0, options=[]):
         return True
 
 def test_protocol(host, port, protocol, ciphers, base_frac, wait_time=0, options=[]):
-    if test_cipher(host, port, protocol, options=options):
+    if test_cipher(host, port, protocol, wait_time=wait_time, options=options):
         # the protocol is supported
         results = OrderedDict()
         for i in range(len(ciphers)):
             cipher = ciphers[i]
             print_progress(protocol+" "+cipher, base_frac+[(i, len(ciphers))])
         # the protocol is supported
         results = OrderedDict()
         for i in range(len(ciphers)):
             cipher = ciphers[i]
             print_progress(protocol+" "+cipher, base_frac+[(i, len(ciphers))])
-            results[cipher] = test_cipher(host, port, protocol, cipher, options)
-            # throttle
-            time.sleep(wait_time/1000)
+            results[cipher] = test_cipher(host, port, protocol, cipher=cipher, wait_time=wait_time, options=options)
         return results
     else:
         # it is not supported
         return results
     else:
         # it is not supported
@@ -82,67 +93,64 @@ def test_host(host, port, wait_time=0, options=[]):
 # cipher classification
 class CipherStrength(Enum):
     unknown = -1
 # cipher classification
 class CipherStrength(Enum):
     unknown = -1
-    exp = 0
-    low = 1
-    medium = 2
     high = 3
     high = 3
+    
+    def colorName(self):
+        if self.value == CipherStrength.high.value:
+            return ConsoleFormat.color(self.name, ConsoleFormat.GREEN)
+        else:
+            return ConsoleFormat.color(self.name, ConsoleFormat.YELLOW)
+
 CipherProps = namedtuple('CipherProps', 'bits, strength, isPfs')
 
 class CipherPropsProvider:
     def __init__(self):
 CipherProps = namedtuple('CipherProps', 'bits, strength, isPfs')
 
 class CipherPropsProvider:
     def __init__(self):
-        self.exp = set(list_ciphers("EXP"))
-        self.low = set(list_ciphers("LOW"))
-        self.medium = set(list_ciphers("MEDIUM"))
         self.high = set(list_ciphers("HIGH"))
         self.props = {}
     
         self.high = set(list_ciphers("HIGH"))
         self.props = {}
     
-    def __getProps(self, cipher):
+    def getProps(self, protocol, cipher):
+        # strip the sub-version-number from the protocol
+        pos = protocol.find('_')
+        if pos >= 0:
+            protocol = protocol[:pos]
         # as OpenSSL about this cipher
         # as OpenSSL about this cipher
-        cipherInfo = subprocess.check_output(["openssl", "ciphers", "-v", cipher]).decode('UTF-8').strip()
-        assert '\n' not in cipherInfo
-        cipherInfoFields = cipherInfo.split()
+        cipherInfo = subprocess.check_output(["openssl", "ciphers", "-v", "-"+protocol, cipher]).decode('UTF-8').strip()
+        cipherInfoFields = None
+        for line in cipherInfo.split('\n'):
+            line = line.split()
+            if line[0] == cipher:
+                cipherInfoFields = line
+                break
+        if cipherInfoFields is None:
+            raise Exception("Cannot determine cipher properties of {0} (protocol: {1})".format(cipher, protocol))
         # get # of bits
         # get # of bits
-        bitMatch = re.match(r'^Enc=[0-9A-Za-z]+\(([0-9]+)\)$', cipherInfoFields[4])
-        if bitMatch is None:
+        encMatch = re.match(r'^Enc=([0-9A-Za-z]+)\(([0-9]+)\)$', cipherInfoFields[4])
+        if encMatch is None:
             raise Exception("Unexpected OpenSSL output: Cannot determine encryption strength from {1}\nComplete output: {0}".format(cipherInfo, cipherInfoFields[4]))
             raise Exception("Unexpected OpenSSL output: Cannot determine encryption strength from {1}\nComplete output: {0}".format(cipherInfo, cipherInfoFields[4]))
-        bits = int(bitMatch.group(1))
+        encCipher = encMatch.group(1)
+        bits = int(encMatch.group(2))
+        if encCipher == '3DES':
+            # OpenSSL gives the key size, which however for 3DES is a totally bad estimate
+            bits = int(bits*2/3)
         # figure out whether the cipher is pfs
         # figure out whether the cipher is pfs
-        kxMatch = re.match(r'^Kx=([0-9A-Z/]+)$', cipherInfoFields[2])
+        kxMatch = re.match(r'^Kx=([0-9A-Z/()]+)$', cipherInfoFields[2])
         if kxMatch is None:
         if kxMatch is None:
-            raise Exception("Unexpected OpenSSL output: Cannot determine key-exchange method from {1}\nComplete output: {0}".format(cipherInfo), cipherInfoFields[2])
+            raise Exception("Unexpected OpenSSL output: Cannot determine key-exchange method from {1}\nComplete output: {0}".format(cipherInfo, cipherInfoFields[2]))
         kx = kxMatch.group(1)
         kx = kxMatch.group(1)
-        isPfs = kx in ('DH', 'ECDH')
+        isPfs = kx in ('DH', 'DH(512)', 'ECDH')
         # determine security level
         # determine security level
-        isExp = cipher in self.exp
-        isLow = cipher in self.low
-        isMedium = cipher in self.medium
-        isHigh = cipher in self.high
-        assert isExp+isLow+isMedium+isHigh <= 1, "Cipher is more than one from EXP, LOW, MEDIUM, HIGH"
-        if isExp:
-            strength = CipherStrength.exp
-        elif isLow:
-            strength = CipherStrength.low
-        elif isMedium:
-            strength = CipherStrength.medium
-        elif isHigh:
+        if cipher in self.high:
             strength = CipherStrength.high
         else:
             strength = CipherStrength.unknown
         # done!
         return CipherProps(bits=bits, strength=strength, isPfs=isPfs)
             strength = CipherStrength.high
         else:
             strength = CipherStrength.unknown
         # done!
         return CipherProps(bits=bits, strength=strength, isPfs=isPfs)
-    
-    def getProps(self, cipher):
-        if cipher in self.props:
-            return self.props[cipher]
-        props = self.__getProps(cipher)
-        self.props[cipher] = props
-        return props
 
 # main program
 if __name__ == "__main__":
     parser = argparse.ArgumentParser(description='Check TLS ciphers supported by a host')
     parser.add_argument("--starttls", dest="starttls",
 
 # main program
 if __name__ == "__main__":
     parser = argparse.ArgumentParser(description='Check TLS ciphers supported by a host')
     parser.add_argument("--starttls", dest="starttls",
-                        help="Use a STARTTLS variant to establish the TLS connection. Possible values include smpt, imap, xmpp.")
+                        help="Use a STARTTLS variant to establish the TLS connection. Possible values include smtp, imap.")
     parser.add_argument("--wait-time", "-t", dest="wait_time", default="10",
                         help="Time (in ms) to wait between two connections to the server. Default is 10ms.")
     parser.add_argument("host", metavar='HOST[:PORT]',
     parser.add_argument("--wait-time", "-t", dest="wait_time", default="10",
                         help="Time (in ms) to wait between two connections to the server. Default is 10ms.")
     parser.add_argument("host", metavar='HOST[:PORT]',
@@ -174,6 +182,8 @@ if __name__ == "__main__":
         else:
             for cipher, supported in ciphers.items():
                 if supported:
         else:
             for cipher, supported in ciphers.items():
                 if supported:
-                    cipherProps = propsProvider.getProps(cipher)
-                    print("    {0} ({1}, {2} bits, {3})".format(cipher, cipherProps.strength.name, cipherProps.bits, "FS" if cipherProps.isPfs else "not FS"))
+                    cipherProps = propsProvider.getProps(protocol, cipher)
+                    fsText = ConsoleFormat.color("FS", ConsoleFormat.GREEN) if cipherProps.isPfs else ConsoleFormat.color("no FS", ConsoleFormat.RED)
+                    bitColor = ConsoleFormat.GREEN if cipherProps.bits >= 128 else (ConsoleFormat.YELLOW if cipherProps.bits >= 100 else ConsoleFormat.RED)
+                    print("    {0} ({1}, {2}, {3})".format(cipher.ljust(STATE_WIDTH), cipherProps.strength.colorName(), ConsoleFormat.color(str(cipherProps.bits)+" bits", bitColor), fsText))
         print()
         print()