measurement returns a state

This commit is contained in:
Daniel Tsvetkov 2020-03-30 15:51:17 +02:00
parent 908efa9b76
commit 93540bfcd3
2 changed files with 33 additions and 43 deletions

View File

@ -70,14 +70,14 @@ def test_krisi_measurement_3():
b_0 = State((1 / np.sqrt(2)) * (s("|100>") + s("|010>"))) b_0 = State((1 / np.sqrt(2)) * (s("|100>") + s("|010>")))
b_1 = State((1 / np.sqrt(2)) * (s("|011>") + s("|101>"))) b_1 = State((1 / np.sqrt(2)) * (s("|011>") + s("|101>")))
basis = [ basis = [
s("|000>"), s("|000>", name="B0"),
State(np.cos(beta) * b_0 + np.sin(beta) * s("|001>")), State(np.cos(beta) * b_0 + np.sin(beta) * s("|001>"), name="B1"),
State(-np.sin(beta) * b_0 + np.cos(beta) * s("|001>")), State(-np.sin(beta) * b_0 + np.cos(beta) * s("|001>"), name="B2"),
State((1 / np.sqrt(2)) * (s("|100>") - s("|010>"))), State((1 / np.sqrt(2)) * (s("|100>") - s("|010>")), name="B3"),
State((1 / np.sqrt(2)) * (s("|011>") - s("|101>"))), State((1 / np.sqrt(2)) * (s("|011>") - s("|101>")), name="B4"),
State(np.cos(beta) * b_1 + np.sin(beta) * s("|001>")), State(np.cos(beta) * b_1 + np.sin(beta) * s("|001>"), name="B5"),
State(-np.sin(beta) * b_1 + np.cos(beta) * s("|001>")), State(-np.sin(beta) * b_1 + np.cos(beta) * s("|001>"), name="B6"),
s("|111>"), s("|111>", name="B7"),
] ]
def perform_exp(case): def perform_exp(case):
@ -149,12 +149,12 @@ def test_krisi_measurement_3():
return meas return meas
def krisi_3_format_results(results): def krisi_3_format_results(results):
all_pos = generate_bins(8) all_pos = basis
print("Results:") print("Results:")
for case in case_choices: for case in case_choices:
rv = results.get(case) rv = results.get(case)
print("{}:".format(case)) print("{}:".format(case))
print(" raw : {}".format(sorted(rv.items()))) print(" raw : {}".format(rv.items()))
case_total = sum(rv.values()) case_total = sum(rv.values())
print(" total: {}".format(case_total)) print(" total: {}".format(case_total))
for pos in all_pos: for pos in all_pos:

52
lib.py
View File

@ -177,10 +177,15 @@ class State(Vector):
self.name = name self.name = name
self.measurement_result = None self.measurement_result = None
self.last_basis = None self.last_basis = None
# TODO: SHOULD WE NORMALIZE?
if not allow_unnormalized and not self._is_normalized(): if not allow_unnormalized and not self._is_normalized():
raise TypeError("Not a normalized state vector") raise TypeError("Not a normalized state vector")
def __hash__(self):
return hash(str(self.m))
def __le__(self, other):
return repr(self) < repr(other)
@staticmethod @staticmethod
def normalize(vector: Vector): def normalize(vector: Vector):
"""Normalize a state by dividing by the square root of sum of the """Normalize a state by dividing by the square root of sum of the
@ -205,9 +210,7 @@ class State(Vector):
def to_bloch_angles(self): def to_bloch_angles(self):
"""Returns the angles of this state on the Bloch sphere""" """Returns the angles of this state on the Bloch sphere"""
if not self.m.shape == (2, 1): if not self.m.shape == (2, 1):
raise Exception( raise Exception("State needs to describe only 1 qubit on the bloch sphere (2x1 matrix)")
"State needs to describe only 1 qubit on the bloch sphere ("
"2x1 matrix)")
m0, m1 = self.m[0][0], self.m[1][0] m0, m1 = self.m[0][0], self.m[1][0]
# theta is between 0 and pi # theta is between 0 and pi
@ -218,8 +221,7 @@ class State(Vector):
div = np.sin(theta / 2) div = np.sin(theta / 2)
if div == 0: if div == 0:
# here is doesn't matter what phi is as phase at the poles is # here is doesn't matter what phi is as phase at the poles is arbitrary
# arbitrary
phi = 0 phi = 0
else: else:
exp = m1 / div exp = m1 / div
@ -265,9 +267,6 @@ class State(Vector):
str(len(UNIVERSE_STATES))]) str(len(UNIVERSE_STATES))])
self.name = '{}{}'.format(REPR_GREEK_PSI, next_state_sub) self.name = '{}{}'.format(REPR_GREEK_PSI, next_state_sub)
UNIVERSE_STATES.append(self) UNIVERSE_STATES.append(self)
# matrix_rep = "{}".format(self.m).replace('[', '').replace(']',
# '').replace('\n', '|').strip()
# state_name = '|{}> = {}'.format(self.name, matrix_rep)
state_name = '|{}>'.format(self.name) state_name = '|{}>'.format(self.name)
return state_name return state_name
@ -358,11 +357,8 @@ class State(Vector):
weights = list(np.array(weights) / weights_sum) weights = list(np.array(weights) / weights_sum)
assert np.isclose(np.sum(weights), 1.0) assert np.isclose(np.sum(weights), 1.0)
format_str = self.get_fmt_of_element()
choices = empty_choices + [format_str.format(i) for i in
range(len(weights))]
weights = empty_weights + weights weights = empty_weights + weights
self.measurement_result = random.choices(choices, weights)[0] self.measurement_result = random.choices(empty_choices + basis, weights)[0]
self.last_basis = basis self.last_basis = basis
return self.measurement_result return self.measurement_result
@ -382,10 +378,7 @@ class State(Vector):
""" """
max_qubits = int(np.log2(len(self))) max_qubits = int(np.log2(len(self)))
if not (0 < qubit_n <= max_qubits): if not (0 < qubit_n <= max_qubits):
raise Exception( raise Exception("Partial measurement of qubit_n must be between 1 and {}".format(max_qubits))
"Partial measurement of qubit_n must be between 1 and {"
"}".format(
max_qubits))
format_str = self.get_fmt_of_element() format_str = self.get_fmt_of_element()
# e.g. for state |000>: # e.g. for state |000>:
# ['000', '001', '010', '011', '100', '101', '110', '111'] # ['000', '001', '010', '011', '100', '101', '110', '111']
@ -396,17 +389,14 @@ class State(Vector):
# [0, 1, 4, 5] # [0, 1, 4, 5]
weights, choices = defaultdict(list), defaultdict(list) weights, choices = defaultdict(list), defaultdict(list)
for result in [1, 0]: for result in [1, 0]:
indexes_for_p_0 = [i for i, index in indexes_for_p_0 = [i for i, index in enumerate(partial_measurement_of_qbit) if index == result]
enumerate(partial_measurement_of_qbit) if
index == result]
weights[result] = [self.get_prob(j) for j in indexes_for_p_0] weights[result] = [self.get_prob(j) for j in indexes_for_p_0]
choices[result] = [format_str.format(i) for i in indexes_for_p_0] choices[result] = [format_str.format(i) for i in indexes_for_p_0]
weights_01 = [sum(weights[0]), sum(weights[1])] weights_01 = [sum(weights[0]), sum(weights[1])]
measurement_result = random.choices([0, 1], weights_01)[0] measurement_result = random.choices([0, 1], weights_01)[0]
normalization_factor = np.sqrt( normalization_factor = np.sqrt(sum([np.abs(i) ** 2 for i in weights[measurement_result]]))
sum([np.abs(i) ** 2 for i in weights[measurement_result]]))
new_m = weights[measurement_result] / normalization_factor new_m = weights[measurement_result] / normalization_factor
return str(measurement_result), State(new_m.reshape((len(new_m), 1))) return State(s('|' + str(measurement_result) + '>')), State(new_m.reshape((len(new_m), 1)))
def pretty_print(self): def pretty_print(self):
format_str = self.get_fmt_of_element() + " | {}" format_str = self.get_fmt_of_element() + " | {}"
@ -1207,12 +1197,12 @@ def test():
assert np.isclose(_p.get_prob(0), 0.5) # Probability for |+> in 0 is 0.5 assert np.isclose(_p.get_prob(0), 0.5) # Probability for |+> in 0 is 0.5
assert np.isclose(_p.get_prob(1), 0.5) # Probability for |+> in 1 is 0.5 assert np.isclose(_p.get_prob(1), 0.5) # Probability for |+> in 1 is 0.5
assert _0.measure() == '0' assert _0.measure() == _0
assert _1.measure() == '1' assert _1.measure() == _1
assert s("|10>").measure() == '10' assert s("|10>").measure() == s("|10>")
assert s("|10>").measure_partial(1) == ("1", _0) assert s("|10>").measure_partial(1) == (_1, _0)
assert s("|10>").measure_partial(2) == ("0", _1) assert s("|10>").measure_partial(2) == (_0, _1)
# measure in arbitrary basis # measure in arbitrary basis
_0.measure(basis=[_p, _m]) _0.measure(basis=[_p, _m])
@ -1220,9 +1210,9 @@ def test():
# Maximally entangled # Maximally entangled
result, pms = b_phi_p.measure_partial(1) result, pms = b_phi_p.measure_partial(1)
if result == "0": if result == _0:
assert pms == _0 assert pms == _0
elif result == "1": elif result == _1:
assert pms == _1 assert pms == _1
# Test measurement operators # Test measurement operators
@ -1502,7 +1492,7 @@ class QuantumProcessor(object):
@staticmethod @staticmethod
def print_sample(rv): def print_sample(rv):
for k, v in sorted(rv.items(), key=lambda x: x[0]): for k, v in rv.items():
print("{}: {}".format(k, v)) print("{}: {}".format(k, v))