From af37037cf1dfcbfff8f86b8acabf021a0263f9f0 Mon Sep 17 00:00:00 2001 From: tzachid Date: Mon, 21 Nov 2005 12:23:11 +0000 Subject: [PATCH] Make sure that the driver only exists when all it's threads are dead. (Rev 334) git-svn-id: svn://openib.tc.cornell.edu/gen1@174 ad392aa1-c5ef-ae45-8dd8-e69d62a5ef86 --- trunk/ulp/sdp/kernel/Precompile.h | 2 +- trunk/ulp/sdp/kernel/SdpDriver.cpp | 95 +++++++++++++++++++++++++-- trunk/ulp/sdp/kernel/SdpDriver.h | 24 ++++++- trunk/ulp/sdp/kernel/SdpGenUtils.cpp | 5 ++ trunk/ulp/sdp/kernel/SdpGenUtils.h | 1 + trunk/ulp/sdp/kernel/SdpSocket.cpp | 98 ++++++++++++++++++---------- trunk/ulp/sdp/kernel/SdpSocket.h | 2 + trunk/ulp/sdp/todo | 2 - 8 files changed, 185 insertions(+), 44 deletions(-) diff --git a/trunk/ulp/sdp/kernel/Precompile.h b/trunk/ulp/sdp/kernel/Precompile.h index ffcdb34f..55ac814a 100644 --- a/trunk/ulp/sdp/kernel/Precompile.h +++ b/trunk/ulp/sdp/kernel/Precompile.h @@ -20,10 +20,10 @@ class SdpArp; #include "SdpTrace.h" #include "sdpLock.h" #include "RefCount.h" +#include "SdpBufferPool.h" #include "sdpdriver.h" #include "SdpShared.h" #include "SdpUserFile.h" -#include "SdpBufferPool.h" #include "SdpRecvPool.h" #include "SdpConnectionList.h" #include "SdpSocket.h" diff --git a/trunk/ulp/sdp/kernel/SdpDriver.cpp b/trunk/ulp/sdp/kernel/SdpDriver.cpp index 40e87840..cff0f6d5 100644 --- a/trunk/ulp/sdp/kernel/SdpDriver.cpp +++ b/trunk/ulp/sdp/kernel/SdpDriver.cpp @@ -12,6 +12,8 @@ VOID DriverUnload ( SDP_PRINT(SDP_TRACE, SDP_DRIVER, ("called pDriverObject = 0x%x\n", pDriverObject )); ib_api_status_t ib_status; + g_pSdpDriver->WaitForAllThreadsToDie(); + ib_status = ib_close_al(g_pSdpDriver->m_al_handle); g_pSdpDriver->m_al_handle = NULL; @@ -32,7 +34,7 @@ extern "C" NTSTATUS DriverEntry ( IN PDRIVER_OBJECT pDriverObject, IN PUNICODE_STRING pRegistryPath ) { - NTSTATUS rc; + NTSTATUS rc = STATUS_SUCCESS; ib_api_status_t ib_status; PDEVICE_OBJECT pDevObj; SdpDriver *pSdpDriver; @@ -74,7 +76,7 @@ extern "C" NTSTATUS DriverEntry ( } DeviceCreated = true; - pSdpDriver = (SdpDriver *) pDevObj->DeviceExtension; + pSdpDriver = new (pDevObj->DeviceExtension) SdpDriver; rc = pSdpDriver->Init(pDevObj); if (!NT_SUCCESS(rc)) { @@ -302,7 +304,8 @@ if ((InputBufferLength < sizeof (InStruct)) || goto Cleanup; \ } -NTSTATUS SdpDriver::Init(PDEVICE_OBJECT pDevObj) +NTSTATUS +SdpDriver::Init(PDEVICE_OBJECT pDevObj) { NTSTATUS rc = STATUS_SUCCESS; m_pDevObj = pDevObj; @@ -317,12 +320,21 @@ NTSTATUS SdpDriver::Init(PDEVICE_OBJECT pDevObj) SDP_PRINT(SDP_ERR, SDP_DRIVER, ("m_pSdpArp->Init failed rc = 0x%x\n", rc )); goto Cleanup; } -Cleanup: + + ExInitializeFastMutex(&m_ThreadsMutex); + +Cleanup: + if (!NT_SUCCESS(rc)) { + if (m_pSdpArp) { + delete m_pSdpArp; + } + } return rc; } -NTSTATUS SdpDriver::DispatchDeviceIoControl( +NTSTATUS +SdpDriver::DispatchDeviceIoControl( IN PFILE_OBJECT pDeviceObject, IN PIRP pIrp, IN PIO_STACK_LOCATION pIrpSp, @@ -578,4 +590,77 @@ Cleanup: return rc; } +VOID +SdpDriver::AddThread(ThreadHandle *pThreadHandle) +{ + SDP_PRINT(SDP_TRACE, SDP_DRIVER, ("this = 0x%x\n", this )); + // Check if there is any next thread that can be removed from the queue + LARGE_INTEGER WaitTime; + WaitTime.QuadPart = 0; // Don't wait for them to die + + ExAcquireFastMutex(&m_ThreadsMutex); + + WaitForThreadsToDie(&WaitTime); + + // Add me to the list of threads that should be removed + m_ShutDownThreads.InsertTailList(&pThreadHandle->m_List); + ExReleaseFastMutex(&m_ThreadsMutex); + +} + +VOID +SdpDriver::WaitForAllThreadsToDie() +{ + SDP_PRINT(SDP_TRACE, SDP_DRIVER, ("this = 0x%x\n", this )); + + ExAcquireFastMutex(&m_ThreadsMutex); + // Timeout of null will cause a wait forever + WaitForThreadsToDie(NULL); + ExReleaseFastMutex(&m_ThreadsMutex); +} + + +// This function has to be called with the mutex held +VOID +SdpDriver::WaitForThreadsToDie(LARGE_INTEGER *pWaitTime) +{ + SDP_PRINT(SDP_TRACE, SDP_DRIVER, ("this = 0x%x\n", this )); + // Check if there is any next thread that can be removed from the queue + NTSTATUS rc = STATUS_SUCCESS; + + LIST_ENTRY *pNextItem; + ThreadHandle *pNextThreadHandle; + while (m_ShutDownThreads.Size() > 0) { + pNextItem = m_ShutDownThreads.Head(); + pNextThreadHandle = CONTAINING_RECORD(pNextItem, ThreadHandle, m_List); + + rc = MyKeWaitForSingleObject( + pNextThreadHandle->ThreadObject, + Executive, + KernelMode, + FALSE, + pWaitTime + ); + ASSERT((rc == STATUS_SUCCESS) || + (rc == STATUS_TIMEOUT)); + + if (rc == STATUS_TIMEOUT) { + // Nothing that we should do, the thread is not ready yet + SDP_PRINT(SDP_TRACE, SDP_DRIVER, ("this = 0x%x Former thread is not dead yet\n", this )); + break; + } + // SUCESS means that the thread is dead, we can remove it + // from the list + SDP_PRINT(SDP_TRACE, SDP_DRIVER, ("this = 0x%x Former thread is already dead\n", this )); + + m_ShutDownThreads.RemoveHeadList(); + ObDereferenceObject(pNextThreadHandle->ThreadObject); + delete pNextThreadHandle; + + // We now continue and try to remove the next object + + } + +} + diff --git a/trunk/ulp/sdp/kernel/SdpDriver.h b/trunk/ulp/sdp/kernel/SdpDriver.h index f509fabf..057e2921 100644 --- a/trunk/ulp/sdp/kernel/SdpDriver.h +++ b/trunk/ulp/sdp/kernel/SdpDriver.h @@ -3,6 +3,16 @@ #ifndef H_SDP_DRIVER_H #define H_SDP_DRIVER_H +// This struct is being used to hold an object that we can wait on +// for threads to die. + +struct ThreadHandle { + // As this object has a simple life cycle I don't use refferance counting + // for it. This might have to change. + PVOID ThreadObject; + LIST_ENTRY m_List; +}; + class SdpDriver { public: @@ -35,8 +45,12 @@ public: IN ULONG IoControlCode, OUT ULONG &OutputDataSize ); - - + + // The following functions are being used so that the driver + // will wait for all the created threads to end + VOID AddThread(ThreadHandle *pThreadHandle); + + VOID WaitForAllThreadsToDie(); public: ib_al_handle_t m_al_handle ; @@ -44,8 +58,14 @@ public: private: + VOID WaitForThreadsToDie(LARGE_INTEGER *pWWaitTime); + + PDEVICE_OBJECT m_pDevObj; + LinkedList m_ShutDownThreads; + + FAST_MUTEX m_ThreadsMutex; }; diff --git a/trunk/ulp/sdp/kernel/SdpGenUtils.cpp b/trunk/ulp/sdp/kernel/SdpGenUtils.cpp index 34ccce21..399bcff5 100644 --- a/trunk/ulp/sdp/kernel/SdpGenUtils.cpp +++ b/trunk/ulp/sdp/kernel/SdpGenUtils.cpp @@ -207,3 +207,8 @@ void __cdecl operator delete(void* p) { ExFreePoolWithTag(p, GLOBAL_ALLOCATION_TAG); } +void* __cdecl operator new(size_t n, void *addr ) throw() { + return addr; +} + + diff --git a/trunk/ulp/sdp/kernel/SdpGenUtils.h b/trunk/ulp/sdp/kernel/SdpGenUtils.h index ebd95ff8..458f2895 100644 --- a/trunk/ulp/sdp/kernel/SdpGenUtils.h +++ b/trunk/ulp/sdp/kernel/SdpGenUtils.h @@ -76,6 +76,7 @@ LARGE_INTEGER TimeFromLong(ULONG HandredNanos); NTSTATUS Sleep(ULONG HandredNanos); +void* __cdecl operator new(size_t n, void *addr ) throw(); /* Convert an IBAL error to a Winsock error. */ int IbalToWsaError(const ib_api_status_t ib_status ); diff --git a/trunk/ulp/sdp/kernel/SdpSocket.cpp b/trunk/ulp/sdp/kernel/SdpSocket.cpp index 6bdccb31..d8f31d0d 100644 --- a/trunk/ulp/sdp/kernel/SdpSocket.cpp +++ b/trunk/ulp/sdp/kernel/SdpSocket.cpp @@ -47,10 +47,8 @@ cm_apr_callback( static void AL_API cm_dreq_callback(IN ib_cm_dreq_rec_t *p_cm_dreq_rec ) { - SDP_PRINT(SDP_TRACE, SDP_SOCKET, ("dispatch level = %d\n", KeGetCurrentIrql())); SdpSocket *pSocket = (SdpSocket *) p_cm_dreq_rec->qp_context; pSocket->CmDreqCallback(p_cm_dreq_rec); - } static void AL_API @@ -154,6 +152,16 @@ NTSTATUS SdpSocket::Init( m_ConnectionList.Init(this); + // We now allocate the needed structure for the close socket, so that + // we won't be in trouble after the thread was created + m_pCloseSocketThread = new ThreadHandle; + if (m_pCloseSocketThread == NULL) { + SDP_PRINT(SDP_ERR, SDP_SOCKET, ("Failed to allocate new SocketThread this = 0x%p \n",this)); + rc = STATUS_NO_MEMORY; + goto Cleanup; + } + +Cleanup: return rc; } @@ -886,8 +894,8 @@ SdpSocket::WSPAccept( } // I want to copy this data before releasing the lock - ULONG IP = m_DstIp; - USHORT Port = m_DstPort; + ULONG IP = pNewSocket->m_DstIp; + USHORT Port = pNewSocket->m_DstPort; ASSERT(Locked == true); rc = m_Lock.Unlock(); @@ -962,7 +970,7 @@ SdpSocket::WSPCloseSocket( ) { NTSTATUS rc = STATUS_SUCCESS; - SDP_PRINT(SDP_TRACE, SDP_SOCKET, ("this = 0x%p\n",this)); + SDP_PRINT(SDP_TRACE, SDP_SOCKET, ("this = 0x%p state = %s \n",this, SS2String(m_state))); OBJECT_ATTRIBUTES attr; if (!m_Lock.Lock()) { @@ -999,39 +1007,56 @@ SdpSocket::WSPCloseSocket( m_Lock.Unlock(); // Error ignored as this is already an error pass goto Cleanup; } - } - // We will now create a thread that will be resposible for the - // destruction of this socket - AddRef(); + // We will now create a thread that will be resposible for the + // destruction of this socket + AddRef(); - /* Create a new thread, storing both the handle and thread id. */ - InitializeObjectAttributes( &attr, NULL, OBJ_KERNEL_HANDLE, NULL, NULL ); - - HANDLE ThreadHandle; - rc = PsCreateSystemThread( - &ThreadHandle, - THREAD_ALL_ACCESS, - &attr, - NULL, - NULL, - ::CloseSocketThread, - this - ); + /* Create a new thread, storing both the handle and thread id. */ + InitializeObjectAttributes( &attr, NULL, OBJ_KERNEL_HANDLE, NULL, NULL ); + + HANDLE ThreadHandle; + rc = PsCreateSystemThread( + &ThreadHandle, + THREAD_ALL_ACCESS, + &attr, + NULL, + NULL, + ::CloseSocketThread, + this + ); + + if (!NT_SUCCESS(rc)) { + SDP_PRINT(SDP_ERR, SDP_SOCKET, ("PsCreateSystemThread failed rc = 0x%x\n", rc )); + m_Lock.Unlock(); // Error ignored as this is already an error pass + // The thread wasn't created so we should remove the refferance + Release(); + goto Cleanup; + } + + ASSERT(m_pCloseSocketThread != NULL); + // Convert the thread into a handle + rc = ObReferenceObjectByHandle( + ThreadHandle, + THREAD_ALL_ACCESS, + NULL, + KernelMode, + &m_pCloseSocketThread->ThreadObject, + NULL + ); + ASSERT(rc == STATUS_SUCCESS); // According to MSDN, if I set the params + // correctly I shouldn't get an error + + rc = ZwClose(ThreadHandle); + ASSERT(NT_SUCCESS(rc)); // Should always succeed + + g_pSdpDriver->AddThread(m_pCloseSocketThread); + m_pCloseSocketThread = NULL; // Will be delated when the callback thread is deleted - if (!NT_SUCCESS(rc)) { - SDP_PRINT(SDP_ERR, SDP_SOCKET, ("PsCreateSystemThread failed rc = 0x%x\n", rc )); - m_Lock.Unlock(); // Error ignored as this is already an error pass - // The thread wasn't created so we should remove the refferance - Release(); - goto Cleanup; } - // BUGBUG: Replace this with a mechanism that will allow - // to close the thered when the driver goes down - rc = ZwClose(ThreadHandle); - ASSERT(NT_SUCCESS(rc)); // Should succeed + rc = m_Lock.Unlock(); if (!NT_SUCCESS(rc)) { @@ -1615,7 +1640,7 @@ ErrorLocked: VOID SdpSocket::CmDreqCallback(IN ib_cm_dreq_rec_t *p_cm_dreq_rec) { - SDP_PRINT(SDP_TRACE, SDP_SOCKET, ("this = 0x%p\n", this)); + SDP_PRINT(SDP_TRACE, SDP_SOCKET, ("this = 0x%p, dispatch level = %d\n", this, KeGetCurrentIrql())); ASSERT(KeGetCurrentIrql() == PASSIVE_LEVEL); NTSTATUS rc = STATUS_SUCCESS; ib_cm_drep_t cm_drep; @@ -1646,7 +1671,7 @@ SdpSocket::CmDreqCallback(IN ib_cm_dreq_rec_t *p_cm_dreq_rec) goto ErrorLocked; } - // last step is to change our state + // last step is to change our state (this will affect close socket for example) m_state = SS_CONNECTED_DREP_SENT; // We should close the connection know ??????????/ @@ -2331,6 +2356,11 @@ VOID SdpSocket::Shutdown() m_pListeningSocket->Release(); m_pListeningSocket = NULL; } + + if (m_pCloseSocketThread != NULL) { + delete m_pCloseSocketThread; + m_pCloseSocketThread = NULL; + } // Now that all ibal operations have finished we can free the memory diff --git a/trunk/ulp/sdp/kernel/SdpSocket.h b/trunk/ulp/sdp/kernel/SdpSocket.h index 653c9153..b70706c0 100644 --- a/trunk/ulp/sdp/kernel/SdpSocket.h +++ b/trunk/ulp/sdp/kernel/SdpSocket.h @@ -119,6 +119,8 @@ private: bool m_ShutdownCalled; bool m_DisconnectConnectionRecieved; + ThreadHandle* m_pCloseSocketThread; + VOID SignalShutdown(); diff --git a/trunk/ulp/sdp/todo b/trunk/ulp/sdp/todo index 07c1f1b6..0f59b5d5 100644 --- a/trunk/ulp/sdp/todo +++ b/trunk/ulp/sdp/todo @@ -33,6 +33,4 @@ USER MODE: * check the way that errors are reported to the user mode. It seems that returning an error in rc means that the output buffer won't pass out. -* make sure that the "terminator thread" is being killed before we exit. - * Check why sometimes the QP and so are not valid when you come to kill them -- 2.41.0