rename: ssl-check -> tls-check
[tls-check.git] / tls-check
diff --git a/tls-check b/tls-check
new file mode 100755 (executable)
index 0000000..9d84f42
--- /dev/null
+++ b/tls-check
@@ -0,0 +1,111 @@
+#!/usr/bin/python3
+import subprocess, sys, argparse
+from collections import OrderedDict
+from enum import Enum
+
+# progress bar
+def terminal_size():
+    import fcntl, termios, struct
+    try:
+        result = fcntl.ioctl(1, termios.TIOCGWINSZ, struct.pack('HHHH', 0, 0, 0, 0))
+    except OSError:
+        # this is not a terminal
+        return 0, 0
+    h, w, hp, wp = struct.unpack('HHHH', result)
+    assert w > 0 and h > 0, "Empty terminal...??"
+    return w, h
+
+def compute_frac(fracs):
+    frac = 0.0
+    last_frac = 1.0
+    for complete, total in fracs:
+        frac += (complete*last_frac/total)
+        last_frac *= 1/total
+    return frac
+
+def print_progress(state, fracs):
+    STATE_WIDTH = 30
+    w, h = terminal_size()
+    if w < STATE_WIDTH+10: return # not a (wide enough) terminal
+    bar_width = w-STATE_WIDTH-3
+    hashes = int(bar_width*compute_frac(fracs))
+    sys.stdout.write('\r{0} [{1}{2}]'.format(state[:STATE_WIDTH].ljust(STATE_WIDTH), '#'*hashes, ' '*(bar_width-hashes)))
+    sys.stdout.flush()
+def finish_progress():
+    w, h = terminal_size()
+    sys.stdout.write('\r'+(' '*w)+'\r')
+    sys.stdout.flush()
+
+# cipher check
+def list_ciphers():
+    ciphers = subprocess.check_output(["openssl", "ciphers", "ALL:COMPLEMENTOFALL"]).decode('UTF-8').strip()
+    return ciphers.split(':')
+
+def test_cipher(host, port, protocol, cipher = None, options=[]):
+    try:
+        if cipher is not None:
+            options = ["-cipher", cipher]+options
+        subprocess.check_call(["openssl", "s_client", "-"+protocol, "-connect", host+":"+str(port)]+options,
+                              stdin=subprocess.DEVNULL, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL)
+    except subprocess.CalledProcessError:
+        return False
+    else:
+        return True
+
+def test_protocol(host, port, protocol, ciphers, base_frac, options=[]):
+    if test_cipher(host, port, protocol, options=options):
+        # the protocol is supported
+        results = OrderedDict()
+        for i in range(len(ciphers)):
+            cipher = ciphers[i]
+            print_progress(protocol+" "+cipher, base_frac+[(i, len(ciphers))])
+            results[cipher] = test_cipher(host, port, protocol, cipher, options)
+        return results
+    else:
+        # it is not supported
+        return None
+
+def test_host(host, port, options=[]):
+    ciphers = list_ciphers()
+    results = OrderedDict()
+    protocols = ('ssl2', 'ssl3', 'tls1', 'tls1_1', 'tls1_2')
+    for i in range(len(protocols)):
+        protocol = protocols[i]
+        print_progress(protocol, [(i, len(protocols))])
+        results[protocol] = test_protocol(host, port, protocol, ciphers, [(i, len(protocols))], options)
+    finish_progress()
+    return results
+
+if __name__ == "__main__":
+    parser = argparse.ArgumentParser(description='Check TLS ciphers supported by a host')
+    parser.add_argument("--starttls", dest="starttls",
+                        help="Use a STARTTLS variant to establish the TLS connection. Possible values include smpt, imap, xmpp.")
+    parser.add_argument("host", metavar='HOST[:PORT]',
+                        help="The host to check")
+    args = parser.parse_args()
+    
+    # get host, port
+    if ':' in args.host:
+        host, port = args.host.split(':')
+    else:
+        host = args.host
+        port = 443
+    
+    # get options
+    options = []
+    if args.starttls is not None:
+        options += ['-starttls', args.starttls]
+    
+    # run the test
+    results = test_host(host, port, options)
+    
+    # print the results
+    for protocol, ciphers in results.items():
+        print(protocol+":")
+        if ciphers is None:
+            print("    Is not supported by client or server")
+        else:
+            for cipher, supported in ciphers.items():
+                if supported:
+                    print("    "+cipher)
+        print()