From acad05cb13a64ec55cfa676bb0b58b2fc1af5542 Mon Sep 17 00:00:00 2001 From: Jason Rhinelander Date: Sat, 20 May 2017 20:19:26 -0400 Subject: [PATCH] Fix /= operator under Python 3 The Python method for /= was set as `__idiv__`, which should be `__itruediv__` under Python 3. This wasn't totally broken in that without it defined, Python constructs a new object by calling __truediv__. The operator tests, however, didn't actually test the /= operator: when I added it, I saw an extra construction, leading to the problem. This commit also includes tests for the previously untested *= operator, and adds some element-wise vector multiplication and division operators. --- include/pybind11/operators.h | 6 +++++- tests/test_operator_overloading.cpp | 8 ++++++++ tests/test_operator_overloading.py | 17 +++++++++++++++-- 3 files changed, 28 insertions(+), 3 deletions(-) diff --git a/include/pybind11/operators.h b/include/pybind11/operators.h index 2e78c01a3..7d40fb565 100644 --- a/include/pybind11/operators.h +++ b/include/pybind11/operators.h @@ -25,7 +25,7 @@ enum op_id : int { op_int, op_long, op_float, op_str, op_cmp, op_gt, op_ge, op_lt, op_le, op_eq, op_ne, op_iadd, op_isub, op_imul, op_idiv, op_imod, op_ilshift, op_irshift, op_iand, op_ixor, op_ior, op_complex, op_bool, op_nonzero, - op_repr, op_truediv + op_repr, op_truediv, op_itruediv }; enum op_type : int { @@ -129,7 +129,11 @@ PYBIND11_BINARY_OPERATOR(le, ge, operator<=, l <= r) PYBIND11_INPLACE_OPERATOR(iadd, operator+=, l += r) PYBIND11_INPLACE_OPERATOR(isub, operator-=, l -= r) PYBIND11_INPLACE_OPERATOR(imul, operator*=, l *= r) +#if PY_MAJOR_VERSION >= 3 +PYBIND11_INPLACE_OPERATOR(itruediv, operator/=, l /= r) +#else PYBIND11_INPLACE_OPERATOR(idiv, operator/=, l /= r) +#endif PYBIND11_INPLACE_OPERATOR(imod, operator%=, l %= r) PYBIND11_INPLACE_OPERATOR(ilshift, operator<<=, l <<= r) PYBIND11_INPLACE_OPERATOR(irshift, operator>>=, l >>= r) diff --git a/tests/test_operator_overloading.cpp b/tests/test_operator_overloading.cpp index 93aea8010..4e868d939 100644 --- a/tests/test_operator_overloading.cpp +++ b/tests/test_operator_overloading.cpp @@ -39,10 +39,14 @@ public: Vector2 operator+(float value) const { return Vector2(x + value, y + value); } Vector2 operator*(float value) const { return Vector2(x * value, y * value); } Vector2 operator/(float value) const { return Vector2(x / value, y / value); } + Vector2 operator*(const Vector2 &v) const { return Vector2(x * v.x, y * v.y); } + Vector2 operator/(const Vector2 &v) const { return Vector2(x / v.x, y / v.y); } Vector2& operator+=(const Vector2 &v) { x += v.x; y += v.y; return *this; } Vector2& operator-=(const Vector2 &v) { x -= v.x; y -= v.y; return *this; } Vector2& operator*=(float v) { x *= v; y *= v; return *this; } Vector2& operator/=(float v) { x /= v; y /= v; return *this; } + Vector2& operator*=(const Vector2 &v) { x *= v.x; y *= v.y; return *this; } + Vector2& operator/=(const Vector2 &v) { x /= v.x; y /= v.y; return *this; } friend Vector2 operator+(float f, const Vector2 &v) { return Vector2(f + v.x, f + v.y); } friend Vector2 operator-(float f, const Vector2 &v) { return Vector2(f - v.x, f - v.y); } @@ -61,10 +65,14 @@ test_initializer operator_overloading([](py::module &m) { .def(py::self - float()) .def(py::self * float()) .def(py::self / float()) + .def(py::self * py::self) + .def(py::self / py::self) .def(py::self += py::self) .def(py::self -= py::self) .def(py::self *= float()) .def(py::self /= float()) + .def(py::self *= py::self) + .def(py::self /= py::self) .def(float() + py::self) .def(float() - py::self) .def(float() * py::self) diff --git a/tests/test_operator_overloading.py b/tests/test_operator_overloading.py index 02ccb9633..dd37c3497 100644 --- a/tests/test_operator_overloading.py +++ b/tests/test_operator_overloading.py @@ -16,10 +16,21 @@ def test_operator_overloading(): assert str(8 + v1) == "[9.000000, 10.000000]" assert str(8 * v1) == "[8.000000, 16.000000]" assert str(8 / v1) == "[8.000000, 4.000000]" + assert str(v1 * v2) == "[3.000000, -2.000000]" + assert str(v2 / v1) == "[3.000000, -0.500000]" - v1 += v2 + v1 += 2 * v2 + assert str(v1) == "[7.000000, 0.000000]" + v1 -= v2 + assert str(v1) == "[4.000000, 1.000000]" v1 *= 2 assert str(v1) == "[8.000000, 2.000000]" + v1 /= 16 + assert str(v1) == "[0.500000, 0.125000]" + v1 *= v2 + assert str(v1) == "[1.500000, -0.125000]" + v2 /= v1 + assert str(v2) == "[2.000000, 8.000000]" cstats = ConstructorStats.get(Vector2) assert cstats.alive() == 2 @@ -32,7 +43,9 @@ def test_operator_overloading(): '[-7.000000, -6.000000]', '[9.000000, 10.000000]', '[8.000000, 16.000000]', '[0.125000, 0.250000]', '[7.000000, 6.000000]', '[9.000000, 10.000000]', - '[8.000000, 16.000000]', '[8.000000, 4.000000]'] + '[8.000000, 16.000000]', '[8.000000, 4.000000]', + '[3.000000, -2.000000]', '[3.000000, -0.500000]', + '[6.000000, -2.000000]'] assert cstats.default_constructions == 0 assert cstats.copy_constructions == 0 assert cstats.move_constructions >= 10