Decoding Surface Code

In this experiment, we’ll use mdopt to decode toric code. Hereafter, we assume an independent noise model as well as perfect syndrome measurements. In this example, we will mostly follow the procedure described in a similar tutorial from PyMatching with the main difference being the use of mdopt for decoding.

[31]:
import numpy as np
from tqdm import tqdm
import qecstruct as qc
import qecsim.paulitools as pt
import matplotlib.pyplot as plt
from scipy.sparse import hstack, kron, eye, csc_matrix, block_diag

from mdopt.mps.utils import marginalise, create_custom_product_state
from mdopt.contractor.contractor import mps_mpo_contract
from mdopt.optimiser.utils import (
    SWAP,
    COPY_LEFT,
    XOR_BULK,
    XOR_LEFT,
    XOR_RIGHT,
)
from examples.decoding.decoding import (
    apply_constraints,
    apply_bitflip_bias,
)
from examples.decoding.decoding import (
    pauli_to_mps,
    toric_code_x_checks,
    toric_code_x_logicals,
)
from examples.decoding.decoding import (
    css_code_checks,
    css_code_logicals,
    css_code_logicals_sites,
    css_code_constraint_sites,
)
[32]:
code_x = qc.hypergraph_product(rep_code, rep_code) # H_x, CSS code
[33]:
code_x.x_stabs_binary()
[33]:
[0, 1, 16]
[1, 2, 17]
[2, 3, 18]
[4, 5, 16, 19]
[5, 6, 17, 20]
[6, 7, 18, 21]
[8, 9, 19, 22]
[9, 10, 20, 23]
[10, 11, 21, 24]
[12, 13, 22]
[13, 14, 23]
[14, 15, 24]
[34]:
code_x.x_logicals_binary()
[34]:
[3, 7, 11, 15]
[35]:
SEED = 123

LATTICE_SIZE = 4

rep_code = qc.repetition_code(LATTICE_SIZE)
hgp_code = qc.hypergraph_product(rep_code, rep_code)

# parity check matrix from qecsim -> qecstruct linear code -> qecstruct css code -> same machinery as in shor

#num_sites = 2 * toric_code.length() + toric_code.num_x_logicals() + toric_code.num_z_logicals()
#num_logicals = toric_code.num_x_logicals() + toric_code.num_z_logicals()
[36]:
hgp_code
[36]:
X stabilizers:
[0, 1, 16]
[1, 2, 17]
[2, 3, 18]
[4, 5, 16, 19]
[5, 6, 17, 20]
[6, 7, 18, 21]
[8, 9, 19, 22]
[9, 10, 20, 23]
[10, 11, 21, 24]
[12, 13, 22]
[13, 14, 23]
[14, 15, 24]
Z stabilizers:
[0, 4, 16]
[1, 5, 16, 17]
[2, 6, 17, 18]
[3, 7, 18]
[4, 8, 19]
[5, 9, 19, 20]
[6, 10, 20, 21]
[7, 11, 21]
[8, 12, 22]
[9, 13, 22, 23]
[10, 14, 23, 24]
[11, 15, 24]
[37]:
def create_surface_code_matrices_dense(L):
    # Number of qubits
    num_qubits = 2 * L * (L - 1)

    # Initialize dense matrices for H_x and H_z
    H_x = np.zeros((L * (L - 1), num_qubits), dtype=int)
    H_z = np.zeros((L * (L - 1), num_qubits), dtype=int)

    # Define H_z (star checks)
    for y in range(L - 1):
        for x in range(L - 1):
            # Each vertex affects the qubits right and down
            right = y * L + x  # Right horizontal edge
            down = L * (L - 1) + y * (L - 1) + x  # Down vertical edge
            row_idx = y * (L - 1) + x
            H_z[row_idx, right] = 1
            H_z[row_idx, down] = 1

    # Define H_x (plaquette checks)
    for y in range(L - 1):
        for x in range(L - 1):
            # Each plaquette affects four qubits: top, left, bottom, right
            top = y * L + x  # Top horizontal edge
            left = L * (L - 1) + y * (L - 1) + x  # Left vertical edge
            bottom = top + L  # Bottom horizontal edge
            right = left + 1  # Right vertical edge
            row_idx = y * (L - 1) + x
            H_x[row_idx, top] = 1
            H_x[row_idx, left] = 1
            H_x[row_idx, bottom] = 1
            H_x[row_idx, right] = 1

    return H_x, H_z

L = 4
H_x_dense, H_z_dense = create_surface_code_matrices_dense(L)
[38]:
def create_surface_code_matrices_binary(L):
    # Number of qubits
    num_qubits = 2 * L * (L - 1)

    # Lists to store the row data for H_x and H_z
    rows_H_x = []
    rows_H_z = []

    # Define H_z (star checks)
    for y in range(L - 1):
        for x in range(L - 1):
            # Each vertex affects the qubits right and down
            right = y * L + x  # Right horizontal edge
            down = L * (L - 1) + y * (L - 1) + x  # Down vertical edge
            # Ensure the indices are sorted
            rows_H_z.append(sorted([right, down]))

    # Define H_x (plaquette checks)
    for y in range(L - 1):
        for x in range(L - 1):
            # Each plaquette affects four qubits: top, left, bottom, right
            top = y * L + x  # Top horizontal edge
            left = L * (L - 1) + y * (L - 1) + x  # Left vertical edge
            bottom = top + L  # Bottom horizontal edge
            right = left + 1  # Right vertical edge
            # Ensure the indices are sorted
            rows_H_x.append(sorted([top, left, bottom, right]))

    # Create BinaryMatrix instances
    H_x = qc.BinaryMatrix(num_columns=num_qubits, rows=rows_H_x)
    H_z = qc.BinaryMatrix(num_columns=num_qubits, rows=rows_H_z)

    return H_x, H_z

H_x, H_z = create_surface_code_matrices_binary(L)
[39]:
x_code = qc.LinearCode(H_x)
z_code = qc.LinearCode(H_z)
from examples.decoding.decoding import linear_code_parity_matrix_dense
[40]:
def check_orthogonality(dense_H_x, dense_H_z):

    # Calculate the dot product and check for orthogonality
    result = np.dot(dense_H_x, dense_H_z.T) % 2
    non_orthogonal_pairs = np.where(result != 0)
    if non_orthogonal_pairs[0].size > 0:
        return False, non_orthogonal_pairs
    return True, None

# Perform the check
orthogonal, non_orthogonal_pairs = check_orthogonality(H_x_dense, H_z_dense)
if not orthogonal:
    print("Non-orthogonal row pairs (H_x row index, H_z row index):", non_orthogonal_pairs)
else:
    print("The codes are orthogonal.")

Non-orthogonal row pairs (H_x row index, H_z row index): (array([0, 0, 1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 6, 6, 7, 7, 8]), array([1, 3, 2, 4, 3, 5, 4, 6, 5, 7, 6, 8, 0, 7, 1, 8, 2]))
[41]:
surface_code = qc.CssCode(
    x_code=x_code,
    z_code=z_code,
)

print(len(surface_code))
---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
Cell In[41], line 1
----> 1 surface_code = qc.CssCode(
      2     x_code=x_code,
      3     z_code=z_code,
      4 )
      6 print(len(surface_code))

ValueError: codes are not orthogonal
[ ]:
lattice_size = 4
num_logicals = 2
num_sites = 2 * (lattice_size ** 2) + num_logicals
num_sites
[ ]:

[ ]:
toric_code_x_checks(lattice_size=LATTICE_SIZE)
[[0, 4, 16, 19],
 [1, 5, 16, 17],
 [2, 6, 17, 18],
 [3, 7, 18, 19],
 [4, 8, 20, 23],
 [5, 9, 20, 21],
 [6, 10, 21, 22],
 [7, 11, 22, 23],
 [8, 12, 24, 27],
 [9, 13, 24, 25],
 [10, 14, 25, 26],
 [11, 15, 26, 27],
 [0, 12, 28, 31],
 [1, 13, 28, 29],
 [2, 14, 29, 30],
 [3, 15, 30, 31]]
[ ]:
toric_code_x_logicals(lattice_size=LATTICE_SIZE)
[[0, 1, 2, 3], [16, 20, 24, 28]]
[ ]:

[ ]:

[ ]:

[42]:
SEED = 123
LATTICE_SIZE = 4
rep_code = qc.repetition_code(LATTICE_SIZE)
code = qc.hypergraph_product(rep_code, rep_code)
code
[42]:
X stabilizers:
[0, 1, 16]
[1, 2, 17]
[2, 3, 18]
[4, 5, 16, 19]
[5, 6, 17, 20]
[6, 7, 18, 21]
[8, 9, 19, 22]
[9, 10, 20, 23]
[10, 11, 21, 24]
[12, 13, 22]
[13, 14, 23]
[14, 15, 24]
Z stabilizers:
[0, 4, 16]
[1, 5, 16, 17]
[2, 6, 17, 18]
[3, 7, 18]
[4, 8, 19]
[5, 9, 19, 20]
[6, 10, 20, 21]
[7, 11, 21]
[8, 12, 22]
[9, 13, 22, 23]
[10, 14, 23, 24]
[11, 15, 24]
[43]:
num_logicals = code.num_x_logicals() + code.num_z_logicals()
num_sites = 2 * len(code) + num_logicals
[44]:
error_state = "0" * (num_sites - num_logicals)
logicals_state = "+" * num_logicals
state_string = logicals_state + error_state
error_mps = create_custom_product_state(string=state_string)
[45]:
checks_x, checks_z = css_code_checks(code)
print("X checks:")
for check in checks_x:
    print(check)
print("Z checks:")
for check in checks_z:
    print(check)
X checks:
[2, 4, 34]
[4, 6, 36]
[6, 8, 38]
[10, 12, 34, 40]
[12, 14, 36, 42]
[14, 16, 38, 44]
[18, 20, 40, 46]
[20, 22, 42, 48]
[22, 24, 44, 50]
[26, 28, 46]
[28, 30, 48]
[30, 32, 50]
Z checks:
[3, 11, 35]
[5, 13, 35, 37]
[7, 15, 37, 39]
[9, 17, 39]
[11, 19, 41]
[13, 21, 41, 43]
[15, 23, 43, 45]
[17, 25, 45]
[19, 27, 47]
[21, 29, 47, 49]
[23, 31, 49, 51]
[25, 33, 51]
[46]:
constraints_tensors = [XOR_LEFT, XOR_BULK, SWAP, XOR_RIGHT]
logicals_tensors = [COPY_LEFT, XOR_BULK, SWAP, XOR_RIGHT]
[47]:
constraints_sites = css_code_constraint_sites(code)
print("Full X-check lists of sites:")
for string in constraints_sites[0]:
    print(string)
print("Full Z-check lists of sites:")
for string in constraints_sites[1]:
    print(string)
Full X-check lists of sites:
[[2], [4], [3, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33], [34]]
[[4], [6], [5, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35], [36]]
[[6], [8], [7, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37], [38]]
[[10], [12, 34], [11, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 35, 36, 37, 38, 39], [40]]
[[12], [14, 36], [13, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 37, 38, 39, 40, 41], [42]]
[[14], [16, 38], [15, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 39, 40, 41, 42, 43], [44]]
[[18], [20, 40], [19, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 41, 42, 43, 44, 45], [46]]
[[20], [22, 42], [21, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 43, 44, 45, 46, 47], [48]]
[[22], [24, 44], [23, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 45, 46, 47, 48, 49], [50]]
[[26], [28], [27, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45], [46]]
[[28], [30], [29, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47], [48]]
[[30], [32], [31, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49], [50]]
Full Z-check lists of sites:
[[3], [11], [4, 5, 6, 7, 8, 9, 10, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34], [35]]
[[5], [13, 35], [6, 7, 8, 9, 10, 11, 12, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 36], [37]]
[[7], [15, 37], [8, 9, 10, 11, 12, 13, 14, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 38], [39]]
[[9], [17], [10, 11, 12, 13, 14, 15, 16, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38], [39]]
[[11], [19], [12, 13, 14, 15, 16, 17, 18, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40], [41]]
[[13], [21, 41], [14, 15, 16, 17, 18, 19, 20, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 42], [43]]
[[15], [23, 43], [16, 17, 18, 19, 20, 21, 22, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 44], [45]]
[[17], [25], [18, 19, 20, 21, 22, 23, 24, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44], [45]]
[[19], [27], [20, 21, 22, 23, 24, 25, 26, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46], [47]]
[[21], [29, 47], [22, 23, 24, 25, 26, 27, 28, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 48], [49]]
[[23], [31, 49], [24, 25, 26, 27, 28, 29, 30, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 50], [51]]
[[25], [33], [26, 27, 28, 29, 30, 31, 32, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50], [51]]
[48]:
print(code.x_logicals_binary())
print(code.z_logicals_binary())
[3, 7, 11, 15]

[4, 5, 6, 7]

[49]:
logicals_sites = css_code_logicals_sites(code)
print(css_code_logicals_sites(code)[0])
print(css_code_logicals_sites(code)[1])
[[0], [8, 16, 24], [1, 2, 3, 4, 5, 6, 7, 9, 10, 11, 12, 13, 14, 15, 17, 18, 19, 20, 21, 22, 23, 25, 26, 27, 28, 29, 30, 31], [32]]
[[1], [11, 13, 15], [2, 3, 4, 5, 6, 7, 8, 9, 10, 12, 14, 16], [17]]
[50]:
renormalise = True
result_to_explicit = False
sites_to_bias = list(range(num_logicals, num_sites))
error_mps = apply_bitflip_bias(
    mps=error_mps,
    sites_to_bias=sites_to_bias,
    renormalise=renormalise,
)
[51]:
error_mps = apply_constraints(
    error_mps,
    constraints_sites[0],
    constraints_tensors,
    chi_max=1024,
    renormalise=renormalise,
    result_to_explicit=result_to_explicit,
    strategy="Optimized",
)
100%|██████████| 12/12 [01:14<00:00,  6.19s/it]
[52]:
error_mps = apply_constraints(
    error_mps,
    constraints_sites[1],
    constraints_tensors,
    chi_max=1024,
    renormalise=renormalise,
    result_to_explicit=result_to_explicit,
    strategy="Optimized",
)
100%|██████████| 12/12 [29:39<00:00, 148.29s/it]
[53]:
error_mps = apply_constraints(
    error_mps,
    logicals_sites,
    logicals_tensors,
    chi_max=1024,
    renormalise=renormalise,
    result_to_explicit=result_to_explicit,
    strategy="Optimized",
)
100%|██████████| 2/2 [03:17<00:00, 98.82s/it]
[55]:
sites_to_marginalise = list(range(num_logicals, len(error_state) + num_logicals))
logical = marginalise(mps=error_mps, sites_to_marginalise=sites_to_marginalise).dense(
    flatten=True, renormalise=True, norm=1
)
print(logical)
[0.83427324 0.1071154  0.0550003  0.00361105]
[ ]: