Example #1
0
    def test_ctmrg_AKLT_4SITE(self):
        cfg.configure(args)
        torch.set_num_threads(args.omp_cores)

        model = akltS2.AKLTS2()

        def lattice_to_site(coord):
            vx = (coord[0] + abs(coord[0]) * 2) % 2
            vy = (coord[1] + abs(coord[1]) * 2) % 2
            return (vx, vy)

        state = read_ipeps(args.instate, vertexToSite=lattice_to_site)

        def ctmrg_conv_f(state, env, history, ctm_args=cfg.ctm_args):
            with torch.no_grad():
                if not history:
                    history = dict({"log": []})
                dist = float('inf')
                list_rdm = []
                for coord, site in state.sites.items():
                    rdm2x1 = rdm.rdm2x1(coord, state, env)
                    rdm1x2 = rdm.rdm1x2(coord, state, env)
                    list_rdm.extend([rdm2x1, rdm1x2])

                if len(history["log"]) > 1:
                    dist = 0.
                    for i in range(len(list_rdm)):
                        dist += torch.dist(list_rdm[i], history["rdm"][i],
                                           p=2).item()
                history["rdm"] = list_rdm
                history["log"].append(dist)
                if dist < ctm_args.ctm_conv_tol:
                    log.info({
                        "history_length": len(history['log']),
                        "history": history['log']
                    })
                    return True, history
            return False, history

        ctm_env_init = ENV(args.chi, state)
        init_env(state, ctm_env_init)

        ctm_env_init, *ctm_log = ctmrg.run(state,
                                           ctm_env_init,
                                           conv_check=ctmrg_conv_f)

        e_curr0 = model.energy_2x1_1x2(state, ctm_env_init)
        obs_values0, obs_labels = model.eval_obs(state, ctm_env_init)
        obs_dict = dict(zip(obs_labels, obs_values0))

        eps = 1.0e-12
        self.assertTrue(e_curr0 < eps)
        for coord, site in state.sites.items():
            self.assertTrue(obs_dict[f"m{coord}"] < eps, msg=f"m{coord}")
            for l in ["sz", "sp", "sm"]:
                self.assertTrue(abs(obs_dict[f"{l}{coord}"]) < eps,
                                msg=f"{l}{coord}")
Example #2
0
    def loss_fn(state, ctm_env_in, opt_context):
        # possibly re-initialize the environment
        if cfg.opt_args.opt_ctm_reinit:
            init_env(state, ctm_env_in)

        # 1) compute environment by CTMRG
        ctm_env_out, *ctm_log = ctmrg.run(state,
                                          ctm_env_in,
                                          conv_check=ctmrg_conv_energy)
        loss = model.energy_1x1(state, ctm_env_out)

        return (loss, ctm_env_out, *ctm_log)
Example #3
0
    def loss_fn(state, ctm_env_in, opt_context):
        ctm_args = opt_context["ctm_args"]
        opt_args = opt_context["opt_args"]

        # possibly re-initialize the environment
        if opt_args.opt_ctm_reinit:
            init_env(state, ctm_env_in)

        # 1) compute environment by CTMRG
        ctm_env_out, *ctm_log = ctmrg.run(state, ctm_env_in, \
            conv_check=ctmrg_conv_energy, ctm_args=ctm_args)

        # 2) evaluate loss with converged environment
        loss = model.energy_2x1_1x2(state, ctm_env_out)

        return (loss, ctm_env_out, *ctm_log)
Example #4
0
    def test_ctmrg_Ladders_VBS1x2(self):
        cfg.configure(args)
        cfg.print_config()
        torch.set_num_threads(args.omp_cores)

        model = coupledLadders.COUPLEDLADDERS_D2_BIPARTITE(alpha=args.alpha)

        state = read_ipeps(args.instate)

        def ctmrg_conv_energy(state, env, history, ctm_args=cfg.ctm_args):
            with torch.no_grad():
                if not history:
                    history = []
                e_curr = model.energy_2x1_1x2(state, env)
                history.append([e_curr.item()])

                if len(history) > 1 and abs(history[-1][0] - history[-2][0]
                                            ) < ctm_args.ctm_conv_tol:
                    return True, history
            return False, history

        ctm_env_init = ENV(args.chi, state)
        init_env(state, ctm_env_init)

        ctm_env_init, *ctm_log = ctmrg.run(state,
                                           ctm_env_init,
                                           conv_check=ctmrg_conv_energy)

        e_curr0 = model.energy_2x1_1x2(state, ctm_env_init)
        obs_values0, obs_labels = model.eval_obs(state, ctm_env_init)
        obs_dict = dict(zip(obs_labels, obs_values0))

        eps = 1.0e-12
        self.assertTrue(abs(e_curr0 - (-0.375)) < eps)
        for coord, site in state.sites.items():
            self.assertTrue(obs_dict[f"m{coord}"] < eps, msg=f"m{coord}")
            self.assertTrue(obs_dict[f"SS2x1{coord}"] < eps,
                            msg=f"SS2x1{coord}")
            for l in ["sz", "sp", "sm"]:
                self.assertTrue(abs(obs_dict[f"{l}{coord}"]) < eps,
                                msg=f"{l}{coord}")
        for coord in [(0, 0)]:
            self.assertTrue(abs(obs_dict[f"SS1x2{coord}"] - (-0.75)) < eps,
                            msg=f"SS1x2{coord}")
    def loss_fn(state, ctm_env_in, opt_context):
        ctm_args = opt_context["ctm_args"]
        opt_args = opt_context["opt_args"]

        # symmetrize and normalize
        # symm_site= make_d2_symm(state.parent_site)
        # symm_site= symm_site/torch.max(torch.abs(symm_site))
        symm_site = state.parent_site / torch.max(torch.abs(state.parent_site))
        state = IPEPS_D2SYM(symm_site)
        # state.parent_tensors[c]+= state.parent_tensors[c].permute(0,1,4,3,2)
        # state.parent_tensors[c]*= 1.0/torch.max(torch.abs(state.parent_tensors[c]))

        # possibly re-initialize the environment
        if opt_args.opt_ctm_reinit:
            init_env(state, ctm_env_in)

        # 1) compute environment by CTMRG
        ctm_env_out, *ctm_log = ctmrg.run(state, ctm_env_in, \
            conv_check=ctmrg_conv_energy, ctm_args=ctm_args)

        # 2) evaluate loss with converged environment
        loss = model.energy_2x1_1x2(state, ctm_env_out)

        return (loss, ctm_env_out, *ctm_log)
Example #6
0
def main():
    cfg.configure(args)
    cfg.print_config()
    torch.set_num_threads(args.omp_cores)
    torch.manual_seed(args.seed)

    model = coupledLadders.COUPLEDLADDERS(alpha=args.alpha)

    # initialize an ipeps
    # 1) define lattice-tiling function, that maps arbitrary vertex of square lattice
    # coord into one of coordinates within unit-cell of iPEPS ansatz
    if args.instate != None:
        state = read_ipeps(args.instate)
        if args.bond_dim > max(state.get_aux_bond_dims()):
            # extend the auxiliary dimensions
            state = extend_bond_dim(state, args.bond_dim)
        state.add_noise(args.instate_noise)
    elif args.opt_resume is not None:
        state = IPEPS(dict(), lX=2, lY=2)
        state.load_checkpoint(args.opt_resume)
    elif args.ipeps_init_type == 'RANDOM':
        bond_dim = args.bond_dim

        A = torch.rand((model.phys_dim, bond_dim, bond_dim, bond_dim, bond_dim),\
            dtype=cfg.global_args.torch_dtype,device=cfg.global_args.device)
        B = torch.rand((model.phys_dim, bond_dim, bond_dim, bond_dim, bond_dim),\
            dtype=cfg.global_args.torch_dtype,device=cfg.global_args.device)
        C = torch.rand((model.phys_dim, bond_dim, bond_dim, bond_dim, bond_dim),\
            dtype=cfg.global_args.torch_dtype,device=cfg.global_args.device)
        D = torch.rand((model.phys_dim, bond_dim, bond_dim, bond_dim, bond_dim),\
            dtype=cfg.global_args.torch_dtype,device=cfg.global_args.device)

        sites = {(0, 0): A, (1, 0): B, (0, 1): C, (1, 1): D}

        for k in sites.keys():
            sites[k] = sites[k] / torch.max(torch.abs(sites[k]))
        state = IPEPS(sites, lX=2, lY=2)
    else:
        raise ValueError("Missing trial state: --instate=None and --ipeps_init_type= "\
            +str(args.ipeps_init_type)+" is not supported")

    if not state.dtype == model.dtype:
        cfg.global_args.torch_dtype = state.dtype
        print(
            f"dtype of initial state {state.dtype} and model {model.dtype} do not match."
        )
        print(f"Setting default dtype to {cfg.global_args.torch_dtype} and reinitializing "\
        +" the model")
        model = coupledLadders.COUPLEDLADDERS(alpha=args.alpha)

    print(state)

    @torch.no_grad()
    def ctmrg_conv_energy(state, env, history, ctm_args=cfg.ctm_args):
        if not history:
            history = []
        e_curr = model.energy_2x1_1x2(state, env)
        e_curr = e_curr.real if e_curr.is_complex() else e_curr
        history.append(e_curr.item())

        if (len(history) > 1 and abs(history[-1]-history[-2]) < ctm_args.ctm_conv_tol)\
            or len(history) >= ctm_args.ctm_max_iter:
            log.info({"history_length": len(history), "history": history})
            return True, history
        return False, history

    ctm_env = ENV(args.chi, state)
    init_env(state, ctm_env)

    ctm_env, *ctm_log = ctmrg.run(state, ctm_env, conv_check=ctmrg_conv_energy)
    loss0 = model.energy_2x1_1x2(state, ctm_env)
    obs_values, obs_labels = model.eval_obs(state, ctm_env)
    print(", ".join(["epoch", "energy"] + obs_labels))
    print(", ".join([f"{-1}", f"{loss0}"] + [f"{v}" for v in obs_values]))

    def loss_fn(state, ctm_env_in, opt_context):
        ctm_args = opt_context["ctm_args"]
        opt_args = opt_context["opt_args"]

        # possibly re-initialize the environment
        if opt_args.opt_ctm_reinit:
            init_env(state, ctm_env_in)

        # 1) compute environment by CTMRG
        ctm_env_out, *ctm_log = ctmrg.run(state, ctm_env_in, \
            conv_check=ctmrg_conv_energy, ctm_args=ctm_args)

        # 2) evaluate loss with converged environment
        loss = model.energy_2x1_1x2(state, ctm_env_out)

        return (loss, ctm_env_out, *ctm_log)

    def _to_json(l):
        re = [l[i, 0].item() for i in range(l.size()[0])]
        im = [l[i, 1].item() for i in range(l.size()[0])]
        return dict({"re": re, "im": im})

    @torch.no_grad()
    def obs_fn(state, ctm_env, opt_context):
        if ("line_search" in opt_context.keys() and not opt_context["line_search"]) \
            or not "line_search" in opt_context.keys():
            epoch = len(opt_context["loss_history"]["loss"])
            loss = opt_context["loss_history"]["loss"][-1]
            obs_values, obs_labels = model.eval_obs(state, ctm_env)
            print(", ".join([f"{epoch}", f"{loss}"] +
                            [f"{v}" for v in obs_values]))

            with torch.no_grad():
                if args.top_freq > 0 and epoch % args.top_freq == 0:
                    coord_dir_pairs = [((0, 0), (1, 0)), ((0, 0), (0, 1)),
                                       ((1, 1), (1, 0)), ((1, 1), (0, 1))]
                    for c, d in coord_dir_pairs:
                        # transfer operator spectrum
                        print(f"TOP spectrum(T)[{c},{d}] ", end="")
                        l = transferops.get_Top_spec(args.top_n, c, d, state,
                                                     ctm_env)
                        print("TOP " + json.dumps(_to_json(l)))

    # optimize
    optimize_state(state, ctm_env, loss_fn, obs_fn=obs_fn)

    # compute final observables for the best variational state
    outputstatefile = args.out_prefix + "_state.json"
    state = read_ipeps(outputstatefile)
    ctm_env = ENV(args.chi, state)
    init_env(state, ctm_env)
    ctm_env, *ctm_log = ctmrg.run(state, ctm_env, conv_check=ctmrg_conv_energy)
    loss0 = model.energy_2x1_1x2(state, ctm_env)
    obs_values, obs_labels = model.eval_obs(state, ctm_env)
    print(", ".join([f"{args.opt_max_iter}", f"{loss0}"] +
                    [f"{v}" for v in obs_values]))
Example #7
0
def main():
    cfg.configure(args)
    cfg.print_config()
    torch.set_num_threads(args.omp_cores)
    torch.manual_seed(args.seed)

    model = jq.JQ(j1=args.j1, q=args.q)

    # initialize an ipeps
    # 1) define lattice-tiling function, that maps arbitrary vertex of square lattice
    # coord into one of coordinates within unit-cell of iPEPS ansatz

    if args.instate != None:
        state = read_ipeps(args.instate)
        if args.bond_dim > max(state.get_aux_bond_dims()):
            # extend the auxiliary dimensions
            state = extend_bond_dim(state, args.bond_dim)
        state.add_noise(args.instate_noise)
    elif args.ipeps_init_type == 'RANDOM':
        bond_dim = args.bond_dim

        A = torch.rand((model.phys_dim, bond_dim, bond_dim, bond_dim, bond_dim),\
            dtype=cfg.global_args.dtype,device=cfg.global_args.device)
        B = torch.rand((model.phys_dim, bond_dim, bond_dim, bond_dim, bond_dim),\
            dtype=cfg.global_args.dtype,device=cfg.global_args.device)
        C = torch.rand((model.phys_dim, bond_dim, bond_dim, bond_dim, bond_dim),\
            dtype=cfg.global_args.dtype,device=cfg.global_args.device)
        D = torch.rand((model.phys_dim, bond_dim, bond_dim, bond_dim, bond_dim),\
            dtype=cfg.global_args.dtype,device=cfg.global_args.device)

        sites = {(0, 0): A, (1, 0): B, (0, 1): C, (1, 1): D}

        for k in sites.keys():
            sites[k] = sites[k] / torch.max(torch.abs(sites[k]))
        state = IPEPS(sites, lX=2, lY=2)
    else:
        raise ValueError("Missing trial state: -instate=None and -ipeps_init_type= "\
            +str(args.ipeps_init_type)+" is not supported")

    print(state)

    def ctmrg_conv_energy(state, env, history, ctm_args=cfg.ctm_args):
        with torch.no_grad():
            if not history:
                history = []
            e_curr = model.energy_2x2_4site(state, env)
            obs_values, obs_labels = model.eval_obs(state, env)
            history.append([e_curr.item()] + obs_values)
            print(", ".join([f"{len(history)}", f"{e_curr}"] +
                            [f"{v}" for v in obs_values]))

            if len(history) > 1 and abs(
                    history[-1][0] - history[-2][0]) < ctm_args.ctm_conv_tol:
                return True, history
        return False, history

    ctm_env_init = ENV(args.chi, state)
    init_env(state, ctm_env_init)
    print(ctm_env_init)

    e_curr0 = model.energy_2x2_4site(state, ctm_env_init)
    obs_values0, obs_labels = model.eval_obs(state, ctm_env_init)

    print(", ".join(["epoch", "energy"] + obs_labels))
    print(", ".join([f"{-1}", f"{e_curr0}"] + [f"{v}" for v in obs_values0]))

    ctm_env_init, *ctm_log = ctmrg.run(state,
                                       ctm_env_init,
                                       conv_check=ctmrg_conv_energy)

    # ----- S(0).S(r) -----
    site_dir_list = [((0, 0), (1, 0)), ((0, 0), (0, 1)), ((1, 1), (1, 0)),
                     ((1, 1), (0, 1))]
    for sdp in site_dir_list:
        corrSS = model.eval_corrf_SS(*sdp, state, ctm_env_init, args.corrf_r)
        print(f"\n\nSS[{sdp[0]},{sdp[1]}] r " +
              " ".join([label for label in corrSS.keys()]))
        for i in range(args.corrf_r):
            print(f"{i} " +
                  " ".join([f"{corrSS[label][i]}" for label in corrSS.keys()]))

    # ----- (S(0).S(x))(S(rx).S(rx+x)) -----
    for sdp in site_dir_list:
        corrDD = model.eval_corrf_DD_H(*sdp, state, ctm_env_init, args.corrf_r)
        print(f"\n\nDD[{sdp[0]},{sdp[1]}] r " +
              " ".join([label for label in corrDD.keys()]))
        for i in range(args.corrf_r):
            print(f"{i} " +
                  " ".join([f"{corrDD[label][i]}" for label in corrDD.keys()]))

    # ----- (S(0).S(y))(S(rx).S(rx+y)) -----
    for sdp in site_dir_list:
        corrDD_V = model.eval_corrf_DD_V(*sdp, state, ctm_env_init,
                                         args.corrf_r)
        print(f"\n\nDD_V[{sdp[0]},{sdp[1]}] r " +
              " ".join([label for label in corrDD_V.keys()]))
        for i in range(args.corrf_r):
            print(f"{i} " + " ".join(
                [f"{corrDD_V[label][i]}" for label in corrDD_V.keys()]))

    # environment diagnostics
    for c_loc, c_ten in ctm_env_init.C.items():
        u, s, v = torch.svd(c_ten, compute_uv=False)
        print(f"\n\nspectrum C[{c_loc}]")
        for i in range(args.chi):
            print(f"{i} {s[i]}")

    # transfer operator spectrum
    for sdp in site_dir_list:
        print(f"\n\nspectrum(T)[{sdp[0]},{sdp[1]}]")
        l = transferops.get_Top_spec(args.top_n, *sdp, state, ctm_env_init)
        for i in range(l.size()[0]):
            print(f"{i} {l[i,0]} {l[i,1]}")
Example #8
0
def main():
    cfg.configure(args)
    cfg.print_config()
    torch.set_num_threads(args.omp_cores)
    torch.manual_seed(args.seed)

    model = j1j2.J1J2(j1=args.j1, j2=args.j2)

    # initialize an ipeps
    # 1) define lattice-tiling function, that maps arbitrary vertex of square lattice
    # coord into one of coordinates within unit-cell of iPEPS ansatz
    if args.tiling == "BIPARTITE":

        def lattice_to_site(coord):
            vx = (coord[0] + abs(coord[0]) * 2) % 2
            vy = abs(coord[1])
            return ((vx + vy) % 2, 0)
    elif args.tiling == "2SITE":

        def lattice_to_site(coord):
            vx = (coord[0] + abs(coord[0]) * 2) % 2
            vy = (coord[1] + abs(coord[1]) * 1) % 1
            return (vx, vy)
    elif args.tiling == "4SITE":

        def lattice_to_site(coord):
            vx = (coord[0] + abs(coord[0]) * 2) % 2
            vy = (coord[1] + abs(coord[1]) * 2) % 2
            return (vx, vy)
    elif args.tiling == "8SITE":

        def lattice_to_site(coord):
            shift_x = coord[0] + 2 * (coord[1] // 2)
            vx = shift_x % 4
            vy = coord[1] % 2
            return (vx, vy)
    else:
        raise ValueError("Invalid tiling: "+str(args.tiling)+" Supported options: "\
            +"BIPARTITE, 2SITE, 4SITE, 8SITE")

    # initialize an ipeps
    if args.instate != None:
        state = read_ipeps(args.instate, vertexToSite=lattice_to_site)
        if args.bond_dim > max(state.get_aux_bond_dims()):
            # extend the auxiliary dimensions
            state = extend_bond_dim(state, args.bond_dim)
        state.add_noise(args.instate_noise)
    elif args.ipeps_init_type == 'RANDOM':
        bond_dim = args.bond_dim

        A = torch.rand((model.phys_dim, bond_dim, bond_dim, bond_dim, bond_dim),\
            dtype=cfg.global_args.dtype,device=cfg.global_args.device)
        B = torch.rand((model.phys_dim, bond_dim, bond_dim, bond_dim, bond_dim),\
            dtype=cfg.global_args.dtype,device=cfg.global_args.device)

        # normalization of initial random tensors
        A = A / torch.max(torch.abs(A))
        B = B / torch.max(torch.abs(B))

        sites = {(0, 0): A, (1, 0): B}

        if args.tiling == "4SITE":
            C= torch.rand((model.phys_dim, bond_dim, bond_dim, bond_dim, bond_dim),\
                dtype=cfg.global_args.dtype,device=cfg.global_args.device)
            D= torch.rand((model.phys_dim, bond_dim, bond_dim, bond_dim, bond_dim),\
                dtype=cfg.global_args.dtype,device=cfg.global_args.device)
            sites[(0, 1)] = C / torch.max(torch.abs(C))
            sites[(1, 1)] = D / torch.max(torch.abs(D))

        if args.tiling == "8SITE":
            E= torch.rand((model.phys_dim, bond_dim, bond_dim, bond_dim, bond_dim),\
                dtype=cfg.global_args.dtype,device=cfg.global_args.device)
            F= torch.rand((model.phys_dim, bond_dim, bond_dim, bond_dim, bond_dim),\
                dtype=cfg.global_args.dtype,device=cfg.global_args.device)
            G= torch.rand((model.phys_dim, bond_dim, bond_dim, bond_dim, bond_dim),\
                dtype=cfg.global_args.dtype,device=cfg.global_args.device)
            H= torch.rand((model.phys_dim, bond_dim, bond_dim, bond_dim, bond_dim),\
                dtype=cfg.global_args.dtype,device=cfg.global_args.device)
            sites[(2, 0)] = E / torch.max(torch.abs(E))
            sites[(3, 0)] = F / torch.max(torch.abs(F))
            sites[(2, 1)] = G / torch.max(torch.abs(G))
            sites[(3, 1)] = H / torch.max(torch.abs(H))

        state = IPEPS(sites, vertexToSite=lattice_to_site)
    else:
        raise ValueError("Missing trial state: -instate=None and -ipeps_init_type= "\
            +str(args.ipeps_init_type)+" is not supported")

    print(state)

    # 2) select the "energy" function
    if args.tiling == "BIPARTITE" or args.tiling == "2SITE":
        energy_f = model.energy_2x2_2site
    elif args.tiling == "4SITE":
        energy_f = model.energy_2x2_4site
    elif args.tiling == "8SITE":
        energy_f = model.energy_2x2_8site
    else:
        raise ValueError("Invalid tiling: "+str(args.tiling)+" Supported options: "\
            +"BIPARTITE, 2SITE, 4SITE")

    def ctmrg_conv_energy(state, env, history, ctm_args=cfg.ctm_args):
        with torch.no_grad():
            if not history:
                history = []
            e_curr = energy_f(state, env)
            obs_values, obs_labels = model.eval_obs(state, env)
            history.append([e_curr.item()] + obs_values)
            print(", ".join([f"{len(history)}", f"{e_curr}"] +
                            [f"{v}" for v in obs_values]))

            if len(history) > 1 and abs(
                    history[-1][0] - history[-2][0]) < ctm_args.ctm_conv_tol:
                return True, history
        return False, history

    ctm_env_init = ENV(args.chi, state)
    init_env(state, ctm_env_init)
    print(ctm_env_init)

    e_curr0 = energy_f(state, ctm_env_init)
    obs_values0, obs_labels = model.eval_obs(state, ctm_env_init)

    print(", ".join(["epoch", "energy"] + obs_labels))
    print(", ".join([f"{-1}", f"{e_curr0}"] + [f"{v}" for v in obs_values0]))

    ctm_env_init, *ctm_log = ctmrg.run(state,
                                       ctm_env_init,
                                       conv_check=ctmrg_conv_energy)

    # 6) compute final observables
    e_curr0 = energy_f(state, ctm_env_init)
    obs_values0, obs_labels = model.eval_obs(state, ctm_env_init)
    history, t_ctm, t_obs = ctm_log
    print("\n")
    print(", ".join(["epoch", "energy"] + obs_labels))
    print("FINAL " + ", ".join([f"{e_curr0}"] + [f"{v}" for v in obs_values0]))
    print(f"TIMINGS ctm: {t_ctm} conv_check: {t_obs}")

    # 7) ----- additional observables ---------------------------------------------
    corrSS = model.eval_corrf_SS((0, 0), (1, 0), state, ctm_env_init,
                                 args.corrf_r)
    print("\n\nSS[(0,0),(1,0)] r " +
          " ".join([label for label in corrSS.keys()]))
    for i in range(args.corrf_r):
        print(f"{i} " +
              " ".join([f"{corrSS[label][i]}" for label in corrSS.keys()]))

    corrSS = model.eval_corrf_SS((0, 0), (0, 1), state, ctm_env_init,
                                 args.corrf_r)
    print("\n\nSS[(0,0),(0,1)] r " +
          " ".join([label for label in corrSS.keys()]))
    for i in range(args.corrf_r):
        print(f"{i} " +
              " ".join([f"{corrSS[label][i]}" for label in corrSS.keys()]))

    # environment diagnostics
    print("\n")
    for c_loc, c_ten in ctm_env_init.C.items():
        u, s, v = torch.svd(c_ten, compute_uv=False)
        print(f"spectrum C[{c_loc}]")
        for i in range(args.chi):
            print(f"{i} {s[i]}")

    # transfer operator spectrum
    site_dir_list = [((0, 0), (1, 0)), ((0, 0), (0, 1))]
    for sdp in site_dir_list:
        print(f"\n\nspectrum(T)[{sdp[0]},{sdp[1]}]")
        l = transferops.get_Top_spec(args.top_n, *sdp, state, ctm_env_init)
        for i in range(l.size()[0]):
            print(f"{i} {l[i,0]} {l[i,1]}")
Example #9
0
def main():
    cfg.configure(args)
    cfg.print_config()
    torch.set_num_threads(args.omp_cores)
    torch.manual_seed(args.seed)

    model = ising.ISING(hx=args.hx, q=args.q)

    # initialize an ipeps
    # 1) define lattice-tiling function, that maps arbitrary vertex of square lattice
    # coord into one of coordinates within unit-cell of iPEPS ansatz
    if args.instate != None:
        state = read_ipeps(args.instate, vertexToSite=None)
        if args.bond_dim > max(state.get_aux_bond_dims()):
            # extend the auxiliary dimensions
            state = extend_bond_dim(state, args.bond_dim)
        state.add_noise(args.instate_noise)
    elif args.opt_resume is not None:
        state = IPEPS(dict(), lX=1, lY=1)
        state.load_checkpoint(args.opt_resume)
    elif args.ipeps_init_type == 'RANDOM':
        bond_dim = args.bond_dim
        A = torch.rand((model.phys_dim, bond_dim, bond_dim, bond_dim, bond_dim),\
            dtype=cfg.global_args.torch_dtype,device=cfg.global_args.device)
        # normalization of initial random tensors
        A = A / torch.max(torch.abs(A))
        sites = {(0, 0): A}
        state = IPEPS(sites)
    else:
        raise ValueError("Missing trial state: -instate=None and -ipeps_init_type= "\
            +str(args.ipeps_init_type)+" is not supported")

    print(state)

    @torch.no_grad()
    def ctmrg_conv_energy(state, env, history, ctm_args=cfg.ctm_args):
        if not history:
            history = []
        e_curr = model.energy_1x1(state, env)
        history.append(e_curr.item())

        if (len(history) > 1 and abs(history[-1]-history[-2]) < ctm_args.ctm_conv_tol)\
            or len(history) >= ctm_args.ctm_max_iter:
            log.info({"history_length": len(history), "history": history})
            return True, history
        return False, history

    ctm_env = ENV(args.chi, state)
    init_env(state, ctm_env)

    ctm_env, *ctm_log = ctmrg.run(state, ctm_env, conv_check=ctmrg_conv_energy)
    loss0 = model.energy_1x1(state, ctm_env)
    obs_values, obs_labels = model.eval_obs(state, ctm_env)
    print(", ".join(["epoch", "energy"] + obs_labels))
    print(", ".join([f"{-1}", f"{loss0}"] + [f"{v}" for v in obs_values]))

    def loss_fn(state, ctm_env_in, opt_context):
        # possibly re-initialize the environment
        if cfg.opt_args.opt_ctm_reinit:
            init_env(state, ctm_env_in)

        # 1) compute environment by CTMRG
        ctm_env_out, *ctm_log = ctmrg.run(state,
                                          ctm_env_in,
                                          conv_check=ctmrg_conv_energy)
        loss = model.energy_1x1(state, ctm_env_out)

        return (loss, ctm_env_out, *ctm_log)

    @torch.no_grad()
    def obs_fn(state, ctm_env, opt_context):
        epoch = len(opt_context["loss_history"]["loss"])
        loss = opt_context["loss_history"]["loss"][-1]
        obs_values, obs_labels = model.eval_obs(state, ctm_env)
        print(", ".join([f"{epoch}", f"{loss}"] + [f"{v}"
                                                   for v in obs_values]))

    # optimize
    optimize_state(state, ctm_env, loss_fn, obs_fn=obs_fn)

    # compute final observables for the best variational state
    outputstatefile = args.out_prefix + "_state.json"
    state = read_ipeps(outputstatefile)
    ctm_env = ENV(args.chi, state)
    init_env(state, ctm_env)
    ctm_env, *ctm_log = ctmrg.run(state, ctm_env, conv_check=ctmrg_conv_energy)
    loss0 = model.energy_1x1(state, ctm_env)
    obs_values, obs_labels = model.eval_obs(state, ctm_env)
    print(", ".join([f"{args.opt_max_iter}", f"{loss0}"] +
                    [f"{v}" for v in obs_values]))
Example #10
0
def main():
    cfg.configure(args)
    cfg.print_config()
    torch.set_num_threads(args.omp_cores)
    torch.manual_seed(args.seed)

    model = j1j2.J1J2(j1=args.j1, j2=args.j2)

    # initialize an ipeps
    # 1) define lattice-tiling function, that maps arbitrary vertex of square lattice
    # coord into one of coordinates within unit-cell of iPEPS ansatz
    if args.tiling == "BIPARTITE":

        def lattice_to_site(coord):
            vx = (coord[0] + abs(coord[0]) * 2) % 2
            vy = abs(coord[1])
            return ((vx + vy) % 2, 0)
    elif args.tiling == "1SITE":

        def lattice_to_site(coord):
            return (0, 0)
    elif args.tiling == "2SITE":

        def lattice_to_site(coord):
            vx = (coord[0] + abs(coord[0]) * 2) % 2
            vy = (coord[1] + abs(coord[1]) * 1) % 1
            return (vx, vy)
    elif args.tiling == "4SITE":

        def lattice_to_site(coord):
            vx = (coord[0] + abs(coord[0]) * 2) % 2
            vy = (coord[1] + abs(coord[1]) * 2) % 2
            return (vx, vy)
    elif args.tiling == "8SITE":

        def lattice_to_site(coord):
            shift_x = coord[0] + 2 * (coord[1] // 2)
            vx = shift_x % 4
            vy = coord[1] % 2
            return (vx, vy)
    else:
        raise ValueError("Invalid tiling: "+str(args.tiling)+" Supported options: "\
            +"BIPARTITE, 1SITE, 2SITE, 4SITE, 8SITE")

    if args.instate != None:
        state = read_ipeps(args.instate, vertexToSite=lattice_to_site)
        if args.bond_dim > max(state.get_aux_bond_dims()):
            # extend the auxiliary dimensions
            state = extend_bond_dim(state, args.bond_dim)
        state.add_noise(args.instate_noise)
    elif args.opt_resume is not None:
        if args.tiling == "BIPARTITE" or args.tiling == "2SITE":
            state = IPEPS(dict(), lX=2, lY=1)
        elif args.tiling == "1SITE":
            state = IPEPS(dict(), lX=1, lY=1)
        elif args.tiling == "4SITE":
            state = IPEPS(dict(), lX=2, lY=2)
        elif args.tiling == "8SITE":
            state = IPEPS(dict(), lX=4, lY=2)
        state.load_checkpoint(args.opt_resume)
    elif args.ipeps_init_type == 'RANDOM':
        bond_dim = args.bond_dim
        A = torch.rand((model.phys_dim, bond_dim, bond_dim, bond_dim, bond_dim),\
            dtype=cfg.global_args.torch_dtype,device=cfg.global_args.device)

        # normalization of initial random tensors
        A = A / torch.max(torch.abs(A))
        sites = {(0, 0): A}
        if args.tiling in ["BIPARTITE", "2SITE", "4SITE", "8SITE"]:
            B = torch.rand((model.phys_dim, bond_dim, bond_dim, bond_dim, bond_dim),\
                dtype=cfg.global_args.torch_dtype,device=cfg.global_args.device)
            sites[(1, 0)] = B / torch.max(torch.abs(B))
        if args.tiling in ["4SITE", "8SITE"]:
            C= torch.rand((model.phys_dim, bond_dim, bond_dim, bond_dim, bond_dim),\
                dtype=cfg.global_args.torch_dtype,device=cfg.global_args.device)
            D= torch.rand((model.phys_dim, bond_dim, bond_dim, bond_dim, bond_dim),\
                dtype=cfg.global_args.torch_dtype,device=cfg.global_args.device)
            sites[(0, 1)] = C / torch.max(torch.abs(C))
            sites[(1, 1)] = D / torch.max(torch.abs(D))
        if args.tiling == "8SITE":
            E= torch.rand((model.phys_dim, bond_dim, bond_dim, bond_dim, bond_dim),\
                dtype=cfg.global_args.torch_dtype,device=cfg.global_args.device)
            F= torch.rand((model.phys_dim, bond_dim, bond_dim, bond_dim, bond_dim),\
                dtype=cfg.global_args.torch_dtype,device=cfg.global_args.device)
            G= torch.rand((model.phys_dim, bond_dim, bond_dim, bond_dim, bond_dim),\
                dtype=cfg.global_args.torch_dtype,device=cfg.global_args.device)
            H= torch.rand((model.phys_dim, bond_dim, bond_dim, bond_dim, bond_dim),\
                dtype=cfg.global_args.torch_dtype,device=cfg.global_args.device)
            sites[(2, 0)] = E / torch.max(torch.abs(E))
            sites[(3, 0)] = F / torch.max(torch.abs(F))
            sites[(2, 1)] = G / torch.max(torch.abs(G))
            sites[(3, 1)] = H / torch.max(torch.abs(H))
        state = IPEPS(sites, vertexToSite=lattice_to_site)
    else:
        raise ValueError("Missing trial state: -instate=None and -ipeps_init_type= "\
            +str(args.ipeps_init_type)+" is not supported")

    if not state.dtype == model.dtype:
        cfg.global_args.torch_dtype = state.dtype
        print(
            f"dtype of initial state {state.dtype} and model {model.dtype} do not match."
        )
        print(f"Setting default dtype to {cfg.global_args.torch_dtype} and reinitializing "\
        +" the model")
        model = j1j2.J1J2(alpha=args.alpha)

    print(state)

    # 2) select the "energy" function
    if args.tiling == "BIPARTITE" or args.tiling == "2SITE":
        energy_f = model.energy_2x2_2site
        eval_obs_f = model.eval_obs
    elif args.tiling == "1SITE":
        energy_f = model.energy_2x2_1site_BP
        # TODO include eval_obs with rotation on B-sublattice
        eval_obs_f = model.eval_obs
    elif args.tiling == "4SITE":
        energy_f = model.energy_2x2_4site
        eval_obs_f = model.eval_obs
    elif args.tiling == "8SITE":
        energy_f = model.energy_2x2_8site
        eval_obs_f = model.eval_obs
    else:
        raise ValueError("Invalid tiling: "+str(args.tiling)+" Supported options: "\
            +"BIPARTITE, 2SITE, 4SITE, 8SITE")

    @torch.no_grad()
    def ctmrg_conv_energy(state, env, history, ctm_args=cfg.ctm_args):
        if not history:
            history = []
        e_curr = energy_f(state, env)
        history.append(e_curr.item())

        if (len(history) > 1 and abs(history[-1]-history[-2]) < ctm_args.ctm_conv_tol)\
            or len(history) >= ctm_args.ctm_max_iter:
            log.info({"history_length": len(history), "history": history})
            return True, history
        return False, history

    ctm_env = ENV(args.chi, state)
    init_env(state, ctm_env)

    ctm_env, *ctm_log = ctmrg.run(state, ctm_env, conv_check=ctmrg_conv_energy)
    loss0 = energy_f(state, ctm_env)
    obs_values, obs_labels = eval_obs_f(state, ctm_env)
    print(", ".join(["epoch", "energy"] + obs_labels))
    print(", ".join([f"{-1}", f"{loss0}"] + [f"{v}" for v in obs_values]))

    def loss_fn(state, ctm_env_in, opt_context):
        ctm_args = opt_context["ctm_args"]
        opt_args = opt_context["opt_args"]

        # possibly re-initialize the environment
        if opt_args.opt_ctm_reinit:
            init_env(state, ctm_env_in)

        # 1) compute environment by CTMRG
        ctm_env_out, *ctm_log= ctmrg.run(state, ctm_env_in, \
            conv_check=ctmrg_conv_energy, ctm_args=ctm_args)

        # 2) evaluate loss with the converged environment
        loss = energy_f(state, ctm_env_out)

        return (loss, ctm_env_out, *ctm_log)

    def _to_json(l):
        re = [l[i, 0].item() for i in range(l.size()[0])]
        im = [l[i, 1].item() for i in range(l.size()[0])]
        return dict({"re": re, "im": im})

    @torch.no_grad()
    def obs_fn(state, ctm_env, opt_context):
        if ("line_search" in opt_context.keys() and not opt_context["line_search"]) \
            or not "line_search" in opt_context.keys():
            epoch = len(opt_context["loss_history"]["loss"])
            loss = opt_context["loss_history"]["loss"][-1]
            obs_values, obs_labels = eval_obs_f(state, ctm_env)
            print(", ".join([f"{epoch}", f"{loss}"] +
                            [f"{v}" for v in obs_values]))
            log.info("Norm(sites): " +
                     ", ".join([f"{t.norm()}"
                                for c, t in state.sites.items()]))

            with torch.no_grad():
                if args.top_freq > 0 and epoch % args.top_freq == 0:
                    coord_dir_pairs = [((0, 0), (1, 0)), ((0, 0), (0, 1)),
                                       ((1, 1), (1, 0)), ((1, 1), (0, 1))]
                    for c, d in coord_dir_pairs:
                        # transfer operator spectrum
                        print(f"TOP spectrum(T)[{c},{d}] ", end="")
                        l = transferops.get_Top_spec(args.top_n, c, d, state,
                                                     ctm_env)
                        print("TOP " + json.dumps(_to_json(l)))

    # optimize
    optimize_state(state, ctm_env, loss_fn, obs_fn=obs_fn)

    # compute final observables for the best variational state
    outputstatefile = args.out_prefix + "_state.json"
    state = read_ipeps(outputstatefile, vertexToSite=state.vertexToSite)
    ctm_env = ENV(args.chi, state)
    init_env(state, ctm_env)
    ctm_env, *ctm_log = ctmrg.run(state, ctm_env, conv_check=ctmrg_conv_energy)
    loss0 = energy_f(state, ctm_env)
    obs_values, obs_labels = eval_obs_f(state, ctm_env)
    print(", ".join([f"{args.opt_max_iter}", f"{loss0}"] +
                    [f"{v}" for v in obs_values]))
Example #11
0
def main():
    cfg.configure(args)
    cfg.print_config()
    torch.set_num_threads(args.omp_cores)
    torch.manual_seed(args.seed)

    model = triangle.triangle(j1=args.j1, j2=args.j2)

    # initialize an ipeps
    # 1) define lattice-tiling function, that maps arbitrary vertex of square lattice
    # coord into one of coordinates within unit-cell of iPEPS ansatz
    if args.tiling == "3x3":

        def lattice_to_site(coord):
            vx = (-coord[0] + abs(coord[0]) * 3) % 3
            vy = (-coord[1] + abs(coord[1]) * 3) % 3
            #print(vx,vy)
            return ((vx + vy) % 3, 0)
            #return (vx,vy)

    elif args.tiling == "9SITE":

        def lattice_to_site(coord):
            vx = (coord[0] + abs(coord[0]) * 3) % 3
            vy = (coord[1] + abs(coord[1]) * 3) % 3
            return (vx, vy)

    else:
        raise ValueError("Invalid tiling: "+str(args.tiling)+" Supported options: "\
            +"3x3")

    if args.instate != None:
        state = read_ipeps(args.instate, vertexToSite=lattice_to_site)
        if args.bond_dim > max(state.get_aux_bond_dims()):
            # extend the auxiliary dimensions
            state = extend_bond_dim(state, args.bond_dim)
        state.add_noise(args.instate_noise)
    elif args.opt_resume is not None:
        if args.tiling == "3x3":
            state = IPEPS(dict(), lX=3, lY=1)
        state.load_checkpoint(args.opt_resume)
    elif args.ipeps_init_type == 'RANDOM':
        bond_dim = args.bond_dim

        A = torch.rand((model.phys_dim, bond_dim, bond_dim, bond_dim, bond_dim),\
            dtype=cfg.global_args.dtype,device=cfg.global_args.device)
        B = torch.rand((model.phys_dim, bond_dim, bond_dim, bond_dim, bond_dim),\
            dtype=cfg.global_args.dtype,device=cfg.global_args.device)

        # normalization of initial random tensors
        A = A / (torch.max(torch.abs(A)))
        B = B / (torch.max(torch.abs(B)))

        sites = {(0, 0): A, (1, 0): B}
        if args.tiling == "3x3":
            C= torch.rand((model.phys_dim, bond_dim, bond_dim, bond_dim, bond_dim),\
                dtype=cfg.global_args.dtype,device=cfg.global_args.device)
            sites[(2, 0)] = C / torch.max(torch.abs(C))

            state = IPEPS(sites, vertexToSite=lattice_to_site)

        if args.tiling == "9SITE":
            C= torch.rand((model.phys_dim, bond_dim, bond_dim, bond_dim, bond_dim),\
                dtype=cfg.global_args.dtype,device=cfg.global_args.device)
            D= torch.rand((model.phys_dim, bond_dim, bond_dim, bond_dim, bond_dim),\
                dtype=cfg.global_args.dtype,device=cfg.global_args.device)
            E= torch.rand((model.phys_dim, bond_dim, bond_dim, bond_dim, bond_dim),\
                dtype=cfg.global_args.dtype,device=cfg.global_args.device)
            F= torch.rand((model.phys_dim, bond_dim, bond_dim, bond_dim, bond_dim),\
                dtype=cfg.global_args.dtype,device=cfg.global_args.device)
            G= torch.rand((model.phys_dim, bond_dim, bond_dim, bond_dim, bond_dim),\
                dtype=cfg.global_args.dtype,device=cfg.global_args.device)
            H= torch.rand((model.phys_dim, bond_dim, bond_dim, bond_dim, bond_dim),\
                dtype=cfg.global_args.dtype,device=cfg.global_args.device)
            I= torch.rand((model.phys_dim, bond_dim, bond_dim, bond_dim, bond_dim),\
                dtype=cfg.global_args.dtype,device=cfg.global_args.device)
            sites[(2, 0)] = C / torch.max(torch.abs(C))
            sites[(0, 1)] = D / torch.max(torch.abs(D))
            sites[(1, 1)] = E / torch.max(torch.abs(E))
            sites[(2, 1)] = F / torch.max(torch.abs(F))
            sites[(0, 2)] = G / torch.max(torch.abs(G))
            sites[(1, 2)] = H / torch.max(torch.abs(H))
            sites[(2, 2)] = I / torch.max(torch.abs(I))
            state = IPEPS(sites, vertexToSite=lattice_to_site)
    else:
        raise ValueError("Missing trial state: -instate=None and -ipeps_init_type= "\
            +str(args.ipeps_init_type)+" is not supported")

    # 2) select the "energy" function
    if args.tiling == "3x3":
        energy_f = model.energy_2x2_9site

    elif args.tiling == "9SITE":
        energy_f = model.energy_2x2_9site
    else:
        raise ValueError("Invalid tiling: "+str(args.tiling)+" Supported options: "\
            +"BIPARTITE, 2SITE, 4SITE, 8SITE")

    @torch.no_grad()
    def ctmrg_conv_energy(state, env, history, ctm_args=cfg.ctm_args):
        if not history:
            history = []
        e_curr = energy_f(state, env)
        history.append(e_curr.item())

        if (len(history) > 1 and abs(history[-1]-history[-2]) < ctm_args.ctm_conv_tol)\
            or len(history) >= ctm_args.ctm_max_iter:
            log.info({"history_length": len(history), "history": history})
            return True, history
        return False, history

    ctm_env = ENV(args.chi, state)
    init_env(state, ctm_env)

    ctm_env, *ctm_log = ctmrg.run(state, ctm_env, conv_check=ctmrg_conv_energy)
    loss0 = energy_f(state, ctm_env)
    obs_values, obs_labels = model.eval_obs(state, ctm_env)
    print(", ".join(["epoch", "energy"] + obs_labels))
    print(", ".join([f"{-1}", f"{loss0}"] + [f"{v}" for v in obs_values]))

    def loss_fn(state, ctm_env_in, opt_context):
        ctm_args = opt_context["ctm_args"]
        opt_args = opt_context["opt_args"]

        # possibly re-initialize the environment
        if opt_args.opt_ctm_reinit:
            init_env(state, ctm_env_in)

        # 1) compute environment by CTMRG
        ctm_env_out, *ctm_log= ctmrg.run(state, ctm_env_in, \
            conv_check=ctmrg_conv_energy, ctm_args=ctm_args)

        # 2) evaluate loss with the converged environment
        loss = energy_f(state, ctm_env_out)

        return (loss, ctm_env_out, *ctm_log)

    @torch.no_grad()
    def obs_fn(state, ctm_env, opt_context):
        if ("line_search" in opt_context.keys() and not opt_context["line_search"]) \
            or not "line_search" in opt_context.keys():
            epoch = len(opt_context["loss_history"]["loss"])
            loss = opt_context["loss_history"]["loss"][-1]
            #obs_values, obs_labels = model.eval_obs(state,ctm_env)
            #print(", ".join([f"{epoch}",f"{loss}"]+[f"{v}" for v in obs_values]))
            print(", ".join([f"{epoch}", f"{loss}"]))
            log.info("Norm(sites): " +
                     ", ".join([f"{t.norm()}"
                                for c, t in state.sites.items()]))

    # optimize
    optimize_state(state, ctm_env, loss_fn, obs_fn=obs_fn)

    # compute final observables for the best variational state
    outputstatefile = args.out_prefix + "_state.json"
    state = read_ipeps(outputstatefile, vertexToSite=state.vertexToSite)
    ctm_env = ENV(args.chi, state)
    init_env(state, ctm_env)
    ctm_env, *ctm_log = ctmrg.run(state, ctm_env, conv_check=ctmrg_conv_energy)
    opt_energy = energy_f(state, ctm_env)
    obs_values, obs_labels = model.eval_obs(state, ctm_env)
Example #12
0
def main():
    cfg.configure(args)
    cfg.print_config()
    torch.set_num_threads(args.omp_cores)
    torch.manual_seed(args.seed)

    model = coupledLadders.COUPLEDLADDERS_D2_BIPARTITE(alpha=args.alpha)

    # initialize an ipeps
    if args.instate != None:
        state = read_ipeps_d2(args.instate)
        if args.bond_dim > max(state.get_aux_bond_dims()):
            # extend the auxiliary dimensions
            state = extend_bond_dim(state, args.bond_dim)
        state.add_noise(args.instate_noise)
    elif args.opt_resume is not None:
        state = IPEPS_D2SYM()
        state.load_checkpoint(args.opt_resume)
    elif args.ipeps_init_type == 'RANDOM':
        bond_dim = args.bond_dim
        A = torch.rand((model.phys_dim, bond_dim, bond_dim, bond_dim, bond_dim),\
            dtype=cfg.global_args.torch_dtype,device=cfg.global_args.device)
        # A= make_d2_symm(A)
        A = A / torch.max(torch.abs(A))
        state = IPEPS_D2SYM(A)
    else:
        raise ValueError("Missing trial state: -instate=None and -ipeps_init_type= "\
            +str(args.ipeps_init_type)+" is not supported")

    state.sites = state.build_onsite_tensors()
    print(state)

    @torch.no_grad()
    def ctmrg_conv_energy(state, env, history, ctm_args=cfg.ctm_args):
        if not history:
            history = []
        e_curr = model.energy_2x1_1x2(state, env)
        history.append(e_curr.item())

        if (len(history) > 1 and abs(history[-1]-history[-2]) < ctm_args.ctm_conv_tol)\
            or len(history) >= ctm_args.ctm_max_iter:
            log.info({"history_length": len(history), "history": history})
            return True, history
        return False, history

    ctm_env = ENV(args.chi, state)
    init_env(state, ctm_env)

    ctm_env, *ctm_log = ctmrg.run(state, ctm_env, conv_check=ctmrg_conv_energy)
    loss0 = model.energy_2x1_1x2(state, ctm_env)
    obs_values, obs_labels = model.eval_obs(state, ctm_env)
    print(", ".join(["epoch", "energy"] + obs_labels))
    print(", ".join([f"{-1}", f"{loss0}"] + [f"{v}" for v in obs_values]))

    def loss_fn(state, ctm_env_in, opt_context):
        ctm_args = opt_context["ctm_args"]
        opt_args = opt_context["opt_args"]

        # symmetrize and normalize
        # symm_site= make_d2_symm(state.parent_site)
        # symm_site= symm_site/torch.max(torch.abs(symm_site))
        symm_site = state.parent_site / torch.max(torch.abs(state.parent_site))
        state = IPEPS_D2SYM(symm_site)
        # state.parent_tensors[c]+= state.parent_tensors[c].permute(0,1,4,3,2)
        # state.parent_tensors[c]*= 1.0/torch.max(torch.abs(state.parent_tensors[c]))

        # possibly re-initialize the environment
        if opt_args.opt_ctm_reinit:
            init_env(state, ctm_env_in)

        # 1) compute environment by CTMRG
        ctm_env_out, *ctm_log = ctmrg.run(state, ctm_env_in, \
            conv_check=ctmrg_conv_energy, ctm_args=ctm_args)

        # 2) evaluate loss with converged environment
        loss = model.energy_2x1_1x2(state, ctm_env_out)

        return (loss, ctm_env_out, *ctm_log)

    @torch.no_grad()
    def obs_fn(state, ctm_env, opt_context):
        if ("line_search" in opt_context.keys() and not opt_context["line_search"]) \
            or not "line_search" in opt_context.keys():
            epoch = len(opt_context["loss_history"]["loss"])
            loss = opt_context["loss_history"]["loss"][-1]
            obs_values, obs_labels = model.eval_obs(state, ctm_env)
            print(", ".join([f"{epoch}", f"{loss}"] +
                            [f"{v}" for v in obs_values]))

    # optimize
    optimize_state(state, ctm_env, loss_fn, obs_fn=obs_fn)

    # compute final observables for the best variational state
    outputstatefile = args.out_prefix + "_state.json"
    state = read_ipeps_d2(outputstatefile)
    ctm_env = ENV(args.chi, state)
    init_env(state, ctm_env)
    ctm_env, *ctm_log = ctmrg.run(state, ctm_env, conv_check=ctmrg_conv_energy)
    loss0 = model.energy_2x1_1x2(state, ctm_env)
    obs_values, obs_labels = model.eval_obs(state, ctm_env)
    print(", ".join([f"{args.opt_max_iter}", f"{loss0}"] +
                    [f"{v}" for v in obs_values]))
Example #13
0
def main():
    cfg.configure(args)
    cfg.print_config()
    torch.set_num_threads(args.omp_cores)
    torch.manual_seed(args.seed)

    model = akltS2.AKLTS2()

    # initialize an ipeps
    # 1) define lattice-tiling function, that maps arbitrary vertex of square lattice
    # coord into one of coordinates within unit-cell of iPEPS ansatz
    if args.tiling == "BIPARTITE":

        def lattice_to_site(coord):
            vx = (coord[0] + abs(coord[0]) * 2) % 2
            vy = abs(coord[1])
            return ((vx + vy) % 2, 0)
    elif args.tiling == "4SITE":

        def lattice_to_site(coord):
            vx = (coord[0] + abs(coord[0]) * 2) % 2
            vy = (coord[1] + abs(coord[1]) * 2) % 2
            return (vx, vy)
    else:
        raise ValueError("Invalid tiling: "+str(args.tiling)+" Supported options: "\
            +"BIPARTITE, 2SITE, 4SITE, 8SITE")

    # initialize an ipeps
    if args.instate != None:
        state = read_ipeps(args.instate, vertexToSite=lattice_to_site)
        if args.bond_dim > max(state.get_aux_bond_dims()):
            # extend the auxiliary dimensions
            state = extend_bond_dim(state, args.bond_dim)
        state.add_noise(args.instate_noise)
    elif args.ipeps_init_type == 'RANDOM':
        bond_dim = args.bond_dim

        A = torch.rand((model.phys_dim, bond_dim, bond_dim, bond_dim, bond_dim),\
            dtype=cfg.global_args.dtype,device=cfg.global_args.device)
        B = torch.rand((model.phys_dim, bond_dim, bond_dim, bond_dim, bond_dim),\
            dtype=cfg.global_args.dtype,device=cfg.global_args.device)

        # normalization of initial random tensors
        A = A / torch.max(torch.abs(A))
        B = B / torch.max(torch.abs(B))

        sites = {(0, 0): A, (1, 0): B}

        if args.tiling == "4SITE":
            C= torch.rand((model.phys_dim, bond_dim, bond_dim, bond_dim, bond_dim),\
                dtype=cfg.global_args.dtype,device=cfg.global_args.device)
            D= torch.rand((model.phys_dim, bond_dim, bond_dim, bond_dim, bond_dim),\
                dtype=cfg.global_args.dtype,device=cfg.global_args.device)
            sites[(0, 1)] = C / torch.max(torch.abs(C))
            sites[(1, 1)] = D / torch.max(torch.abs(D))

        state = IPEPS(sites, vertexToSite=lattice_to_site)
    else:
        raise ValueError("Missing trial state: -instate=None and -ipeps_init_type= "\
            +str(args.ipeps_init_type)+" is not supported")

    print(state)

    def ctmrg_conv_f(state, env, history, ctm_args=cfg.ctm_args):
        with torch.no_grad():
            if not history:
                history = dict({"log": []})
            dist = float('inf')
            list_rdm = []
            for coord, site in state.sites.items():
                rdm2x1 = rdm.rdm2x1(coord, state, env)
                rdm1x2 = rdm.rdm1x2(coord, state, env)
                list_rdm.extend([rdm2x1, rdm1x2])

            # compute observables
            e_curr = model.energy_2x1_1x2(state, env)
            obs_values, obs_labels = model.eval_obs(state, env)
            print(", ".join([f"{len(history['log'])}", f"{e_curr}"] +
                            [f"{v}" for v in obs_values]))

            if len(history["log"]) > 1:
                dist = 0.
                for i in range(len(list_rdm)):
                    dist += torch.dist(list_rdm[i], history["rdm"][i],
                                       p=2).item()
            history["rdm"] = list_rdm
            history["log"].append(dist)
            if dist < ctm_args.ctm_conv_tol:
                log.info({
                    "history_length": len(history['log']),
                    "history": history['log']
                })
                return True, history
        return False, history

    ctm_env_init = ENV(args.chi, state)
    init_env(state, ctm_env_init)
    print(ctm_env_init)

    e_curr0 = model.energy_2x1_1x2(state, ctm_env_init)
    obs_values0, obs_labels = model.eval_obs(state, ctm_env_init)

    print(", ".join(["epoch", "energy"] + obs_labels))
    print(", ".join([f"{-1}", f"{e_curr0}"] + [f"{v}" for v in obs_values0]))

    ctm_env_init, *ctm_log = ctmrg.run(state,
                                       ctm_env_init,
                                       conv_check=ctmrg_conv_f)

    # environment diagnostics
    for c_loc, c_ten in ctm_env_init.C.items():
        u, s, v = torch.svd(c_ten, compute_uv=False)
        print(f"\n\nspectrum C[{c_loc}]")
        for i in range(args.chi):
            print(f"{i} {s[i]}")
Example #14
0
    def loss_fn(state, pess, ctm_env_in, opt_context):
        ctm_args = opt_context["ctm_args"]
        opt_args = opt_context["opt_args"]

        # possibly re-initialize the environment
        if opt_args.opt_ctm_reinit:
            init_env(state, ctm_env_in)

        # 0) combining ipess into the ipeps
        #for 1x1 tiling
        if args.tiling == "1x1":
            A, B, C, R_up, R_down = pess
            A = A / torch.max(torch.abs(A))
            B = B / torch.max(torch.abs(B))
            C = C / torch.max(torch.abs(C))
            R_up = R_up / torch.max(torch.abs(R_up))
            R_down = R_down / torch.max(torch.abs(R_down))

            unit = [A, B, C, R_up, R_down]
            unit = {(0, 0): unit}
            unit = OrderedDict(unit)
            test = IPESS(unit, vertexToSite=lattice_to_site)
            test.write_to_file(args.file, normalize=True)

            T1 = combine_ipess_into_ipeps(A, B, C, R_up, R_down)

            sites = {(0, 0): T1}
            state = IPEPS(sites, vertexToSite=lattice_to_site)

        if args.tiling == "2x2":
            A1, B1, C1, R1_up, R1_down,\
            A2, B2, C2, R2_up, R2_down,\
            A3, B3, C3, R3_up, R3_down,\
            A4, B4, C4, R4_up, R4_down = pess

            unit1 = [A1, B1, C1, R1_up, R1_down]
            unit2 = [A2, B2, C2, R2_up, R2_down]
            unit3 = [A3, B3, C3, R3_up, R3_down]
            unit4 = [A4, B4, C4, R4_up, R4_down]

            unit = {(0, 0): unit1}
            unit[(0, 1)] = unit2
            unit[(1, 0)] = unit3
            unit[(1, 1)] = unit4

            unit = OrderedDict(unit)
            test = IPESS(unit, vertexToSite=lattice_to_site)
            test.write_to_file(args.file, normalize=True)

            T1 = combine_ipess_into_ipeps(A1, B1, C1, R1_up, R1_down)
            T2 = combine_ipess_into_ipeps(A2, B2, C2, R2_up, R2_down)
            T3 = combine_ipess_into_ipeps(A3, B3, C3, R3_up, R3_down)
            T4 = combine_ipess_into_ipeps(A4, B4, C4, R4_up, R4_down)

            sites = {(0, 0): T1}
            sites[(1, 0)] = T2
            sites[(0, 1)] = T3
            sites[(1, 1)] = T4
            state = IPEPS(sites, vertexToSite=lattice_to_site)

        # 1) compute environment by CTMRG
        ctm_env_out, *ctm_log= ctmrg.run(state, ctm_env_in, \
            conv_check=ctmrg_conv_energy, ctm_args=ctm_args)

        # 2) evaluate loss with the converged environment
        loss = energy_f(state, ctm_env_out)
        #print(loss)

        return (loss, ctm_env_out, *ctm_log)
Example #15
0
def main():
    cfg.configure(args)
    cfg.print_config()
    torch.set_num_threads(args.omp_cores)
    torch.manual_seed(args.seed)
    torch.autograd.set_detect_anomaly(True)

    model = kagomej1.KagomeJ1(j1=args.j1, j2=args.j2)

    # initialize an ipeps
    # 1) define lattice-tiling function, that maps arbitrary vertex of square lattice
    # coord into one of coordinates within unit-cell of iPEPS ansatz
    if args.tiling == "2x2":

        def lattice_to_site(coord):
            vx = (coord[0] + abs(coord[0]) * 2) % 2
            vy = (coord[1] + abs(coord[1]) * 2) % 2
            return (vx, vy)

    elif args.tiling == "1x1":

        def lattice_to_site(coord):
            return (0, 0)

    elif args.tiling == "3x3":

        def lattice_to_site(coord):
            vx = (coord[0] + abs(coord[0]) * 3) % 3
            vy = (coord[1] + abs(coord[1]) * 3) % 3
            return (vx, vy)

    else:
        raise ValueError("Invalid tiling: "+str(args.tiling)+" Supported options: "\
            +"2x2 or 1x1")

    if args.instate != None:

        test = read_ipess(args.instate, vertexToSite=lattice_to_site)
        test.add_noise(args.instate_noise)
        unit = []
        for t in test.sites.values():
            for par in t:
                #par.requires_grad_(True)
                unit.append(par)

        if args.tiling == "2x2":

            A1, B1, C1, R1_up, R1_down,\
            A2, B2, C2, R2_up, R2_down,\
            A3, B3, C3, R3_up, R3_down,\
            A4, B4, C4, R4_up, R4_down= unit

            i = 0
            pess = OrderedDict()
            pess1 = OrderedDict()
            pess2 = OrderedDict()
            pess3 = OrderedDict()
            pess4 = OrderedDict()

            for par in unit:
                par.requires_grad_(True)

                if i < 5: pess1[i] = par
                elif i < 10: pess2[i - 5] = par
                elif i < 15: pess3[i - 10] = par
                else: pess4[i - 15] = par
                i += 1

            pess[(0, 0)] = pess1.values()
            pess[(1, 0)] = pess2.values()
            pess[(0, 1)] = pess3.values()
            pess[(1, 1)] = pess4.values()

            pess = OrderedDict(pess).values()

            T1 = combine_ipess_into_ipeps(A1, B1, C1, R1_up, R1_down)
            T2 = combine_ipess_into_ipeps(A2, B2, C2, R2_up, R2_down)
            T3 = combine_ipess_into_ipeps(A3, B3, C3, R3_up, R3_down)
            T4 = combine_ipess_into_ipeps(A4, B4, C4, R4_up, R4_down)

            sites = {(0, 0): T1}
            sites[(1, 0)] = T2
            sites[(0, 1)] = T3
            sites[(1, 1)] = T4

            state = IPEPS(sites, vertexToSite=lattice_to_site)

        if args.tiling == "1x1":
            A, B, C, R_up, R_down = unit
            pess = OrderedDict()
            dict1 = {0: A}
            dict1[1] = B
            dict1[2] = C
            dict1[3] = R_up
            dict1[4] = R_down
            pess[(0, 0)] = OrderedDict(dict1).values()
            pess = OrderedDict(pess).values()

            CR_u = torch.tensordot(C.clone(), R_up.clone(), ([0], [1]))
            BR_d = torch.tensordot(B.clone(), R_down.clone(), ([2], [1]))
            ABR_d = torch.tensordot(A.clone(), BR_d.clone(), ([2], [2]))

            T1 = torch.tensordot(CR_u, ABR_d, ([1], [4]))
            #T1 = T1.permute(4,6,0,5,3,1,2)
            T1 = T1.permute(4, 6, 0, 3, 1, 2, 5)
            T1 = T1.contiguous().view(
                T1.size()[0] * T1.size()[1] * T1.size()[2],
                T1.size()[3],
                T1.size()[4],
                T1.size()[5],
                T1.size()[6])
            T1 = T1 / torch.max(torch.abs(T1))

            sites = {(0, 0): T1}
            state = IPEPS(sites, vertexToSite=lattice_to_site)

        #if args.bond_dim > max(state.get_aux_bond_dims()):
        # extend the auxiliary dimensions
        #state = extend_bond_dim(state, args.bond_dim)
    elif args.opt_resume is not None:
        if args.tiling == "2x2":
            state = IPEPS(dict(), lX=2, lY=2)
        elif args.tiling == "1x1":
            state = IPEPS(dict(), lX=1, lY=1)

        state.load_checkpoint(args.opt_resume)
    elif args.ipeps_init_type == 'RANDOM':
        bond_dim = args.bond_dim

        pess = OrderedDict()

        T1, pess1 = initial_ipess(model, bond_dim)
        pess[(0, 0)] = pess1
        sites = {(0, 0): T1}

        if args.tiling == "2x2":

            T2, pess2 = initial_ipess(model, bond_dim)
            T3, pess3 = initial_ipess(model, bond_dim)
            T4, pess4 = initial_ipess(model, bond_dim)

            sites[(1, 0)] = T2
            sites[(0, 1)] = T3
            sites[(1, 1)] = T4

            pess[(1, 0)] = pess2
            pess[(0, 1)] = pess3
            pess[(1, 1)] = pess4

        if args.tiling == "3x3":

            T2, pess2 = initial_ipess(model, bond_dim)
            T3, pess3 = initial_ipess(model, bond_dim)
            T4, pess4 = initial_ipess(model, bond_dim)
            T5, pess5 = initial_ipess(model, bond_dim)
            T6, pess6 = initial_ipess(model, bond_dim)
            T7, pess7 = initial_ipess(model, bond_dim)
            T8, pess8 = initial_ipess(model, bond_dim)
            T9, pess9 = initial_ipess(model, bond_dim)

            sites[(0, 1)] = T2
            sites[(0, 2)] = T3
            sites[(1, 0)] = T4
            sites[(1, 1)] = T5
            sites[(1, 2)] = T6
            sites[(2, 0)] = T7
            sites[(2, 1)] = T8
            sites[(2, 2)] = T9

            pess[(0, 1)] = pess2
            pess[(0, 2)] = pess3
            pess[(1, 0)] = pess4
            pess[(1, 1)] = pess5
            pess[(1, 2)] = pess6
            pess[(2, 0)] = pess7
            pess[(2, 1)] = pess8
            pess[(2, 2)] = pess9

        pess = pess.values()
        state = IPEPS(sites, vertexToSite=lattice_to_site)

    else:
        raise ValueError("Missing trial state: -instate=None and -ipeps_init_type= "\
            +str(args.ipeps_init_type)+" is not supported")

    print(state)

    # 2) select the "energy" function
    if args.tiling == "1x1":
        energy_f = model.energy_2x2_1site

    elif args.tiling == "2x2":
        energy_f = model.energy_2x2_4site

    elif args.tiling == "3x3":
        energy_f = model.energy_2x2_9site

    else:
        raise ValueError("Invalid tiling: "+str(args.tiling)+" Supported options: "\
            +"BIPARTITE, 2SITE, 4SITE, 8SITE")

    @torch.no_grad()
    def ctmrg_conv_energy(state, env, history, ctm_args=cfg.ctm_args):
        if not history:
            history = []
        e_curr = energy_f(state, env)
        history.append(e_curr.item())

        if (len(history) > 1 and abs(history[-1]-history[-2]) < ctm_args.ctm_conv_tol)\
            or len(history) >= ctm_args.ctm_max_iter:
            log.info({"history_length": len(history), "history": history})
            return True, history
        return False, history

    ctm_env = ENV(args.chi, state)
    init_env(state, ctm_env)

    ctm_env, *ctm_log = ctmrg.run(state, ctm_env, conv_check=ctmrg_conv_energy)
    loss0 = energy_f(state, ctm_env)
    obs_values, obs_labels = model.eval_obs(state, ctm_env)
    print(obs_labels)
    print([f"{v}" for v in obs_values])
    print(", ".join(["epoch", "energy"]))
    print(", ".join([f"{-1}", f"{loss0}"]))

    #exit()

    def loss_fn(state, pess, ctm_env_in, opt_context):
        ctm_args = opt_context["ctm_args"]
        opt_args = opt_context["opt_args"]

        # possibly re-initialize the environment
        if opt_args.opt_ctm_reinit:
            init_env(state, ctm_env_in)

        # 0) combining ipess into the ipeps
        #for 1x1 tiling
        if args.tiling == "1x1":
            A, B, C, R_up, R_down = pess
            A = A / torch.max(torch.abs(A))
            B = B / torch.max(torch.abs(B))
            C = C / torch.max(torch.abs(C))
            R_up = R_up / torch.max(torch.abs(R_up))
            R_down = R_down / torch.max(torch.abs(R_down))

            unit = [A, B, C, R_up, R_down]
            unit = {(0, 0): unit}
            unit = OrderedDict(unit)
            test = IPESS(unit, vertexToSite=lattice_to_site)
            test.write_to_file(args.file, normalize=True)

            T1 = combine_ipess_into_ipeps(A, B, C, R_up, R_down)

            sites = {(0, 0): T1}
            state = IPEPS(sites, vertexToSite=lattice_to_site)

        if args.tiling == "2x2":
            A1, B1, C1, R1_up, R1_down,\
            A2, B2, C2, R2_up, R2_down,\
            A3, B3, C3, R3_up, R3_down,\
            A4, B4, C4, R4_up, R4_down = pess

            unit1 = [A1, B1, C1, R1_up, R1_down]
            unit2 = [A2, B2, C2, R2_up, R2_down]
            unit3 = [A3, B3, C3, R3_up, R3_down]
            unit4 = [A4, B4, C4, R4_up, R4_down]

            unit = {(0, 0): unit1}
            unit[(0, 1)] = unit2
            unit[(1, 0)] = unit3
            unit[(1, 1)] = unit4

            unit = OrderedDict(unit)
            test = IPESS(unit, vertexToSite=lattice_to_site)
            test.write_to_file(args.file, normalize=True)

            T1 = combine_ipess_into_ipeps(A1, B1, C1, R1_up, R1_down)
            T2 = combine_ipess_into_ipeps(A2, B2, C2, R2_up, R2_down)
            T3 = combine_ipess_into_ipeps(A3, B3, C3, R3_up, R3_down)
            T4 = combine_ipess_into_ipeps(A4, B4, C4, R4_up, R4_down)

            sites = {(0, 0): T1}
            sites[(1, 0)] = T2
            sites[(0, 1)] = T3
            sites[(1, 1)] = T4
            state = IPEPS(sites, vertexToSite=lattice_to_site)

        # 1) compute environment by CTMRG
        ctm_env_out, *ctm_log= ctmrg.run(state, ctm_env_in, \
            conv_check=ctmrg_conv_energy, ctm_args=ctm_args)

        # 2) evaluate loss with the converged environment
        loss = energy_f(state, ctm_env_out)
        #print(loss)

        return (loss, ctm_env_out, *ctm_log)

    @torch.no_grad()
    def obs_fn(state, ctm_env, opt_context):
        if ("line_search" in opt_context.keys() and not opt_context["line_search"]) \
            or not "line_search" in opt_context.keys():
            epoch = len(opt_context["loss_history"]["loss"])
            loss = opt_context["loss_history"]["loss"][-1]
            #obs_values, obs_labels = model.eval_obs(state,ctm_env)
            #print(", ".join([f"{epoch}",f"{loss}"]+[f"{v}" for v in obs_values]))
            #print(state.get_parameters())
            print(", ".join([f"{epoch}", f"{loss}"]))
            log.info("Norm(sites): " +
                     ", ".join([f"{t.norm()}"
                                for c, t in state.sites.items()]))

    # optimize
    print("Start optimization")
    optimize_state(state, ctm_env, loss_fn, obs_fn=obs_fn, parameters=pess)

    # compute final observables for the best variational state
    outputstatefile = args.out_prefix + "_state.json"
    state = read_ipeps(outputstatefile, vertexToSite=state.vertexToSite)
    ctm_env = ENV(args.chi, state)
    init_env(state, ctm_env)
    ctm_env, *ctm_log = ctmrg.run(state, ctm_env, conv_check=ctmrg_conv_energy)
    opt_energy = energy_f(state, ctm_env)
    #obs_values, obs_labels = model.eval_obs(state,ctm_env)
    #print(", ".join([f"{args.opt_max_iter}",f"{opt_energy}"]+[f"{v}" for v in obs_values]))
    print("Enegy", opt_energy)
Example #16
0
def main():
    cfg.configure(args)
    cfg.print_config()
    torch.set_num_threads(args.omp_cores)
    torch.manual_seed(args.seed)

    model = j1j2.J1J2(j1=args.j1, j2=args.j2)
    
    # initialize an ipeps
    # 1) define lattice-tiling function, that maps arbitrary vertex of square lattice
    # coord into one of coordinates within unit-cell of iPEPS ansatz    
    def lattice_to_site(coord):
        return (0, 0)

    if args.instate!=None:
        state = read_ipeps(args.instate, vertexToSite=lattice_to_site)
        if args.bond_dim > max(state.get_aux_bond_dims()):
            # extend the auxiliary dimensions
            state = extend_bond_dim(state, args.bond_dim)
        state.add_noise(args.instate_noise)
    elif args.opt_resume is not None:
        state= IPEPS(dict(), lX=1, lY=1, vertexToSite=lattice_to_site)
        state.load_checkpoint(args.opt_resume)
    elif args.ipeps_init_type=='RANDOM':
        bond_dim = args.bond_dim
        
        A = torch.rand((model.phys_dim, bond_dim, bond_dim, bond_dim, bond_dim),\
            dtype=cfg.global_args.dtype,device=cfg.global_args.device)
        # normalization of initial random tensors
        A = A/torch.max(torch.abs(A))
        sites = {(0,0): A}
        state = IPEPS(sites, vertexToSite=lattice_to_site)
    else:
        raise ValueError("Missing trial state: -instate=None and -ipeps_init_type= "\
            +str(args.ipeps_init_type)+" is not supported")

    print(state)
    
    # 2) select the "energy" function
    energy_f=model.energy_2x2_1site_BP

    @torch.no_grad()
    def ctmrg_conv_energy(state, env, history, ctm_args=cfg.ctm_args):
        if not history:
            history=[]
        e_curr = energy_f(state, env)
        history.append(e_curr.item())

        if (len(history) > 1 and abs(history[-1]-history[-2]) < ctm_args.ctm_conv_tol)\
            or len(history) >= ctm_args.ctm_max_iter:
            log.info({"history_length": len(history), "history": history})
            return True, history
        return False, history

    # 3) choose C4v irrep (or their mix)
    def symmetrize(state):
        A= state.site((0,0))
        A_symm= make_c4v_symm_A1(A)
        symm_state= IPEPS({(0,0): A_symm}, vertexToSite=state.vertexToSite)
        return symm_state

    symm_state= symmetrize(state)
    ctm_env= ENV(args.chi, symm_state)
    init_env(symm_state, ctm_env)

    ctm_env, *ctm_log= ctmrg.run(symm_state, ctm_env, conv_check=ctmrg_conv_energy)
    loss0 = energy_f(symm_state, ctm_env)
    obs_values, obs_labels = model.eval_obs(symm_state,ctm_env)
    print(", ".join(["epoch","energy"]+obs_labels))
    print(", ".join([f"{-1}",f"{loss0}"]+[f"{v}" for v in obs_values]))

    def loss_fn(state, ctm_env_in, opt_context):
        ctm_args= opt_context["ctm_args"]
        opt_args= opt_context["opt_args"]

        symm_state= symmetrize(state)

        # possibly re-initialize the environment
        if opt_args.opt_ctm_reinit:
            init_env(symm_state, ctm_env_in)

        # 1) compute environment by CTMRG
        ctm_env_out, *ctm_log= ctmrg.run(symm_state, ctm_env_in, \
            conv_check=ctmrg_conv_energy, ctm_args=ctm_args)

        # 2) evaluate loss with the converged environment
        loss = energy_f(symm_state, ctm_env_out)
        
        return (loss, ctm_env_out, *ctm_log)

    @torch.no_grad()
    def obs_fn(state, ctm_env, opt_context):
        if ("line_search" in opt_context.keys() and not opt_context["line_search"]) \
            or not "line_search" in opt_context.keys():
            symm_state= symmetrize(state)
            epoch= len(opt_context["loss_history"]["loss"]) 
            loss= opt_context["loss_history"]["loss"][-1]
            obs_values, obs_labels = model.eval_obs(symm_state,ctm_env)
            print(", ".join([f"{epoch}",f"{loss}"]+[f"{v}" for v in obs_values]+\
                [f"{torch.max(torch.abs(symm_state.site((0,0))))}"]))

    # optimize
    optimize_state(state, ctm_env, loss_fn, obs_fn=obs_fn)

    # compute final observables for the best variational state
    outputstatefile= args.out_prefix+"_state.json"
    state= read_ipeps(outputstatefile, vertexToSite=state.vertexToSite)
    symm_state= symmetrize(state)
    ctm_env = ENV(args.chi, symm_state)
    init_env(symm_state, ctm_env)
    ctm_env, *ctm_log= ctmrg.run(symm_state, ctm_env, conv_check=ctmrg_conv_energy)
    opt_energy = energy_f(symm_state,ctm_env)
    obs_values, obs_labels = model.eval_obs(symm_state,ctm_env)
    print(", ".join([f"{args.opt_max_iter}",f"{opt_energy}"]+[f"{v}" for v in obs_values]))