Decoding Surface Code
In this experiment, we’ll use mdopt
to decode toric code. Hereafter, we assume an independent noise model as well as perfect syndrome measurements. In this example, we will mostly follow the procedure described in a similar tutorial from PyMatching with the main difference being the use of mdopt
for decoding.
[31]:
import numpy as np
from tqdm import tqdm
import qecstruct as qc
import qecsim.paulitools as pt
import matplotlib.pyplot as plt
from scipy.sparse import hstack, kron, eye, csc_matrix, block_diag
from mdopt.mps.utils import marginalise, create_custom_product_state
from mdopt.contractor.contractor import mps_mpo_contract
from mdopt.optimiser.utils import (
SWAP,
COPY_LEFT,
XOR_BULK,
XOR_LEFT,
XOR_RIGHT,
)
from examples.decoding.decoding import (
apply_constraints,
apply_bitflip_bias,
)
from examples.decoding.decoding import (
pauli_to_mps,
toric_code_x_checks,
toric_code_x_logicals,
)
from examples.decoding.decoding import (
css_code_checks,
css_code_logicals,
css_code_logicals_sites,
css_code_constraint_sites,
)
[32]:
code_x = qc.hypergraph_product(rep_code, rep_code) # H_x, CSS code
[33]:
code_x.x_stabs_binary()
[33]:
[0, 1, 16]
[1, 2, 17]
[2, 3, 18]
[4, 5, 16, 19]
[5, 6, 17, 20]
[6, 7, 18, 21]
[8, 9, 19, 22]
[9, 10, 20, 23]
[10, 11, 21, 24]
[12, 13, 22]
[13, 14, 23]
[14, 15, 24]
[34]:
code_x.x_logicals_binary()
[34]:
[3, 7, 11, 15]
[35]:
SEED = 123
LATTICE_SIZE = 4
rep_code = qc.repetition_code(LATTICE_SIZE)
hgp_code = qc.hypergraph_product(rep_code, rep_code)
# parity check matrix from qecsim -> qecstruct linear code -> qecstruct css code -> same machinery as in shor
#num_sites = 2 * toric_code.length() + toric_code.num_x_logicals() + toric_code.num_z_logicals()
#num_logicals = toric_code.num_x_logicals() + toric_code.num_z_logicals()
[36]:
hgp_code
[36]:
X stabilizers:
[0, 1, 16]
[1, 2, 17]
[2, 3, 18]
[4, 5, 16, 19]
[5, 6, 17, 20]
[6, 7, 18, 21]
[8, 9, 19, 22]
[9, 10, 20, 23]
[10, 11, 21, 24]
[12, 13, 22]
[13, 14, 23]
[14, 15, 24]
Z stabilizers:
[0, 4, 16]
[1, 5, 16, 17]
[2, 6, 17, 18]
[3, 7, 18]
[4, 8, 19]
[5, 9, 19, 20]
[6, 10, 20, 21]
[7, 11, 21]
[8, 12, 22]
[9, 13, 22, 23]
[10, 14, 23, 24]
[11, 15, 24]
[37]:
def create_surface_code_matrices_dense(L):
# Number of qubits
num_qubits = 2 * L * (L - 1)
# Initialize dense matrices for H_x and H_z
H_x = np.zeros((L * (L - 1), num_qubits), dtype=int)
H_z = np.zeros((L * (L - 1), num_qubits), dtype=int)
# Define H_z (star checks)
for y in range(L - 1):
for x in range(L - 1):
# Each vertex affects the qubits right and down
right = y * L + x # Right horizontal edge
down = L * (L - 1) + y * (L - 1) + x # Down vertical edge
row_idx = y * (L - 1) + x
H_z[row_idx, right] = 1
H_z[row_idx, down] = 1
# Define H_x (plaquette checks)
for y in range(L - 1):
for x in range(L - 1):
# Each plaquette affects four qubits: top, left, bottom, right
top = y * L + x # Top horizontal edge
left = L * (L - 1) + y * (L - 1) + x # Left vertical edge
bottom = top + L # Bottom horizontal edge
right = left + 1 # Right vertical edge
row_idx = y * (L - 1) + x
H_x[row_idx, top] = 1
H_x[row_idx, left] = 1
H_x[row_idx, bottom] = 1
H_x[row_idx, right] = 1
return H_x, H_z
L = 4
H_x_dense, H_z_dense = create_surface_code_matrices_dense(L)
[38]:
def create_surface_code_matrices_binary(L):
# Number of qubits
num_qubits = 2 * L * (L - 1)
# Lists to store the row data for H_x and H_z
rows_H_x = []
rows_H_z = []
# Define H_z (star checks)
for y in range(L - 1):
for x in range(L - 1):
# Each vertex affects the qubits right and down
right = y * L + x # Right horizontal edge
down = L * (L - 1) + y * (L - 1) + x # Down vertical edge
# Ensure the indices are sorted
rows_H_z.append(sorted([right, down]))
# Define H_x (plaquette checks)
for y in range(L - 1):
for x in range(L - 1):
# Each plaquette affects four qubits: top, left, bottom, right
top = y * L + x # Top horizontal edge
left = L * (L - 1) + y * (L - 1) + x # Left vertical edge
bottom = top + L # Bottom horizontal edge
right = left + 1 # Right vertical edge
# Ensure the indices are sorted
rows_H_x.append(sorted([top, left, bottom, right]))
# Create BinaryMatrix instances
H_x = qc.BinaryMatrix(num_columns=num_qubits, rows=rows_H_x)
H_z = qc.BinaryMatrix(num_columns=num_qubits, rows=rows_H_z)
return H_x, H_z
H_x, H_z = create_surface_code_matrices_binary(L)
[39]:
x_code = qc.LinearCode(H_x)
z_code = qc.LinearCode(H_z)
from examples.decoding.decoding import linear_code_parity_matrix_dense
[40]:
def check_orthogonality(dense_H_x, dense_H_z):
# Calculate the dot product and check for orthogonality
result = np.dot(dense_H_x, dense_H_z.T) % 2
non_orthogonal_pairs = np.where(result != 0)
if non_orthogonal_pairs[0].size > 0:
return False, non_orthogonal_pairs
return True, None
# Perform the check
orthogonal, non_orthogonal_pairs = check_orthogonality(H_x_dense, H_z_dense)
if not orthogonal:
print("Non-orthogonal row pairs (H_x row index, H_z row index):", non_orthogonal_pairs)
else:
print("The codes are orthogonal.")
Non-orthogonal row pairs (H_x row index, H_z row index): (array([0, 0, 1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 6, 6, 7, 7, 8]), array([1, 3, 2, 4, 3, 5, 4, 6, 5, 7, 6, 8, 0, 7, 1, 8, 2]))
[41]:
surface_code = qc.CssCode(
x_code=x_code,
z_code=z_code,
)
print(len(surface_code))
---------------------------------------------------------------------------
ValueError Traceback (most recent call last)
Cell In[41], line 1
----> 1 surface_code = qc.CssCode(
2 x_code=x_code,
3 z_code=z_code,
4 )
6 print(len(surface_code))
ValueError: codes are not orthogonal
[ ]:
lattice_size = 4
num_logicals = 2
num_sites = 2 * (lattice_size ** 2) + num_logicals
num_sites
[ ]:
[ ]:
toric_code_x_checks(lattice_size=LATTICE_SIZE)
[[0, 4, 16, 19],
[1, 5, 16, 17],
[2, 6, 17, 18],
[3, 7, 18, 19],
[4, 8, 20, 23],
[5, 9, 20, 21],
[6, 10, 21, 22],
[7, 11, 22, 23],
[8, 12, 24, 27],
[9, 13, 24, 25],
[10, 14, 25, 26],
[11, 15, 26, 27],
[0, 12, 28, 31],
[1, 13, 28, 29],
[2, 14, 29, 30],
[3, 15, 30, 31]]
[ ]:
toric_code_x_logicals(lattice_size=LATTICE_SIZE)
[[0, 1, 2, 3], [16, 20, 24, 28]]
[ ]:
[ ]:
[ ]:
[42]:
SEED = 123
LATTICE_SIZE = 4
rep_code = qc.repetition_code(LATTICE_SIZE)
code = qc.hypergraph_product(rep_code, rep_code)
code
[42]:
X stabilizers:
[0, 1, 16]
[1, 2, 17]
[2, 3, 18]
[4, 5, 16, 19]
[5, 6, 17, 20]
[6, 7, 18, 21]
[8, 9, 19, 22]
[9, 10, 20, 23]
[10, 11, 21, 24]
[12, 13, 22]
[13, 14, 23]
[14, 15, 24]
Z stabilizers:
[0, 4, 16]
[1, 5, 16, 17]
[2, 6, 17, 18]
[3, 7, 18]
[4, 8, 19]
[5, 9, 19, 20]
[6, 10, 20, 21]
[7, 11, 21]
[8, 12, 22]
[9, 13, 22, 23]
[10, 14, 23, 24]
[11, 15, 24]
[43]:
num_logicals = code.num_x_logicals() + code.num_z_logicals()
num_sites = 2 * len(code) + num_logicals
[44]:
error_state = "0" * (num_sites - num_logicals)
logicals_state = "+" * num_logicals
state_string = logicals_state + error_state
error_mps = create_custom_product_state(string=state_string)
[45]:
checks_x, checks_z = css_code_checks(code)
print("X checks:")
for check in checks_x:
print(check)
print("Z checks:")
for check in checks_z:
print(check)
X checks:
[2, 4, 34]
[4, 6, 36]
[6, 8, 38]
[10, 12, 34, 40]
[12, 14, 36, 42]
[14, 16, 38, 44]
[18, 20, 40, 46]
[20, 22, 42, 48]
[22, 24, 44, 50]
[26, 28, 46]
[28, 30, 48]
[30, 32, 50]
Z checks:
[3, 11, 35]
[5, 13, 35, 37]
[7, 15, 37, 39]
[9, 17, 39]
[11, 19, 41]
[13, 21, 41, 43]
[15, 23, 43, 45]
[17, 25, 45]
[19, 27, 47]
[21, 29, 47, 49]
[23, 31, 49, 51]
[25, 33, 51]
[46]:
constraints_tensors = [XOR_LEFT, XOR_BULK, SWAP, XOR_RIGHT]
logicals_tensors = [COPY_LEFT, XOR_BULK, SWAP, XOR_RIGHT]
[47]:
constraints_sites = css_code_constraint_sites(code)
print("Full X-check lists of sites:")
for string in constraints_sites[0]:
print(string)
print("Full Z-check lists of sites:")
for string in constraints_sites[1]:
print(string)
Full X-check lists of sites:
[[2], [4], [3, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33], [34]]
[[4], [6], [5, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35], [36]]
[[6], [8], [7, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37], [38]]
[[10], [12, 34], [11, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 35, 36, 37, 38, 39], [40]]
[[12], [14, 36], [13, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 37, 38, 39, 40, 41], [42]]
[[14], [16, 38], [15, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 39, 40, 41, 42, 43], [44]]
[[18], [20, 40], [19, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 41, 42, 43, 44, 45], [46]]
[[20], [22, 42], [21, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 43, 44, 45, 46, 47], [48]]
[[22], [24, 44], [23, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 45, 46, 47, 48, 49], [50]]
[[26], [28], [27, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45], [46]]
[[28], [30], [29, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47], [48]]
[[30], [32], [31, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49], [50]]
Full Z-check lists of sites:
[[3], [11], [4, 5, 6, 7, 8, 9, 10, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34], [35]]
[[5], [13, 35], [6, 7, 8, 9, 10, 11, 12, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 36], [37]]
[[7], [15, 37], [8, 9, 10, 11, 12, 13, 14, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 38], [39]]
[[9], [17], [10, 11, 12, 13, 14, 15, 16, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38], [39]]
[[11], [19], [12, 13, 14, 15, 16, 17, 18, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40], [41]]
[[13], [21, 41], [14, 15, 16, 17, 18, 19, 20, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 42], [43]]
[[15], [23, 43], [16, 17, 18, 19, 20, 21, 22, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 44], [45]]
[[17], [25], [18, 19, 20, 21, 22, 23, 24, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44], [45]]
[[19], [27], [20, 21, 22, 23, 24, 25, 26, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46], [47]]
[[21], [29, 47], [22, 23, 24, 25, 26, 27, 28, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 48], [49]]
[[23], [31, 49], [24, 25, 26, 27, 28, 29, 30, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 50], [51]]
[[25], [33], [26, 27, 28, 29, 30, 31, 32, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50], [51]]
[48]:
print(code.x_logicals_binary())
print(code.z_logicals_binary())
[3, 7, 11, 15]
[4, 5, 6, 7]
[49]:
logicals_sites = css_code_logicals_sites(code)
print(css_code_logicals_sites(code)[0])
print(css_code_logicals_sites(code)[1])
[[0], [8, 16, 24], [1, 2, 3, 4, 5, 6, 7, 9, 10, 11, 12, 13, 14, 15, 17, 18, 19, 20, 21, 22, 23, 25, 26, 27, 28, 29, 30, 31], [32]]
[[1], [11, 13, 15], [2, 3, 4, 5, 6, 7, 8, 9, 10, 12, 14, 16], [17]]
[50]:
renormalise = True
result_to_explicit = False
sites_to_bias = list(range(num_logicals, num_sites))
error_mps = apply_bitflip_bias(
mps=error_mps,
sites_to_bias=sites_to_bias,
renormalise=renormalise,
)
[51]:
error_mps = apply_constraints(
error_mps,
constraints_sites[0],
constraints_tensors,
chi_max=1024,
renormalise=renormalise,
result_to_explicit=result_to_explicit,
strategy="Optimized",
)
100%|██████████| 12/12 [01:14<00:00, 6.19s/it]
[52]:
error_mps = apply_constraints(
error_mps,
constraints_sites[1],
constraints_tensors,
chi_max=1024,
renormalise=renormalise,
result_to_explicit=result_to_explicit,
strategy="Optimized",
)
100%|██████████| 12/12 [29:39<00:00, 148.29s/it]
[53]:
error_mps = apply_constraints(
error_mps,
logicals_sites,
logicals_tensors,
chi_max=1024,
renormalise=renormalise,
result_to_explicit=result_to_explicit,
strategy="Optimized",
)
100%|██████████| 2/2 [03:17<00:00, 98.82s/it]
[55]:
sites_to_marginalise = list(range(num_logicals, len(error_state) + num_logicals))
logical = marginalise(mps=error_mps, sites_to_marginalise=sites_to_marginalise).dense(
flatten=True, renormalise=True, norm=1
)
print(logical)
[0.83427324 0.1071154 0.0550003 0.00361105]
[ ]: