#!/usr/bin/python3
# -*- mode: python; -*-

import sys
import subprocess
import argparse
try:
    import ipaddress
except Exception:
    sys.exit("This requires the python ipaddress module from https://pypi.python.org/pypi/ipaddress")


def main():

    parser = argparse.ArgumentParser(description='check if destination traffic would get encrypted by IPsec')
    parser.add_argument('-v', '--version', action='store_true', help='show version and exit')
    parser.add_argument('-d', '--debug', action='store_true', help='show debugging')
    parser.add_argument('-s', '--source', action='store', help='source address of the packet')
    parser.add_argument('dest', nargs='?')
    args = parser.parse_args()

    if args.version:
        print("version: 0.1")
        sys.exit(0)

    # check if dest is IP address, or resolve?
    dest = args.dest

    if args.debug:
        print("checking destination %s" % dest)

    if args.source and not dest:
        sys.exit("if specifying a source, specifying a destination is required")

    source = None
    if dest:
        try:
            ipdst = ipaddress.ip_address(dest)
        except Exception:
            sys.exit("%s is not a valid destination IP address" % dest)

        if args.source:
            source = args.source
        else:
            getsrccmd = "ip -o ro get %s" % dest
            output = subprocess.getoutput([getsrccmd])
            try:
                source = output.split("src")[1].strip().split(" ")[0]
            except Exception:
                sys.exit("failed to find source ip for destination %s" % dest)

    if args.debug:
        print("Need to find matching IPsec policy for %s/32 <=> %s/32" % (source, dest))

    if dest:
        if "/" in source:
            source = source.split("/")[0]
        if "/" in dest:
            dest = dest.split("/")[0]
        try:
            ipsrc = ipaddress.ip_address(source)
        except Exception:
            sys.exit("%s is not a valid sourc IP address" % source)
        if ipsrc.version != ipdst.version:
            print("woah IP familty mismatch between %s and %s" % (source, dest))
            sys.exit(1)

    ipxfrmcmd = 'ip -o xfrm pol | grep -v socket | grep "dir out"'
    output = subprocess.getoutput([ipxfrmcmd])
    polsrc = ""
    poldst = ""
    for line in output.split("\n"):
        src = ""
        dst = ""
        reqid = "0"
        line = line.replace("\\", " ")
        data = line.split()
        for i, word in enumerate(data):
            if src == "" and word == "src":
                src = data[i + 1]
                if dest:
                    polsrc = ipaddress.ip_network(data[i + 1])
            if dst == "" and word == "dst":
                dst = data[i + 1]
                if dest:
                    poldst = ipaddress.ip_network(data[i + 1])
            if word == "reqid":
                reqid = data[i + 1]
        if reqid != "0":
            if dest:
                if polsrc != "":
                    if poldst != "":
                        if ipsrc in polsrc:
                            if ipdst in poldst:
                                print("%s <=> %s using reqid %s" % (src, dst, reqid))
                                sys.exit(0)
            else:
                # not dest, so we are listing all policies
                if src:
                    print("%s <=> %s using reqid %s" % (src, dst, reqid))

    if dest:
        print("Packet with source address %s to destination %s would not be encrypted" % (source, dest))
        sys.exit(1)


if __name__ == "__main__":
    main()
