Esempio n. 1
0
def main():
    cfg.configure(args)
    cfg.print_config()
    torch.set_num_threads(args.omp_cores)
    torch.manual_seed(args.seed)

    model = akltS2.AKLTS2_C4V_BIPARTITE()

    # initialize an ipeps
    if args.instate != None:
        state = read_ipeps_c4v(args.instate)
        if args.bond_dim > max(state.get_aux_bond_dims()):
            # extend the auxiliary dimensions
            state = extend_bond_dim(state, args.bond_dim)
        state.add_noise(args.instate_noise)
        state.sites[(0, 0)] = state.sites[(0, 0)] / torch.max(
            torch.abs(state.sites[(0, 0)]))
    elif args.opt_resume is not None:
        state = IPEPS_C4V()
        state.load_checkpoint(args.opt_resume)
    elif args.ipeps_init_type == 'RANDOM':
        bond_dim = args.bond_dim
        A= torch.rand((model.phys_dim, bond_dim, bond_dim, bond_dim, bond_dim),\
            dtype=cfg.global_args.torch_dtype, device=cfg.global_args.device)
        A = make_c4v_symm(A)
        A = A / torch.max(torch.abs(A))
        state = IPEPS_C4V(A)
    else:
        raise ValueError("Missing trial state: -instate=None and -ipeps_init_type= "\
            +str(args.ipeps_init_type)+" is not supported")

    print(state)

    @torch.no_grad()
    def ctmrg_conv_rho2x1dist(state, env, history, ctm_args=cfg.ctm_args):
        if not history:
            history = dict({"log": []})
        rdm2x1 = rdm2x1_sl(state, env, force_cpu=ctm_args.conv_check_cpu)
        dist = float('inf')
        if len(history["log"]) > 0:
            dist = torch.dist(rdm2x1, history["rdm"], p=2).item()
        history["rdm"] = rdm2x1
        history["log"].append(dist)
        if dist < ctm_args.ctm_conv_tol or len(
                history["log"]) >= ctm_args.ctm_max_iter:
            log.info({
                "history_length": len(history['log']),
                "history": history['log']
            })
            return True, history
        return False, history

    ctm_env = ENV_C4V(args.chi, state)
    init_env(state, ctm_env)
    ctm_env, *ctm_log = ctmrg_c4v.run(state,
                                      ctm_env,
                                      conv_check=ctmrg_conv_rho2x1dist)

    loss = model.energy_1x1(state, ctm_env)
    obs_values, obs_labels = model.eval_obs(state, ctm_env)
    print(", ".join(["epoch", "energy"] + obs_labels))
    print(", ".join([f"{-1}", f"{loss}"] + [f"{v}" for v in obs_values]))

    def loss_fn(state, ctm_env_in, opt_context):
        # symmetrize on-site tensor
        state = IPEPS_C4V(state.sites[(0, 0)])
        state.sites[(0, 0)] = make_c4v_symm(state.sites[(0, 0)])
        state.sites[(0,
                     0)] = state.sites[(0, 0)] / torch.max(state.sites[(0, 0)])

        # possibly re-initialize the environment
        if cfg.opt_args.opt_ctm_reinit:
            init_env(state, ctm_env_in)

        # 1) compute environment by CTMRG
        ctm_env_out, *ctm_log = ctmrg_c4v.run(state,
                                              ctm_env_in,
                                              conv_check=ctmrg_conv_rho2x1dist)
        loss = model.energy_1x1(state, ctm_env_out)

        return (loss, ctm_env_out, *ctm_log)

    @torch.no_grad()
    def obs_fn(state, ctm_env, opt_context):
        epoch = len(opt_context["loss_history"]["loss"])
        loss = opt_context["loss_history"]["loss"][-1]
        obs_values, obs_labels = model.eval_obs(state, ctm_env)
        print(", ".join([f"{epoch}", f"{loss}"] + [f"{v}"
                                                   for v in obs_values]))

    # optimize
    optimize_state(state, ctm_env, loss_fn, obs_fn=obs_fn)

    # compute final observables for the best variational state
    outputstatefile = args.out_prefix + "_state.json"
    state = read_ipeps_c4v(outputstatefile)
    ctm_env = ENV_C4V(args.chi, state)
    init_env(state, ctm_env)
    ctm_env, *ctm_log = ctmrg_c4v.run(state,
                                      ctm_env,
                                      conv_check=ctmrg_conv_rho2x1dist)
    opt_energy = model.energy_1x1(state, ctm_env)
    obs_values, obs_labels = model.eval_obs(state, ctm_env)
    print(", ".join([f"{args.opt_max_iter}", f"{opt_energy}"] +
                    [f"{v}" for v in obs_values]))
Esempio n. 2
0
def main():
    cfg.configure(args)
    cfg.print_config()
    torch.set_num_threads(args.omp_cores)
    torch.manual_seed(args.seed)

    model = hb.HB(spin_s=args.spinS, j1=args.j1, k1=args.k1)

    # initialize an ipeps
    # 1) define lattice-tiling function, that maps arbitrary vertex of square lattice
    # coord into one of coordinates within unit-cell of iPEPS ansatz
    if args.tiling == "BIPARTITE":

        def lattice_to_site(coord):
            vx = (coord[0] + abs(coord[0]) * 2) % 2
            vy = abs(coord[1])
            return ((vx + vy) % 2, 0)
    elif args.tiling == "2SITE":

        def lattice_to_site(coord):
            vx = (coord[0] + abs(coord[0]) * 2) % 2
            vy = (coord[1] + abs(coord[1]) * 1) % 1
            return (vx, vy)
    elif args.tiling == "4SITE":

        def lattice_to_site(coord):
            vx = (coord[0] + abs(coord[0]) * 2) % 2
            vy = (coord[1] + abs(coord[1]) * 2) % 2
            return (vx, vy)
    elif args.tiling == "8SITE":

        def lattice_to_site(coord):
            shift_x = coord[0] + 2 * (coord[1] // 2)
            vx = shift_x % 4
            vy = coord[1] % 2
            return (vx, vy)
    else:
        raise ValueError("Invalid tiling: "+str(args.tiling)+" Supported options: "\
            +"BIPARTITE, 2SITE, 4SITE, 8SITE")

    if args.instate != None:
        state = read_ipeps(args.instate, vertexToSite=lattice_to_site)
        if args.bond_dim > max(state.get_aux_bond_dims()):
            # extend the auxiliary dimensions
            state = extend_bond_dim(state, args.bond_dim)
        state.add_noise(args.instate_noise)
    elif args.opt_resume is not None:
        if args.tiling == "BIPARTITE" or args.tiling == "2SITE":
            state = IPEPS(dict(), lX=2, lY=1)
        elif args.tiling == "4SITE":
            state = IPEPS(dict(), lX=2, lY=2)
        elif args.tiling == "8SITE":
            state = IPEPS(dict(), lX=4, lY=2)
        state.load_checkpoint(args.opt_resume)
    elif args.ipeps_init_type == 'RANDOM':
        bond_dim = args.bond_dim

        A = torch.rand((model.phys_dim, bond_dim, bond_dim, bond_dim, bond_dim),\
            dtype=cfg.global_args.dtype,device=cfg.global_args.device)
        B = torch.rand((model.phys_dim, bond_dim, bond_dim, bond_dim, bond_dim),\
            dtype=cfg.global_args.dtype,device=cfg.global_args.device)

        # normalization of initial random tensors
        A = A / torch.max(torch.abs(A))
        B = B / torch.max(torch.abs(B))

        sites = {(0, 0): A, (1, 0): B}

        if args.tiling == "4SITE" or args.tiling == "8SITE":
            C= torch.rand((model.phys_dim, bond_dim, bond_dim, bond_dim, bond_dim),\
                dtype=cfg.global_args.dtype,device=cfg.global_args.device)
            D= torch.rand((model.phys_dim, bond_dim, bond_dim, bond_dim, bond_dim),\
                dtype=cfg.global_args.dtype,device=cfg.global_args.device)
            sites[(0, 1)] = C / torch.max(torch.abs(C))
            sites[(1, 1)] = D / torch.max(torch.abs(D))

        if args.tiling == "8SITE":
            E= torch.rand((model.phys_dim, bond_dim, bond_dim, bond_dim, bond_dim),\
                dtype=cfg.global_args.dtype,device=cfg.global_args.device)
            F= torch.rand((model.phys_dim, bond_dim, bond_dim, bond_dim, bond_dim),\
                dtype=cfg.global_args.dtype,device=cfg.global_args.device)
            G= torch.rand((model.phys_dim, bond_dim, bond_dim, bond_dim, bond_dim),\
                dtype=cfg.global_args.dtype,device=cfg.global_args.device)
            H= torch.rand((model.phys_dim, bond_dim, bond_dim, bond_dim, bond_dim),\
                dtype=cfg.global_args.dtype,device=cfg.global_args.device)
            sites[(2, 0)] = E / torch.max(torch.abs(E))
            sites[(3, 0)] = F / torch.max(torch.abs(F))
            sites[(2, 1)] = G / torch.max(torch.abs(G))
            sites[(3, 1)] = H / torch.max(torch.abs(H))

        state = IPEPS(sites, vertexToSite=lattice_to_site)
    else:
        raise ValueError("Missing trial state: -instate=None and -ipeps_init_type= "\
            +str(args.ipeps_init_type)+" is not supported")

    print(state)

    # 2) select the "energy" function
    if args.tiling == "BIPARTITE" or args.tiling == "2SITE":
        energy_f = model.energy_2x1_1x2
    elif args.tiling == "4SITE":
        energy_f = model.energy_2x1_1x2
    elif args.tiling == "8SITE":
        energy_f = model.energy_2x1_1x2
    else:
        raise ValueError("Invalid tiling: "+str(args.tiling)+" Supported options: "\
            +"BIPARTITE, 2SITE, 4SITE, 8SITE")

    @torch.no_grad()
    def ctmrg_conv_energy(state, env, history, ctm_args=cfg.ctm_args):
        if not history:
            history = []
        e_curr = energy_f(state, env)
        history.append(e_curr.item())

        if (len(history) > 1 and abs(history[-1]-history[-2]) < ctm_args.ctm_conv_tol)\
            or len(history) >= ctm_args.ctm_max_iter:
            log.info({"history_length": len(history), "history": history})
            return True, history
        return False, history

    ctm_env = ENV(args.chi, state)
    init_env(state, ctm_env)

    ctm_env, *ctm_log = ctmrg.run(state, ctm_env, conv_check=ctmrg_conv_energy)
    loss0 = energy_f(state, ctm_env)
    obs_values, obs_labels = model.eval_obs(state, ctm_env)
    print(", ".join(["epoch", "energy"] + obs_labels))
    print(", ".join([f"{-1}", f"{loss0}"] + [f"{v}" for v in obs_values]))

    def loss_fn(state, ctm_env_in, opt_context):
        # possibly re-initialize the environment
        if cfg.opt_args.opt_ctm_reinit:
            init_env(state, ctm_env_in)

        # 1) compute environment by CTMRG
        ctm_env_out, *ctm_log = ctmrg.run(state,
                                          ctm_env_in,
                                          conv_check=ctmrg_conv_energy)
        loss = energy_f(state, ctm_env_out)

        return (loss, ctm_env_out, *ctm_log)

    @torch.no_grad()
    def obs_fn(state, ctm_env, opt_context):
        epoch = len(opt_context["loss_history"]["loss"])
        loss = opt_context["loss_history"]["loss"][-1]
        obs_values, obs_labels = model.eval_obs(state, ctm_env)
        print(", ".join([f"{epoch}", f"{loss}"] + [f"{v}"
                                                   for v in obs_values]))

    # optimize
    optimize_state(state, ctm_env, loss_fn, obs_fn=obs_fn)

    # compute final observables for the best variational state
    outputstatefile = args.out_prefix + "_state.json"
    state = read_ipeps(outputstatefile, vertexToSite=state.vertexToSite)
    ctm_env = ENV(args.chi, state)
    init_env(state, ctm_env)
    ctm_env, *ctm_log = ctmrg.run(state, ctm_env, conv_check=ctmrg_conv_energy)
    opt_energy = energy_f(state, ctm_env)
    obs_values, obs_labels = model.eval_obs(state, ctm_env)
    print(", ".join([f"{args.opt_max_iter}", f"{opt_energy}"] +
                    [f"{v}" for v in obs_values]))
Esempio n. 3
0
def main():
    cfg.configure(args)
    cfg.print_config()
    torch.set_num_threads(args.omp_cores)
    torch.manual_seed(args.seed)

    model = ising.ISING(hx=args.hx, q=args.q)

    # initialize an ipeps
    # 1) define lattice-tiling function, that maps arbitrary vertex of square lattice
    # coord into one of coordinates within unit-cell of iPEPS ansatz
    if args.instate != None:
        state = read_ipeps(args.instate, vertexToSite=None)
        if args.bond_dim > max(state.get_aux_bond_dims()):
            # extend the auxiliary dimensions
            state = extend_bond_dim(state, args.bond_dim)
        state.add_noise(args.instate_noise)
    elif args.opt_resume is not None:
        state = IPEPS(dict(), lX=1, lY=1)
        state.load_checkpoint(args.opt_resume)
    elif args.ipeps_init_type == 'RANDOM':
        bond_dim = args.bond_dim
        A = torch.rand((model.phys_dim, bond_dim, bond_dim, bond_dim, bond_dim),\
            dtype=cfg.global_args.dtype,device=cfg.global_args.device)
        # normalization of initial random tensors
        A = A / torch.max(torch.abs(A))
        sites = {(0, 0): A}
        state = IPEPS(sites)
    else:
        raise ValueError("Missing trial state: -instate=None and -ipeps_init_type= "\
            +str(args.ipeps_init_type)+" is not supported")

    print(state)

    @torch.no_grad()
    def ctmrg_conv_energy(state, env, history, ctm_args=cfg.ctm_args):
        if not history:
            history = []
        e_curr = model.energy_1x1(state, env)
        history.append(e_curr.item())

        if (len(history) > 1 and abs(history[-1]-history[-2]) < ctm_args.ctm_conv_tol)\
            or len(history) >= ctm_args.ctm_max_iter:
            log.info({"history_length": len(history), "history": history})
            return True, history
        return False, history

    ctm_env = ENV(args.chi, state)
    init_env(state, ctm_env)

    ctm_env, *ctm_log = ctmrg.run(state, ctm_env, conv_check=ctmrg_conv_energy)
    loss = model.energy_1x1(state, ctm_env)
    obs_values, obs_labels = model.eval_obs(state, ctm_env)
    print(", ".join(["epoch", "energy"] + obs_labels))
    print(", ".join([f"{-1}", f"{loss}"] + [f"{v}" for v in obs_values]))

    def loss_fn(state, ctm_env_in, opt_context):
        # possibly re-initialize the environment
        if cfg.opt_args.opt_ctm_reinit:
            init_env(state, ctm_env_in)

        # 1) compute environment by CTMRG
        ctm_env_out, *ctm_log = ctmrg.run(state,
                                          ctm_env_in,
                                          conv_check=ctmrg_conv_energy)
        loss = model.energy_1x1(state, ctm_env_out)

        return (loss, ctm_env_out, *ctm_log)

    @torch.no_grad()
    def obs_fn(state, ctm_env, opt_context):
        epoch = len(opt_context["loss_history"]["loss"])
        loss = opt_context["loss_history"]["loss"][-1]
        obs_values, obs_labels = model.eval_obs(state, ctm_env)
        print(", ".join([f"{epoch}", f"{loss}"] + [f"{v}"
                                                   for v in obs_values]))

    # optimize
    optimize_state(state, ctm_env, loss_fn, obs_fn=obs_fn)

    # compute final observables for the best variational state
    outputstatefile = args.out_prefix + "_state.json"
    state = read_ipeps(outputstatefile)
    ctm_env = ENV(args.chi, state)
    init_env(state, ctm_env)
    ctm_env, *ctm_log = ctmrg.run(state, ctm_env, conv_check=ctmrg_conv_energy)
    opt_energy = model.energy_1x1(state, ctm_env)
    obs_values, obs_labels = model.eval_obs(state, ctm_env)
    print(", ".join([f"{args.opt_max_iter}", f"{opt_energy}"] +
                    [f"{v}" for v in obs_values]))
Esempio n. 4
0
def main():
    cfg.configure(args)
    cfg.print_config()
    torch.set_num_threads(args.omp_cores)
    torch.manual_seed(args.seed)
    torch.autograd.set_detect_anomaly(True)

    model = kagomej1.KagomeJ1(j1=args.j1, j2=args.j2)

    # initialize an ipeps
    # 1) define lattice-tiling function, that maps arbitrary vertex of square lattice
    # coord into one of coordinates within unit-cell of iPEPS ansatz
    if args.tiling == "2x2":

        def lattice_to_site(coord):
            vx = (coord[0] + abs(coord[0]) * 2) % 2
            vy = (coord[1] + abs(coord[1]) * 2) % 2
            return (vx, vy)

    elif args.tiling == "1x1":

        def lattice_to_site(coord):
            return (0, 0)

    elif args.tiling == "3x3":

        def lattice_to_site(coord):
            vx = (coord[0] + abs(coord[0]) * 3) % 3
            vy = (coord[1] + abs(coord[1]) * 3) % 3
            return (vx, vy)

    else:
        raise ValueError("Invalid tiling: "+str(args.tiling)+" Supported options: "\
            +"2x2 or 1x1")

    if args.instate != None:

        test = read_ipess(args.instate, vertexToSite=lattice_to_site)
        test.add_noise(args.instate_noise)
        unit = []
        for t in test.sites.values():
            for par in t:
                #par.requires_grad_(True)
                unit.append(par)

        if args.tiling == "2x2":

            A1, B1, C1, R1_up, R1_down,\
            A2, B2, C2, R2_up, R2_down,\
            A3, B3, C3, R3_up, R3_down,\
            A4, B4, C4, R4_up, R4_down= unit

            i = 0
            pess = OrderedDict()
            pess1 = OrderedDict()
            pess2 = OrderedDict()
            pess3 = OrderedDict()
            pess4 = OrderedDict()

            for par in unit:
                par.requires_grad_(True)

                if i < 5: pess1[i] = par
                elif i < 10: pess2[i - 5] = par
                elif i < 15: pess3[i - 10] = par
                else: pess4[i - 15] = par
                i += 1

            pess[(0, 0)] = pess1.values()
            pess[(1, 0)] = pess2.values()
            pess[(0, 1)] = pess3.values()
            pess[(1, 1)] = pess4.values()

            pess = OrderedDict(pess).values()

            T1 = combine_ipess_into_ipeps(A1, B1, C1, R1_up, R1_down)
            T2 = combine_ipess_into_ipeps(A2, B2, C2, R2_up, R2_down)
            T3 = combine_ipess_into_ipeps(A3, B3, C3, R3_up, R3_down)
            T4 = combine_ipess_into_ipeps(A4, B4, C4, R4_up, R4_down)

            sites = {(0, 0): T1}
            sites[(1, 0)] = T2
            sites[(0, 1)] = T3
            sites[(1, 1)] = T4

            state = IPEPS(sites, vertexToSite=lattice_to_site)

        if args.tiling == "1x1":
            A, B, C, R_up, R_down = unit
            pess = OrderedDict()
            dict1 = {0: A}
            dict1[1] = B
            dict1[2] = C
            dict1[3] = R_up
            dict1[4] = R_down
            pess[(0, 0)] = OrderedDict(dict1).values()
            pess = OrderedDict(pess).values()

            CR_u = torch.tensordot(C.clone(), R_up.clone(), ([0], [1]))
            BR_d = torch.tensordot(B.clone(), R_down.clone(), ([2], [1]))
            ABR_d = torch.tensordot(A.clone(), BR_d.clone(), ([2], [2]))

            T1 = torch.tensordot(CR_u, ABR_d, ([1], [4]))
            #T1 = T1.permute(4,6,0,5,3,1,2)
            T1 = T1.permute(4, 6, 0, 3, 1, 2, 5)
            T1 = T1.contiguous().view(
                T1.size()[0] * T1.size()[1] * T1.size()[2],
                T1.size()[3],
                T1.size()[4],
                T1.size()[5],
                T1.size()[6])
            T1 = T1 / torch.max(torch.abs(T1))

            sites = {(0, 0): T1}
            state = IPEPS(sites, vertexToSite=lattice_to_site)

        #if args.bond_dim > max(state.get_aux_bond_dims()):
        # extend the auxiliary dimensions
        #state = extend_bond_dim(state, args.bond_dim)
    elif args.opt_resume is not None:
        if args.tiling == "2x2":
            state = IPEPS(dict(), lX=2, lY=2)
        elif args.tiling == "1x1":
            state = IPEPS(dict(), lX=1, lY=1)

        state.load_checkpoint(args.opt_resume)
    elif args.ipeps_init_type == 'RANDOM':
        bond_dim = args.bond_dim

        pess = OrderedDict()

        T1, pess1 = initial_ipess(model, bond_dim)
        pess[(0, 0)] = pess1
        sites = {(0, 0): T1}

        if args.tiling == "2x2":

            T2, pess2 = initial_ipess(model, bond_dim)
            T3, pess3 = initial_ipess(model, bond_dim)
            T4, pess4 = initial_ipess(model, bond_dim)

            sites[(1, 0)] = T2
            sites[(0, 1)] = T3
            sites[(1, 1)] = T4

            pess[(1, 0)] = pess2
            pess[(0, 1)] = pess3
            pess[(1, 1)] = pess4

        if args.tiling == "3x3":

            T2, pess2 = initial_ipess(model, bond_dim)
            T3, pess3 = initial_ipess(model, bond_dim)
            T4, pess4 = initial_ipess(model, bond_dim)
            T5, pess5 = initial_ipess(model, bond_dim)
            T6, pess6 = initial_ipess(model, bond_dim)
            T7, pess7 = initial_ipess(model, bond_dim)
            T8, pess8 = initial_ipess(model, bond_dim)
            T9, pess9 = initial_ipess(model, bond_dim)

            sites[(0, 1)] = T2
            sites[(0, 2)] = T3
            sites[(1, 0)] = T4
            sites[(1, 1)] = T5
            sites[(1, 2)] = T6
            sites[(2, 0)] = T7
            sites[(2, 1)] = T8
            sites[(2, 2)] = T9

            pess[(0, 1)] = pess2
            pess[(0, 2)] = pess3
            pess[(1, 0)] = pess4
            pess[(1, 1)] = pess5
            pess[(1, 2)] = pess6
            pess[(2, 0)] = pess7
            pess[(2, 1)] = pess8
            pess[(2, 2)] = pess9

        pess = pess.values()
        state = IPEPS(sites, vertexToSite=lattice_to_site)

    else:
        raise ValueError("Missing trial state: -instate=None and -ipeps_init_type= "\
            +str(args.ipeps_init_type)+" is not supported")

    print(state)

    # 2) select the "energy" function
    if args.tiling == "1x1":
        energy_f = model.energy_2x2_1site

    elif args.tiling == "2x2":
        energy_f = model.energy_2x2_4site

    elif args.tiling == "3x3":
        energy_f = model.energy_2x2_9site

    else:
        raise ValueError("Invalid tiling: "+str(args.tiling)+" Supported options: "\
            +"BIPARTITE, 2SITE, 4SITE, 8SITE")

    @torch.no_grad()
    def ctmrg_conv_energy(state, env, history, ctm_args=cfg.ctm_args):
        if not history:
            history = []
        e_curr = energy_f(state, env)
        history.append(e_curr.item())

        if (len(history) > 1 and abs(history[-1]-history[-2]) < ctm_args.ctm_conv_tol)\
            or len(history) >= ctm_args.ctm_max_iter:
            log.info({"history_length": len(history), "history": history})
            return True, history
        return False, history

    ctm_env = ENV(args.chi, state)
    init_env(state, ctm_env)

    ctm_env, *ctm_log = ctmrg.run(state, ctm_env, conv_check=ctmrg_conv_energy)
    loss0 = energy_f(state, ctm_env)
    obs_values, obs_labels = model.eval_obs(state, ctm_env)
    print(obs_labels)
    print([f"{v}" for v in obs_values])
    print(", ".join(["epoch", "energy"]))
    print(", ".join([f"{-1}", f"{loss0}"]))

    #exit()

    def loss_fn(state, pess, ctm_env_in, opt_context):
        ctm_args = opt_context["ctm_args"]
        opt_args = opt_context["opt_args"]

        # possibly re-initialize the environment
        if opt_args.opt_ctm_reinit:
            init_env(state, ctm_env_in)

        # 0) combining ipess into the ipeps
        #for 1x1 tiling
        if args.tiling == "1x1":
            A, B, C, R_up, R_down = pess
            A = A / torch.max(torch.abs(A))
            B = B / torch.max(torch.abs(B))
            C = C / torch.max(torch.abs(C))
            R_up = R_up / torch.max(torch.abs(R_up))
            R_down = R_down / torch.max(torch.abs(R_down))

            unit = [A, B, C, R_up, R_down]
            unit = {(0, 0): unit}
            unit = OrderedDict(unit)
            test = IPESS(unit, vertexToSite=lattice_to_site)
            test.write_to_file(args.file, normalize=True)

            T1 = combine_ipess_into_ipeps(A, B, C, R_up, R_down)

            sites = {(0, 0): T1}
            state = IPEPS(sites, vertexToSite=lattice_to_site)

        if args.tiling == "2x2":
            A1, B1, C1, R1_up, R1_down,\
            A2, B2, C2, R2_up, R2_down,\
            A3, B3, C3, R3_up, R3_down,\
            A4, B4, C4, R4_up, R4_down = pess

            unit1 = [A1, B1, C1, R1_up, R1_down]
            unit2 = [A2, B2, C2, R2_up, R2_down]
            unit3 = [A3, B3, C3, R3_up, R3_down]
            unit4 = [A4, B4, C4, R4_up, R4_down]

            unit = {(0, 0): unit1}
            unit[(0, 1)] = unit2
            unit[(1, 0)] = unit3
            unit[(1, 1)] = unit4

            unit = OrderedDict(unit)
            test = IPESS(unit, vertexToSite=lattice_to_site)
            test.write_to_file(args.file, normalize=True)

            T1 = combine_ipess_into_ipeps(A1, B1, C1, R1_up, R1_down)
            T2 = combine_ipess_into_ipeps(A2, B2, C2, R2_up, R2_down)
            T3 = combine_ipess_into_ipeps(A3, B3, C3, R3_up, R3_down)
            T4 = combine_ipess_into_ipeps(A4, B4, C4, R4_up, R4_down)

            sites = {(0, 0): T1}
            sites[(1, 0)] = T2
            sites[(0, 1)] = T3
            sites[(1, 1)] = T4
            state = IPEPS(sites, vertexToSite=lattice_to_site)

        # 1) compute environment by CTMRG
        ctm_env_out, *ctm_log= ctmrg.run(state, ctm_env_in, \
            conv_check=ctmrg_conv_energy, ctm_args=ctm_args)

        # 2) evaluate loss with the converged environment
        loss = energy_f(state, ctm_env_out)
        #print(loss)

        return (loss, ctm_env_out, *ctm_log)

    @torch.no_grad()
    def obs_fn(state, ctm_env, opt_context):
        if ("line_search" in opt_context.keys() and not opt_context["line_search"]) \
            or not "line_search" in opt_context.keys():
            epoch = len(opt_context["loss_history"]["loss"])
            loss = opt_context["loss_history"]["loss"][-1]
            #obs_values, obs_labels = model.eval_obs(state,ctm_env)
            #print(", ".join([f"{epoch}",f"{loss}"]+[f"{v}" for v in obs_values]))
            #print(state.get_parameters())
            print(", ".join([f"{epoch}", f"{loss}"]))
            log.info("Norm(sites): " +
                     ", ".join([f"{t.norm()}"
                                for c, t in state.sites.items()]))

    # optimize
    print("Start optimization")
    optimize_state(state, ctm_env, loss_fn, obs_fn=obs_fn, parameters=pess)

    # compute final observables for the best variational state
    outputstatefile = args.out_prefix + "_state.json"
    state = read_ipeps(outputstatefile, vertexToSite=state.vertexToSite)
    ctm_env = ENV(args.chi, state)
    init_env(state, ctm_env)
    ctm_env, *ctm_log = ctmrg.run(state, ctm_env, conv_check=ctmrg_conv_energy)
    opt_energy = energy_f(state, ctm_env)
    #obs_values, obs_labels = model.eval_obs(state,ctm_env)
    #print(", ".join([f"{args.opt_max_iter}",f"{opt_energy}"]+[f"{v}" for v in obs_values]))
    print("Enegy", opt_energy)