Anders and Briegel in Python
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

125 lines
3.9KB

  1. #!/usr/bin/python
  2. # -*- coding: utf-8 -*-
  3. """
  4. Exposes a few basic QI operators
  5. And a circuit-model simulator
  6. """
  7. import numpy as np
  8. import itertools as it
  9. def hermitian_conjugate(u):
  10. """ Shortcut to the Hermitian conjugate """
  11. return np.conjugate(np.transpose(u))
  12. # Constants
  13. ir2 = 1 / np.sqrt(2)
  14. # Operators
  15. id = np.array(np.eye(2, dtype=complex))
  16. px = np.array([[0, 1], [1, 0]], dtype=complex)
  17. py = np.array([[0, -1j], [1j, 0]], dtype=complex)
  18. pz = np.array([[1, 0], [0, -1]], dtype=complex)
  19. ha = hadamard = np.array([[1, 1], [1, -1]], dtype=complex) * ir2
  20. ph = np.array([[1, 0], [0, 1j]], dtype=complex)
  21. t = np.array([[1, 0], [0, np.exp(1j * np.pi / 4)]], dtype=complex)
  22. sqx = np.array(
  23. [[1. + 0.j, -0. + 1.j], [-0. + 1.j, 1. - 0.j]], dtype=complex) * ir2
  24. msqx = np.array(
  25. [[1. + 0.j, 0. - 1.j], [0. - 1.j, 1. - 0.j]], dtype=complex) * ir2
  26. sqy = np.array(
  27. [[1. + 0.j, 1. + 0.j], [-1. - 0.j, 1. - 0.j]], dtype=complex) * ir2
  28. msqy = np.array(
  29. [[1. + 0.j, -1. - 0.j], [1. + 0.j, 1. - 0.j]], dtype=complex) * ir2
  30. sqz = np.array(
  31. [[1. + 1.j, 0. + 0.j], [0. + 0.j, 1. - 1.j]], dtype=complex) * ir2
  32. msqz = np.array(
  33. [[1. - 1.j, 0. + 0.j], [0. + 0.j, 1. + 1.j]], dtype=complex) * ir2
  34. # CZ gate
  35. cz = np.array(np.eye(4), dtype=complex)
  36. cz[3, 3] = -1
  37. # States
  38. zero = np.array([[1], [0]], dtype=complex)
  39. one = np.array([[0], [1]], dtype=complex)
  40. plus = np.array([[1], [1]], dtype=complex) / np.sqrt(2)
  41. bond = cz.dot(np.kron(plus, plus))
  42. nobond = np.kron(plus, plus)
  43. # Labelling stuff
  44. common_us = id, px, py, pz, ha, ph, sqz, msqz, sqy, msqy, sqx, msqx
  45. names = "identity", "px", "py", "pz", "hadamard", "phase", "sqz", "msqz", "sqy", "msqy", "sqx", "msqx"
  46. by_name = dict(zip(names, common_us))
  47. paulis = px, py, pz
  48. operators = id, px, py, pz
  49. def normalize_global_phase(m):
  50. """ Normalize the global phase of a matrix """
  51. v = (x for x in m.flatten() if np.abs(x) > 0.001).next()
  52. phase = np.arctan2(v.imag, v.real) % np.pi
  53. rot = np.exp(-1j * phase)
  54. return rot * m if rot * v > 0 else -rot * m
  55. class CircuitModel(object):
  56. def __init__(self, nqubits):
  57. self.nqubits = nqubits
  58. self.d = 2 ** nqubits
  59. self.state = np.zeros((self.d, 1), dtype=complex)
  60. self.state[0, 0] = 1
  61. def act_cz(self, control, target):
  62. """ Act a CU somewhere """
  63. control = 1 << control
  64. target = 1 << target
  65. for i in xrange(self.d):
  66. if (i & control) and (i & target):
  67. self.state[i, 0] *= -1
  68. def act_cnot(self, control, target):
  69. """ Act a CNOT """
  70. self.act_hadamard(target)
  71. self.act_cz(control, target)
  72. self.act_hadamard(target)
  73. def act_hadamard(self, qubit):
  74. """ Act a hadamard somewhere """
  75. where = 1 << qubit
  76. output = np.zeros((self.d, 1), dtype=complex)
  77. for i, v in enumerate(self.state):
  78. q = int(i & where > 0)
  79. output[i] += v * ha[q, q]
  80. output[i ^ where] += v * ha[int(not q), q]
  81. self.state = output
  82. def act_local_rotation(self, qubit, u):
  83. """ Act a local unitary somwhere """
  84. where = 1 << qubit
  85. output = np.zeros((self.d, 1), dtype=complex)
  86. for i, v in enumerate(self.state):
  87. q = int(i & where > 0)
  88. output[i] += v * u[q, q] # TODO this is probably wrong
  89. output[i ^ where] += v * u[int(not q), q]
  90. self.state = output
  91. def __eq__(self, other):
  92. """ Check whether two quantum states are the same or not
  93. UP TO A GLOBAL PHASE """
  94. a = normalize_global_phase(self.state)
  95. b = normalize_global_phase(other.state)
  96. return np.allclose(a, b)
  97. def __str__(self):
  98. s = ""
  99. for i in range(self.d):
  100. label = bin(i)[2:].rjust(self.nqubits, "0")
  101. if abs(self.state[i, 0]) > 0.00001:
  102. s += "|{}>: {:.2f}\n".format(label, self.state[i, 0])
  103. return s