diff --git a/abp/fancy.py b/abp/fancy.py index 32ea3b2..19f0a2d 100644 --- a/abp/fancy.py +++ b/abp/fancy.py @@ -10,11 +10,7 @@ import util class GraphState(graphstate.GraphState, nx.Graph): def __init__(self, *args, **kwargs): - if args and type(args[0]) == nx.Graph: - graphstate.GraphState.__init__(self) - self.from_nx(args[0]) - else: - graphstate.GraphState.__init__(self, *args, **kwargs) + graphstate.GraphState.__init__(self, *args, **kwargs) self.connect_to_server() def connect_to_server(self, uri = "ws://localhost:5000"): @@ -25,15 +21,6 @@ class GraphState(graphstate.GraphState, nx.Graph): except: #TODO: bad practice self.ws = None - def from_nx(self, g): - """ Clone from a networkx graph. Hacky af """ - self.adj = g.adj.copy() - self.node = g.node.copy() - # TODO: hacky af - for key, value in self.node.items(): - self.node[key]["vop"] = clifford.by_name["identity"] - - def shutdown(self): """ Close the connection to the websocket """ if not self.ws: diff --git a/abp/graphstate.py b/abp/graphstate.py index 703f983..5b19447 100644 --- a/abp/graphstate.py +++ b/abp/graphstate.py @@ -16,10 +16,10 @@ class GraphState(object): Internally it uses the same dictionary-of-dictionaries data structure as ``networkx``. """ - def __init__(self, nodes=[], deterministic=False, vop="identity"): + def __init__(self, data=(), deterministic=False, vop="identity"): """ Construct a ``GraphState`` - :param nodes: An iterable of nodes used to construct the graph, or an integer -- the number of nodes. + :param data: An iterable of nodes used to construct the graph, or an integer -- the number of nodes, or a ``nx.Graph``. :param deterministic: If ``True``, the behaviour of the graph is deterministic up to but not including the choice of measurement outcome. This is slightly less efficient, but useful for testing. If ``False``, the specific graph representation will sometimes be random -- of course, all possible representations still map to the same state vector. :param vop: The default VOP for new qubits. Setting ``vop="identity"`` initializes qubits in :math:`|+\\rangle`. Setting ``vop="hadamard"`` initializes qubits in :math:`|0\\rangle`. """ @@ -27,11 +27,20 @@ class GraphState(object): self.deterministic = deterministic self.adj, self.node = {}, {} try: - for n in nodes: - self._add_node(n, vop=vop) - except TypeError: - for n in range(nodes): - self._add_node(n, vop=vop) + # Cloning from a networkx graph + self.adj = data.adj.copy() + self.node = data.node.copy() + for key, value in self.node.items(): + self.node[key]["vop"] = data.node[key].get("vop", clifford.by_name["identity"]) + except AttributeError: + try: + # Provided with a list of node names? + for n in data: + self._add_node(n, vop=vop) + except TypeError: + # Provided with an integer? + for n in range(data): + self._add_node(n, vop=vop) def _add_node(self, node, **kwargs): """ Add a node. By default, nodes are initialized with ``vop=``:math:`I`, i.e. they are in the :math:`|+\\rangle` state. diff --git a/tests/test_fancy.py b/tests/test_fancy.py index 636eecf..f1fba44 100644 --- a/tests/test_fancy.py +++ b/tests/test_fancy.py @@ -41,3 +41,5 @@ def test_from_nx(): assert psi.node[0]["vop"] == 0 assert len(psi.edges()) > 0 psi.measure(0, "px", detail=True) + + psi = fancy.GraphState(nx.Graph(((0, 1),))) diff --git a/tests/test_graphstate.py b/tests/test_graphstate.py index 0f6e3b5..889bfd2 100644 --- a/tests/test_graphstate.py +++ b/tests/test_graphstate.py @@ -3,6 +3,7 @@ import mock import random import numpy as np from tqdm import tqdm +import networkx as nx REPEATS = 100 DEPTH = 100 @@ -121,3 +122,12 @@ def test_stabilizer_state_multiqubit(n=6): b = mock.circuit_to_state(mock.CircuitModelWrapper, n, circuit) assert a.to_state_vector() == b + +def test_from_nx(): + """ Creating from a networkx graph """ + g = nx.random_geometric_graph(100, 2) + psi = GraphState(g) + assert len(psi.node) == 100 + + psi = GraphState(nx.Graph(((0, 1),))) +