diff --git a/Doc/library/math.rst b/Doc/library/math.rst index 7129525c788767..ce92e04614ed53 100644 --- a/Doc/library/math.rst +++ b/Doc/library/math.rst @@ -218,6 +218,22 @@ Number-theoretic and representation functions :meth:`x.__trunc__() `. +.. function:: comb(n, k) + + Return the binomial coefficient indexed by the pair of integers n >= k >= 0. + + It is the coefficient of kth term in polynomial expansion of the expression + (1 + x)^n. It is also known as the number of ways to choose an unordered + subset of k elements from a fixed set of n elements, usually called + *n choose k*. + + Raises :exc:`TypeError` if argument(s) are non-integer, :exc:`ValueError` + if argument(s) are negative or k > n, and :exc:`OverflowError` if k and (n-k) + are too large (beyond `LLONG_MAX`). + + .. versionadded:: 3.8 + + Note that :func:`frexp` and :func:`modf` have a different call/return pattern than their C equivalents: they take a single argument and return a pair of values, rather than returning their second return value through an 'output diff --git a/Lib/test/test_math.py b/Lib/test/test_math.py index cb05dee0e0fd3c..812ad4e5699608 100644 --- a/Lib/test/test_math.py +++ b/Lib/test/test_math.py @@ -523,6 +523,52 @@ def testFactorialHugeInputs(self): self.assertRaises(OverflowError, math.factorial, 10**100) self.assertRaises(OverflowError, math.factorial, 1e100) + def testCombFactorial(self): + """Test (n choose k) = n! / (k! (n-k)!) when 0 <= k <= n.""" + for n in range(100): + for k in range(n+1): + self.assertEqual(math.comb(n, k), + math.factorial(n) // math.factorial(k) // math.factorial(n-k)) + + def testCombTriangle(self): + """Test (n+1 choose k+1) = (n choose k) + (n choose k+1)""" + for n in range(100): + for k in range(n): + self.assertEqual(math.comb(n + 1, k + 1), math.comb(n, k) + math.comb(n, k + 1)) + + def testCombZero(self): + """(n choose k) raises ValueError when k>n""" + for k in range(100): + for n in range(k): + self.assertRaises(ValueError, math.comb, n, k) + + def testCombOne(self): + """Test (n choose 0) = (n choose n) = 1""" + for n in range(100): + self.assertEqual(1, math.comb(n, 0)) + self.assertEqual(1, math.comb(n, n)) + + def testCombValueErrors(self): + """Test that math.comb raises ValueError on negative inputs or k>n.""" + for neg in [-1, -10**100]: + self.assertRaises(ValueError, math.comb, 0, neg) + self.assertRaises(ValueError, math.comb, neg, 0) + + for n in range(100): + for k in range(n+1, 100): + self.assertRaises(ValueError, math.comb, n, k) + + def testCombOverflow(self): + """math.comb raises OverflowError on inputs too large for C longs.""" + # min(k, n - k) > LLONG_MAX) + self.assertRaises(OverflowError, math.comb, 10**400, 10**200) + + def testCombTypeErrors(self): + """Test math.comb raises TypeError on non-int inputs.""" + for non_int in [-1e100, -1.0, math.pi, decimal.Decimal(5.2), "5"]: + self.assertRaises(TypeError, math.comb, 0, non_int) + self.assertRaises(TypeError, math.comb, non_int, 0) + def testFloor(self): self.assertRaises(TypeError, math.floor) self.assertEqual(int, type(math.floor(0.5))) diff --git a/Misc/NEWS.d/next/Library/2019-01-02-19-48-23.bpo-35431.FhG6QA.rst b/Misc/NEWS.d/next/Library/2019-01-02-19-48-23.bpo-35431.FhG6QA.rst new file mode 100644 index 00000000000000..3896e7d75fc12f --- /dev/null +++ b/Misc/NEWS.d/next/Library/2019-01-02-19-48-23.bpo-35431.FhG6QA.rst @@ -0,0 +1,3 @@ +Implement :func:`math.comb` that returns binomial coefficient, that is the +coefficient of kth term in expansion of polynomial (1 + x)^n. +Patch by Keller Fuchs and Yash Aggarwal. diff --git a/Modules/clinic/mathmodule.c.h b/Modules/clinic/mathmodule.c.h index 1806a01588c5ab..6c103a35e7166a 100644 --- a/Modules/clinic/mathmodule.c.h +++ b/Modules/clinic/mathmodule.c.h @@ -628,4 +628,42 @@ math_prod(PyObject *module, PyObject *const *args, Py_ssize_t nargs, PyObject *k exit: return return_value; } -/*[clinic end generated code: output=96e71135dce41c48 input=a9049054013a1b77]*/ + +PyDoc_STRVAR(math_comb__doc__, +"comb($module, n, k, /)\n" +"--\n" +"\n" +"Return the binomial coefficient indexed by the pair of integers n >= k >= 0.\n" +"\n" +"It is the coefficient of kth term in polynomial expansion of the expression\n" +"(1 + x)^n. It is also known as the number of ways to choose an unordered\n" +"subset of k elements from a fixed set of n elements, usually called\n" +"*n choose k*.\n" +"\n" +"Raises a TypeError if argument(s) are non-integer and ValueError\n" +"if argument(s) are negative or k > n."); + +#define MATH_COMB_METHODDEF \ + {"comb", (PyCFunction)(void(*)(void))math_comb, METH_FASTCALL, math_comb__doc__}, + +static PyObject * +math_comb_impl(PyObject *module, PyObject *n, PyObject *k); + +static PyObject * +math_comb(PyObject *module, PyObject *const *args, Py_ssize_t nargs) +{ + PyObject *return_value = NULL; + PyObject *n; + PyObject *k; + + if (!_PyArg_CheckPositional("comb", nargs, 2, 2)) { + goto exit; + } + n = args[0]; + k = args[1]; + return_value = math_comb_impl(module, n, k); + +exit: + return return_value; +} +/*[clinic end generated code: output=333afdbd248d74d1 input=a9049054013a1b77]*/ diff --git a/Modules/mathmodule.c b/Modules/mathmodule.c index ba8423211c2b53..16d200831aa3dc 100644 --- a/Modules/mathmodule.c +++ b/Modules/mathmodule.c @@ -2706,6 +2706,151 @@ math_prod_impl(PyObject *module, PyObject *iterable, PyObject *start) } +/*[clinic input] +math.comb + + n: object + + k: object + / + +Return the binomial coefficient indexed by the pair of integers n >= k >= 0. + +It is the coefficient of kth term in polynomial expansion of the expression +(1 + x)^n. It is also known as the number of ways to choose an unordered +subset of k elements from a fixed set of n elements, usually called +*n choose k*. + +Raises a TypeError if argument(s) are non-integer and ValueError +if argument(s) are negative or k > n. + +[clinic start generated code]*/ + +static PyObject * +math_comb_impl(PyObject *module, PyObject *n, PyObject *k) +/*[clinic end generated code: output=bd2cec8d854f3493 input=75e1a19623bae7dc]*/ +{ + if (!(PyLong_Check(n) && PyLong_Check(k))) { + PyErr_SetString(PyExc_TypeError, + "comb() only accepts integer arguments"); + return NULL; + } + n = PyNumber_Long(n); + if (n == NULL) { + return NULL; + } + k = PyNumber_Long(k); + if (k == NULL) { + Py_DECREF(n); + return NULL; + } + + PyObject *val = NULL, + *temp_obj1 = NULL, + *temp_obj2 = NULL, + *dump_var = NULL; + int overflow, cmp; + long long i, terms; + + cmp = PyObject_RichCompareBool(n, k, Py_LT); + if (cmp == 1) { + PyErr_Format(PyExc_ValueError, + "n must be an integer greater or equal to k"); + goto fail_comb; + } + else if (cmp == -1) { + goto fail_comb; + } + + /* b = min(b, a - b) */ + dump_var = PyNumber_Subtract(n, k); + if (dump_var == NULL) { + goto fail_comb; + } + cmp = PyObject_RichCompareBool(k, dump_var, Py_GT); + if (cmp == 1) { + Py_DECREF(k); + k = dump_var; + dump_var = NULL; + } + else if (cmp == -1) { + goto fail_comb; + } + else { + Py_DECREF(dump_var); + dump_var = NULL; + } + + terms = PyLong_AsLongLongAndOverflow(k, &overflow); + if (terms == -1 && PyErr_Occurred()) { + goto fail_comb; + } + else if (overflow == 1) { + PyErr_Format(PyExc_OverflowError, + "Either (n - k) or k must not exceed %lld", + LLONG_MAX); + goto fail_comb; + } + else if (overflow == -1 || terms < 0) { + PyErr_Format(PyExc_ValueError, + "k must be a positive integer"); + goto fail_comb; + } + + if (terms == 0) { + Py_DECREF(n); + Py_DECREF(k); + return PyNumber_Long(_PyLong_One); + } + + val = PyNumber_Long(n); + for (i = 1; i < terms; ++i) { + temp_obj1 = PyLong_FromSsize_t(i); + if (temp_obj1 == NULL) { + goto fail_comb; + } + temp_obj2 = PyNumber_Subtract(n, temp_obj1); + if (temp_obj2 == NULL) { + goto fail_comb; + } + dump_var = val; + val = PyNumber_Multiply(val, temp_obj2); + if (val == NULL) { + goto fail_comb; + } + Py_DECREF(dump_var); + dump_var = NULL; + Py_DECREF(temp_obj2); + temp_obj2 = PyLong_FromUnsignedLongLong((unsigned long long)(i + 1)); + if (temp_obj2 == NULL) { + goto fail_comb; + } + dump_var = val; + val = PyNumber_FloorDivide(val, temp_obj2); + if (val == NULL) { + goto fail_comb; + } + Py_DECREF(dump_var); + Py_DECREF(temp_obj1); + Py_DECREF(temp_obj2); + } + Py_DECREF(n); + Py_DECREF(k); + + return val; + +fail_comb: + Py_XDECREF(n); + Py_XDECREF(k); + Py_XDECREF(val); + Py_XDECREF(dump_var); + Py_XDECREF(temp_obj1); + Py_XDECREF(temp_obj2); + + return NULL; +} + + static PyMethodDef math_methods[] = { {"acos", math_acos, METH_O, math_acos_doc}, {"acosh", math_acosh, METH_O, math_acosh_doc}, @@ -2754,6 +2899,7 @@ static PyMethodDef math_methods[] = { {"tanh", math_tanh, METH_O, math_tanh_doc}, MATH_TRUNC_METHODDEF MATH_PROD_METHODDEF + MATH_COMB_METHODDEF {NULL, NULL} /* sentinel */ };