www.pudn.com > nat.rar > ip_nat_tcp.cpp


#include  
#include  
 
#include "rawping.h" 
#include "list.h" 
#include "ip_nat_tuple.h" 
#include "ip_conntrack.h" 
#include "nat.h" 
#include "ip_nat_proto.h"



struct tcp_cksum{
	unsigned long ip;
	unsigned short port;
};


#define MINS * 60 
#define HOURS * 60 MINS
#define DAYS * 24 HOURS

static unsigned long tcp_timeouts[]
= { 30 MINS, 	/*	TCP_CONNTRACK_NONE,	*/
    5 DAYS,	/*	TCP_CONNTRACK_ESTABLISHED,	*/
    2 MINS,	/*	TCP_CONNTRACK_SYN_SENT,	*/
    60 ,	/*	TCP_CONNTRACK_SYN_RECV,	*/
    2 MINS,	/*	TCP_CONNTRACK_FIN_WAIT,	*/
    2 MINS,	/*	TCP_CONNTRACK_TIME_WAIT,	*/
    10 ,	/*	TCP_CONNTRACK_CLOSE,	*/
    60 ,	/*	TCP_CONNTRACK_CLOSE_WAIT,	*/
    30 ,	/*	TCP_CONNTRACK_LAST_ACK,	*/
    2 MINS,	/*	TCP_CONNTRACK_LISTEN,	*/
};

#define sNO TCP_CONNTRACK_NONE
#define sES TCP_CONNTRACK_ESTABLISHED
#define sSS TCP_CONNTRACK_SYN_SENT
#define sSR TCP_CONNTRACK_SYN_RECV
#define sFW TCP_CONNTRACK_FIN_WAIT
#define sTW TCP_CONNTRACK_TIME_WAIT
#define sCL TCP_CONNTRACK_CLOSE
#define sCW TCP_CONNTRACK_CLOSE_WAIT
#define sLA TCP_CONNTRACK_LAST_ACK
#define sLI TCP_CONNTRACK_LISTEN
#define sIV TCP_CONNTRACK_MAX

static enum tcp_conntrack tcp_conntracks[2][5][TCP_CONNTRACK_MAX] = {
	{
/*	ORIGINAL */
/* 	  sNO, sES, sSS, sSR, sFW, sTW, sCL, sCW, sLA, sLI 	*/
/*syn*/	{sSS, sES, sSS, sSR, sSS, sSS, sSS, sSS, sSS, sLI },
/*fin*/	{sTW, sFW, sSS, sTW, sFW, sTW, sCL, sTW, sLA, sLI },
/*ack*/	{sES, sES, sSS, sES, sFW, sTW, sCL, sCW, sLA, sES },
/*rst*/ {sCL, sCL, sSS, sCL, sCL, sTW, sCL, sCL, sCL, sCL },
/*none*/{sIV, sIV, sIV, sIV, sIV, sIV, sIV, sIV, sIV, sIV }
	},
	{
/*	REPLY */
/* 	  sNO, sES, sSS, sSR, sFW, sTW, sCL, sCW, sLA, sLI 	*/
/*syn*/	{sSR, sES, sSR, sSR, sSR, sSR, sSR, sSR, sSR, sSR },
/*fin*/	{sCL, sCW, sSS, sTW, sTW, sTW, sCL, sCW, sLA, sLI },
/*ack*/	{sCL, sES, sSS, sSR, sFW, sTW, sCL, sCW, sCL, sLI },
/*rst*/ {sCL, sCL, sCL, sCL, sCL, sCL, sCL, sCL, sLA, sLI },
/*none*/{sIV, sIV, sIV, sIV, sIV, sIV, sIV, sIV, sIV, sIV }
	}
};
static int tcp_pkt_to_tuple(const void *datah, unsigned int datalen,
			    struct ip_conntrack_tuple *tuple)
{
	const struct tcp_hdr *hdr = (struct tcp_hdr *)datah;

	tuple->src.u.tcp.port = hdr->th_sport;
	tuple->dst.u.tcp.port = hdr->th_dport;

	return 1;
}

static int tcp_invert_tuple(struct ip_conntrack_tuple *tuple,
			    const struct ip_conntrack_tuple *orig)
{
	tuple->src.u.tcp.port = orig->dst.u.tcp.port;
	tuple->dst.u.tcp.port = orig->src.u.tcp.port;
	return 1;
}


static int
tcp_in_range(const struct ip_conntrack_tuple *tuple,
	     enum ip_nat_manip_type maniptype,
	     const union ip_conntrack_manip_proto *min,
	     const union ip_conntrack_manip_proto *max)
{
	unsigned short port;

	if (maniptype == IP_NAT_MANIP_SRC)
		port = tuple->src.u.tcp.port;
	else
		port = tuple->dst.u.tcp.port;

	return ntohs(port) >= ntohs(min->tcp.port)
		&& ntohs(port) <= ntohs(max->tcp.port);
}

static int
tcp_unique_tuple(struct ip_conntrack_tuple *tuple,
		 const struct ip_nat_range *range,
		 enum ip_nat_manip_type maniptype,
		 const struct ip_conntrack *conntrack)
{
	static unsigned short port = 0, *portptr;
	unsigned int range_size, min, i;

	if (maniptype == IP_NAT_MANIP_SRC)
		portptr = &tuple->src.u.tcp.port;
	else
		portptr = &tuple->dst.u.tcp.port;

	
	if (!(range->flags & IP_NAT_RANGE_PROTO_SPECIFIED)) {
		
		if (maniptype == IP_NAT_MANIP_DST)
			return 0;

		
		if (ntohs(*portptr) < 1024) {
			if (ntohs(*portptr)<512) {
				min = 1;
				range_size = 511 - min + 1;
			} else {
				min = 600;
				range_size = 1023 - min + 1;
			}
		} else {
			min = 1024;
			range_size = 65535 - 1024 + 1;
		}
	} else {
		min = ntohs(range->min.tcp.port);
		range_size = ntohs(range->max.tcp.port) - min + 1;
	}

	for (i = 0; i < range_size; i++, port++) {
		*portptr = htons(min + port % range_size);
		if (!ip_nat_used_tuple(tuple, conntrack)) {
			return 1;
		}
	}
	return 0;
}


void
tcp_fw_manip_pkt(struct ip_hdr *iph, unsigned int len,
	      const struct ip_conntrack_manip *manip,
	      enum ip_nat_manip_type maniptype)
{
	struct tcp_hdr *hdr = (struct tcp_hdr *)((unsigned long *)iph + iph->hl);
	unsigned long oldip;
	unsigned short *portptr;
	struct tcp_cksum new1,old;
	
	if (maniptype == IP_NAT_MANIP_SRC) 
	{
		oldip = iph->source_ip;
		portptr = &hdr->th_sport;
	} else 
	{
		oldip = iph->dest_ip;
		portptr = &hdr->th_dport;
	}
	

	//if(((void *)&hdr->th_sum + sizeof(hdr->th_sum) - (void *)iph) <= len) 
	{
		old.ip = oldip;
		old.port = htons(*portptr);

		new1.ip = manip->ip;
		new1.port = htons(manip->u.tcp.port);
		checksumadjust((unsigned char *)&hdr->th_sum,(unsigned char *)&new1, 6,
					(unsigned char *)&old, 6);			 
	}
	
	*portptr = manip->u.tcp.port;
}

static void
tcp_manip_pkt(struct ip_hdr *iph, unsigned int len,
	      const struct ip_conntrack_manip *manip,
	      enum ip_nat_manip_type maniptype)
{
	struct tcp_hdr *hdr = (struct tcp_hdr *)((unsigned long *)iph + iph->hl);
	unsigned long oldip;
	unsigned short *portptr;
	struct tcp_cksum new1,old;

	if (maniptype == IP_NAT_MANIP_SRC) 
	{
		oldip = iph->source_ip;
		portptr = &hdr->th_sport;
	} else 
	{
		oldip = iph->dest_ip;
		portptr = &hdr->th_dport;
	}
	

//	if(((void *)&hdr->th_sum + sizeof(hdr->th_sum) - (void *)iph) <= len) 
	{
		old.ip = oldip;
		old.port = htons(*portptr);

		new1.ip = manip->ip;
		new1.port = htons(manip->u.tcp.port);
		checksumadjust((unsigned char*)&hdr->th_sum,(unsigned char *)&new1, 6,
					(unsigned char *)&old, 6);	

	}

	
	*portptr = manip->u.tcp.port;
}

static unsigned int
tcp_print(char *buffer,
	  const struct ip_conntrack_tuple *match,
	  const struct ip_conntrack_tuple *mask)
{
	unsigned int len = 0;

	if (mask->src.u.tcp.port)
		len += sprintf(buffer + len, "srcpt=%u ",
			       ntohs(match->src.u.tcp.port));


	if (mask->dst.u.tcp.port)
		len += sprintf(buffer + len, "dstpt=%u ",
			       ntohs(match->dst.u.tcp.port));

	return len;
}

static unsigned int
tcp_print_range(char *buffer, const struct ip_nat_range *range)
{
	if (range->min.tcp.port != 0 || range->max.tcp.port != 0xFFFF) {
		if (range->min.tcp.port == range->max.tcp.port)
			return sprintf(buffer, "port %u ",
				       ntohs(range->min.tcp.port));
		else
			return sprintf(buffer, "ports %u-%u ",
				       ntohs(range->min.tcp.port),
				       ntohs(range->max.tcp.port));
	}
	else return 0;
}

static unsigned int get_conntrack_index(const struct tcp_hdr *tcph)
{
	if (tcph->th_flags & TH_RST ) return 3;
	else if (tcph->th_flags & TH_SYN) return 0;
	else if (tcph->th_flags & TH_FIN) return 1;
	else if (tcph->th_flags & TH_ACK) return 2;
	else return 4;
}

static int tcp_packet(struct ip_conntrack *conntrack,
		      struct ip_hdr *iph, unsigned int len,
		      enum ip_nat_manip_type maniptype)
{
	enum tcp_conntrack newconntrack, oldtcpstate;
	struct tcp_hdr *tcph = (struct tcp_hdr *)((unsigned long *)iph + iph->hl);

	if (len < (iph->hl + tcph->th_off) * 4) 
	{
		DEBUGP("ip_conntrack_tcp: Truncated packet.\n");
		return -1;
	}

	oldtcpstate = conntrack->proto.tcp.state;
	newconntrack
		= tcp_conntracks
		[maniptype]
		[get_conntrack_index(tcph)][oldtcpstate];

	
	if (newconntrack == TCP_CONNTRACK_MAX) {
		DEBUGP("ip_conntrack_tcp: Invalid dir=%i index=%u conntrack=%u\n",
		       maniptype, get_conntrack_index(tcph),
		       conntrack->proto.tcp.state);
		
		return -1;
	}

	conntrack->proto.tcp.state = newconntrack;

	
	if (oldtcpstate == TCP_CONNTRACK_SYN_SENT
	    && maniptype == IP_CT_DIR_REPLY
	    && (tcph->th_flags & TH_SYN) && (tcph->th_flags & TH_ACK))
		conntrack->proto.tcp.handshake_ack
			= htonl(ntohl(tcph->th_seq) + 1);
	
	if (!(conntrack->status & IPS_SEEN_REPLY) && tcph->th_flags & TH_RST) {
		if (del_timer(&conntrack->timeout))
			conntrack->timeout.function((unsigned long)conntrack);
	} else {
		
		if (oldtcpstate == TCP_CONNTRACK_SYN_RECV
		    && maniptype == IP_CT_DIR_ORIGINAL
		    && (tcph->th_flags & TH_ACK) && !(tcph->th_flags & TH_SYN)
		    && tcph->th_ack == conntrack->proto.tcp.handshake_ack)
			conntrack->status |=IPS_ASSURED;

		ip_ct_refresh(conntrack, tcp_timeouts[newconntrack]);
	}

	return 1;
}


struct ip_nat_protocol ip_nat_protocol_tcp
= { { 0, 0 }, IPPROTO_TCP,
    tcp_pkt_to_tuple,
    tcp_invert_tuple,
    tcp_manip_pkt,
    tcp_in_range,
    tcp_unique_tuple,
    tcp_print,
    tcp_print_range,
    tcp_packet
};