Explorar el Código

Add tests, guard against dtype!=complex

master
Pete Shadbolt hace 8 años
padre
commit
3e4fc54afb
Se han modificado 4 ficheros con 31 adiciones y 47 borrados
  1. +0
    -42
      run-tests.py
  2. +0
    -5
      run-tests.sh
  3. +4
    -0
      src/permanent.c
  4. +27
    -0
      tests/test.py

+ 0
- 42
run-tests.py Ver fichero

@@ -1,42 +0,0 @@
import os, sys
import numpy as np
import time
from matplotlib import pyplot as plt
from permanent import permanent
import itertools as it


def permanent(a):
""" Slow way to compute the permanent """
r = range(len(a))
return sum([np.prod(a[r, p]) for p in it.permutations(r)])


if __name__ == '__main__':
maxtime=1
dimensions=range(1,11)

for (function, label) in zip((permanent, perm_ryser), ("C", "Python")):
counts=[]
for dimension in dimensions:
print dimension
real=np.random.uniform(-1, 1, dimension*dimension).reshape((dimension, dimension))
imag=np.random.uniform(-1, 1, dimension*dimension).reshape((dimension, dimension))
submatrix=real+1j*imag

t=time.clock()
n=0
while time.clock()-t < maxtime:
for i in range(5):
function(submatrix)
n+=5
counts.append(n)

plt.plot(dimensions, counts, ".-", label=label)

plt.ylabel("Number of permanents per second")
plt.xlabel("Dimension")
plt.xlim(min(dimensions), max(dimensions))
plt.legend()
plt.semilogy()
plt.savefig("out.pdf")

+ 0
- 5
run-tests.sh Ver fichero

@@ -1,5 +0,0 @@
#!/bin/bash

rm permanent/*.so
python ./setup.py build_ext --inplace &&
python ./run-tests.py

+ 4
- 0
src/permanent.c Ver fichero

@@ -48,6 +48,10 @@ static PyObject *permanent(PyObject *self, PyObject *args) {
// Parse the input
PyArrayObject *submatrix;
if (!PyArg_ParseTuple(args, "O!", &PyArray_Type, &submatrix)) {return NULL;}
if (!PyArray_ISCOMPLEX(submatrix)) {
PyErr_SetString(PyExc_TypeError, "Array dtype must be `complex`.");
return NULL;
}

// Compute the permanent
npy_complex128 p = ryser(submatrix);


+ 27
- 0
tests/test.py Ver fichero

@@ -0,0 +1,27 @@
import numpy as np
from permanent import permanent
from nose.tools import raises

def test_permanent():
""" Dumb tests """
m = np.eye(10, dtype=complex)
assert permanent(m) == 1
m = np.zeros((10, 10), dtype=complex)
assert permanent(m) == 0


def test_floaty():
""" More tests using a precomputed permanent """
np.random.seed(1234)
m = np.random.uniform(0, 1, 16) + 1j*np.random.uniform(0, 1, 16)
m = m.reshape(4,4)
p = permanent(m)
assert np.allclose(p, -8.766131870776363+1.072095650303524j)

@raises(TypeError)
def test_error():
""" Should raise a TypeError as we are using the wrong dtype """
m = np.eye(10, dtype=float)
permanent(m)

Cargando…
Cancelar
Guardar