www.pudn.com > tdi_fw.zip > filter.c
// -*- mode: C++; tab-width: 4; indent-tabs-mode: nil -*- (for GNU Emacs) // // $Id: filter.c,v 1.5 2002/12/03 13:32:17 dev Exp $ /* * Filtering related routines */ #include#include #include "sock.h" #include "filter.h" #include "memtrack.h" #include "pid_pname.h" #include "tdi_fw.h" // size of cyclic queue for logging #define REQUEST_QUEUE_SIZE 1024 /* rules chains (main (first entry) and process-related) */ static struct { struct { struct flt_rule *head; struct flt_rule *tail; char *pname; // name of process } chain[MAX_CHAINS_COUNT]; KSPIN_LOCK guard; } g_rules; /* logging request queue */ static struct { struct flt_request *data; KSPIN_LOCK guard; ULONG head; /* write to head */ ULONG tail; /* read from tail */ HANDLE event_handle; PKEVENT event; } g_queue; // init NTSTATUS filter_init(void) { NTSTATUS status; int i; pid_pname_init(); /* rules chain */ KeInitializeSpinLock(&g_rules.guard); for (i = 0; i < MAX_CHAINS_COUNT; i++) { g_rules.chain[i].head = g_rules.chain[i].tail = NULL; g_rules.chain[i].pname = NULL; } /* request queue */ KeInitializeSpinLock(&g_queue.guard); g_queue.data = (struct flt_request *)malloc_np(sizeof(struct flt_request) * REQUEST_QUEUE_SIZE); if (g_queue.data == NULL) { KdPrint(("[tdi_fw] filter_init: malloc_np!\n")); return STATUS_INSUFFICIENT_RESOURCES; } memset(g_queue.data, 0, sizeof(struct flt_request) * REQUEST_QUEUE_SIZE); g_queue.head = g_queue.tail = 0; return STATUS_SUCCESS; } // init for user part starting NTSTATUS filter_init_2(void) { NTSTATUS status; if (g_queue.event_handle == NULL) { UNICODE_STRING str; OBJECT_ATTRIBUTES oa; RtlInitUnicodeString(&str, L"\\BaseNamedObjects\\tdi_fw_request"); InitializeObjectAttributes(&oa, &str, 0, NULL, NULL); status = ZwCreateEvent(&g_queue.event_handle, EVENT_ALL_ACCESS, &oa, SynchronizationEvent, FALSE); if (status != STATUS_SUCCESS) { KdPrint(("[tdi_fw] filter_init_2: ZwCreateEvent: 0x%x\n", status)); return status; } } if (g_queue.event == NULL) { status = ObReferenceObjectByHandle(g_queue.event_handle, EVENT_ALL_ACCESS, NULL, KernelMode, &g_queue.event, NULL); if (status != STATUS_SUCCESS) { KdPrint(("[tdi_fw] filter_init_2: ObReferenceObjectByHandle: 0x%x\n", status)); return status; } } return STATUS_SUCCESS; } // cleanup for user part void filter_free_2(void) { if (&g_queue.event != NULL) { ObDereferenceObject(&g_queue.event); g_queue.event = NULL; } if (g_queue.event_handle != NULL) { ZwClose(g_queue.event_handle); g_queue.event_handle = NULL; } } // free void filter_free(void) { KIRQL irql; struct plist_entry *ple; int i; // clear all chains for (i = 0; i < MAX_CHAINS_COUNT; i++) clear_flt_chain(i); /* clear request queue */ KeAcquireSpinLock(&g_queue.guard, &irql); for (i = 0; i < REQUEST_QUEUE_SIZE; i++) if (g_queue.data[i].pname != NULL) free(g_queue.data[i].pname); free(g_queue.data); KeReleaseSpinLock(&g_queue.guard, irql); pid_pname_free(); } // quick filter int quick_filter(struct flt_request *request, struct flt_rule *rule) { const struct sockaddr_in *from, *to; struct flt_rule *r; struct plist_entry *ple; KIRQL irql; int result; // not IP if (request->addr.len != sizeof(struct sockaddr_in) || request->addr.from.sa_family != AF_INET || request->addr.to.sa_family != AF_INET) { KdPrint(("[tdi_fw] quick_filter: not ip addr!\n")); return FILTER_DENY; } from = (const struct sockaddr_in *)&request->addr.from; to = (const struct sockaddr_in *)&request->addr.to; // default behavior result = FILTER_ALLOW; if (rule != NULL) { memset(rule, 0, sizeof(*rule)); rule->result = result; } // quick filter KeAcquireSpinLock(&g_rules.guard, &irql); // go through rules for (r = g_rules.chain[pid_pname_get_context(request->pid)].head; r != NULL; r = r->next) // Can anybody understand it? if (r->proto == request->proto && r->direction == request->direction && (r->addr_from & r->mask_from) == (from->sin_addr.s_addr & r->mask_from) && (r->addr_to & r->mask_to) == (to->sin_addr.s_addr & r->mask_to) && (r->port_from == 0 || ((r->port2_from == 0) ? (r->port_from == from->sin_port) : (ntohs(from->sin_port) >= ntohs(r->port_from) && ntohs(from->sin_port) <= ntohs(r->port2_from)))) && (r->port_to == 0 || ((r->port2_to == 0) ? (r->port_to == to->sin_port) : (ntohs(to->sin_port) >= ntohs(r->port_to) && ntohs(to->sin_port) <= ntohs(r->port2_to))))) { result = r->result; KdPrint(("[tdi_fw] quick_filter: found rule with result: %d\n", result)); if (rule != NULL) { memcpy(rule, r, sizeof(*rule)); rule->next = NULL; // useless field } break; } KeReleaseSpinLock(&g_rules.guard, irql); request->result = result; return result; } // write request to request queue BOOLEAN log_request(struct flt_request *request) { KIRQL irql, irql2; ULONG next_head; char pname_buf[256], *pname; struct plist_entry *ple; if (g_got_control == 0) // don't log - no control app return FALSE; KeAcquireSpinLock(&g_queue.guard, &irql); next_head = (g_queue.head + 1) % REQUEST_QUEUE_SIZE; if (next_head == g_queue.tail) { // queue overflow: reject one entry from tail KdPrint(("[tdi_fw] log_request: queue overflow!\n")); request->log_skipped = g_queue.data[g_queue.tail].log_skipped + 1; g_queue.tail = (g_queue.tail + 1) % REQUEST_QUEUE_SIZE; } else request->log_skipped = 0; memcpy(&g_queue.data[g_queue.head], request, sizeof(struct flt_request)); // try to get process name pname = NULL; if (pid_pname_resolve(request->pid, pname_buf, sizeof(pname_buf)) ) { KdPrint(("[tdi_fw] log_request: pid:%u; pname:%s\n", request->pid, pname_buf)); // ala strdup() pname = (char *)malloc_np(strlen(pname_buf) + 1); if (pname != NULL) strcpy(pname, pname_buf); else KdPrint(("[tdi_fw] log_request: malloc_np!\n")); } g_queue.data[g_queue.head].pname = pname; g_queue.head = next_head; KeReleaseSpinLock(&g_queue.guard, irql); // signal to user app if (g_queue.event != NULL) KeSetEvent(g_queue.event, IO_NO_INCREMENT, FALSE); return TRUE; } // read requests from log queue to buffer ULONG get_request(char *buf, ULONG buf_size) { ULONG result = 0; KIRQL irql; // sanity check if (buf_size < sizeof(struct flt_request)) return 0; KeAcquireSpinLock(&g_queue.guard, &irql); while (g_queue.head != g_queue.tail) { int pname_size; if (g_queue.data[g_queue.tail].pname != NULL) pname_size = strlen(g_queue.data[g_queue.tail].pname) + 1; else pname_size = 0; if (buf_size < sizeof(struct flt_request) + pname_size) return result; memcpy(buf, &g_queue.data[g_queue.tail], sizeof(struct flt_request)); if (g_queue.data[g_queue.tail].pname != NULL) { ((struct flt_request *)buf)->struct_size += pname_size; strcpy(buf + sizeof(struct flt_request), g_queue.data[g_queue.tail].pname); free(g_queue.data[g_queue.tail].pname); g_queue.data[g_queue.tail].pname = NULL; } result += sizeof(struct flt_request) + pname_size; buf += sizeof(struct flt_request) + pname_size; buf_size -= sizeof(struct flt_request) + pname_size; g_queue.tail = (g_queue.tail + 1) % REQUEST_QUEUE_SIZE; } KdPrint(("[tdi_fw] get_request: copied %u bytes\n", result)); KeReleaseSpinLock(&g_queue.guard, irql); return result; } // add rule to rules chain NTSTATUS add_flt_rule(int chain, const struct flt_rule *rule) { NTSTATUS status; struct flt_rule *new_rule; KIRQL irql; // sanity check if (chain < 0 && chain >= MAX_CHAINS_COUNT) return STATUS_INVALID_PARAMETER_1; KeAcquireSpinLock(&g_rules.guard, &irql); new_rule = (struct flt_rule *)malloc_np(sizeof(struct flt_rule)); if (new_rule == NULL) { KdPrint(("[tdi_fw] add_flt_rule: malloc_np\n")); status = STATUS_INSUFFICIENT_RESOURCES; goto done; } memcpy(new_rule, rule, sizeof(*new_rule)); // append new_rule->next = NULL; if (g_rules.chain[chain].tail == NULL) { g_rules.chain[chain].head = new_rule; g_rules.chain[chain].tail = new_rule; } else { g_rules.chain[chain].tail->next = new_rule; g_rules.chain[chain].tail = new_rule; } status = STATUS_SUCCESS; done: KeReleaseSpinLock(&g_rules.guard, irql); return status; } // clear rules chain NTSTATUS clear_flt_chain(int chain) { struct flt_rule *rule; KIRQL irql; // sanity check if (chain < 0 && chain >= MAX_CHAINS_COUNT) return STATUS_INVALID_PARAMETER_1; /* rules chain */ KeAcquireSpinLock(&g_rules.guard, &irql); for (rule = g_rules.chain[chain].head; rule != NULL;) { struct flt_rule *rule2 = rule->next; free(rule); rule = rule2; } g_rules.chain[chain].head = NULL; g_rules.chain[chain].tail = NULL; KeReleaseSpinLock(&g_rules.guard, irql); return STATUS_SUCCESS; } // set process name for chain NTSTATUS set_chain_pname(int chain, char *pname) { KIRQL irql; NTSTATUS status; // sanity check if (chain < 0 || chain >= MAX_CHAINS_COUNT) return STATUS_INVALID_PARAMETER_1; KdPrint(("[tdi_fw] set_chain_pname: setting name %s for chain %d\n", pname, chain)); KeAcquireSpinLock(&g_rules.guard, &irql); if (g_rules.chain[chain].pname != NULL) free(g_rules.chain[chain].pname); g_rules.chain[chain].pname = (char *)malloc_np(strlen(pname) + 1); if (g_rules.chain[chain].pname != NULL) { // copy pname strcpy(g_rules.chain[chain].pname, pname); status = STATUS_SUCCESS; } else status = STATUS_INSUFFICIENT_RESOURCES; KeReleaseSpinLock(&g_rules.guard, irql); return status; } // set result of process name by pid resolving NTSTATUS set_pid_pname(ULONG pid, char *pname) { KIRQL irql; int i, chain = 0; KdPrint(("[tdi_fw] set_pid_pname: setting pname %s for pid %u\n", pname, pid)); KeAcquireSpinLock(&g_rules.guard, &irql); for (i = 0; i < MAX_CHAINS_COUNT; i++) if (g_rules.chain[i].pname != NULL && _stricmp(pname, g_rules.chain[i].pname) == 0) { KdPrint(("[tdi_fw] set_pid_pname: found chain %d\n", i)); chain = i; break; } KeReleaseSpinLock(&g_rules.guard, irql); return pid_pname_set(pid, pname, chain); }