www.pudn.com > nat.rar > ip_nat_udp.cpp



#include  
#include  
 
#include "rawping.h" 
#include "list.h" 
#include "ip_nat_tuple.h" 
#include "ip_conntrack.h" 
#include "nat.h" 
#include "ip_nat_proto.h"
#include "list.h"


struct udp_chsum{
	unsigned long ip;
	unsigned short port;
};

#define UDP_TIMEOUT 30 // 一个UDP请求等待应答的时间是30*HZ(这个值一般是30秒)。
#define UDP_STREAM_TIMEOUT 180


static int udp_pkt_to_tuple(const void *datah, unsigned int datalen,
			    struct ip_conntrack_tuple *tuple)
{
	const struct udp_hdr  *hdr = (struct udp_hdr *)datah;

	tuple->src.u.udp.port = hdr->uh_sport;
	tuple->dst.u.udp.port = hdr->uh_dport;

	return 1;
}

static int udp_invert_tuple(struct ip_conntrack_tuple *tuple,
			    const struct ip_conntrack_tuple *orig)
{
	tuple->src.u.udp.port = orig->dst.u.udp.port;
	tuple->dst.u.udp.port = orig->src.u.udp.port;
	return 1;
}

static int
udp_in_range(const struct ip_conntrack_tuple *tuple,
	     enum ip_nat_manip_type maniptype,
	     const union ip_conntrack_manip_proto *min,
	     const union ip_conntrack_manip_proto *max)
{
	unsigned int port;

	if (maniptype == IP_NAT_MANIP_SRC)
		port = tuple->src.u.udp.port;
	else
		port = tuple->dst.u.udp.port;

	return ntohs(port) >= ntohs(min->udp.port)
		&& ntohs(port) <= ntohs(max->udp.port);
}

static int
udp_unique_tuple(struct ip_conntrack_tuple *tuple,
		 const struct ip_nat_range *range,
		 enum ip_nat_manip_type maniptype,
		 const struct ip_conntrack *conntrack)
{
	static unsigned short port = 0, *portptr;
	unsigned int range_size, min, i;

	if (maniptype == IP_NAT_MANIP_SRC)
		portptr = &tuple->src.u.udp.port;
	else
		portptr = &tuple->dst.u.udp.port;

	
	if (!(range->flags & IP_NAT_RANGE_PROTO_SPECIFIED)) {
		if (maniptype == IP_NAT_MANIP_DST)
			return 0;

		if (ntohs(*portptr) < 1024) {
			
			if (ntohs(*portptr)<512) {
				min = 1;
				range_size = 511 - min + 1;
			} else {
				min = 600;
				range_size = 1023 - min + 1;
			}
		} else {
			min = 1024;
			range_size = 65535 - 1024 + 1;
		}
	} else {
		min = ntohs(range->min.udp.port);
		range_size = ntohs(range->max.udp.port) - min + 1;
	}

	for (i = 0; i < range_size; i++, port++) {
		*portptr = htons(min + port % range_size);
		//DUMP_TUPLE(tuple);
		if (!ip_nat_used_tuple(tuple, conntrack))
			return 1;
	}
	return 0;
}

static void
udp_manip_pkt(struct ip_hdr *iph, unsigned int len,
	      const struct ip_conntrack_manip *manip,
	      enum ip_nat_manip_type maniptype)
{
	struct tsd_hdr  psdh;
	struct udp_hdr  *hdr = (struct udp_hdr  *)((unsigned long *)iph + iph->hl);
	unsigned long oldip;
	unsigned short *portptr;
	struct udp_chsum new1,old;
	

	if (maniptype == IP_NAT_MANIP_SRC) 
	{
		//oldip = iph->source_ip;
		oldip = iph->source_ip;
		portptr = &hdr->uh_sport;
	} else 
	{
		//oldip = iph->dest_ip;
		oldip = iph->dest_ip;
		portptr = &hdr->uh_dport;
	}
	if (hdr->uh_sum)
	{	
		old.ip = oldip;
		//old.port = htons(*portptr);
		old.port = *portptr;

		new1.ip = manip->ip;
		//new1.port = htons(manip->u.udp.port);
		new1.port = manip->u.udp.port;

		checksumadjust((unsigned char*)&hdr->uh_sum,(unsigned char *)&new1, 6,
					(unsigned char *)&old, 6);
			
	}
	
	*portptr = manip->u.udp.port;

}

static unsigned int
udp_print(char *buffer,
	  const struct ip_conntrack_tuple *match,
	  const struct ip_conntrack_tuple *mask)
{
	unsigned int len = 0;

	if (mask->src.u.udp.port)
		len += sprintf(buffer + len, "srcpt=%u ",
			       ntohs(match->src.u.udp.port));


	if (mask->dst.u.udp.port)
		len += sprintf(buffer + len, "dstpt=%u ",
			       ntohs(match->dst.u.udp.port));

	return len;
}

static unsigned int
udp_print_range(char *buffer, const struct ip_nat_range *range)
{
	if (range->min.udp.port != 0 || range->max.udp.port != 0xFFFF)
	{
		if (range->min.udp.port == range->max.udp.port)
			return sprintf(buffer, "port %u ",
				       ntohs(range->min.udp.port));
		else
			return sprintf(buffer, "ports %u-%u ",
				       ntohs(range->min.udp.port),
				       ntohs(range->max.udp.port));
	}
	else return 0;
}

static int udp_packet(struct ip_conntrack *conntrack,
		      struct ip_hdr *iph, unsigned int len,
		       enum ip_nat_manip_type maniptype)
{
	if (conntrack->status & IPS_SEEN_REPLY) 
	{
		ip_ct_refresh(conntrack, UDP_STREAM_TIMEOUT);
		conntrack->status |= IPS_ASSURED;
	} else
		ip_ct_refresh(conntrack, UDP_TIMEOUT);

	return 1;
}



struct ip_nat_protocol ip_nat_protocol_udp
= { { 0, 0 },  IPPROTO_UDP,
    udp_pkt_to_tuple,
    udp_invert_tuple,
    udp_manip_pkt,
    udp_in_range,
    udp_unique_tuple,
    udp_print,
    udp_print_range,
    udp_packet
};