瀏覽代碼

Improve clifford.py, add measurement lookup table

master
Pete Shadbolt 8 年之前
父節點
當前提交
092afb82f6
共有 3 個檔案被更改,包括 49 行新增55 行删除
  1. +45
    -43
      abp/clifford.py
  2. +2
    -6
      tests/test_against_anders_and_briegel.py
  3. +2
    -6
      tests/test_measurement.py

+ 45
- 43
abp/clifford.py 查看文件

@@ -17,15 +17,12 @@ decompositions = ("xxxx", "xx", "zzxx", "zz", "zxx", "z", "zzz", "xxz",


def conjugate(vop, transform): def conjugate(vop, transform):
""" Returns transform * vop * transform^dagger and a phase in {+1, -1} """ """ Returns transform * vop * transform^dagger and a phase in {+1, -1} """
assert vop in set(xrange(4))
op = times_table[transform, vop]
op = times_table[op, conjugation_table[transform]]
is_id_or_vop = (transform % 4 == 0) or (transform % 4 == vop)
is_non_pauli = (transform >= 4) and (transform <= 15)
phase = ((-1, 1), (1, -1))[is_id_or_vop][is_non_pauli]
if vop == 0:
phase = 1
return op, phase
return measurement_table[vop, transform]

def use_old_cz():
""" Use the CZ table from A&B's code """
global cz_table
from anders_cz import cz_table


def get_name(i): def get_name(i):
@@ -109,7 +106,8 @@ def get_measurement_table():
Compute a table of transform * operation * transform^dagger Compute a table of transform * operation * transform^dagger
This is pretty unintelligible right now, we should probably compute the phase from unitaries instead This is pretty unintelligible right now, we should probably compute the phase from unitaries instead
""" """
for operation_index, transform_index in itertools.product(range(4), range(24)):
measurement_table = np.zeros((4, 24, 2), dtype=complex)
for vop, transform in it.product(range(4), range(24)):
assert vop in set(xrange(4)) assert vop in set(xrange(4))
op = times_table[transform, vop] op = times_table[transform, vop]
op = times_table[op, conjugation_table[transform]] op = times_table[op, conjugation_table[transform]]
@@ -118,9 +116,8 @@ def get_measurement_table():
phase = ((-1, 1), (1, -1))[is_id_or_vop][is_non_pauli] phase = ((-1, 1), (1, -1))[is_id_or_vop][is_non_pauli]
if vop == 0: if vop == 0:
phase = 1 phase = 1
return op, phase

measurement_table[vop, transform] = [op, phase]
return measurement_table




def get_commuters(unitaries): def get_commuters(unitaries):
@@ -166,47 +163,52 @@ def write_javascript_tables():
.format(json.dumps(by_name))) .format(json.dumps(by_name)))
f.write("};"); f.write("};");


def temp(filename):
""" Get a temporary path """
tempdir = tempfile.gettempdir()
return os.path.join(tempdir, filename)


# First try to load tables from cache. If that fails, build them from
# scratch and store in /tmp/
tempdir = tempfile.gettempdir()
temp = lambda filename: os.path.join(tempdir, filename)
try:
if __name__ == "__main__":
raise IOError

# Parse command line args
# parser = argparse.ArgumentParser()
# parser.add_argument("-l", "--legacy", help="Use legacy CZ table", action="store_true", default=False)
# args = parser.parse_args()
legacy_cz = False

unitaries = np.load(temp("unitaries.npy"))
conjugation_table = np.load(temp("conjugation_table.npy"))
times_table = np.load(temp("times_table.npy"))
if legacy_cz:
import anders_cz
cz_table = anders_cz.cz_table
else:
cz_table = np.load(temp("cz_table.npy"))

with open(temp("by_name.json")) as f:
by_name = json.load(f)

except IOError:
# Spend time building the tables
def compute_everything():
""" Compute all lookup tables """
global unitaries, by_name, conjugation_table, times_table, cz_table, measurement_table
unitaries = get_unitaries() unitaries = get_unitaries()
by_name = get_by_name(unitaries) by_name = get_by_name(unitaries)
conjugation_table = get_conjugation_table(unitaries) conjugation_table = get_conjugation_table(unitaries)
times_table = get_times_table(unitaries) times_table = get_times_table(unitaries)
cz_table = get_cz_table(unitaries) cz_table = get_cz_table(unitaries)
measurement_table = get_measurement_table()


# Write it all to disk
def save_to_disk():
""" Save all tables to disk """
global unitaries, by_name, conjugation_table, times_table, cz_table, measurement_table
np.save(temp("unitaries.npy"), unitaries) np.save(temp("unitaries.npy"), unitaries)
np.save(temp("conjugation_table.npy"), conjugation_table) np.save(temp("conjugation_table.npy"), conjugation_table)
np.save(temp("times_table.npy"), times_table) np.save(temp("times_table.npy"), times_table)
np.save(temp("cz_table.npy"), cz_table) np.save(temp("cz_table.npy"), cz_table)
np.save(temp("measurement_table.npy"), measurement_table)
write_javascript_tables() write_javascript_tables()

with open(temp("by_name.json"), "wb") as f: with open(temp("by_name.json"), "wb") as f:
json.dump(by_name, f) json.dump(by_name, f)

def load_from_disk():
""" Load all the tables from disk """
global unitaries, by_name, conjugation_table, times_table, cz_table, measurement_table
unitaries = np.load(temp("unitaries.npy"))
conjugation_table = np.load(temp("conjugation_table.npy"))
times_table = np.load(temp("times_table.npy"))
measurement_table = np.load(temp("measurement_table.npy"))
cz_table = np.load(temp("cz_table.npy"))

with open(temp("by_name.json")) as f:
by_name = json.load(f)


if __name__ == "__main__":
compute_everything()
save_to_disk()
else:
try:
load_from_disk()
except IOError:
compute_everything()
save_to_disk()

+ 2
- 6
tests/test_against_anders_and_briegel.py 查看文件

@@ -48,9 +48,7 @@ def test_local_rotations():
def test_cz_table(N=10): def test_cz_table(N=10):
""" Test the CZ table """ """ Test the CZ table """


# Don't test if we are using Pete's CZ table - doesn't make sense
if not clifford.legacy_cz:
return
clifford.use_old_cz()


for j in range(24): for j in range(24):
a = graphsim.GraphRegister(2) a = graphsim.GraphRegister(2)
@@ -92,9 +90,7 @@ def test_with_cphase_gates_hadamard_only(N=10):
def test_all(N=10): def test_all(N=10):
""" Test all gates at random """ """ Test all gates at random """


# Don't test if we are using Pete's CZ table - doesn't make sense
if not clifford.legacy_cz:
return
clifford.use_old_cz()


a = graphsim.GraphRegister(N) a = graphsim.GraphRegister(N)
b = GraphState() b = GraphState()


+ 2
- 6
tests/test_measurement.py 查看文件

@@ -1,14 +1,10 @@
from abp import GraphState from abp import GraphState


def test_z_measurement(): def test_z_measurement():
g = GraphState(0)
g = GraphState([0])
assert g.measure_z(0, 0) == 0 assert g.measure_z(0, 0) == 0
assert g.measure_z(0, 1) == 1 assert g.measure_z(0, 1) == 1
assert not all(g.measure_z(0) == 0 for i in range(100))

g.act_hadamard(0)
print g
assert all(g.measure_z(0) == 1 for i in range(100))
# TODO







Loading…
取消
儲存