Browse Source

Improve clifford.py, add measurement lookup table

master
Pete Shadbolt 8 years ago
parent
commit
092afb82f6
3 changed files with 49 additions and 55 deletions
  1. +45
    -43
      abp/clifford.py
  2. +2
    -6
      tests/test_against_anders_and_briegel.py
  3. +2
    -6
      tests/test_measurement.py

+ 45
- 43
abp/clifford.py View File

@@ -17,15 +17,12 @@ decompositions = ("xxxx", "xx", "zzxx", "zz", "zxx", "z", "zzz", "xxz",


def conjugate(vop, transform): def conjugate(vop, transform):
""" Returns transform * vop * transform^dagger and a phase in {+1, -1} """ """ Returns transform * vop * transform^dagger and a phase in {+1, -1} """
assert vop in set(xrange(4))
op = times_table[transform, vop]
op = times_table[op, conjugation_table[transform]]
is_id_or_vop = (transform % 4 == 0) or (transform % 4 == vop)
is_non_pauli = (transform >= 4) and (transform <= 15)
phase = ((-1, 1), (1, -1))[is_id_or_vop][is_non_pauli]
if vop == 0:
phase = 1
return op, phase
return measurement_table[vop, transform]

def use_old_cz():
""" Use the CZ table from A&B's code """
global cz_table
from anders_cz import cz_table


def get_name(i): def get_name(i):
@@ -109,7 +106,8 @@ def get_measurement_table():
Compute a table of transform * operation * transform^dagger Compute a table of transform * operation * transform^dagger
This is pretty unintelligible right now, we should probably compute the phase from unitaries instead This is pretty unintelligible right now, we should probably compute the phase from unitaries instead
""" """
for operation_index, transform_index in itertools.product(range(4), range(24)):
measurement_table = np.zeros((4, 24, 2), dtype=complex)
for vop, transform in it.product(range(4), range(24)):
assert vop in set(xrange(4)) assert vop in set(xrange(4))
op = times_table[transform, vop] op = times_table[transform, vop]
op = times_table[op, conjugation_table[transform]] op = times_table[op, conjugation_table[transform]]
@@ -118,9 +116,8 @@ def get_measurement_table():
phase = ((-1, 1), (1, -1))[is_id_or_vop][is_non_pauli] phase = ((-1, 1), (1, -1))[is_id_or_vop][is_non_pauli]
if vop == 0: if vop == 0:
phase = 1 phase = 1
return op, phase

measurement_table[vop, transform] = [op, phase]
return measurement_table




def get_commuters(unitaries): def get_commuters(unitaries):
@@ -166,47 +163,52 @@ def write_javascript_tables():
.format(json.dumps(by_name))) .format(json.dumps(by_name)))
f.write("};"); f.write("};");


def temp(filename):
""" Get a temporary path """
tempdir = tempfile.gettempdir()
return os.path.join(tempdir, filename)


# First try to load tables from cache. If that fails, build them from
# scratch and store in /tmp/
tempdir = tempfile.gettempdir()
temp = lambda filename: os.path.join(tempdir, filename)
try:
if __name__ == "__main__":
raise IOError

# Parse command line args
# parser = argparse.ArgumentParser()
# parser.add_argument("-l", "--legacy", help="Use legacy CZ table", action="store_true", default=False)
# args = parser.parse_args()
legacy_cz = False

unitaries = np.load(temp("unitaries.npy"))
conjugation_table = np.load(temp("conjugation_table.npy"))
times_table = np.load(temp("times_table.npy"))
if legacy_cz:
import anders_cz
cz_table = anders_cz.cz_table
else:
cz_table = np.load(temp("cz_table.npy"))

with open(temp("by_name.json")) as f:
by_name = json.load(f)

except IOError:
# Spend time building the tables
def compute_everything():
""" Compute all lookup tables """
global unitaries, by_name, conjugation_table, times_table, cz_table, measurement_table
unitaries = get_unitaries() unitaries = get_unitaries()
by_name = get_by_name(unitaries) by_name = get_by_name(unitaries)
conjugation_table = get_conjugation_table(unitaries) conjugation_table = get_conjugation_table(unitaries)
times_table = get_times_table(unitaries) times_table = get_times_table(unitaries)
cz_table = get_cz_table(unitaries) cz_table = get_cz_table(unitaries)
measurement_table = get_measurement_table()


# Write it all to disk
def save_to_disk():
""" Save all tables to disk """
global unitaries, by_name, conjugation_table, times_table, cz_table, measurement_table
np.save(temp("unitaries.npy"), unitaries) np.save(temp("unitaries.npy"), unitaries)
np.save(temp("conjugation_table.npy"), conjugation_table) np.save(temp("conjugation_table.npy"), conjugation_table)
np.save(temp("times_table.npy"), times_table) np.save(temp("times_table.npy"), times_table)
np.save(temp("cz_table.npy"), cz_table) np.save(temp("cz_table.npy"), cz_table)
np.save(temp("measurement_table.npy"), measurement_table)
write_javascript_tables() write_javascript_tables()

with open(temp("by_name.json"), "wb") as f: with open(temp("by_name.json"), "wb") as f:
json.dump(by_name, f) json.dump(by_name, f)

def load_from_disk():
""" Load all the tables from disk """
global unitaries, by_name, conjugation_table, times_table, cz_table, measurement_table
unitaries = np.load(temp("unitaries.npy"))
conjugation_table = np.load(temp("conjugation_table.npy"))
times_table = np.load(temp("times_table.npy"))
measurement_table = np.load(temp("measurement_table.npy"))
cz_table = np.load(temp("cz_table.npy"))

with open(temp("by_name.json")) as f:
by_name = json.load(f)


if __name__ == "__main__":
compute_everything()
save_to_disk()
else:
try:
load_from_disk()
except IOError:
compute_everything()
save_to_disk()

+ 2
- 6
tests/test_against_anders_and_briegel.py View File

@@ -48,9 +48,7 @@ def test_local_rotations():
def test_cz_table(N=10): def test_cz_table(N=10):
""" Test the CZ table """ """ Test the CZ table """


# Don't test if we are using Pete's CZ table - doesn't make sense
if not clifford.legacy_cz:
return
clifford.use_old_cz()


for j in range(24): for j in range(24):
a = graphsim.GraphRegister(2) a = graphsim.GraphRegister(2)
@@ -92,9 +90,7 @@ def test_with_cphase_gates_hadamard_only(N=10):
def test_all(N=10): def test_all(N=10):
""" Test all gates at random """ """ Test all gates at random """


# Don't test if we are using Pete's CZ table - doesn't make sense
if not clifford.legacy_cz:
return
clifford.use_old_cz()


a = graphsim.GraphRegister(N) a = graphsim.GraphRegister(N)
b = GraphState() b = GraphState()


+ 2
- 6
tests/test_measurement.py View File

@@ -1,14 +1,10 @@
from abp import GraphState from abp import GraphState


def test_z_measurement(): def test_z_measurement():
g = GraphState(0)
g = GraphState([0])
assert g.measure_z(0, 0) == 0 assert g.measure_z(0, 0) == 0
assert g.measure_z(0, 1) == 1 assert g.measure_z(0, 1) == 1
assert not all(g.measure_z(0) == 0 for i in range(100))

g.act_hadamard(0)
print g
assert all(g.measure_z(0) == 1 for i in range(100))
# TODO







Loading…
Cancel
Save