1313#include "pycore_pystate.h" // _PyThreadState_GET()
1414#include "pycore_runtime_init.h" // _PyRuntimeState_INIT
1515#include "pycore_sysmodule.h"
16+ #include "pyatomic.h"
1617
1718/* --------------------------------------------------------------------------
1819CAUTION
@@ -240,6 +241,30 @@ _PyRuntimeState_ReInitThreads(_PyRuntimeState *runtime)
240241static void _PyGILState_NoteThreadState (
241242 struct _gilstate_runtime_state * gilstate , PyThreadState * tstate );
242243
244+ int
245+ _PyThreadState_GetStatus (PyThreadState * tstate )
246+ {
247+ return _Py_atomic_load_int_relaxed (& tstate -> status );
248+ }
249+
250+ static int
251+ _PyThreadState_Attach (PyThreadState * tstate )
252+ {
253+ if (_Py_atomic_compare_exchange_int (
254+ & tstate -> status ,
255+ _Py_THREAD_DETACHED ,
256+ _Py_THREAD_ATTACHED )) {
257+ return 1 ;
258+ }
259+ return 0 ;
260+ }
261+
262+ static void
263+ _PyThreadState_Detach (PyThreadState * tstate )
264+ {
265+ _Py_atomic_store_int (& tstate -> status , _Py_THREAD_DETACHED );
266+ }
267+
243268PyStatus
244269_PyInterpreterState_Enable (_PyRuntimeState * runtime )
245270{
@@ -517,13 +542,14 @@ PyInterpreterState_Delete(PyInterpreterState *interp)
517542{
518543 _PyRuntimeState * runtime = interp -> runtime ;
519544 struct pyinterpreters * interpreters = & runtime -> interpreters ;
520- zapthreads (interp , 0 );
521-
522- _PyEval_FiniState (& interp -> ceval );
523545
524546 /* Delete current thread. After this, many C API calls become crashy. */
525547 _PyThreadState_Swap (& runtime -> gilstate , NULL );
526548
549+ zapthreads (interp , 0 );
550+
551+ _PyEval_FiniState (& interp -> ceval );
552+
527553 HEAD_LOCK (runtime );
528554 PyInterpreterState * * p ;
529555 for (p = & interpreters -> head ; ; p = & (* p )-> next ) {
@@ -910,6 +936,7 @@ _PyThreadState_Init(PyThreadState *tstate)
910936void
911937_PyThreadState_SetCurrent (PyThreadState * tstate )
912938{
939+ tstate -> fast_thread_id = _Py_ThreadId ();
913940 _PyGILState_NoteThreadState (& tstate -> interp -> runtime -> gilstate , tstate );
914941}
915942
@@ -1094,15 +1121,25 @@ PyThreadState_Clear(PyThreadState *tstate)
10941121/* Common code for PyThreadState_Delete() and PyThreadState_DeleteCurrent() */
10951122static void
10961123tstate_delete_common (PyThreadState * tstate ,
1097- struct _gilstate_runtime_state * gilstate )
1124+ struct _gilstate_runtime_state * gilstate ,
1125+ int is_current )
10981126{
1127+ assert (is_current ? tstate -> status == _Py_THREAD_ATTACHED
1128+ : tstate -> status != _Py_THREAD_ATTACHED );
1129+
10991130 _Py_EnsureTstateNotNULL (tstate );
11001131 PyInterpreterState * interp = tstate -> interp ;
11011132 if (interp == NULL ) {
11021133 Py_FatalError ("NULL interpreter" );
11031134 }
1104- _PyRuntimeState * runtime = interp -> runtime ;
11051135
1136+ if (gilstate -> autoInterpreterState &&
1137+ PyThread_tss_get (& gilstate -> autoTSSkey ) == tstate )
1138+ {
1139+ PyThread_tss_set (& gilstate -> autoTSSkey , NULL );
1140+ }
1141+
1142+ _PyRuntimeState * runtime = interp -> runtime ;
11061143 HEAD_LOCK (runtime );
11071144 if (tstate -> prev ) {
11081145 tstate -> prev -> next = tstate -> next ;
@@ -1115,10 +1152,8 @@ tstate_delete_common(PyThreadState *tstate,
11151152 }
11161153 HEAD_UNLOCK (runtime );
11171154
1118- if (gilstate -> autoInterpreterState &&
1119- PyThread_tss_get (& gilstate -> autoTSSkey ) == tstate )
1120- {
1121- PyThread_tss_set (& gilstate -> autoTSSkey , NULL );
1155+ if (is_current ) {
1156+ _PyThreadState_SET (NULL );
11221157 }
11231158 _PyStackChunk * chunk = tstate -> datastack_chunk ;
11241159 tstate -> datastack_chunk = NULL ;
@@ -1138,7 +1173,7 @@ _PyThreadState_Delete(PyThreadState *tstate, int check_current)
11381173 _Py_FatalErrorFormat (__func__ , "tstate %p is still current" , tstate );
11391174 }
11401175 }
1141- tstate_delete_common (tstate , gilstate );
1176+ tstate_delete_common (tstate , gilstate , 0 );
11421177 free_threadstate (tstate );
11431178}
11441179
@@ -1155,7 +1190,7 @@ _PyThreadState_DeleteCurrent(PyThreadState *tstate)
11551190{
11561191 _Py_EnsureTstateNotNULL (tstate );
11571192 struct _gilstate_runtime_state * gilstate = & tstate -> interp -> runtime -> gilstate ;
1158- tstate_delete_common (tstate , gilstate );
1193+ tstate_delete_common (tstate , gilstate , 1 );
11591194 _PyRuntimeGILState_SetThreadState (gilstate , NULL );
11601195 _PyEval_ReleaseLock (tstate );
11611196 free_threadstate (tstate );
@@ -1230,9 +1265,36 @@ PyThreadState_Get(void)
12301265PyThreadState *
12311266_PyThreadState_Swap (struct _gilstate_runtime_state * gilstate , PyThreadState * newts )
12321267{
1233- PyThreadState * oldts = _PyRuntimeGILState_GetThreadState (gilstate );
1268+ PyThreadState * oldts = _Py_current_tstate ;
1269+
1270+ #if defined(Py_DEBUG )
1271+ // The new thread-state should correspond to the current native thread
1272+ // XXX: breaks subinterpreter tests
1273+ if (newts && newts -> fast_thread_id != _Py_ThreadId ()) {
1274+ Py_FatalError ("Invalid thread state for this thread" );
1275+ }
1276+ #endif
1277+
1278+ if (oldts != NULL ) {
1279+ int status = _Py_atomic_load_int (& oldts -> status );
1280+ assert (status == _Py_THREAD_ATTACHED || status == _Py_THREAD_GC );
1281+
1282+ if (status == _Py_THREAD_ATTACHED ) {
1283+ _PyThreadState_Detach (oldts );
1284+ }
1285+ }
1286+
1287+ _Py_current_tstate = newts ;
1288+
1289+ if (newts ) {
1290+ int attached = _PyThreadState_Attach (newts );
1291+ if (!attached ) {
1292+ // _PyThreadState_GC_Park(newts);
1293+ }
1294+
1295+ assert (_Py_atomic_load_int (& newts -> status ) == _Py_THREAD_ATTACHED );
1296+ }
12341297
1235- _PyRuntimeGILState_SetThreadState (gilstate , newts );
12361298 /* It should not be possible for more than one thread state
12371299 to be used for a thread. Check this the best we can in debug
12381300 builds.
@@ -1243,8 +1305,7 @@ _PyThreadState_Swap(struct _gilstate_runtime_state *gilstate, PyThreadState *new
12431305 to it, we need to ensure errno doesn't change.
12441306 */
12451307 int err = errno ;
1246- PyThreadState * check = _PyGILState_GetThisThreadState (gilstate );
1247- if (check && check -> interp == newts -> interp && check != newts )
1308+ if (oldts && oldts -> interp == newts -> interp && oldts != newts )
12481309 Py_FatalError ("Invalid thread state for this thread" );
12491310 errno = err ;
12501311 }
0 commit comments