crsh_chker = check4particle_soft_crash(rthrsh0, pthrsh0, rthrsh,
                                                   pthrsh, crash_path)
        else:
            crsh_chker = check4particle_hard_crash(rthrsh, pthrsh, crash_path)
    else:
        crsh_chker = check4particle_crash_dummy(rthrsh0, pthrsh0, rthrsh,
                                                pthrsh, crash_path)

    linear_integrator_obj = linear_integrator(MD_parameters.integrator_method,
                                              crsh_chker)

    hamiltonian_obj = make_hamiltonian(hamiltonian_type, tau_long,
                                       ML_parameters)

    if hamiltonian_type != "noML":  # use prediction for ML
        chk_pt = checkpoint(hamiltonian_obj.get_netlist(
        ))  # opt = None, sch = None ; for test, don't need opt, sch
        if load_model_file is not None: chk_pt.load_checkpoint(load_model_file)
        hamiltonian_obj.eval()
        hamiltonian_obj.requires_grad_false()

    init_qpl, _, _ = data_io.read_trajectory_qpl(MC_init_config_filename)
    # init_qp_bs.shape = [nsamples, 3=(q, p, boxsize), 1, nparticle, DIM]

    init_q = torch.squeeze(init_qpl[:, 0, 0, :, :], dim=1)
    # init_q.shape = [nsamples, nparticle, DIM]

    init_p = torch.squeeze(init_qpl[:, 1, 0, :, :], dim=1)
    # init_p.shape = [nsamples, nparticle, DIM]

    boxsize = torch.squeeze(init_qpl[:, 2, 0, :, :], dim=1)
    # boxsize.shape = [nsamples, nparticle, DIM]
    def __init__(self,
                 linear_integrator_obj,
                 any_HNN_obj,
                 phase_space,
                 opt,
                 sch,
                 data_loader,
                 pothrsh,
                 qp_weight,
                 Lambda,
                 clip_value,
                 lr_thrsh,
                 loss_type,
                 system_logs,
                 load_model_file=None):
        '''
        Parameters
        ----------
        linear_integrator_obj : use for integrator using large time step
        any_HNN_obj : pass any HNN object to this container
        phase_space : contains q_list, p_list as input
                    q list shape is [nsamples, nparticle, DIM]
        opt         : create one optimizer from two models parameters
        sch         : lr decay 0.99 every 100 epochs
        data_loader : DataLoaders on Custom Datasets
                 two tensors contain train and valid data
                 each shape is [nsamples, 2, niter, nparticle, DIM] , here 2 is (q,p)
                 niter is initial and append strike iter so that 2
        load_model_file : file for save or load them
                 default is None
        '''

        MD_learner._obj_count += 1
        assert (MD_learner._obj_count == 1
                ), type(self).__name__ + " has more than one object"

        self.linear_integrator = linear_integrator_obj
        self.any_HNN = any_HNN_obj
        self.data_loader = data_loader
        self.pothrsh = pothrsh
        self.Lambda = Lambda

        self._phase_space = phase_space
        self._opt = opt
        self._sch = sch
        self.chk_pt = checkpoint(self.any_HNN.get_netlist(), self._opt,
                                 self._sch)
        self.lr_thrsh = lr_thrsh

        self.system_logs = system_logs

        if load_model_file is not None:
            self.chk_pt.load_checkpoint(load_model_file)

        self.w = qp_weight

        if loss_type == 'MSE_loss':
            print('loss type : mse_loss')
            self._loss = qp_MSE_loss
        elif loss_type == 'MAE_loss':
            print('loss type : mae_loss')
            self._loss = qp_MAE_loss
        elif loss_type == 'exp_loss':
            print('loss type : exp_loss')
            self._loss = qp_exp_loss
        else:
            assert (False), 'invalid loss type given'

        self.tau_cur = self.data_loader.data_set.train_set.data_tau_long
        boxsize = self.data_loader.data_set.train_set.data_boxsize

        self.clip_value = clip_value

        self._phase_space.set_boxsize(boxsize)

        print('MD_learner initialized : tau_cur ', self.tau_cur, ' boxsize ',
              boxsize, 'pothrsh', pothrsh, 'Lambda', self.Lambda, 'clip value',
              self.clip_value)