/*
 * mrmap - map a DVMRP multicast backbone.
 *
 * mrmap [-n] [-a] [-h hops] [-t thresh] [-r retries] [-T timeout] router
 *
 * Written Sat Jul 3 1993 by Van Jacobson
 *
 * Copyright (c) 1993 Regents of the University of California.
 * All rights reserved.
 *
 * Redistribution and use in source and binary forms, with or without
 * modification, are permitted provided that the following conditions
 * are met:
 * 1. Redistributions of source code must retain the above copyright
 *    notice, this list of conditions and the following disclaimer.
 * 2. Redistributions in binary form must reproduce the above copyright
 *    notice, this list of conditions and the following disclaimer in the
 *    documentation and/or other materials provided with the distribution.
 * 3. All advertising materials mentioning features or use of this software
 *    must display the following acknowledgement:
 *	This product includes software developed by the Computer Systems
 *	Engineering Group at Lawrence Berkeley Laboratory.
 * 4. Neither the name of the University nor of the Laboratory may be used
 *    to endorse or promote products derived from this software without
 *    specific prior written permission.
 *
 * THIS SOFTWARE IS PROVIDED BY THE REGENTS AND CONTRIBUTORS ``AS IS'' AND
 * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
 * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
 * ARE DISCLAIMED.  IN NO EVENT SHALL THE REGENTS OR CONTRIBUTORS BE LIABLE
 * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
 * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS
 * OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION)
 * HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
 * LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY
 * OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF
 * SUCH DAMAGE.
 */

#ifndef lint
static char rcsid[] =
    "@(#) $Header: mrmap.c,v 1.5 93/10/30 02:59:31 van Exp $ (LBL)";
#endif

#include <sys/param.h>
#include <sys/socket.h>
#include <sys/ioctl.h>
#include <sys/time.h>
#include <net/if.h>
#include <netinet/in.h>
#include <netinet/in_systm.h>
#include <netinet/ip.h>
#include <netinet/igmp.h>
#include <netinet/ip_mroute.h>
#include <netdb.h>
#include <arpa/inet.h>
#include <errno.h>
#include <stdio.h>
#include "dvmrp.h"

struct mr {
	u_long addr;
	u_long level;
	int triesleft;
	int hops;
	struct mr* next;
};

#define HASHBITS 12
#define HASHSIZE (1 << HASHBITS)
#define MRHASH(s) ((((s) >> HASHBITS) ^ s) & (HASHSIZE - 1))

struct mr* mrhash[HASHSIZE];
struct mr* todo;
struct mr* todo_tail;

u_long our_addr;
int nretries = 10;
int timeout = 1;
int printname = 1;
int doall;
int maxhops;
int maxthresh;

int igmp_socket;

struct mr*
mr_find(u_long addr, int hops, int thresh)
{
	struct mr* mr;
	u_long h;

	for (h = MRHASH(addr); (mr = mrhash[h]); h = (++h & (HASHSIZE - 1))) {
		if (mr->addr == addr) {
			if (hops >= 0 && mr->hops > hops)
				mr->hops = hops;
			return (mr);
		}
	}
	mrhash[h] = mr = (struct mr*)malloc(sizeof(struct mr));
	mr->addr = addr;
	mr->level = 0;
	mr->hops = hops;
	if ((maxhops && hops > maxhops) ||
	    (maxthresh && thresh >= maxthresh) ||
	    IN_MULTICAST(ntohl(addr))) {
		mr->triesleft = 0;
	} else {
		mr->triesleft = nretries;
		if (todo_tail)
			todo_tail->next = mr;
		todo_tail = mr;
		if (todo == 0)
			todo = mr;
	}
	mr->next = 0;
	return (mr);
}

void
mr_done(struct mr *mr)
{
	struct mr* lmr;

	if (mr->triesleft == 0)
		return;
	mr->triesleft = 0;
	if (mr == todo) {
		if ((todo = mr->next) == 0)
			todo_tail = 0;
		return;
	}
	for (lmr = todo; lmr->next != mr; lmr = lmr->next) {
	}
	if ((lmr->next = mr->next) == 0) {
		todo_tail = lmr;
	}
}

char*
inet_name(u_long addr)
{
	struct hostent *e;

	if (addr == 0)
		return ("local");
	e = gethostbyaddr(&addr, sizeof(addr), AF_INET);
	return e ? e->h_name : 0;
}

char*
fmt_addr(u_long addr)
{
	static char namestr[1024];
	struct hostent *e;
	char* cp;
	struct in_addr ina;

	ina.s_addr = addr;
	cp = inet_ntoa(ina);

	if (printname) {
		char* nm = inet_name(addr);
		sprintf(namestr, "%s(%s)", nm? nm : cp, cp);
		cp = namestr;
	}
	return (cp);
}

/*
 * inet_cksum extracted from ping.c by
 *	Mike Muuss
 *	U. S. Army Ballistic Research Laboratory
 *	December, 1983
 * Modified at UC Berkeley
 */
int
inet_cksum(u_short* addr, u_int len)
{
	register int nleft = (int)len;
	register u_short *w = addr;
	u_short answer = 0;
	register int sum = 0;

	while( nleft > 1 )  {
		sum += *w++;
		nleft -= 2;
	}

	/* mop up an odd byte, if necessary */
	if( nleft == 1 ) {
		*(u_char *) (&answer) = *(u_char *)w ;
		sum += answer;
	}

	/* add back carry outs from top 16 bits to low 16 bits */
	sum = (sum >> 16) + (sum & 0xffff);	/* add hi 16 to low 16 */
	sum += (sum >> 16);			/* add carry */
	answer = ~sum;				/* truncate to 16 bits */
	return (answer);
}

void send_igmp(u_long dst, int code)
{
	u_char buf[sizeof(struct ip) + sizeof(struct igmp)];
	struct sockaddr_in sdst;
	struct ip *ip;
	struct igmp *igmp;

	ip = (struct ip *)buf;
	ip->ip_tos = 0;
	ip->ip_len = sizeof(struct ip) + sizeof(struct igmp);
	ip->ip_off = 0;
	ip->ip_p   = IPPROTO_IGMP;
	ip->ip_ttl = 255;
	ip->ip_src.s_addr = our_addr;
	ip->ip_dst.s_addr = dst;

	igmp = (struct igmp *)(ip + 1);
	igmp->igmp_type = IGMP_DVMRP;
	igmp->igmp_code = code;
	igmp->igmp_group.s_addr = 0;
	igmp->igmp_cksum = 0;
	igmp->igmp_cksum = inet_cksum((u_short *)igmp, sizeof(struct igmp));

	bzero(sdst, sizeof(sdst));
	sdst.sin_family = AF_INET;
	sdst.sin_addr.s_addr = dst;
	if (sendto(igmp_socket, buf, ip->ip_len, 0,
			(struct sockaddr *)&sdst, sizeof(sdst)) < 0) {
		perror("mrmap: send");
	}
}

/*
 * Send a neighbors-list request.
 */
void 
ask(u_long dst)
{
	send_igmp(dst, DVMRP_ASK_NEIGHBORS);
}

void 
ask2(u_long dst)
{
	send_igmp(dst, DVMRP_ASK_NEIGHBORS2);
}

/*
 * Process an incoming neighbor-list message.
 */
void 
accept_neighbors(u_long src, u_long dst, u_char *p, int datalen)
{
	struct mr* mr = mr_find(src, -1, 0);
	u_char *ep = p + datalen;
#define GET_ADDR(a) (a = ((u_long)*p++ << 24), a += ((u_long)*p++ << 16),\
		     a += ((u_long)*p++ << 8), a += *p++)

	mr_done(mr);
	while (p < ep) {
		register u_long laddr;
		register u_char metric;
		register u_char thresh;
		register int ncount, hops;

		GET_ADDR(laddr);
		laddr = htonl(laddr);
		if (laddr != src) {
			struct mr* lmr = mr_find(laddr, mr->hops, 0);
			if (mr->hops != lmr->hops)
				mr->hops = lmr->hops;
			mr_done(lmr);
			printf("= %s", fmt_addr(src));
			printf(" %s\n", fmt_addr(laddr));
		}
		metric = *p++;
		thresh = *p++;
		ncount = *p++;
		if (mr->hops < 0)
			hops = -1;
		else
			hops = mr->hops + 1;
		while (--ncount >= 0) {
			register u_long neighbor;
			GET_ADDR(neighbor);
			neighbor = htonl(neighbor);
			if (neighbor)
				mr_find(neighbor, hops, thresh);
			printf("> %s ", fmt_addr(laddr));
			printf("%s [%d/%d/tunnel/srcrt]\n", fmt_addr(neighbor),
			       metric, thresh);
		}
	}
}

void 
accept_neighbors2(u_long src, u_long dst, u_char *p, int datalen, u_long vers)
{
	struct mr* mr = mr_find(src, -1, 0);
	u_char *ep = p + datalen;

	mr_done(mr);
	while (p < ep) {
		register u_char metric;
		register u_char thresh;
		register u_char flags;
		register int ncount, hops;
		register u_long laddr = *(u_long*)p;

		if (laddr != src) {
			struct mr* lmr = mr_find(laddr, mr->hops, 0);
			if (mr->hops != lmr->hops)
				mr->hops = lmr->hops;
			mr_done(lmr);
			printf("= %s", fmt_addr(src));
			printf(" %s\n", fmt_addr(laddr));
		}
		p += 4;
		metric = *p++;
		thresh = *p++;
		flags = *p++;
		ncount = *p++;
		if (mr->hops < 0)
			hops = -1;
		else
			hops = mr->hops + 1;
		while (--ncount >= 0) {
			register u_long neighbor = *(u_long*)p;
			if (neighbor && (doall ||
			    (flags & (DVMRP_NF_DOWN|DVMRP_NF_DISABLED)) == 0))
				mr_find(neighbor, hops, thresh);
			p += 4;
			printf("> %s ", fmt_addr(laddr));
			printf("%s [%d/%d", fmt_addr(neighbor),
			       metric, thresh);
			if (flags & DVMRP_NF_TUNNEL)
				printf("/tunnel");
			if (flags & DVMRP_NF_SRCRT)
				printf("/srcrt");
			if (flags & DVMRP_NF_QUERIER)
				printf("/querier");
			if (flags & DVMRP_NF_DISABLED)
				printf("/disabled");
			if (flags & DVMRP_NF_DOWN)
				printf("/down");
			printf("]\n");
		}
	}
}

u_long 
host_addr(char *name)
{
	struct hostent *e = gethostbyname(name);
	int     addr;

	if (e)
		memcpy(&addr, e->h_addr_list[0], e->h_length);
	else {
		addr = inet_addr(name);
		if (addr == -1)
			addr = 0;
	}

	return addr;
}


main(int argc, char **argv)
{
	struct mr *mr;
	u_long target_addr = 0;
	int sockbuf = 24*1024;

	setlinebuf(stderr);
	setlinebuf(stdout);

	if (geteuid() != 0) {
		fprintf(stderr, "must be root\n");
		exit(1);
	}
	argv++, argc--;
	while (argc > 0 && argv[0][0] == '-') {
		switch (argv[0][1]) {
		case 'h':
			++argv; --argc;
			if (*argv == 0 || (maxhops = atoi(*argv)) <= 0)
				goto usage;
			break;
		case 'r':
			++argv; --argc;
			if (*argv == 0 || (nretries = atoi(*argv)) <= 0)
				goto usage;
			break;
		case 't':
			++argv; --argc;
			if (*argv == 0 || (maxthresh = atoi(*argv)) <= 0)
				goto usage;
			break;
		case 'T':
			++argv; --argc;
			if (*argv == 0 || (timeout = atoi(*argv)) <= 0)
				goto usage;
			break;
		case 'n':
			printname = 0;
			break;
		case 'a':
			doall = 1;
			break;
		default:
			goto usage;
		}
		argv++, argc--;
	}

	if (argc > 1 || (argc == 1 && !(target_addr = host_addr(argv[0])))) {
    usage:
		fprintf(stderr,
"usage: mrmap [-n] [-a] [-h hops] [-t thresh] [-r retries] [-T timeout] router\n");
		exit(1);
	}
	if (target_addr == 0)
		goto usage;

	if ((igmp_socket = socket(AF_INET, SOCK_RAW, IPPROTO_IGMP)) < 0) {
		perror("mrmap: socket");
		exit(1);
	}
	if (setsockopt(igmp_socket, SOL_SOCKET, SO_RCVBUF, (char*)&sockbuf,
			sizeof(sockbuf)) < 0) {
		perror("mrmap: SO_RCVBUF");
		exit(1);
	}
	{			/* Find a good local address for us. */
		int     udp;
		struct sockaddr_in addr;
		int     addrlen = sizeof(addr);

		addr.sin_family = AF_INET;
		addr.sin_addr.s_addr = target_addr;
		addr.sin_port = htons(2000);	/* any port over 1024 will
						 * do... */
		if ((udp = socket(AF_INET, SOCK_DGRAM, 0)) < 0
		|| connect(udp, (struct sockaddr *) & addr, sizeof(addr)) < 0
		    || getsockname(udp, (struct sockaddr *) & addr, &addrlen) < 0) {
			perror("Determining local address");
			exit(-1);
		}
		close(udp);
		our_addr = addr.sin_addr.s_addr;
	}

	mr_find(target_addr, 0, 0);

	/* Main receive loop */
	while (todo) {
		fd_set fds;
		struct timeval tv;
		int count, recvlen, dummy = 0;
		register u_long src, dst, group;
		struct ip *ip;
		struct igmp *igmp;
		int ipdatalen, iphdrlen, igmpdatalen;
		u_char buf[16384];

		if (--todo->triesleft < 0) {
			printf("? %s\n", fmt_addr(todo->addr));
			mr_done(todo);
			continue;
		}
		if (todo->level == 0)
			ask(todo->addr);
		else
			ask2(todo->addr);

		FD_ZERO(&fds);
		FD_SET(igmp_socket, &fds);

		tv.tv_sec = timeout;
		tv.tv_usec = 0;

		count = select(igmp_socket + 1, &fds, 0, 0, &tv);

		if (count < 0) {
			if (errno != EINTR)
				perror("select");
			continue;
		} else if (count == 0) {
			continue;
		}
		recvlen = recvfrom(igmp_socket, buf, sizeof(buf),
				   0, NULL, &dummy);
		if (recvlen < sizeof(struct ip)) {
			if (recvlen < 0 && errno != EINTR)
				perror("recvfrom");
			++todo->triesleft;
			continue;
		}
		ip = (struct ip *) buf;
		src = ip->ip_src.s_addr;
		dst = ip->ip_dst.s_addr;
		iphdrlen = ip->ip_hl << 2;
		ipdatalen = ip->ip_len;
		if (iphdrlen + ipdatalen != recvlen)
			continue;

		igmp = (struct igmp *)(ip + 1);
		group = ntohl(igmp->igmp_group.s_addr);
		igmpdatalen = ipdatalen - IGMP_MINLEN;
		if (igmpdatalen < 0)
			continue;

		if (igmp->igmp_type != IGMP_DVMRP)
			continue;

		if (igmp->igmp_code != DVMRP_NEIGHBORS &&
		    igmp->igmp_code != DVMRP_NEIGHBORS2)
			continue;

		mr = mr_find(src, -1, 0);
		if (mr->triesleft <= 0)
			continue;

		switch (igmp->igmp_code) {

		case DVMRP_NEIGHBORS:
			if (group) {
				/* knows about DVMRP_NEIGHBORS2 msg */
				mr->level = group;
			} else {
				accept_neighbors(src, dst, (char *)(igmp + 1),
						 igmpdatalen);
			}
			break;

		case DVMRP_NEIGHBORS2:
			accept_neighbors2(src, dst, (char *)(igmp + 1),
					  igmpdatalen, group);
			break;
		}
	}
	exit(0);
}
