www.pudn.com > NetPaw.rar > threadpool.h
/*
ThreadPool.cpp: implementation of the CThreadPool2 class
This class was stolen shamelessly from atlutil.h, fixes several bugs
and introduces a new function call.
Version 1.01 -- Last Modified 6/11/2004
1.01
- Made the class compile cleanly at level 4
- Added threadpool COM support. Simply #define THREADPOOL_USES_COM to
turn it on.
1.0
- Fixed a bug in the ThreadProc() function which would cause the thread
to exit if an I/O event fails.
- Added IsThreadInPool() to assist implementers in determining whether or
not they are being called from within a pooled thread.
*/
//////////////////////////////////////////////////////////////////////
#pragma once
#include
#ifndef ATL_POOL_NUM_THREADS
#define ATL_POOL_NUM_THREADS 0
#endif
#ifndef ATL_POOL_STACK_SIZE
#define ATL_POOL_STACK_SIZE 0
#endif
#ifndef ATLS_DEFAULT_THREADSPERPROC
#define ATLS_DEFAULT_THREADSPERPROC 2
#endif
#ifndef ATLS_DEFAULT_THREADPOOLSHUTDOWNTIMEOUT
#define ATLS_DEFAULT_THREADPOOLSHUTDOWNTIMEOUT 36000
#endif
//
// CThreadPool2
// This class is a simple IO completion port based thread pool
// Worker:
// is a class that is responsible for handling requests
// queued on the thread pool.
// It must have a typedef for RequestType, where request type
// is the datatype to be queued on the pool
// RequestType must be castable to (DWORD)
// The value -1 is reserved for shutdown
// of the pool
// Worker must also have a void Execute(RequestType request, void *pvParam, OVERLAPPED *pOverlapped) function
// ThreadTraits:
// is a class that implements a static CreateThread function
// This allows for overriding how the threads are created
//
#ifndef ATLS_POOL_SHUTDOWN
#define ATLS_POOL_SHUTDOWN ((OVERLAPPED*) ((__int64) -1))
#endif
// Modification of ATL's CThreadPool, primarily to fix the ThreadProc so it *properly* handles IOCP events
template
class CThreadPool2
#ifdef __ATLUTIL_H__
: public IThreadPoolConfig
#endif
{
protected:
CSimpleMap m_threadMap;
DWORD m_dwThreadEventId;
CComCriticalSection m_critSec;
DWORD m_dwStackSize;
DWORD m_dwMaxWait;
void *m_pvWorkerParam;
LONG m_bShutdown;
HANDLE m_hThreadEvent;
HANDLE m_hRequestQueue;
public:
CThreadPool2() throw() :
m_hRequestQueue(NULL),
m_pvWorkerParam(NULL),
m_dwMaxWait(ATLS_DEFAULT_THREADPOOLSHUTDOWNTIMEOUT),
m_bShutdown(FALSE),
m_dwThreadEventId(0),
m_dwStackSize(0)
{
}
~CThreadPool2() throw()
{
Shutdown();
}
// Initialize the thread pool
// if nNumThreads > 0, then it specifies the number of threads
// if nNumThreads < 0, then it specifies the number of threads per proc (-)
// if nNumThreads == 0, then it defaults to two threads per proc
// hCompletion is a handle of a file to associate with the completion port
// pvWorkerParam is a parameter that will be passed to TWorker::Execute
// dwStackSize:
// The stack size to use when creating the threads
HRESULT Initialize( void *pvWorkerParam = NULL, int nNumThreads = 0, DWORD dwStackSize = 0, HANDLE hCompletion = INVALID_HANDLE_VALUE ) throw()
{
ATLASSERT( m_hRequestQueue == NULL );
if (m_hRequestQueue) // Already initialized
return AtlHresultFromWin32(ERROR_ALREADY_INITIALIZED);
if (S_OK != m_critSec.Init())
return E_FAIL;
m_hThreadEvent = CreateEvent(NULL, FALSE, FALSE, NULL);
if (!m_hThreadEvent)
{
m_critSec.Term();
return AtlHresultFromLastError();
}
// Create IO completion port to queue the requests
m_hRequestQueue = CreateIoCompletionPort(hCompletion, NULL, 0, nNumThreads);
if (m_hRequestQueue == NULL)
{
// failed creating the Io completion port
m_critSec.Term();
CloseHandle(m_hThreadEvent);
return AtlHresultFromLastError();
}
m_pvWorkerParam = pvWorkerParam;
m_dwStackSize = dwStackSize;
HRESULT hr = SetSize(nNumThreads);
if (hr != S_OK)
{
// Close the request queue handle
CloseHandle(m_hRequestQueue);
// Clear the queue handle
m_hRequestQueue = NULL;
// Uninitialize the critical sections
m_critSec.Term();
CloseHandle(m_hThreadEvent);
return hr;
}
return S_OK;
}
// Checks to see if the current thread is one of the thread pool
// threads.
BOOL IsThreadInPool()
{
DWORD dwThreadId = GetCurrentThreadId();
CComCritSecLock lock(m_critSec, false);
if (FAILED(lock.Lock()))
{
// out of memory
ATLASSERT( FALSE );
return FALSE;
}
for (int i = m_threadMap.GetSize() - 1; i >= 0; i--)
{
if (m_threadMap.GetKeyAt(i) == dwThreadId) return TRUE;
}
return FALSE;
}
// Shutdown the thread pool
// This function posts the shutdown request to all the threads in the pool
// It will wait for the threads to shutdown a maximum of dwMaxWait MS.
// If the timeout expires it just returns without terminating the threads.
void Shutdown( DWORD dwMaxWait = 0 ) throw()
{
if (!m_hRequestQueue) // Not initialized
return;
CComCritSecLock lock(m_critSec, false);
if (FAILED(lock.Lock()))
{
// out of memory
ATLASSERT( FALSE );
return;
}
if (dwMaxWait == 0)
dwMaxWait = m_dwMaxWait;
HRESULT hr = InternalResizePool(0, dwMaxWait);
if (hr != S_OK)
ATLTRACE(atlTraceUtil, 0, _T("Thread pool not shutting down cleanly : %08x"), hr);
// If the threads have not returned, then something is wrong
for (int i = m_threadMap.GetSize() - 1; i >= 0; i--)
{
HANDLE hThread = m_threadMap.GetValueAt(i);
DWORD dwExitCode;
GetExitCodeThread(hThread, &dwExitCode);
if (dwExitCode == STILL_ACTIVE)
{
ATLTRACE(atlTraceUtil, 0, _T("Terminating thread"));
TerminateThread(hThread, 0);
}
CloseHandle(hThread);
}
// Close the request queue handle
CloseHandle(m_hRequestQueue);
// Clear the queue handle
m_hRequestQueue = NULL;
ATLASSERT(m_threadMap.GetSize() == 0);
// Uninitialize the critical sections
lock.Unlock();
m_critSec.Term();
CloseHandle(m_hThreadEvent);
}
// IThreadPoolConfig methods
HRESULT STDMETHODCALLTYPE SetSize(int nNumThreads) throw()
{
if (nNumThreads == 0)
nNumThreads = -ATLS_DEFAULT_THREADSPERPROC;
if (nNumThreads < 0)
{
SYSTEM_INFO si;
GetSystemInfo(&si);
nNumThreads = (int) (-nNumThreads) * si.dwNumberOfProcessors;
}
return InternalResizePool(nNumThreads, m_dwMaxWait);
}
HRESULT STDMETHODCALLTYPE GetSize(int *pnNumThreads) throw()
{
if (!pnNumThreads)
return E_POINTER;
*pnNumThreads = GetNumThreads();
return S_OK;
}
HRESULT STDMETHODCALLTYPE SetTimeout(DWORD dwMaxWait) throw()
{
m_dwMaxWait = dwMaxWait;
return S_OK;
}
HRESULT STDMETHODCALLTYPE GetTimeout(DWORD *pdwMaxWait) throw()
{
if (!pdwMaxWait)
return E_POINTER;
*pdwMaxWait = m_dwMaxWait;
return S_OK;
}
// IUnknown methods
HRESULT STDMETHODCALLTYPE QueryInterface(REFIID riid, void **ppv) throw()
{
if (!ppv)
return E_POINTER;
*ppv = NULL;
if (InlineIsEqualGUID(riid, __uuidof(IUnknown)) ||
InlineIsEqualGUID(riid, __uuidof(IThreadPoolConfig)))
{
*ppv = static_cast(this);
AddRef();
return S_OK;
}
return E_NOINTERFACE;
}
ULONG STDMETHODCALLTYPE AddRef() throw()
{
return 1;
}
ULONG STDMETHODCALLTYPE Release() throw()
{
return 1;
}
HANDLE GetQueueHandle() throw()
{
return m_hRequestQueue;
}
int GetNumThreads() throw()
{
return m_threadMap.GetSize();
}
// QueueRequest adds a request to the thread pool
// it will be picked up by one of the threads and dispatched to the worker
// in WorkerThreadProc
BOOL QueueRequest(typename TWorker::RequestType request) throw()
{
if (!PostQueuedCompletionStatus(m_hRequestQueue, 0, (ULONG_PTR) request, NULL))
return FALSE;
return TRUE;
}
protected:
DWORD ThreadProc() throw()
{
DWORD dwBytesTransfered;
ULONG_PTR dwCompletionKey;
OVERLAPPED* pOverlapped;
// this block is to ensure theWorker gets destructed before the
// thread handle is closed
{
// We instantiate an instance of the worker class on the stack
// for the life time of the thread.
TWorker theWorker;
if (theWorker.Initialize(m_pvWorkerParam) == FALSE)
{
return 1;
}
#if (!defined(_ATL_NO_COM_SUPPORT) && defined(THREADPOOL_USES_COM))
#if ((_WIN32_WINNT >= 0x0400 ) || defined(_WIN32_DCOM)) && defined(_ATL_FREE_THREADED)
return SUCCEEDED(CoInitializeEx(NULL, COINIT_MULTITHREADED));
#else
return SUCCEEDED(CoInitialize(NULL));
#endif
#endif // !_ATL_NO_COM_SUPPORT && THREADPOOL_USES_COM
SetEvent(m_hThreadEvent);
// Get the request from the IO completion port
while (TRUE)
{
GetQueuedCompletionStatus(m_hRequestQueue, &dwBytesTransfered, &dwCompletionKey, &pOverlapped, INFINITE);
if (pOverlapped)
{
if (pOverlapped == ATLS_POOL_SHUTDOWN) // Shut down
{
LONG bResult = InterlockedExchange(&m_bShutdown, FALSE);
if (bResult) // Shutdown has not been cancelled
break;
// else, shutdown has been cancelled -- continue as before
}
else // Do work
{
TWorker::RequestType request = (TWorker::RequestType) dwCompletionKey;
// Process the request. Notice the following:
// (1) It is the worker's responsibility to free any memory associated
// with the request if the request is complete
// (2) If the request still requires some more processing
// the worker should queue the request again for dispatching
theWorker.Execute(request, m_pvWorkerParam, pOverlapped);
}
}
}
theWorker.Terminate(m_pvWorkerParam);
}
#if (!defined(_ATL_NO_COM_SUPPORT) && defined(THREADPOOL_USES_COM))
CoUninitialize();
#endif // !_ATL_NO_COM_SUPPORT && THREADPOOL_USES_COM
m_dwThreadEventId = GetCurrentThreadId();
SetEvent(m_hThreadEvent);
return 0;
}
static DWORD WINAPI WorkerThreadProc(LPVOID pv) throw()
{
CThreadPool2 * pThis =
reinterpret_cast< CThreadPool2 * >(pv);
return pThis->ThreadProc();
}
HRESULT InternalResizePool(int nNumThreads, int dwMaxWait) throw()
{
if (!m_hRequestQueue) // Not initialized
return E_FAIL;
CComCritSecLock lock(m_critSec, false);
if (FAILED(lock.Lock()))
{
// out of memory
ATLASSERT( FALSE );
return E_FAIL;
}
int nCurThreads = m_threadMap.GetSize();
if (nNumThreads <= nCurThreads)
{
int nNumShutdownThreads = nCurThreads - nNumThreads;
for (int nThreadIndex = 0; nThreadIndex < nNumShutdownThreads; nThreadIndex++)
{
ResetEvent(m_hThreadEvent);
InterlockedExchange(&m_bShutdown, TRUE);
PostQueuedCompletionStatus(m_hRequestQueue, 0, 0, ATLS_POOL_SHUTDOWN);
DWORD dwRet = WaitForSingleObject(m_hThreadEvent, dwMaxWait);
if (dwRet == WAIT_TIMEOUT)
{
LONG bResult = InterlockedExchange(&m_bShutdown, FALSE);
if (bResult) // Nobody picked up the shutdown message
{
return AtlHresultFromWin32(WAIT_TIMEOUT);
}
}
else if (dwRet != WAIT_OBJECT_0)
{
return AtlHresultFromLastError();
}
int nIndex = m_threadMap.FindKey(m_dwThreadEventId);
if (nIndex != -1)
{
HANDLE hThread = m_threadMap.GetValueAt(nIndex);
// Wait for the thread to shutdown
if (WaitForSingleObject(hThread, 60000) == WAIT_OBJECT_0)
{
CloseHandle(hThread);
m_threadMap.RemoveAt(nIndex);
}
else
{
// Thread failed to terminate
return E_FAIL;
}
}
}
}
else
{
int nNumNewThreads = nNumThreads - nCurThreads;
// Create and initialize worker threads
for (int nThreadIndex = 0; nThreadIndex < nNumNewThreads; nThreadIndex++)
{
DWORD dwThreadID;
ResetEvent(m_hThreadEvent);
CHandle hdlThread( ThreadTraits::CreateThread(NULL, m_dwStackSize, WorkerThreadProc, (LPVOID)this, 0, &dwThreadID) );
if (!hdlThread)
{
HRESULT hr = AtlHresultFromLastError();
ATLASSERT(hr != S_OK);
return hr;
}
DWORD dwRet = WaitForSingleObject(m_hThreadEvent, dwMaxWait);
if (dwRet != WAIT_OBJECT_0)
{
if (dwRet == WAIT_TIMEOUT)
{
return HRESULT_FROM_WIN32(WAIT_TIMEOUT);
}
else
{
return AtlHresultFromLastError();
}
}
if (m_threadMap.Add(dwThreadID, hdlThread) != FALSE)
{
hdlThread.Detach();
}
}
}
return S_OK;
}
}; // class CThreadPool2