#!/usr/bin/python

"""
Compatibility shim for lert.

This is not a full lert implementation. It does only basic operations,
like authenticating to lert server and printing response.
"""

import argparse
import hesiod
import krbV
import random
import socket
import sys

from struct import pack

# =========== lert constants ===========

LERT_VERSION = '2'
LERT_TIMEOUT = 1
LERT_SERVICE = 'daemon'
LERT_PORT = 3717
LERT_MAX_PACKET_SIZE = 2048
LERT_MESSAGE_PATH_BASE = '/afs/athena/system/config/lert/lert'

LERT_BAD = '0'
LERT_FREE = '1'
LERT_MSG = '2'
LERT_SICK  = '3'

# =========== Helper Athena functions ===========

class LertError(Exception):
    pass

def locate_server():
    """Locates the lert server through Hesiod."""

    lookup = hesiod.Lookup("lert", "sloc")
    if not lookup or not lookup.results:
        raise LertError("Unable to locate lert server through Hesiod")

    return random.choice(lookup.results)

# =========== Actual lert client ===========

def lert(server, acknowledge):
    """Sends lert server a message and waits for response."""

    _, _, addrs = socket.gethostbyaddr(server)
    dest = (addrs[0], LERT_PORT)

    # Initialize Kerberos variables
    krb_ctx = krbV.default_context()
    krb_ccache = krb_ctx.default_ccache()
    krb_cprinc = krb_ccache.principal()
    krb_sprinc = krbV.Principal(name='%s/%s' % (LERT_SERVICE, server.lower()), context=krb_ctx)
    krb_authctx = krbV.AuthContext(context=krb_ctx)
    krb_authctx.rcache = krb_ctx.default_rcache()

    # Initialize the socket
    sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM, socket.IPPROTO_UDP)
    sock.connect(dest)
    sock.settimeout(LERT_TIMEOUT)

    # Get the AP_REQ
    krb_authctx.genaddrs(sock, krbV.KRB5_AUTH_CONTEXT_GENERATE_LOCAL_ADDR | krbV.KRB5_AUTH_CONTEXT_GENERATE_REMOTE_ADDR)
    _, ap_req = krb_ctx.mk_req(client=krb_cprinc, server=krb_sprinc, ccache=krb_ccache, auth_context=krb_authctx)

    # Send the request
    packet = LERT_VERSION + pack('B', int(acknowledge)) + b"\0\0"
    packet += ap_req
    sock.sendall(packet)

    # Receive the response
    raw_response = sock.recv(LERT_MAX_PACKET_SIZE)
    response = krb_authctx.rd_priv(raw_response)

    # Parse the response
    server_version = response[0]
    response_code = response[1]
    if server_version != LERT_VERSION:
        raise LertError("Unrecognized lert server version")

    if response_code == LERT_BAD:
        raise LertError("lert server was unable to process the request")
    if response_code == LERT_SICK:
        raise LertError("lert server is malfunctioning")
    if response_code == LERT_MSG:
        file_suffix = response[2:]
        with open(LERT_MESSAGE_PATH_BASE + file_suffix, 'r') as f:
            return f.read().strip()
    if response_code == LERT_FREE:
        return ""
    raise LertError("lert server returned invalid status code")

# =========== Command-line interface ===========

class DeprecatedAction(argparse.Action):
    def __call__(self, parser, namespace, values, opt_str=None):
        print >>sys.stderr, "WARNING: '%s' is obsolete and will be removed in future versions." % (opt_str)

def deprecate_flag(parser, name):
    """Adds an argument which is marked as deprecated."""

    parser.add_argument(name, nargs=0, action=DeprecatedAction, help="[obsolete]")

def main(args):
    argparser = argparse.ArgumentParser(description="lert deactivation message client")

    argparser.add_argument('-n', '--no', action="store_true",
        help="Inform server that the message was read and should not be shown again")
    argparser.add_argument('-q', '--quiet', action="store_true",
        help="Do not display errors if operation fails")
    argparser.add_argument('-s', '--server', help="lert server to use")

    deprecate_flag(argparser, '-z')
    deprecate_flag(argparser, '-m')

    args = argparser.parse_args(args)

    server = args.server if args.server else locate_server()

    try:
        message = lert(server, args.no)
        if message:
            header_path = LERT_MESSAGE_PATH_BASE + ('1' if args.no else '0')
            with open(header_path, 'r') as f:
                print f.read().strip()
            print
            print message
    except Exception as err:
        if not args.quiet:
            if isinstance(err, krbV.Krb5Error):
                _, err = err
            print >>sys.stderr, "ERROR: %s" % err
        sys.exit(1)

if __name__ == '__main__':
    main(sys.argv[1:])
