diff --git a/lib.py b/lib.py index 8409c66..be4862e 100644 --- a/lib.py +++ b/lib.py @@ -624,7 +624,7 @@ class UnitaryOperator(LinearTransformation, UnitaryMatrix): def operator_func(self, other, which_qbit=None): this_cols, other_rows = np.shape(self.m)[1], np.shape(other.m)[0] - if this_cols == other_rows: + if this_cols == other_rows and which_qbit is None: return State(np.dot(self.m, other.m)) if which_qbit is None: raise Exception("Operating dim-{} operator on a dim-{} state. " @@ -640,29 +640,30 @@ class UnitaryOperator(LinearTransformation, UnitaryMatrix): new_m = np.prod(extended_m).m elif type(which_qbit) is list: # single or multiple qubit-gate - assert 1 <= len(which_qbit) < total_qbits + assert 1 <= len(which_qbit) <= total_qbits assert len(which_qbit) == len(self.partials) assert all([q < total_qbits for q in which_qbit]) this_gate_len = 2 ** (len(self.partials) - 1) bins = generate_bins(this_gate_len) - extended_m, next_partial = [[] for _ in bins], 0 + extended_m, next_control = [[] for _ in bins], 0 + partial_mapping = dict(zip(which_qbit, self.partials)) for qbit in range(total_qbits): - if qbit not in which_qbit: + if qbit not in partial_mapping: for i in range(this_gate_len): extended_m[i].append(I) else: - this_partial = self.partials[next_partial] + this_partial = partial_mapping[qbit] # TODO: This works only with C_partial :((( if this_partial == C_partial: for i in range(this_gate_len): - bin_dig = bins[i][next_partial] + bin_dig = bins[i][next_control] extended_m[i].append( s("|" + bin_dig + "><" + bin_dig + "|")) + next_control += 1 else: for i in range(this_gate_len - 1): extended_m[i].append(I) extended_m[-1].append(this_partial) - next_partial += 1 new_m = sum([np.prod(e).m for e in extended_m]) else: raise Exception( @@ -1019,6 +1020,13 @@ def test_partials(): "which_qubit to operate on", CNOT.on, s("|100>")) + # Test flipped CNOT + assert CNOT.on(s("|01>"), which_qbit=[1, 0]) == s("|11>") + assert CNOT.on(s("|001>"), which_qbit=[2, 0]) == s("|101>") + + # Test SWAP via 3 successive CNOTs, the second is flipped + assert CNOT.on(CNOT.on(CNOT.on(s("|10>")), which_qbit=[1, 0])) == s("|01>") + # apply on 0, 1 of 3qbit state assert CNOT.on(s("|000>"), which_qbit=[0, 1]) == s("|000>") assert CNOT.on(s("|100>"), which_qbit=[0, 1]) == s("|110>")