diff --git a/CMakeLists.txt b/CMakeLists.txt index a1a6f4f4d..ba5f665a2 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -205,6 +205,7 @@ set(PYBIND11_HEADERS include/pybind11/conduit/pybind11_conduit_v1.h include/pybind11/conduit/pybind11_platform_abi_id.h include/pybind11/conduit/wrap_include_python_h.h + include/pybind11/critical_section.h include/pybind11/options.h include/pybind11/eigen.h include/pybind11/eigen/common.h diff --git a/docs/advanced/misc.rst b/docs/advanced/misc.rst index b8cb1923e..42cfc69f8 100644 --- a/docs/advanced/misc.rst +++ b/docs/advanced/misc.rst @@ -295,22 +295,29 @@ This module is sub-interpreter safe, for both ``shared_gil`` ("legacy") and function concurrently from different threads. This is safe because each sub-interpreter's GIL protects it's own Python objects from concurrent access. -However, the module is no longer free-threading safe, for the same reason as before, because the -calculation is not synchronized. We can synchronize it using a Python critical section. +However, the module is no longer free-threading safe, for the same reason as +before, because the calculation is not synchronized. We can synchronize it +using a Python critical section. This will do nothing if not in free-threaded +Python. You can have it lock one or two Python objects. You cannot nest it. +(Note: In Python 3.13t, Python re-locks if you enter a critical section again, +which happens in various places. This was optimized away in 3.14+. Use a +``std::mutex`` instead if this is a problem). .. code-block:: cpp - :emphasize-lines: 1,5,10 + :emphasize-lines: 1,4,8 + + #include + // ... PYBIND11_MODULE(example, m, py::multiple_interpreters::per_interpreter_gil(), py::mod_gil_not_used()) { m.def("calc_next", []() { size_t old; py::dict g = py::globals(); - Py_BEGIN_CRITICAL_SECTION(g); + py::scoped_critical_section guard(g); if (!g.contains("myseed")) g["myseed"] = 0; old = g["myseed"]; g["myseed"] = (old + 1) * 10; - Py_END_CRITICAL_SECTION(); return old; }); } diff --git a/include/pybind11/critical_section.h b/include/pybind11/critical_section.h new file mode 100644 index 000000000..e94ca765c --- /dev/null +++ b/include/pybind11/critical_section.h @@ -0,0 +1,50 @@ +// Copyright (c) 2016-2025 The Pybind Development Team. +// All rights reserved. Use of this source code is governed by a +// BSD-style license that can be found in the LICENSE file. + +#pragma once + +#include "pytypes.h" + +PYBIND11_NAMESPACE_BEGIN(PYBIND11_NAMESPACE) + +/// This does not do anything if there's a GIL. On free-threaded Python, +/// it locks an object. This uses the CPython API, which has limits +class scoped_critical_section { +public: +#ifdef Py_GIL_DISABLED + explicit scoped_critical_section(handle obj) : has2(false) { + PyCriticalSection_Begin(§ion, obj.ptr()); + } + + scoped_critical_section(handle obj1, handle obj2) : has2(true) { + PyCriticalSection2_Begin(§ion2, obj1.ptr(), obj2.ptr()); + } + + ~scoped_critical_section() { + if (has2) { + PyCriticalSection2_End(§ion2); + } else { + PyCriticalSection_End(§ion); + } + } +#else + explicit scoped_critical_section(handle) {}; + scoped_critical_section(handle, handle) {}; + ~scoped_critical_section() = default; +#endif + + scoped_critical_section(const scoped_critical_section &) = delete; + scoped_critical_section &operator=(const scoped_critical_section &) = delete; + +private: +#ifdef Py_GIL_DISABLED + bool has2; + union { + PyCriticalSection section; + PyCriticalSection2 section2; + }; +#endif +}; + +PYBIND11_NAMESPACE_END(PYBIND11_NAMESPACE) diff --git a/tests/extra_python_package/test_files.py b/tests/extra_python_package/test_files.py index 154c31bc5..63e59f65a 100644 --- a/tests/extra_python_package/test_files.py +++ b/tests/extra_python_package/test_files.py @@ -44,6 +44,7 @@ main_headers = { "include/pybind11/chrono.h", "include/pybind11/common.h", "include/pybind11/complex.h", + "include/pybind11/critical_section.h", "include/pybind11/eigen.h", "include/pybind11/embed.h", "include/pybind11/eval.h", diff --git a/tests/test_embed/test_interpreter.cpp b/tests/test_embed/test_interpreter.cpp index e555c0d70..3654708d2 100644 --- a/tests/test_embed/test_interpreter.cpp +++ b/tests/test_embed/test_interpreter.cpp @@ -1,3 +1,4 @@ +#include #include // Silence MSVC C++17 deprecation warning from Catch regarding std::uncaught_exceptions (up to @@ -365,15 +366,11 @@ TEST_CASE("Threads") { #ifdef Py_GIL_DISABLED # if PY_VERSION_HEX < 0x030E0000 std::lock_guard lock(mutex); - locals["count"] = locals["count"].cast() + 1; # else - Py_BEGIN_CRITICAL_SECTION(locals.ptr()); - locals["count"] = locals["count"].cast() + 1; - Py_END_CRITICAL_SECTION(); + py::scoped_critical_section lock(locals); # endif -#else - locals["count"] = locals["count"].cast() + 1; #endif + locals["count"] = locals["count"].cast() + 1; }); }