www.pudn.com > NetPaw.rar > threadpool.h


/* 
 ThreadPool.cpp: implementation of the CThreadPool2 class 
 
 This class was stolen shamelessly from atlutil.h, fixes several bugs 
 and introduces a new function call. 
 
 Version 1.01 -- Last Modified 6/11/2004 
 
 1.01 
   - Made the class compile cleanly at level 4 
   - Added threadpool COM support.  Simply #define THREADPOOL_USES_COM to 
     turn it on. 
 
 1.0 
  - Fixed a bug in the ThreadProc() function which would cause the thread 
    to exit if an I/O event fails. 
 
  - Added IsThreadInPool() to assist implementers in determining whether or 
    not they are being called from within a pooled thread. 
 
*/ 
////////////////////////////////////////////////////////////////////// 
 
#pragma once 
 
#include  
 
#ifndef ATL_POOL_NUM_THREADS 
	#define ATL_POOL_NUM_THREADS 0 
#endif 
 
#ifndef ATL_POOL_STACK_SIZE 
	#define ATL_POOL_STACK_SIZE 0 
#endif 
 
#ifndef ATLS_DEFAULT_THREADSPERPROC 
	#define ATLS_DEFAULT_THREADSPERPROC 2 
#endif 
 
#ifndef ATLS_DEFAULT_THREADPOOLSHUTDOWNTIMEOUT 
	#define ATLS_DEFAULT_THREADPOOLSHUTDOWNTIMEOUT 36000 
#endif 
 
// 
// CThreadPool2 
// This class is a simple IO completion port based thread pool 
//	Worker: 
//		is a class that is responsible for handling requests 
//		queued on the thread pool. 
//		It must have a typedef for RequestType, where request type 
//		is the datatype to be queued on the pool 
//		RequestType must be castable to (DWORD) 
//		The value -1 is reserved for shutdown 
//		of the pool 
//		Worker must also have a void Execute(RequestType request, void *pvParam, OVERLAPPED *pOverlapped) function 
//	ThreadTraits: 
//		is a class that implements a static CreateThread function 
//		This allows for overriding how the threads are created 
// 
#ifndef ATLS_POOL_SHUTDOWN 
#define ATLS_POOL_SHUTDOWN ((OVERLAPPED*) ((__int64) -1)) 
#endif 
 
// Modification of ATL's CThreadPool, primarily to fix the ThreadProc so it *properly* handles IOCP events 
template  
	class CThreadPool2 
#ifdef __ATLUTIL_H__ 
  : public IThreadPoolConfig 
#endif 
{ 
protected: 
	CSimpleMap m_threadMap; 
 
	DWORD m_dwThreadEventId; 
 
	CComCriticalSection m_critSec; 
	DWORD m_dwStackSize; 
	DWORD m_dwMaxWait; 
 
	void *m_pvWorkerParam; 
	LONG m_bShutdown; 
 
	HANDLE m_hThreadEvent; 
	HANDLE m_hRequestQueue; 
 
public: 
 
	CThreadPool2() throw() : 
		m_hRequestQueue(NULL), 
		m_pvWorkerParam(NULL), 
		m_dwMaxWait(ATLS_DEFAULT_THREADPOOLSHUTDOWNTIMEOUT), 
		m_bShutdown(FALSE), 
		m_dwThreadEventId(0), 
		m_dwStackSize(0) 
	{ 
	} 
 
	~CThreadPool2() throw() 
	{ 
		Shutdown(); 
	} 
 
	// Initialize the thread pool 
	// if nNumThreads > 0, then it specifies the number of threads 
	// if nNumThreads < 0, then it specifies the number of threads per proc (-) 
	// if nNumThreads == 0, then it defaults to two threads per proc 
	// hCompletion is a handle of a file to associate with the completion port 
	// pvWorkerParam is a parameter that will be passed to TWorker::Execute 
	//	dwStackSize: 
	//		The stack size to use when creating the threads 
	HRESULT Initialize( void *pvWorkerParam = NULL, int nNumThreads = 0, DWORD dwStackSize = 0, HANDLE hCompletion = INVALID_HANDLE_VALUE ) throw() 
	{ 
		ATLASSERT( m_hRequestQueue == NULL ); 
 
		if (m_hRequestQueue)   // Already initialized 
			return AtlHresultFromWin32(ERROR_ALREADY_INITIALIZED); 
 
		if (S_OK != m_critSec.Init()) 
			return E_FAIL; 
 
		m_hThreadEvent = CreateEvent(NULL, FALSE, FALSE, NULL); 
		if (!m_hThreadEvent) 
		{ 
			m_critSec.Term(); 
			return AtlHresultFromLastError(); 
		} 
 
		// Create IO completion port to queue the requests 
		m_hRequestQueue = CreateIoCompletionPort(hCompletion, NULL, 0, nNumThreads); 
		if (m_hRequestQueue == NULL) 
		{ 
			// failed creating the Io completion port 
			m_critSec.Term(); 
			CloseHandle(m_hThreadEvent); 
			return AtlHresultFromLastError();		 
		} 
		m_pvWorkerParam = pvWorkerParam; 
		m_dwStackSize = dwStackSize; 
 
		HRESULT hr = SetSize(nNumThreads); 
		if (hr != S_OK) 
		{ 
			// Close the request queue handle 
			CloseHandle(m_hRequestQueue); 
 
			// Clear the queue handle 
			m_hRequestQueue = NULL; 
 
			// Uninitialize the critical sections 
			m_critSec.Term(); 
			CloseHandle(m_hThreadEvent); 
 
			return hr; 
		} 
 
		return S_OK; 
	} 
 
	// Checks to see if the current thread is one of the thread pool 
	// threads. 
	BOOL IsThreadInPool() 
	{ 
		DWORD dwThreadId = GetCurrentThreadId(); 
 
		CComCritSecLock lock(m_critSec, false); 
		if (FAILED(lock.Lock())) 
		{ 
			// out of memory 
			ATLASSERT( FALSE ); 
			return FALSE; 
		} 
 
		for (int i = m_threadMap.GetSize() - 1; i >= 0; i--) 
		{ 
			if (m_threadMap.GetKeyAt(i) == dwThreadId) return TRUE; 
		} 
		return FALSE; 
	} 
 
	// Shutdown the thread pool 
	// This function posts the shutdown request to all the threads in the pool 
	// It will wait for the threads to shutdown a maximum of dwMaxWait MS. 
	// If the timeout expires it just returns without terminating the threads. 
	void Shutdown( DWORD dwMaxWait = 0 ) throw() 
	{ 
		if (!m_hRequestQueue)   // Not initialized 
			return; 
 
		CComCritSecLock lock(m_critSec, false); 
		if (FAILED(lock.Lock())) 
		{ 
			// out of memory 
			ATLASSERT( FALSE ); 
			return; 
		} 
 
		if (dwMaxWait == 0) 
			dwMaxWait = m_dwMaxWait; 
 
		HRESULT hr = InternalResizePool(0, dwMaxWait); 
		if (hr != S_OK) 
			ATLTRACE(atlTraceUtil, 0, _T("Thread pool not shutting down cleanly : %08x"), hr); 
			// If the threads have not returned, then something is wrong 
 
		for (int i = m_threadMap.GetSize() - 1; i >= 0; i--) 
		{ 
			HANDLE hThread = m_threadMap.GetValueAt(i); 
			DWORD dwExitCode; 
			GetExitCodeThread(hThread, &dwExitCode); 
			if (dwExitCode == STILL_ACTIVE) 
			{ 
				ATLTRACE(atlTraceUtil, 0, _T("Terminating thread")); 
				TerminateThread(hThread, 0); 
			} 
			CloseHandle(hThread); 
		} 
 
		// Close the request queue handle 
		CloseHandle(m_hRequestQueue); 
 
		// Clear the queue handle 
		m_hRequestQueue = NULL; 
 
		ATLASSERT(m_threadMap.GetSize() == 0); 
 
		// Uninitialize the critical sections 
		lock.Unlock(); 
		m_critSec.Term(); 
		CloseHandle(m_hThreadEvent); 
	} 
 
	// IThreadPoolConfig methods 
	HRESULT STDMETHODCALLTYPE SetSize(int nNumThreads) throw() 
	{ 
		if (nNumThreads == 0) 
			nNumThreads = -ATLS_DEFAULT_THREADSPERPROC; 
 
		if (nNumThreads < 0) 
		{ 
			SYSTEM_INFO si; 
			GetSystemInfo(&si); 
			nNumThreads = (int) (-nNumThreads) * si.dwNumberOfProcessors; 
		} 
 
		return InternalResizePool(nNumThreads, m_dwMaxWait); 
	} 
 
	HRESULT STDMETHODCALLTYPE GetSize(int *pnNumThreads) throw() 
	{ 
		if (!pnNumThreads) 
			return E_POINTER; 
 
		*pnNumThreads = GetNumThreads(); 
		return S_OK; 
	} 
 
	HRESULT STDMETHODCALLTYPE SetTimeout(DWORD dwMaxWait) throw() 
	{ 
		m_dwMaxWait = dwMaxWait; 
 
		return S_OK; 
	} 
 
	HRESULT STDMETHODCALLTYPE GetTimeout(DWORD *pdwMaxWait) throw() 
	{ 
		if (!pdwMaxWait) 
			return E_POINTER; 
 
		*pdwMaxWait = m_dwMaxWait; 
 
		return S_OK; 
	} 
 
	// IUnknown methods 
	HRESULT STDMETHODCALLTYPE QueryInterface(REFIID riid, void **ppv) throw() 
	{ 
		if (!ppv) 
			return E_POINTER; 
 
		*ppv = NULL; 
 
		if (InlineIsEqualGUID(riid, __uuidof(IUnknown)) || 
			InlineIsEqualGUID(riid, __uuidof(IThreadPoolConfig))) 
		{ 
			*ppv = static_cast(this); 
			AddRef(); 
			return S_OK; 
		} 
		return E_NOINTERFACE; 
	} 
 
	ULONG STDMETHODCALLTYPE AddRef() throw() 
	{ 
		return 1; 
	} 
 
	ULONG STDMETHODCALLTYPE Release() throw() 
	{ 
		return 1; 
	} 
 
 
	HANDLE GetQueueHandle() throw() 
	{ 
		return m_hRequestQueue; 
	} 
 
	int GetNumThreads() throw() 
	{ 
		return m_threadMap.GetSize(); 
	} 
 
	// QueueRequest adds a request to the thread pool 
	// it will be picked up by one of the threads and dispatched to the worker 
	// in WorkerThreadProc 
	BOOL QueueRequest(typename TWorker::RequestType request) throw() 
	{ 
		if (!PostQueuedCompletionStatus(m_hRequestQueue, 0, (ULONG_PTR) request, NULL)) 
			return FALSE; 
		return TRUE; 
	} 
 
 
protected: 
 
	DWORD ThreadProc() throw() 
	{ 
		DWORD dwBytesTransfered; 
		ULONG_PTR dwCompletionKey; 
		OVERLAPPED* pOverlapped; 
 
		// this block is to ensure theWorker gets destructed before the  
		// thread handle is closed 
		{ 
			// We instantiate an instance of the worker class on the stack 
			// for the life time of the thread. 
			TWorker theWorker; 
 
			if (theWorker.Initialize(m_pvWorkerParam) == FALSE) 
			{ 
				return 1; 
			} 
 
#if (!defined(_ATL_NO_COM_SUPPORT) && defined(THREADPOOL_USES_COM)) 
 
#if ((_WIN32_WINNT >= 0x0400 ) || defined(_WIN32_DCOM)) && defined(_ATL_FREE_THREADED) 
			return SUCCEEDED(CoInitializeEx(NULL, COINIT_MULTITHREADED)); 
#else 
			return SUCCEEDED(CoInitialize(NULL)); 
#endif 
#endif // !_ATL_NO_COM_SUPPORT && THREADPOOL_USES_COM 
 
			SetEvent(m_hThreadEvent); 
			// Get the request from the IO completion port 
			while (TRUE) 
			{ 
				GetQueuedCompletionStatus(m_hRequestQueue, &dwBytesTransfered, &dwCompletionKey, &pOverlapped, INFINITE); 
				if (pOverlapped) 
				{ 
					if (pOverlapped == ATLS_POOL_SHUTDOWN) // Shut down 
					{ 
						LONG bResult = InterlockedExchange(&m_bShutdown, FALSE); 
						if (bResult) // Shutdown has not been cancelled 
							break; 
 
						// else, shutdown has been cancelled -- continue as before 
					} 
					else										// Do work 
					{ 
						TWorker::RequestType request = (TWorker::RequestType) dwCompletionKey; 
 
						// Process the request.  Notice the following: 
						// (1) It is the worker's responsibility to free any memory associated 
						// with the request if the request is complete 
						// (2) If the request still requires some more processing 
						// the worker should queue the request again for dispatching 
						theWorker.Execute(request, m_pvWorkerParam, pOverlapped); 
					} 
				} 
			} 
 
			theWorker.Terminate(m_pvWorkerParam); 
		} 
 
#if (!defined(_ATL_NO_COM_SUPPORT) && defined(THREADPOOL_USES_COM)) 
		CoUninitialize(); 
#endif // !_ATL_NO_COM_SUPPORT && THREADPOOL_USES_COM 
 
		m_dwThreadEventId = GetCurrentThreadId(); 
		SetEvent(m_hThreadEvent); 
 
		return 0;  
	} 
 
	static DWORD WINAPI WorkerThreadProc(LPVOID pv) throw() 
	{ 
		CThreadPool2 * pThis =  
			reinterpret_cast< CThreadPool2 * >(pv);  
 
		return pThis->ThreadProc(); 
	}  
 
	HRESULT InternalResizePool(int nNumThreads, int dwMaxWait) throw() 
	{ 
		if (!m_hRequestQueue)   // Not initialized 
			return E_FAIL; 
 
		CComCritSecLock lock(m_critSec, false); 
		if (FAILED(lock.Lock())) 
		{ 
			// out of memory 
			ATLASSERT( FALSE ); 
			return E_FAIL; 
		} 
 
		int nCurThreads = m_threadMap.GetSize(); 
		if (nNumThreads <= nCurThreads) 
		{ 
			int nNumShutdownThreads = nCurThreads - nNumThreads; 
			for (int nThreadIndex = 0; nThreadIndex < nNumShutdownThreads; nThreadIndex++) 
			{     
				ResetEvent(m_hThreadEvent); 
 
				InterlockedExchange(&m_bShutdown, TRUE); 
				PostQueuedCompletionStatus(m_hRequestQueue, 0, 0, ATLS_POOL_SHUTDOWN); 
				DWORD dwRet = WaitForSingleObject(m_hThreadEvent, dwMaxWait); 
 
				if (dwRet == WAIT_TIMEOUT) 
				{ 
					LONG bResult = InterlockedExchange(&m_bShutdown, FALSE); 
					if (bResult) // Nobody picked up the shutdown message 
					{ 
						return AtlHresultFromWin32(WAIT_TIMEOUT); 
					} 
				} 
				else if (dwRet != WAIT_OBJECT_0) 
				{ 
					return AtlHresultFromLastError(); 
				} 
 
				int nIndex = m_threadMap.FindKey(m_dwThreadEventId); 
				if (nIndex != -1) 
				{ 
					HANDLE hThread = m_threadMap.GetValueAt(nIndex); 
					// Wait for the thread to shutdown 
					if (WaitForSingleObject(hThread, 60000) == WAIT_OBJECT_0)  
					{ 
						CloseHandle(hThread); 
						m_threadMap.RemoveAt(nIndex); 
					} 
					else 
					{ 
						// Thread failed to terminate 
						return E_FAIL; 
					} 
				} 
			} 
		} 
		else 
		{ 
			int nNumNewThreads = nNumThreads - nCurThreads; 
 
			// Create and initialize worker threads 
			for (int nThreadIndex = 0; nThreadIndex < nNumNewThreads; nThreadIndex++) 
			{ 
				DWORD dwThreadID; 
				ResetEvent(m_hThreadEvent); 
				CHandle hdlThread( ThreadTraits::CreateThread(NULL, m_dwStackSize, WorkerThreadProc, (LPVOID)this, 0, &dwThreadID) ); 
 
				if (!hdlThread) 
				{ 
					HRESULT hr = AtlHresultFromLastError(); 
					ATLASSERT(hr != S_OK); 
					return hr; 
				} 
 
				DWORD dwRet = WaitForSingleObject(m_hThreadEvent, dwMaxWait); 
				if (dwRet != WAIT_OBJECT_0) 
				{ 
					if (dwRet == WAIT_TIMEOUT) 
					{ 
						return HRESULT_FROM_WIN32(WAIT_TIMEOUT); 
					} 
					else 
					{ 
						return AtlHresultFromLastError(); 
					} 
				} 
 
				if (m_threadMap.Add(dwThreadID, hdlThread) != FALSE) 
				{ 
					hdlThread.Detach(); 
				} 
			} 
		} 
 
		return S_OK; 
	} 
 
}; // class CThreadPool2