www.pudn.com > NETINFO.rar > connect.cpp


#include "connect.h" 
 
#include "tcpip.h" 
#include "acl.h" 
#include  
 
#define GENERAL_CONNECTION_TIMEOUT_SECONDS 600 
 
void IpNtoh(IP *ip) 
{ 
//winsock.h :u_short PASCAL FAR ntohs (u_short netshort); 
	ip->len = ntohs( ip->len ); 
	ip->id  = ntohs( ip->id  ); 
	*(WORD*)( (BYTE*)ip + 6 ) = ntohs( *(WORD*)( (BYTE*)ip + 6 ) ); 
	ip->checksum = ntohs( ip->checksum ); 
} 
 
void TcpNtoh(TCP *tcp) 
{ 
	tcp->srcPort = ntohs( tcp->srcPort ); 
	tcp->dstPort = ntohs( tcp->dstPort ); 
	tcp->seqNum = ntohl( tcp->seqNum ); 
	tcp->ackNum = ntohl( tcp->ackNum ); 
	*(WORD*)( (BYTE*)tcp+12 ) = ntohs( *(WORD*)( (BYTE*)tcp+12 ) ); 
	tcp->window = ntohs( tcp->window ); 
	tcp->checksum = ntohs( tcp->checksum ); 
	tcp->urgPoint = ntohs( tcp->urgPoint ); 
} 
 
void UdpNtoh(UDP* udp) 
{ 
	udp->srcPort = ntohs(udp->srcPort); 
	udp->dstPort = ntohs(udp->dstPort); 
	udp->length = ntohs(udp->length); 
	udp->checksum = ntohs(udp->checksum); 
} 
 
void IpHton(IP *ip) 
{ 
	ip->len = htons( ip->len ); 
	ip->id  = htons( ip->id  ); 
	*(WORD*)( (BYTE*)ip + 6 ) = htons( *(WORD*)( (BYTE*)ip + 6 ) ); 
	ip->checksum = htons( ip->checksum ); 
} 
 
void TcpHton(TCP *tcp) 
{ 
	tcp->srcPort = htons( tcp->srcPort ); 
	tcp->dstPort = htons( tcp->dstPort ); 
	tcp->seqNum = htonl( tcp->seqNum ); 
	tcp->ackNum = htonl( tcp->ackNum ); 
	*(WORD*)( (BYTE*)tcp+12 ) = htons( *(WORD*)( (BYTE*)tcp+12 ) ); 
	tcp->window = htons( tcp->window ); 
	tcp->checksum = htons( tcp->checksum ); 
	tcp->urgPoint = htons( tcp->urgPoint ); 
} 
 
void UdpHton(UDP* udp) 
{ 
	udp->srcPort = htons(udp->srcPort); 
	udp->dstPort = htons(udp->dstPort); 
	udp->length = htons(udp->length); 
	udp->checksum = htons(udp->checksum); 
} 
 
void GetPseudoTcpHeader(PSEUDO_TCP_HEADER* pseudo, const IP* ip, WORD tcpLen) 
{ 
	pseudo->src = ip->src; 
	pseudo->dst = ip->dst; 
	pseudo->zero = 0; 
	pseudo->protocol = IP_PROTOCOL_TCP; 
	pseudo->tcpLen = htons( tcpLen ); 
} 
 
void GetPseudoUdpHeader(PSEUDO_UDP_HEADER* pseudo, const IP* ip, WORD udpLen) 
{ 
	pseudo->src = ip->src; 
	pseudo->dst = ip->dst; 
	pseudo->zero = 0; 
	pseudo->protocol = IP_PROTOCOL_UDP; 
	pseudo->udpLen = htons( udpLen ); 
} 
 
DWORD GetSum(const WORD *buffer, int size) 
{ 
	DWORD cksum=0; 
 
	while(size >1)  
	{ 
		cksum+=*buffer++; 
		size -=sizeof(WORD); 
	} 
   
	if(size )  
	{ 
		cksum += *(UCHAR*)buffer; 
	} 
	return cksum; 
} 
 
WORD GetCheckSumFromSum(DWORD cksum) 
{ 
	cksum = (cksum >> 16) + (cksum & 0xffff); 
	cksum += (cksum >>16); 
	return (WORD)(~cksum); 
} 
 
WORD GetCheckSum(const WORD *buffer, int size)  
{ 
	return GetCheckSumFromSum( GetSum(buffer, size) ); 
} 
 
WORD GetTcpCheckSum(const TCP* tcp, const IP* ip, WORD tcpLen) 
{ 
	PSEUDO_TCP_HEADER pseudo; 
	GetPseudoTcpHeader(&pseudo, ip, tcpLen); 
	DWORD sum = GetSum((const WORD*)&pseudo, sizeof(pseudo) ); 
	sum += GetSum((const WORD*)tcp, tcpLen); 
	return GetCheckSumFromSum(sum); 
} 
 
WORD GetUdpCheckSum(const UDP* udp, const IP* ip, WORD udpLen) 
{ 
	PSEUDO_UDP_HEADER pseudo; 
	GetPseudoUdpHeader(&pseudo, ip, udpLen); 
	DWORD sum = GetSum((const WORD*)&pseudo, sizeof(pseudo) ); 
	sum += GetSum((const WORD*)udp, udpLen); 
	return GetCheckSumFromSum(sum); 
} 
 
WORD MyGetIpCheckSum(const WORD*pBuffer, DWORD number) 
{ 
	WORD result = 0; 
	for(DWORD i=0; isrc; 
	tcpAddr[0].port = tcp->srcPort; 
	tcpAddr[1].ip = ip->dst; 
	tcpAddr[1].port = tcp->dstPort; 
	if( pdwSrc ) *pdwSrc = 0; 
	if( memcmp(tcpAddr+0, tcpAddr+1, sizeof(tcpAddr[0]))>0 ) 
	{ 
		TCP_ADDR temp; 
		temp = tcpAddr[0]; 
		tcpAddr[0] = tcpAddr[1]; 
		tcpAddr[1] = temp; 
		if( pdwSrc ) *pdwSrc = 1; 
	} 
	return TRUE; 
} 
 
void GetConnectionTempFileName(const TCP_ADDR* connecter, const TCP_ADDR* listener,  
							   char fileName[MAX_PATH]) 
{ 
	SYSTEMTIME lt; 
	GetLocalTime(<); 
	sprintf(fileName, "%04d%02d%02d_%02d%02d%02d_%04d_%d__%d_%d_%d_%d_%d__%d_%d_%d_%d_%d.txt",  
		lt.wYear, lt.wMonth, lt.wDay, lt.wHour, lt.wMinute, lt.wSecond, lt.wMilliseconds, rand(), 
		connecter->ip.b1, connecter->ip.b2, connecter->ip.b3, connecter->ip.b4, connecter->port, 
		listener->ip.b1, listener->ip.b2, listener->ip.b3, listener->ip.b4, listener->port); 
} 
 
 
DWORD CConnection::m_dwConnectCount = 0; 
 
CConnection::CConnection(const ETHERNET *ether, const IP*ip, NetInfo* pNi, const NETINFO_CALLBACKS *pFuncs, DWORD dwAttachData) 
:m_pNi(pNi), m_pFuncs(pFuncs), m_dwAttachData(dwAttachData) 
{ 
	CConnection::m_dwConnectCount++; 
	assert( ip->mf == 0 && ip->offset == 0 ); 
	TCP *tcp = (TCP*)( (BYTE*)ip + 4 * ip->hl ); 
	assert( tcp->ack && tcp->syn ); 
 
	//write last time 
	m_lastTime = time(NULL); 
 
	//reset data 
	memset(m_window, 0, sizeof(m_window)); 
 
	//get information from (SYN,ACK)TCP packet 
	memcpy(m_ether, ether, ETHERNET_ADDR_LEN*2); 
	m_addr[0].ip = ip->src; 
	m_addr[0].port = tcp->srcPort; 
	m_data[0].seq = tcp->ackNum; 
	m_data[0].closed = FALSE; 
	m_addr[1].ip = ip->dst; 
	m_addr[1].port = tcp->dstPort; 
	m_data[1].seq = tcp->seqNum + 1; 
	m_data[1].closed = FALSE; 
	m_listener = 0; 
	m_connecter = 1; 
 
	//order this SOCKET_DATA 
	if( memcmp( &m_addr[0], &m_addr[1], sizeof(TCP_ADDR) ) > 0 ) 
	{ 
		TCP_ADDR addrTemp; 
		SOCKET_DATA dataTemp; 
		BYTE etherTemp[ETHERNET_ADDR_LEN]; 
		addrTemp = m_addr[0]; 
		dataTemp = m_data[0]; 
		memcpy(etherTemp, m_ether[0], ETHERNET_ADDR_LEN); 
		m_addr[0] = m_addr[1]; 
		m_data[0] = m_data[1]; 
		m_addr[1] = addrTemp; 
		m_data[1] = dataTemp; 
		memcpy(m_ether[0], m_ether[1], ETHERNET_ADDR_LEN); 
		memcpy(m_ether[1], etherTemp, ETHERNET_ADDR_LEN); 
		m_listener = 1; 
		m_connecter = 0; 
	} 
} 
 
CConnection::~CConnection() 
{ 
	CConnection::m_dwConnectCount--; 
	if( m_pFuncs && m_pFuncs->OnCloseConnect ) 
		m_pFuncs->OnCloseConnect(m_pNi, this); 
} 
 
BOOL CConnection::IsTimeOut() 
{ 
	time_t curTime = time(NULL); 
	if( m_lastTime <= curTime && curTime <= m_lastTime + GENERAL_CONNECTION_TIMEOUT_SECONDS ) 
		return FALSE; 
	else 
		return TRUE; 
} 
 
//return FALSE indicate this connection should be closed,  
//return TRUE means this connection should be kept 
BOOL CConnection::OnTcpIpPacket(const IP *ip) 
{ 
	m_lastTime = time(NULL); 
	TCP *tcp = (TCP*)( (BYTE*)ip + 4 * ip->hl ); 
	DWORD headLen = ip->hl * 4 + tcp->hl * 4; 
	DWORD dataLen = ip->len - headLen; 
	BYTE *pData = (BYTE*)ip + headLen; 
	assert( ip->mf == 0 && ip->offset == 0 ); //now I cannot process fragment of IP packet 
	int sender, receiver; 
	if( memcmp(&ip->src, &m_addr[0].ip, sizeof(IP_ADDR) ) == 0 
		&& tcp->srcPort == m_addr[0].port ) 
	{ 
		sender = 0; 
		receiver = 1; 
	} 
	else 
	{ 
		sender = 1; 
		receiver = 0; 
	} 
 
	if( dataLen ) 
	{ 
		//if window is too small 
		DWORD endOffset =  tcp->seqNum + dataLen - m_data[receiver].seq; 
		if( endOffset == 0 ) 
			return CanKeepContinue(ip, tcp, sender); 
		if( endOffset >= MAX_TCP_WINDOW_SIZE ) 
			return TRUE; //tcp packet that out of window range cannot be treated as valid packet 
	 
		DWORD dstBegin, srcBegin, copyLen; 
		//now we can copy data to tcp window buffer 
		if( tcp->seqNum - m_data[receiver].seq < MAX_TCP_WINDOW_SIZE ) 
		{ 
			//tcp packet is fully in window 
			srcBegin = 0; 
			copyLen = dataLen; 
			dstBegin = tcp->seqNum - m_data[receiver].seq; 
		} 
		else //tcp packet is partly in window 
		{ 
			srcBegin = m_data[receiver].seq - tcp->seqNum; 
			copyLen = dataLen - srcBegin; 
			dstBegin = 0; 
		} 
		memcpy( m_window[receiver].buffer + dstBegin,  pData + srcBegin,  copyLen ); 
		memset( m_window[receiver].bitmap + dstBegin,  1, copyLen ); 
		if( dstBegin + copyLen > m_window[receiver].lastByte ) 
			m_window[receiver].lastByte = dstBegin + copyLen; 
 
		//we look up a dense block, DWORD i means length of this block 
		for(DWORD i=0; i<=m_window[receiver].lastByte && m_window[receiver].bitmap[i]!=0; i++)NULL; 
		if( i == 0 ) 
			return TRUE; 
		if( !OnData( sender, receiver, m_window[receiver].buffer, i ) ) 
			return FALSE; 
 
		//move data to begin 
		DWORD lastByte = m_window[receiver].lastByte; 
		memmove( m_window[receiver].buffer, m_window[receiver].buffer+i, lastByte + 1 - i); 
		memmove( m_window[receiver].bitmap, m_window[receiver].bitmap+i, lastByte + 1 - i); 
		memset( m_window[receiver].bitmap + lastByte - i, 0, i+1); 
		m_data[receiver].seq += i; 
		m_window[receiver].lastByte = lastByte - i; 
	} 
 
	return CanKeepContinue(ip, tcp, sender); 
} 
 
BOOL CConnection::OnData(int sender, int receiver, const BYTE* pData, DWORD length) 
{ 
	if( m_pFuncs && m_pFuncs->OnTcpData ) 
		m_pFuncs->OnTcpData(m_pNi, this, sender==m_connecter, pData, length); 
	return TRUE; 
} 
 
BOOL CConnection::CanKeepContinue(const IP*ip, const TCP*tcp, int sender) 
{ 
	if( tcp->fin ) 
		m_data[sender].closed = TRUE; 
	if( tcp->rst ) 
		m_data[0].closed = m_data[1].closed = TRUE; 
	if( m_data[0].closed && m_data[1].closed ) 
		return FALSE; 
	else 
		return TRUE; 
} 
 
BOOL CConnection::IsMyPacket(const IP *ip) 
{ 
	const TCP *tcp = (TCP*)( (BYTE*)ip + 4 * ip->hl ); 
	TCP_ADDR addr[2]; 
	TcpIp2TcpAddr(ip, tcp, addr, NULL); 
	if( memcmp(addr, m_addr, 2*sizeof(TCP_ADDR))==0 ) 
		return TRUE; 
	else 
		return FALSE; 
} 
 
void CConnection::BreakDownConnection() 
{ 
	BYTE rstBuffer[MAX_TCP_PACKET]; 
	DWORD length; 
 
	//send immediate RST tcp packet from connecter to listener 
	if(!m_pNi->m_bInputIsFile && MakeRstTcpPacketV2(m_ether[m_connecter], m_ether[m_listener], 
							&m_addr[m_connecter].ip, &m_addr[m_listener].ip, 
							m_addr[m_connecter].port, m_addr[m_listener].port, 
							m_data[m_listener].seq, m_data[m_connecter].seq, 
							rstBuffer, sizeof(rstBuffer), &length) ) 
		pcap_sendpacket(m_pNi->m_pInputPcap, rstBuffer, length); 
 
	//send immediate RST tcp packet from connecter to listener 
	if(!m_pNi->m_bInputIsFile && MakeRstTcpPacketV2(m_ether[m_listener], m_ether[m_connecter], 
							&m_addr[m_listener].ip, &m_addr[m_connecter].ip, 
							m_addr[m_listener].port, m_addr[m_connecter].port, 
							m_data[m_connecter].seq, m_data[m_listener].seq, 
							rstBuffer, sizeof(rstBuffer), &length) ) 
		pcap_sendpacket(m_pNi->m_pInputPcap, rstBuffer, length); 
 
	//add this connection to deny list 
	ACL_ITEM item; 
	item.permit = ACL_DENY; 
	item.validUntil = time(NULL) + CONNECT_BREAKDOWN_DURATION; 
	item.src.addrLow = item.src.addrHigh = m_addr[m_connecter].ip; 
	item.src.portLow = item.src.portHigh = m_addr[m_connecter].port; 
	item.dst.addrLow = item.dst.addrHigh = m_addr[m_listener].ip; 
	item.dst.portLow = item.dst.portHigh = m_addr[m_listener].port; 
	AddTempAclItem(m_pNi, &item); 
} 
 
 
DWORD GetConnectAttachData(HTCPCONNECT hConnect) 
{ 
	return hConnect->m_dwAttachData; 
} 
 
void SetConnectAttachData(HTCPCONNECT hConnect, DWORD dwAttachData) 
{ 
	hConnect->m_dwAttachData = dwAttachData; 
} 
 
enum CONNECT_TYPE GetConnectTypeEx(HTCPCONNECT hConnect) 
{ 
	return hConnect->GetConnectType(); 
} 
 
void GetConnectAddress(HTCPCONNECT hConnect, TCP_ADDR* connecter, TCP_ADDR* listener) 
{ 
	*connecter = hConnect->m_addr[hConnect->m_connecter]; 
	*listener = hConnect->m_addr[hConnect->m_listener]; 
} 
 
void BreakDownConnect(HTCPCONNECT hConnect) 
{ 
	hConnect->BreakDownConnection(); 
}