def __init__(self, hparams, data_path=None):
        super(Model, self).__init__()
        self.hparams = hparams
        self.data_path = data_path
        self.T_pred = self.hparams.T_pred
        self.loss_fn = torch.nn.MSELoss(reduction='none')

        self.recog_net_1 = MLP_Encoder(64 * 64, 300, 3, nonlinearity='elu')
        self.recog_net_2 = MLP_Encoder(64 * 64, 300, 3, nonlinearity='elu')
        self.obs_net_1 = MLP_Decoder(1, 100, 64 * 64, nonlinearity='elu')
        self.obs_net_2 = MLP_Decoder(1, 100, 64 * 64, nonlinearity='elu')

        V_net = MLP(4, 100, 1)
        M_net = PSD(4, 300, 2)
        g_net = MatrixNet(4, 100, 4, shape=(2, 2))

        self.ode = Lag_Net(q_dim=2,
                           u_dim=2,
                           g_net=g_net,
                           M_net=M_net,
                           V_net=V_net)

        self.link1_para = torch.nn.Parameter(
            torch.tensor(0.0, dtype=self.dtype))

        self.train_dataset = None
        self.non_ctrl_ind = 1
    def __init__(self, hparams, data_path=None):
        super(Model, self).__init__()
        self.hparams = hparams
        self.data_path = data_path
        self.T_pred = self.hparams.T_pred
        self.loss_fn = torch.nn.MSELoss(reduction='none')

        self.recog_net_1 = MLP_Encoder(64*64, 300, 2, nonlinearity='elu')
        self.recog_net_2 = MLP_Encoder(64*64, 300, 3, nonlinearity='elu')
        self.obs_net = MLP_Decoder(3, 200, 3*64*64, nonlinearity='elu')

        V_net = MLP(3, 100, 1) ; M_net = PSD(3, 300, 2)
        g_net = MatrixNet(3, 100, 4, shape=(2,2))

        self.ode = Lag_Net_R1_T1(g_net=g_net, M_net=M_net, V_net=V_net)

        self.train_dataset = None
        self.non_ctrl_ind = 1
    def __init__(self, hparams, data_path=None):
        super(Model, self).__init__()
        self.hparams = hparams
        self.data_path = data_path
        self.T_pred = self.hparams.T_pred
        self.loss_fn = torch.nn.MSELoss(reduction='none')

        self.recog_net_1 = MLP_Encoder(64 * 64, 300, 2, nonlinearity='elu')
        self.recog_net_2 = MLP_Encoder(64 * 64, 300, 3, nonlinearity='elu')
        self.obs_net_1 = MLP_Decoder(1, 100, 64 * 64, nonlinearity='elu')
        self.obs_net_2 = MLP_Decoder(1, 100, 64 * 64, nonlinearity='elu')

        g_net = MatrixNet(3, 100, 4, shape=(2, 2))
        g_baseline_net = MLP(5, 400, 2)

        self.ode = Lag_Net_R1_T1(g_net=g_net,
                                 g_baseline=g_baseline_net,
                                 dyna_model='g_baseline')

        self.train_dataset = None
        self.non_ctrl_ind = 1