From 4e3d9fea74ed50a042d98f68fa35a3133482289b Mon Sep 17 00:00:00 2001 From: Eric Cousineau Date: Fri, 22 May 2020 00:43:01 -0400 Subject: [PATCH] operators: Explicitly expose `py::hash(py::self)` Add warnings about extending STL --- include/pybind11/operators.h | 5 +++++ tests/test_operator_overloading.cpp | 25 ++++++++++++++++++++++++- tests/test_operator_overloading.py | 14 ++++++++++++-- 3 files changed, 41 insertions(+), 3 deletions(-) diff --git a/include/pybind11/operators.h b/include/pybind11/operators.h index b3dd62c3b..293d5abd2 100644 --- a/include/pybind11/operators.h +++ b/include/pybind11/operators.h @@ -147,6 +147,9 @@ PYBIND11_INPLACE_OPERATOR(ixor, operator^=, l ^= r) PYBIND11_INPLACE_OPERATOR(ior, operator|=, l |= r) PYBIND11_UNARY_OPERATOR(neg, operator-, -l) PYBIND11_UNARY_OPERATOR(pos, operator+, +l) +// WARNING: This usage of `abs` should only be done for existing STL overloads. +// Adding overloads directly in to the `std::` namespace is advised against: +// https://en.cppreference.com/w/cpp/language/extending_std PYBIND11_UNARY_OPERATOR(abs, abs, std::abs(l)) PYBIND11_UNARY_OPERATOR(hash, hash, std::hash()(l)) PYBIND11_UNARY_OPERATOR(invert, operator~, (~l)) @@ -160,6 +163,8 @@ PYBIND11_UNARY_OPERATOR(float, float_, (double) l) NAMESPACE_END(detail) using detail::self; +// Add named operators so that they are accessible via `py::`. +using detail::hash; NAMESPACE_END(PYBIND11_NAMESPACE) diff --git a/tests/test_operator_overloading.cpp b/tests/test_operator_overloading.cpp index 7b111704b..52fcd3383 100644 --- a/tests/test_operator_overloading.cpp +++ b/tests/test_operator_overloading.cpp @@ -43,6 +43,13 @@ public: friend Vector2 operator-(float f, const Vector2 &v) { return Vector2(f - v.x, f - v.y); } friend Vector2 operator*(float f, const Vector2 &v) { return Vector2(f * v.x, f * v.y); } friend Vector2 operator/(float f, const Vector2 &v) { return Vector2(f / v.x, f / v.y); } + + bool operator==(const Vector2 &v) const { + return x == v.x && y == v.y; + } + bool operator!=(const Vector2 &v) const { + return x != v.x || y != v.y; + } private: float x, y; }; @@ -55,6 +62,11 @@ int operator+(const C2 &, const C2 &) { return 22; } int operator+(const C2 &, const C1 &) { return 21; } int operator+(const C1 &, const C2 &) { return 12; } +// Note: Specializing explicit within `namespace std { ... }` is done due to a +// bug in GCC<7. If you are supporting compilers later than this, consider +// specializing `using template<> struct std::hash<...>` in the global +// namespace instead, per this recommendation: +// https://en.cppreference.com/w/cpp/language/extending_std#Adding_template_specializations namespace std { template<> struct hash { @@ -63,6 +75,11 @@ namespace std { }; } +// Not a good abs function, but easy to test. +std::string abs(const Vector2&) { + return "abs(Vector2)"; +} + // MSVC warns about unknown pragmas, and warnings are errors. #ifndef _MSC_VER #pragma GCC diagnostic push @@ -107,7 +124,13 @@ TEST_SUBMODULE(operators, m) { .def(float() / py::self) .def(-py::self) .def("__str__", &Vector2::toString) - .def(hash(py::self)) + .def("__repr__", &Vector2::toString) + .def(py::self == py::self) + .def(py::self != py::self) + .def(py::hash(py::self)) + // N.B. See warning about usage of `py::detail::abs(py::self)` in + // `operators.h`. + .def("__abs__", [](const Vector2& v) { return abs(v); }) ; m.attr("Vector") = m.attr("Vector2"); diff --git a/tests/test_operator_overloading.py b/tests/test_operator_overloading.py index f283f5b3a..1cee29889 100644 --- a/tests/test_operator_overloading.py +++ b/tests/test_operator_overloading.py @@ -6,6 +6,9 @@ from pybind11_tests import ConstructorStats def test_operator_overloading(): v1 = m.Vector2(1, 2) v2 = m.Vector(3, -1) + v3 = m.Vector2(1, 2) # Same value as v1, but different instance. + assert v1 is not v3 + assert str(v1) == "[1.000000, 2.000000]" assert str(v2) == "[3.000000, -1.000000]" @@ -24,7 +27,11 @@ def test_operator_overloading(): assert str(v1 * v2) == "[3.000000, -2.000000]" assert str(v2 / v1) == "[3.000000, -0.500000]" + assert v1 == v3 + assert v1 != v2 assert hash(v1) == 4 + # TODO(eric.cousineau): Make this work. + # assert abs(v1) == "abs(Vector2)" v1 += 2 * v2 assert str(v1) == "[7.000000, 0.000000]" @@ -40,14 +47,17 @@ def test_operator_overloading(): assert str(v2) == "[2.000000, 8.000000]" cstats = ConstructorStats.get(m.Vector2) - assert cstats.alive() == 2 + assert cstats.alive() == 3 del v1 - assert cstats.alive() == 1 + assert cstats.alive() == 2 del v2 + assert cstats.alive() == 1 + del v3 assert cstats.alive() == 0 assert cstats.values() == [ '[1.000000, 2.000000]', '[3.000000, -1.000000]', + '[1.000000, 2.000000]', '[-3.000000, 1.000000]', '[4.000000, 1.000000]', '[-2.000000, 3.000000]',