Decoding Quantum Hypergraph Product Codes
[ ]:
import qecstruct as qc
from mdopt.optimiser.dephasing_dmrg import DephasingDMRG
from mdopt.mps.utils import (
create_custom_product_state,
create_simple_product_state,
inner_product,
marginalise,
)
from mdopt.optimiser.utils import (
XOR_BULK,
XOR_LEFT,
XOR_RIGHT,
COPY_LEFT,
SWAP,
)
from examples.decoding.decoding import (
css_code_constraint_sites,
css_code_logicals_sites,
apply_bitflip_bias,
apply_constraints,
)
[2]:
SEED = 123
NUM_BITS, NUM_CHECKS = 10, 6
CHECK_DEGREE, BIT_DEGREE = 5, 3
if NUM_BITS / NUM_CHECKS != CHECK_DEGREE / BIT_DEGREE:
raise ValueError("The Tanner graph of the code must be bipartite.")
code = qc.random_regular_code(
NUM_BITS, NUM_CHECKS, BIT_DEGREE, CHECK_DEGREE, qc.Rng(SEED)
)
hgpc = qc.hypergraph_product(code, code)
num_sites = 2 * hgpc.length() + hgpc.num_x_logicals() + hgpc.num_z_logicals()
num_logicals = hgpc.num_x_logicals() + hgpc.num_z_logicals()
error = "0" * num_sites
string_state = "+" * num_logicals + error
error_mps = create_custom_product_state(string=string_state, form="Right-canonical")
constraints_tensors = [XOR_LEFT, XOR_BULK, SWAP, XOR_RIGHT]
logicals_tensors = [COPY_LEFT, XOR_BULK, SWAP, XOR_RIGHT]
chi_max = 64
[3]:
constraints_sites = css_code_constraint_sites(hgpc)
logicals_sites = css_code_logicals_sites(hgpc)
[4]:
renormalise = True
result_to_explicit = False
sites_to_bias = list(range(num_logicals, num_sites))
error_mps = apply_bitflip_bias(
mps=error_mps,
sites_to_bias=sites_to_bias,
renormalise=renormalise,
result_to_explicit=result_to_explicit,
)
[5]:
error_mps = apply_constraints(
error_mps,
constraints_sites[0],
constraints_tensors,
chi_max=chi_max,
renormalise=renormalise,
result_to_explicit=result_to_explicit,
)
error_mps = apply_constraints(
error_mps,
constraints_sites[1],
constraints_tensors,
chi_max=chi_max,
renormalise=renormalise,
result_to_explicit=result_to_explicit,
)
error_mps = apply_constraints(
error_mps,
logicals_sites,
logicals_tensors,
chi_max=chi_max,
renormalise=renormalise,
result_to_explicit=result_to_explicit,
)
100%|██████████| 60/60 [00:43<00:00, 1.36it/s]
100%|██████████| 60/60 [01:13<00:00, 1.22s/it]
100%|██████████| 2/2 [00:01<00:00, 1.73it/s]
[6]:
sites_to_marginalise = list(range(num_logicals, len(error)))
[7]:
logicals = marginalise(mps=error_mps, sites_to_marginalise=sites_to_marginalise)
[8]:
num_dmrg_sites = len(logicals)
mps_dmrg_start = create_simple_product_state(num_dmrg_sites, which="+")
mps_dmrg_target = create_simple_product_state(num_dmrg_sites, which="0")
engine = DephasingDMRG(
mps=mps_dmrg_start,
mps_target=mps_dmrg_target,
chi_max=chi_max,
mode="LA",
silent=False,
)
[9]:
engine.run(num_iter=1)
mps_dmrg_final = engine.mps
100%|██████████| 1/1 [00:00<00:00, 2.48it/s]
[10]:
overlap = abs(inner_product(mps_dmrg_final, mps_dmrg_target))
print(overlap)
1.0
[ ]: