#!/usr/bin/env python3
import subprocess, sys, argparse, time, re
from collections import OrderedDict, namedtuple
from enum import Enum

# progress bar and other console output fun
STATE_WIDTH = 30

def terminal_size():
    import fcntl, termios, struct
    try:
        result = fcntl.ioctl(1, termios.TIOCGWINSZ, struct.pack('HHHH', 0, 0, 0, 0))
    except OSError:
        # this is not a terminal
        return 0, 0
    h, w, hp, wp = struct.unpack('HHHH', result)
    assert w > 0 and h > 0, "Empty terminal...??"
    return w, h

def compute_frac(fracs):
    frac = 0.0
    last_frac = 1.0
    for complete, total in fracs:
        frac += (complete*last_frac/total)
        last_frac *= 1/total
    return frac

def print_progress(state, fracs):
    w, h = terminal_size()
    if w < STATE_WIDTH+10: return # not a (wide enough) terminal
    bar_width = w-STATE_WIDTH-3
    hashes = int(bar_width*compute_frac(fracs))
    sys.stdout.write('\r{0} [{1}{2}]'.format(state[:STATE_WIDTH].ljust(STATE_WIDTH), '#'*hashes, ' '*(bar_width-hashes)))
    sys.stdout.flush()
def finish_progress():
    w, h = terminal_size()
    sys.stdout.write('\r'+(' '*w)+'\r')
    sys.stdout.flush()

class ConsoleFormat:
    BLACK, RED, GREEN, YELLOW, BLUE, MAGENTA, CYAN, WHITE = range(8)
    RESET_SEQ = "\033[0m"
    COLOR_SEQ = "\033[1;3%dm"
    BOLD_SEQ  = "\033[1m"
    
    @staticmethod
    def color(text, color):
        return (ConsoleFormat.COLOR_SEQ % color) + text + ConsoleFormat.RESET_SEQ


# cipher check
def list_ciphers(spec="ALL:COMPLEMENTOFALL"):
    ciphers = subprocess.check_output(["openssl", "ciphers", spec]).decode('UTF-8').strip()
    return ciphers.split(':')

def test_cipher(host, port, protocol, cipher = None, wait_time=0, options=[]):
    # throttle
    time.sleep(wait_time/1000)
    try:
        if cipher is not None:
            options = ["-cipher", cipher]+options
        subprocess.check_call(["openssl", "s_client", "-"+protocol, "-connect", host+":"+str(port), "-servername", host]+options,
                              stdin=subprocess.DEVNULL, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL)
    except subprocess.CalledProcessError:
        return False
    else:
        return True

def test_protocol(host, port, protocol, ciphers, base_frac, wait_time=0, options=[]):
    if test_cipher(host, port, protocol, wait_time=wait_time, options=options):
        # the protocol is supported
        results = OrderedDict()
        for i in range(len(ciphers)):
            cipher = ciphers[i]
            print_progress(protocol+" "+cipher, base_frac+[(i, len(ciphers))])
            results[cipher] = test_cipher(host, port, protocol, cipher=cipher, wait_time=wait_time, options=options)
        return results
    else:
        # it is not supported
        return None

def test_host(host, port, wait_time=0, options=[]):
    ciphers = list_ciphers()
    results = OrderedDict()
    protocols = ('ssl2', 'ssl3', 'tls1', 'tls1_1', 'tls1_2')
    for i in range(len(protocols)):
        protocol = protocols[i]
        print_progress(protocol, [(i, len(protocols))])
        results[protocol] = test_protocol(host, port, protocol, ciphers, [(i, len(protocols))], wait_time, options)
    finish_progress()
    return results

# cipher classification
class CipherStrength(Enum):
    unknown = -1
    high = 3
    
    def colorName(self):
        if self.value == CipherStrength.high.value:
            return ConsoleFormat.color(self.name, ConsoleFormat.GREEN)
        else:
            return ConsoleFormat.color(self.name, ConsoleFormat.YELLOW)

CipherProps = namedtuple('CipherProps', 'bits, strength, isPfs')

class CipherPropsProvider:
    def __init__(self):
        self.high = set(list_ciphers("HIGH"))
        self.props = {}
    
    def getProps(self, protocol, cipher):
        # strip the sub-version-number from the protocol
        pos = protocol.find('_')
        if pos >= 0:
            protocol = protocol[:pos]
        # as OpenSSL about this cipher
        cipherInfo = subprocess.check_output(["openssl", "ciphers", "-v", "-"+protocol, cipher]).decode('UTF-8').strip()
        cipherInfoFields = None
        for line in cipherInfo.split('\n'):
            line = line.split()
            if line[0] == cipher:
                cipherInfoFields = line
                break
        if cipherInfoFields is None:
            raise Exception("Cannot determine cipher properties of {0} (protocol: {1})".format(cipher, protocol))
        # get # of bits
        encMatch = re.match(r'^Enc=([0-9A-Za-z]+)\(([0-9]+)\)$', cipherInfoFields[4])
        if encMatch is None:
            raise Exception("Unexpected OpenSSL output: Cannot determine encryption strength from {1}\nComplete output: {0}".format(cipherInfo, cipherInfoFields[4]))
        encCipher = encMatch.group(1)
        bits = int(encMatch.group(2))
        if encCipher == '3DES':
            # OpenSSL gives the key size, which however for 3DES is a totally bad estimate
            bits = int(bits*2/3)
        # figure out whether the cipher is pfs
        kxMatch = re.match(r'^Kx=([0-9A-Z/()]+)$', cipherInfoFields[2])
        if kxMatch is None:
            raise Exception("Unexpected OpenSSL output: Cannot determine key-exchange method from {1}\nComplete output: {0}".format(cipherInfo, cipherInfoFields[2]))
        kx = kxMatch.group(1)
        isPfs = kx in ('DH', 'DH(512)', 'ECDH')
        # determine security level
        if cipher in self.high:
            strength = CipherStrength.high
        else:
            strength = CipherStrength.unknown
        # done!
        return CipherProps(bits=bits, strength=strength, isPfs=isPfs)

# main program
if __name__ == "__main__":
    parser = argparse.ArgumentParser(description='Check TLS ciphers supported by a host')
    parser.add_argument("--starttls", dest="starttls",
                        help="Use a STARTTLS variant to establish the TLS connection. Possible values include smtp, imap.")
    parser.add_argument("--wait-time", "-t", dest="wait_time", default="10",
                        help="Time (in ms) to wait between two connections to the server. Default is 10ms.")
    parser.add_argument("host", metavar='HOST[:PORT]',
                        help="The host to check")
    args = parser.parse_args()
    
    # get host, port
    if ':' in args.host:
        host, port = args.host.split(':')
    else:
        host = args.host
        port = 443
    
    # get options and other stuff
    wait_time = float(args.wait_time)
    options = []
    if args.starttls is not None:
        options += ['-starttls', args.starttls]
    
    # run the test
    results = test_host(host, port, wait_time, options)
    
    # print the results
    propsProvider = CipherPropsProvider()
    for protocol, ciphers in results.items():
        print(protocol+":")
        if ciphers is None:
            print("    Is not supported by client or server")
        else:
            for cipher, supported in ciphers.items():
                if supported:
                    cipherProps = propsProvider.getProps(protocol, cipher)
                    fsText = ConsoleFormat.color("FS", ConsoleFormat.GREEN) if cipherProps.isPfs else ConsoleFormat.color("no FS", ConsoleFormat.RED)
                    bitColor = ConsoleFormat.GREEN if cipherProps.bits >= 128 else (ConsoleFormat.YELLOW if cipherProps.bits >= 100 else ConsoleFormat.RED)
                    print("    {0} ({1}, {2}, {3})".format(cipher.ljust(STATE_WIDTH), cipherProps.strength.colorName(), ConsoleFormat.color(str(cipherProps.bits)+" bits", bitColor), fsText))
        print()