def optimize_state(state, ctm_env_init, loss_fn, obs_fn=None, post_proc=None,
    main_args=cfg.main_args, opt_args=cfg.opt_args,ctm_args=cfg.ctm_args, 
    global_args=cfg.global_args):
    r"""
    :param state: initial wavefunction
    :param ctm_env_init: initial environment corresponding to ``state``
    :param loss_fn: loss function
    :param model: model with definition of observables
    :param main_args: parsed command line arguments
    :param opt_args: optimization configuration
    :param ctm_args: CTM algorithm configuration
    :param global_args: global configuration
    :type state: IPEPS
    :type ctm_env_init: ENV
    :type loss_fn: function(IPEPS,ENV,CTMARGS,OPTARGS,GLOBALARGS)->torch.tensor
    :type model: TODO Model base class
    :type main_args: argparse.Namespace
    :type opt_args: OPTARGS
    :type ctm_args: CTMARGS
    :type global_args: GLOBALARGS

    Optimizes initial wavefunction ``state`` with respect to ``loss_fn`` using LBFGS optimizer.
    The main parameters influencing the optimization process are given in :py:class:`config.OPTARGS`.
    """
    verbosity = opt_args.verbosity_opt_epoch
    checkpoint_file = main_args.out_prefix+"_checkpoint.p"   
    outputstatefile= main_args.out_prefix+"_state.json"
    t_data = dict({"loss": [], "min_loss": 1.0e+16, "loss_ls": [], "min_loss_ls": 1.0e+16})
    current_env=[ctm_env_init]
    context= dict({"ctm_args":ctm_args, "opt_args":opt_args, "loss_history": t_data})
    epoch= 0

    parameters= state.get_parameters()
    for A in parameters: A.requires_grad_(True)

    optimizer = lbfgs_modified.LBFGS_MOD(parameters, max_iter=opt_args.max_iter_per_epoch, lr=opt_args.lr, \
        tolerance_grad=opt_args.tolerance_grad, tolerance_change=opt_args.tolerance_change, \
        history_size=opt_args.history_size, line_search_fn=opt_args.line_search, \
        line_search_eps=opt_args.line_search_tol)

    # load and/or modify optimizer state from checkpoint
    if main_args.opt_resume is not None:
        print(f"INFO: resuming from check point. resume = {main_args.opt_resume}")
        checkpoint = torch.load(main_args.opt_resume)
        epoch0 = checkpoint["epoch"]
        loss0 = checkpoint["loss"]
        cp_state_dict= checkpoint["optimizer_state_dict"]
        cp_opt_params= cp_state_dict["param_groups"][0]
        cp_opt_history= cp_state_dict["state"][cp_opt_params["params"][0]]
        if main_args.opt_resume_override_params:
            cp_opt_params["lr"] = opt_args.lr
            cp_opt_params["max_iter"] = opt_args.max_iter_per_epoch
            cp_opt_params["tolerance_grad"] = opt_args.tolerance_grad
            cp_opt_params["tolerance_change"] = opt_args.tolerance_change
            # resize stored old_dirs, old_stps, ro, al to new history size
            cp_history_size= cp_opt_params["history_size"]
            cp_opt_params["history_size"] = opt_args.history_size
            if opt_args.history_size < cp_history_size:
                if len(cp_opt_history["old_dirs"]) > opt_args.history_size: 
                    cp_opt_history["old_dirs"]= cp_opt_history["old_dirs"][-opt_args.history_size:]
                    cp_opt_history["old_stps"]= cp_opt_history["old_stps"][-opt_args.history_size:]
            cp_ro_filtered= list(filter(None,cp_opt_history["ro"]))
            cp_al_filtered= list(filter(None,cp_opt_history["al"]))
            if len(cp_ro_filtered) > opt_args.history_size:
                cp_opt_history["ro"]= cp_ro_filtered[-opt_args.history_size:]
                cp_opt_history["al"]= cp_al_filtered[-opt_args.history_size:]
            else:
                cp_opt_history["ro"]= cp_ro_filtered + [None for i in range(opt_args.history_size-len(cp_ro_filtered))]
                cp_opt_history["al"]= cp_al_filtered + [None for i in range(opt_args.history_size-len(cp_ro_filtered))]
        cp_state_dict["param_groups"][0]= cp_opt_params
        cp_state_dict["state"][cp_opt_params["params"][0]]= cp_opt_history
        optimizer.load_state_dict(cp_state_dict)
        print(f"checkpoint.loss = {loss0}")

    def grad_fd(loss0):
        loc_opt_args= copy.deepcopy(opt_args)
        loc_opt_args.opt_ctm_reinit= opt_args.fd_ctm_reinit
        loc_ctm_args= copy.deepcopy(ctm_args)
        # TODO check if we are optimizing C4v symmetric ansatz
        if opt_args.line_search_svd_method != 'DEFAULT':
            loc_ctm_args.projector_svd_method= opt_args.line_search_svd_method
        loc_context= dict({"ctm_args":loc_ctm_args, "opt_args":loc_opt_args, \
            "loss_history": t_data, "line_search": False})

        # compute components of the grad
        fd_grad=dict()
        with torch.no_grad():
            for k in state.coeffs.keys():
                A_orig= state.coeffs[k].clone()
                assert len(A_orig.size())==1, "coefficient tensor is not 1D"
                fd_grad[k]= torch.zeros(A_orig.size(),dtype=A_orig.dtype,device=A_orig.device)
                for i in range(state.coeffs[k].size()[0]):                  
                    e_i= torch.zeros(A_orig.size()[0],dtype=A_orig.dtype,device=A_orig.device)
                    e_i[i]= opt_args.fd_eps
                    state.coeffs[k]+=e_i
                    loc_env= current_env[0].clone()
                    loss1, ctm_env, history, timings = loss_fn(state, loc_env,\
                        loc_context)
                    fd_grad[k][i]=(float(loss1-loss0)/opt_args.fd_eps)
                    log.info(f"FD_GRAD {i} loss1 {loss1} grad_i {fd_grad[k][i]}"\
                        +f" timings {timings}")
                    state.coeffs[k].data.copy_(A_orig)
        log.info(f"FD_GRAD grad {fd_grad}")

        return fd_grad

    #@profile
    def closure(linesearching=False):
        context["line_search"]=linesearching

        # 0) evaluate loss
        optimizer.zero_grad()
        with torch.no_grad():
            loss, ctm_env, history, timings= loss_fn(state, current_env[0], context)

        # 1) record loss and store current state if the loss improves
        if linesearching:
            t_data["loss_ls"].append(loss.item())
            if t_data["min_loss_ls"] > t_data["loss_ls"][-1]:
                t_data["min_loss_ls"]= t_data["loss_ls"][-1]
        else:  
            t_data["loss"].append(loss.item())
            if t_data["min_loss"] > t_data["loss"][-1]:
                t_data["min_loss"]= t_data["loss"][-1]
                state.write_to_file(outputstatefile, normalize=True)

        # 2) log CTM metrics for debugging
        if opt_args.opt_logging:
            log_entry=dict({"id": epoch, "loss": t_data["loss"][-1], "timings": timings})
            if linesearching:
                log_entry["LS"]=len(t_data["loss_ls"])
                log_entry["loss"]=t_data["loss_ls"]
            log.info(json.dumps(log_entry))

        # 3) compute desired observables
        if obs_fn is not None:
            obs_fn(state, ctm_env, context)

        # 4) evaluate gradient
        t_grad0= time.perf_counter()
        grad= grad_fd(loss)
        for k in state.coeffs.keys():
            state.coeffs[k].grad= grad[k]
        t_grad1= time.perf_counter()

        # 5) log grad metrics
        if opt_args.opt_logging:
            log_entry=dict({"id": epoch, "t_grad": t_grad1-t_grad0 })
            if linesearching: log_entry["LS"]=len(t_data["loss_ls"])
            log.info(json.dumps(log_entry))

        # 6) detach current environment from autograd graph
        current_env[0] = ctm_env.detach().clone()

        return loss
    
    # closure for derivative-free line search. This closure
    # is to be called within torch.no_grad context
    @torch.no_grad()
    def closure_linesearch(linesearching):
        context["line_search"]=linesearching

        # 1) evaluate loss
        loc_opt_args= copy.deepcopy(opt_args)
        loc_opt_args.opt_ctm_reinit= opt_args.line_search_ctm_reinit
        loc_ctm_args= copy.deepcopy(ctm_args)
        # TODO check if we are optimizing C4v symmetric ansatz
        if opt_args.line_search_svd_method != 'DEFAULT':
            loc_ctm_args.projector_svd_method= opt_args.line_search_svd_method
        loc_context= dict({"ctm_args":loc_ctm_args, "opt_args":loc_opt_args, "loss_history": t_data,
            "line_search": True})
        loss, ctm_env, history, timings = loss_fn(state, current_env[0],\
            loc_context)

        # 2) store current state if the loss improves
        t_data["loss_ls"].append(loss.item())
        if t_data["min_loss_ls"] > t_data["loss_ls"][-1]:
            t_data["min_loss_ls"]= t_data["loss_ls"][-1]

        # 5) log CTM metrics for debugging
        if opt_args.opt_logging:
            log_entry=dict({"id": epoch, "LS": len(t_data["loss_ls"]), \
                "loss": t_data["loss_ls"], "timings": timings})
            log.info(json.dumps(log_entry))

        # 4) compute desired observables
        if obs_fn is not None:
            obs_fn(state, ctm_env, context)

        current_env[0]= ctm_env
        return loss

    for epoch in range(main_args.opt_max_iter):
        # checkpoint the optimizer
        # checkpointing before step, guarantees the correspondence between the wavefunction
        # and the last computed value of loss t_data["loss"][-1]
        if epoch>0:
            store_checkpoint(checkpoint_file, state, optimizer, epoch, t_data["loss"][-1])

        # After execution closure ``current_env`` **IS NOT** corresponding to ``state``, since
        # the ``state`` on-site tensors have been modified by gradient. 
        optimizer.step_2c(closure, closure_linesearch)
        
        # reset line search history
        t_data["loss_ls"]=[]
        t_data["min_loss_ls"]=1.0e+16

        if post_proc is not None:
            post_proc(state, current_env[0], context)

    # optimization is over, store the last checkpoint
    store_checkpoint(checkpoint_file, state, optimizer, \
        main_args.opt_max_iter, t_data["loss"][-1])
def optimize_state(state,
                   ctm_env_init,
                   loss_fn,
                   obs_fn=None,
                   post_proc=None,
                   main_args=cfg.main_args,
                   opt_args=cfg.opt_args,
                   ctm_args=cfg.ctm_args,
                   global_args=cfg.global_args):
    r"""
    :param state: initial wavefunction
    :param ctm_env_init: initial environment corresponding to ``state``
    :param loss_fn: loss function
    :param model: model with definition of observables
    :param local_args: parsed command line arguments
    :param opt_args: optimization configuration
    :param ctm_args: CTM algorithm configuration
    :param global_args: global configuration
    :type state: IPEPS
    :type ctm_env_init: ENV
    :type loss_fn: function(IPEPS,ENV,CTMARGS,OPTARGS,GLOBALARGS)->torch.tensor
    :type model: TODO Model base class
    :type local_args: argparse.Namespace
    :type opt_args: OPTARGS
    :type ctm_args: CTMARGS
    :type global_args: GLOBALARGS

    Optimizes initial wavefunction ``state`` with respect to ``loss_fn`` using 
    :class:`optim.lbfgs_modified.LBFGS_MOD` optimizer.
    The main parameters influencing the optimization process are given in :class:`config.OPTARGS`.
    Calls to functions ``loss_fn``, ``obs_fn``, and ``post_proc`` pass the current configuration
    as dictionary ``{"ctm_args":ctm_args, "opt_args":opt_args}``
    """
    verbosity = opt_args.verbosity_opt_epoch
    checkpoint_file = main_args.out_prefix + "_checkpoint.p"
    outputstatefile = main_args.out_prefix + "_state.json"
    t_data = dict({
        "loss": [],
        "min_loss": 1.0e+16,
        "loss_ls": [],
        "min_loss_ls": 1.0e+16
    })
    current_env = [ctm_env_init]
    context = dict({
        "ctm_args": ctm_args,
        "opt_args": opt_args,
        "loss_history": t_data
    })
    epoch = 0

    parameters = state.get_parameters()
    for A in parameters:
        A.requires_grad_(True)

    optimizer = lbfgs_modified.LBFGS_MOD(parameters, max_iter=opt_args.max_iter_per_epoch, lr=opt_args.lr, \
        tolerance_grad=opt_args.tolerance_grad, tolerance_change=opt_args.tolerance_change, \
        history_size=opt_args.history_size, line_search_fn=opt_args.line_search, \
        line_search_eps=opt_args.line_search_tol)

    # load and/or modify optimizer state from checkpoint
    if main_args.opt_resume is not None:
        print(
            f"INFO: resuming from check point. resume = {main_args.opt_resume}"
        )
        checkpoint = torch.load(main_args.opt_resume)
        epoch0 = checkpoint["epoch"]
        loss0 = checkpoint["loss"]
        cp_state_dict = checkpoint["optimizer_state_dict"]
        cp_opt_params = cp_state_dict["param_groups"][0]
        cp_opt_history = cp_state_dict["state"][cp_opt_params["params"][0]]
        if main_args.opt_resume_override_params:
            cp_opt_params["lr"] = opt_args.lr
            cp_opt_params["max_iter"] = opt_args.max_iter_per_epoch
            cp_opt_params["tolerance_grad"] = opt_args.tolerance_grad
            cp_opt_params["tolerance_change"] = opt_args.tolerance_change
            # resize stored old_dirs, old_stps, ro, al to new history size
            cp_history_size = cp_opt_params["history_size"]
            cp_opt_params["history_size"] = opt_args.history_size
            if opt_args.history_size < cp_history_size:
                if len(cp_opt_history["old_dirs"]) > opt_args.history_size:
                    cp_opt_history["old_dirs"] = cp_opt_history["old_dirs"][
                        -opt_args.history_size:]
                    cp_opt_history["old_stps"] = cp_opt_history["old_stps"][
                        -opt_args.history_size:]
            cp_ro_filtered = list(filter(None, cp_opt_history["ro"]))
            cp_al_filtered = list(filter(None, cp_opt_history["al"]))
            if len(cp_ro_filtered) > opt_args.history_size:
                cp_opt_history["ro"] = cp_ro_filtered[-opt_args.history_size:]
                cp_opt_history["al"] = cp_al_filtered[-opt_args.history_size:]
            else:
                cp_opt_history["ro"] = cp_ro_filtered + [
                    None
                    for i in range(opt_args.history_size - len(cp_ro_filtered))
                ]
                cp_opt_history["al"] = cp_al_filtered + [
                    None
                    for i in range(opt_args.history_size - len(cp_ro_filtered))
                ]
        cp_state_dict["param_groups"][0] = cp_opt_params
        cp_state_dict["state"][cp_opt_params["params"][0]] = cp_opt_history
        optimizer.load_state_dict(cp_state_dict)
        print(f"checkpoint.loss = {loss0}")

    #@profile
    def closure(linesearching=False):
        context["line_search"] = linesearching
        optimizer.zero_grad()

        # 0) evaluate loss
        loss, ctm_env, history, t_ctm, t_check = loss_fn(
            state, current_env[0], context)

        # 1) record loss and store current state if the loss improves
        if linesearching:
            t_data["loss_ls"].append(loss.item())
            if t_data["min_loss_ls"] > t_data["loss_ls"][-1]:
                t_data["min_loss_ls"] = t_data["loss_ls"][-1]
        else:
            t_data["loss"].append(loss.item())
            if t_data["min_loss"] > t_data["loss"][-1]:
                t_data["min_loss"] = t_data["loss"][-1]
                state.write_to_file(outputstatefile, normalize=True)

        # 2) log CTM metrics for debugging
        if opt_args.opt_logging:
            log_entry=dict({"id": epoch, "loss": t_data["loss"][-1], "t_ctm": t_ctm, \
                    "t_check": t_check})
            if linesearching:
                log_entry["LS"] = len(t_data["loss_ls"])
                log_entry["loss"] = t_data["loss_ls"]
            log.info(json.dumps(log_entry))

        # 3) compute desired observables
        if obs_fn is not None:
            obs_fn(state, ctm_env, context)

        # 4) evaluate gradient
        t_grad0 = time.perf_counter()
        loss.backward()
        t_grad1 = time.perf_counter()

        # 5) log grad metrics
        if opt_args.opt_logging:
            log_entry = dict({"id": epoch})
            if linesearching: log_entry["LS"] = len(t_data["loss_ls"])
            else:
                log_entry["t_grad"] = t_grad1 - t_grad0
                if opt_args.opt_log_grad:
                    log_entry["grad"] = [p.grad.tolist() for p in parameters]
            log.info(json.dumps(log_entry))

        # 6) detach current environment from autograd graph
        lst_C = list(ctm_env.C.values())
        lst_T = list(ctm_env.T.values())
        current_env[0] = ctm_env
        for el in lst_T + lst_C:
            el.detach_()

        return loss

    # closure for derivative-free line search. This closure
    # is to be called within torch.no_grad context
    @torch.no_grad()
    def closure_linesearch(linesearching):
        context["line_search"] = linesearching

        # 1) evaluate loss
        loc_opt_args = copy.deepcopy(opt_args)
        loc_opt_args.opt_ctm_reinit = opt_args.line_search_ctm_reinit
        loc_ctm_args = copy.deepcopy(ctm_args)

        if opt_args.line_search_svd_method != 'DEFAULT':
            loc_ctm_args.projector_svd_method = opt_args.line_search_svd_method
        ls_context = dict({
            "ctm_args": loc_ctm_args,
            "opt_args": loc_opt_args,
            "loss_history": t_data,
            "line_search": linesearching
        })
        loss, ctm_env, history, t_ctm, t_check = loss_fn(state, current_env[0],\
            ls_context)

        # 2) store current state if the loss improves
        t_data["loss_ls"].append(loss.item())
        if t_data["min_loss_ls"] > t_data["loss_ls"][-1]:
            t_data["min_loss_ls"] = t_data["loss_ls"][-1]

        # 3) log metrics for debugging
        if opt_args.opt_logging:
            log_entry=dict({"id": epoch, "LS": len(t_data["loss_ls"]), \
                "loss": t_data["loss_ls"], "t_ctm": t_ctm, "t_check": t_check})
            log.info(json.dumps(log_entry))

        # 4) compute desired observables
        if obs_fn is not None:
            obs_fn(state, ctm_env, context)

        current_env[0] = ctm_env
        return loss

    for epoch in range(main_args.opt_max_iter):
        # checkpoint the optimizer
        # checkpointing before step, guarantees the correspondence between the wavefunction
        # and the last computed value of loss t_data["loss"][-1]
        if epoch > 0:
            store_checkpoint(checkpoint_file, state, optimizer, epoch,
                             t_data["loss"][-1])

        # After execution closure ``current_env`` **IS NOT** corresponding to ``state``, since
        # the ``state`` on-site tensors have been modified by gradient.
        optimizer.step_2c(closure, closure_linesearch)

        # reset line search history
        t_data["loss_ls"] = []
        t_data["min_loss_ls"] = 1.0e+16

        if post_proc is not None:
            post_proc(state, current_env[0], context)

    # optimization is over, store the last checkpoint
    store_checkpoint(checkpoint_file, state, optimizer, \
        main_args.opt_max_iter, t_data["loss"][-1])