Decoding Quantum Hypergraph Product Codes

[ ]:
import qecstruct as qc

from mdopt.optimiser.dephasing_dmrg import DephasingDMRG
from mdopt.mps.utils import (
    create_custom_product_state,
    create_simple_product_state,
    inner_product,
    marginalise,
)
from mdopt.optimiser.utils import (
    XOR_BULK,
    XOR_LEFT,
    XOR_RIGHT,
    COPY_LEFT,
    SWAP,
)
from examples.decoding.decoding import (
    css_code_constraint_sites,
    css_code_logicals_sites,
    apply_bitflip_bias,
    apply_constraints,
)
[2]:
SEED = 123
NUM_BITS, NUM_CHECKS = 10, 6
CHECK_DEGREE, BIT_DEGREE = 5, 3
if NUM_BITS / NUM_CHECKS != CHECK_DEGREE / BIT_DEGREE:
    raise ValueError("The Tanner graph of the code must be bipartite.")
code = qc.random_regular_code(
    NUM_BITS, NUM_CHECKS, BIT_DEGREE, CHECK_DEGREE, qc.Rng(SEED)
)
hgpc = qc.hypergraph_product(code, code)
num_sites = 2 * hgpc.length() + hgpc.num_x_logicals() + hgpc.num_z_logicals()
num_logicals = hgpc.num_x_logicals() + hgpc.num_z_logicals()
error = "0" * num_sites
string_state = "+" * num_logicals + error
error_mps = create_custom_product_state(string=string_state, form="Right-canonical")
constraints_tensors = [XOR_LEFT, XOR_BULK, SWAP, XOR_RIGHT]
logicals_tensors = [COPY_LEFT, XOR_BULK, SWAP, XOR_RIGHT]
chi_max = 64
[3]:
constraints_sites = css_code_constraint_sites(hgpc)
logicals_sites = css_code_logicals_sites(hgpc)
[4]:
renormalise = True
result_to_explicit = False
sites_to_bias = list(range(num_logicals, num_sites))
error_mps = apply_bitflip_bias(
    mps=error_mps,
    sites_to_bias=sites_to_bias,
    renormalise=renormalise,
    result_to_explicit=result_to_explicit,
)
[5]:
error_mps = apply_constraints(
    error_mps,
    constraints_sites[0],
    constraints_tensors,
    chi_max=chi_max,
    renormalise=renormalise,
    result_to_explicit=result_to_explicit,
)
error_mps = apply_constraints(
    error_mps,
    constraints_sites[1],
    constraints_tensors,
    chi_max=chi_max,
    renormalise=renormalise,
    result_to_explicit=result_to_explicit,
)
error_mps = apply_constraints(
    error_mps,
    logicals_sites,
    logicals_tensors,
    chi_max=chi_max,
    renormalise=renormalise,
    result_to_explicit=result_to_explicit,
)
100%|██████████| 60/60 [00:43<00:00,  1.36it/s]
100%|██████████| 60/60 [01:13<00:00,  1.22s/it]
100%|██████████| 2/2 [00:01<00:00,  1.73it/s]
[6]:
sites_to_marginalise = list(range(num_logicals, len(error)))
[7]:
logicals = marginalise(mps=error_mps, sites_to_marginalise=sites_to_marginalise)
[8]:
num_dmrg_sites = len(logicals)
mps_dmrg_start = create_simple_product_state(num_dmrg_sites, which="+")
mps_dmrg_target = create_simple_product_state(num_dmrg_sites, which="0")
engine = DephasingDMRG(
    mps=mps_dmrg_start,
    mps_target=mps_dmrg_target,
    chi_max=chi_max,
    mode="LA",
    silent=False,
)
[9]:
engine.run(num_iter=1)
mps_dmrg_final = engine.mps
100%|██████████| 1/1 [00:00<00:00,  2.48it/s]
[10]:
overlap = abs(inner_product(mps_dmrg_final, mps_dmrg_target))
print(overlap)
1.0
[ ]: