Browse Source

Add tests, guard against dtype!=complex

Pete Shadbolt 2 years ago
parent
commit
3e4fc54afb
4 changed files with 31 additions and 47 deletions
  1. 0
    42
      run-tests.py
  2. 0
    5
      run-tests.sh
  3. 4
    0
      src/permanent.c
  4. 27
    0
      tests/test.py

+ 0
- 42
run-tests.py View File

@@ -1,42 +0,0 @@
1
-import os, sys
2
-import numpy as np
3
-import time
4
-from matplotlib import pyplot as plt
5
-from permanent import permanent
6
-import itertools as it
7
-
8
-
9
-def permanent(a):
10
-    """ Slow way to compute the permanent """
11
-    r = range(len(a))
12
-    return sum([np.prod(a[r, p]) for p in it.permutations(r)])
13
-
14
-
15
-if __name__ == '__main__':
16
-    maxtime=1
17
-    dimensions=range(1,11)
18
-
19
-    for (function, label) in zip((permanent, perm_ryser), ("C", "Python")):
20
-        counts=[]
21
-        for dimension in dimensions:
22
-            print dimension
23
-            real=np.random.uniform(-1, 1, dimension*dimension).reshape((dimension, dimension))
24
-            imag=np.random.uniform(-1, 1, dimension*dimension).reshape((dimension, dimension))
25
-            submatrix=real+1j*imag
26
-
27
-            t=time.clock()
28
-            n=0
29
-            while time.clock()-t < maxtime:
30
-                for i in range(5):
31
-                    function(submatrix)
32
-                n+=5
33
-            counts.append(n)
34
-
35
-        plt.plot(dimensions, counts, ".-", label=label)
36
-
37
-    plt.ylabel("Number of permanents per second")
38
-    plt.xlabel("Dimension")
39
-    plt.xlim(min(dimensions), max(dimensions))
40
-    plt.legend()
41
-    plt.semilogy()
42
-    plt.savefig("out.pdf")

+ 0
- 5
run-tests.sh View File

@@ -1,5 +0,0 @@
1
-#!/bin/bash
2
-
3
-rm permanent/*.so
4
-python ./setup.py build_ext --inplace &&
5
-python ./run-tests.py

+ 4
- 0
src/permanent.c View File

@@ -48,6 +48,10 @@ static PyObject *permanent(PyObject *self, PyObject *args) {
48 48
   // Parse the input
49 49
   PyArrayObject *submatrix;
50 50
   if (!PyArg_ParseTuple(args, "O!", &PyArray_Type, &submatrix)) {return NULL;}
51
+  if (!PyArray_ISCOMPLEX(submatrix)) {
52
+      PyErr_SetString(PyExc_TypeError, "Array dtype must be `complex`.");
53
+      return NULL;
54
+  }
51 55
 
52 56
   // Compute the permanent
53 57
   npy_complex128 p = ryser(submatrix);

+ 27
- 0
tests/test.py View File

@@ -0,0 +1,27 @@
1
+import numpy as np
2
+from permanent import permanent
3
+from nose.tools import raises
4
+
5
+def test_permanent():
6
+    """ Dumb tests """
7
+    m = np.eye(10, dtype=complex)
8
+    assert permanent(m) == 1
9
+    m = np.zeros((10, 10), dtype=complex)
10
+    assert permanent(m) == 0
11
+
12
+
13
+def test_floaty():
14
+    """ More tests using a precomputed permanent """
15
+    np.random.seed(1234)
16
+    m = np.random.uniform(0, 1, 16) + 1j*np.random.uniform(0, 1, 16)
17
+    m = m.reshape(4,4)
18
+    p = permanent(m)
19
+    assert np.allclose(p, -8.766131870776363+1.072095650303524j)
20
+
21
+    
22
+@raises(TypeError)
23
+def test_error():
24
+    """ Should raise a TypeError as we are using the wrong dtype """
25
+    m = np.eye(10, dtype=float)
26
+    permanent(m)
27
+