diff --git a/lib.py b/lib.py index be4862e..6decf7f 100644 --- a/lib.py +++ b/lib.py @@ -602,8 +602,8 @@ class HermitianMatrix(SquareMatrix): class UnitaryOperator(LinearTransformation, UnitaryMatrix): - def __init__(self, m: ListOrNdarray, name: str = '', partials=None, *args, - **kwargs): + def __init__(self, m: ListOrNdarray, name: str = '', partials=None, + decomposition=None, *args, **kwargs): """UnitaryOperator inherits from both LinearTransformation and a Unitary matrix It is used to act on a State vector by defining the operator to be @@ -614,13 +614,15 @@ class UnitaryOperator(LinearTransformation, UnitaryMatrix): the list is the Nth partial that is used, i.e. for the first - |0><0|, for the second - |1><1| """ - if np.shape(m) != (2, 2) and partials is None: - raise Exception("Please define partials in a non-single operator") + if np.shape(m) != (2, 2) and partials is None and decomposition is None: + raise Exception("Please define partials or decomposition in " + "a non-single operator") UnitaryMatrix.__init__(self, m=m, *args, **kwargs) LinearTransformation.__init__(self, m=m, func=self.operator_func, *args, **kwargs) self.name = name self.partials = partials + self.decomposition = decomposition def operator_func(self, other, which_qbit=None): this_cols, other_rows = np.shape(self.m)[1], np.shape(other.m)[0] @@ -631,6 +633,8 @@ class UnitaryOperator(LinearTransformation, UnitaryMatrix): "Please specify which_qubit to operate on".format( this_cols, other_rows)) total_qbits = int(np.log2(other_rows)) + if type(which_qbit) is list and len(which_qbit) == 1: + which_qbit = which_qbit[0] if type(which_qbit) is int: # single qubit-gate assert this_cols == 2 @@ -639,39 +643,40 @@ class UnitaryOperator(LinearTransformation, UnitaryMatrix): range(total_qbits)] new_m = np.prod(extended_m).m elif type(which_qbit) is list: - # single or multiple qubit-gate - assert 1 <= len(which_qbit) <= total_qbits - assert len(which_qbit) == len(self.partials) - assert all([q < total_qbits for q in which_qbit]) - this_gate_len = 2 ** (len(self.partials) - 1) - bins = generate_bins(this_gate_len) - extended_m, next_control = [[] for _ in bins], 0 - partial_mapping = dict(zip(which_qbit, self.partials)) - for qbit in range(total_qbits): - if qbit not in partial_mapping: - for i in range(this_gate_len): - extended_m[i].append(I) - else: - this_partial = partial_mapping[qbit] - # TODO: This works only with C_partial :((( - if this_partial == C_partial: + if self.decomposition: + return self.decomposition(other, *which_qbit) + else: + # single or multiple qubit-gate + assert 1 <= len(which_qbit) <= total_qbits + assert len(which_qbit) == len(self.partials) + assert all([q < total_qbits for q in which_qbit]) + this_gate_len = 2 ** (len(self.partials) - 1) + bins = generate_bins(this_gate_len) + extended_m, next_control = [[] for _ in bins], 0 + partial_mapping = dict(zip(which_qbit, self.partials)) + for qbit in range(total_qbits): + if qbit not in partial_mapping: for i in range(this_gate_len): - bin_dig = bins[i][next_control] - extended_m[i].append( - s("|" + bin_dig + "><" + bin_dig + "|")) - next_control += 1 - else: - for i in range(this_gate_len - 1): extended_m[i].append(I) - extended_m[-1].append(this_partial) - new_m = sum([np.prod(e).m for e in extended_m]) + else: + this_partial = partial_mapping[qbit] + # TODO: This works only with C_partial :((( + if this_partial == C_partial: + for i in range(this_gate_len): + bin_dig = bins[i][next_control] + extended_m[i].append( + s("|" + bin_dig + "><" + bin_dig + "|")) + next_control += 1 + else: + for i in range(this_gate_len - 1): + extended_m[i].append(I) + extended_m[-1].append(this_partial) + new_m = sum([np.prod(e).m for e in extended_m]) else: raise Exception( "which_qubit needs to be either an int of N-th qubit or list") - extended_op = UnitaryOperator(new_m, name=self.name, - partials=self.partials) - return State(np.dot(extended_op.m, other.m)) + return State(np.dot(new_m, other.m)) def __repr__(self): if self.name: @@ -836,8 +841,7 @@ T = lambda phi: Gate([[1, 0], # # Decomposed CNOT : # reverse engineered from -# https://quantumcomputing.stackexchange.com/questions/4252/how-to-derive-the -# -cnot-matrix-for-a-3-qbit-system-where-the-control-target-qbi +# https://quantumcomputing.stackexchange.com/questions/4252/how-to-derive-the-cnot-matrix-for-a-3-qbit-system-where-the-control-target-qbi # # CNOT(q1, I, q2): # |0><0| x I_2 x I_2 + |1><1| x I_2 x X @@ -874,13 +878,17 @@ CZ = Gate([ [0, 0, 0, -1], ], name="CZ", partials=[C_partial, z_partial]) -# TODO: These are not the correct partials +# Can't do partials for SWAP, but can be decomposed +# SWAP is decomposed into three CNOTs where the second one is flipped SWAP = Gate([ [1, 0, 0, 0], [0, 0, 1, 0], [0, 1, 0, 0], [0, 0, 0, 1], -], name="SWAP", partials=[C_partial, C_partial]) +], name="SWAP", + decomposition=lambda state, x, y: + CNOT.on(CNOT.on(CNOT.on(state, [x, y]), [y, x]), [x, y]) +) TOFF = Gate([ @@ -1012,6 +1020,10 @@ def test_eigenstuff(): def test_partials(): + # test single qbit state + assert X.on(s("|000>"), which_qbit=1) == s("|010>") + assert X.on(s("|000>"), which_qbit=[1]) == s("|010>") + # normal 2 qbit state assert CNOT.on(s("|00>")) == s("|00>") assert CNOT.on(s("|10>")) == s("|11>") @@ -1024,9 +1036,12 @@ def test_partials(): assert CNOT.on(s("|01>"), which_qbit=[1, 0]) == s("|11>") assert CNOT.on(s("|001>"), which_qbit=[2, 0]) == s("|101>") - # Test SWAP via 3 successive CNOTs, the second is flipped + # Test SWAP via 3 successive CNOTs, the second CNOT is flipped assert CNOT.on(CNOT.on(CNOT.on(s("|10>")), which_qbit=[1, 0])) == s("|01>") + # Test SWAP via 3 successive CNOTs, the 0<->2 are swapped + assert CNOT.on(CNOT.on(CNOT.on(s("|100>"), [0, 2]), [2, 0]), [0, 2]) == s("|001>") + # apply on 0, 1 of 3qbit state assert CNOT.on(s("|000>"), which_qbit=[0, 1]) == s("|000>") assert CNOT.on(s("|100>"), which_qbit=[0, 1]) == s("|110>") @@ -1048,8 +1063,8 @@ def test_partials(): # test SWAP gate assert SWAP.on(s("|10>")) == s("|01>") assert SWAP.on(s("|01>")) == s("|10>") - # TODO: - # assert SWAP.on(s("|001>"), which_qbit=[0, 2]) == s("|100>") + # Tests SWAP on far-away gates + assert SWAP.on(s("|001>"), which_qbit=[0, 2]) == s("|100>") # test Toffoli gate assert TOFF.on(s("|000>")) == s("|000>")