コード例 #1
0
def classify(dataset, labels):
    """classification using node embeddings and logistic regression"""
    print('classification using lr :dataset:', dataset)

    #load node embeddings
    y = np.argmax(labels, axis=1)
    node_z_mean = np.load("result/{}.node.z.mean.npy".format(dataset))
    node_z_var = np.load("result/{}.node.z.var.npy".format(dataset))

    q_z = VonMisesFisher(torch.tensor(node_z_mean), torch.tensor(node_z_var))

    #train the model and get metrics
    macro_f1_avg = 0
    micro_f1_avg = 0
    acc_avg = 0

    for i in range(10):
        #sample data
        node_embedding = q_z.rsample()
        node_embedding = node_embedding.numpy()

        X_train, X_test, y_train, y_test = train_test_split(node_embedding,
                                                            y,
                                                            train_size=0.2,
                                                            test_size=1000,
                                                            random_state=2019)

        clf = LogisticRegression(solver="lbfgs",
                                 multi_class="multinomial",
                                 random_state=0).fit(X_train, y_train)

        y_pred = clf.predict(X_test)

        macro_f1 = f1_score(y_test, y_pred, average="macro")
        micro_f1 = f1_score(y_test, y_pred, average="micro")
        accuracy = accuracy_score(y_test, y_pred, normalize=True)

        macro_f1_avg += macro_f1
        micro_f1_avg += micro_f1
        acc_avg += accuracy

    return macro_f1_avg / 10, micro_f1_avg / 10, acc_avg / 10
コード例 #2
0
class Model(pl.LightningModule):
    def __init__(self, hparams, data_path=None):
        super(Model, self).__init__()
        self.hparams = hparams
        self.data_path = data_path
        self.T_pred = self.hparams.T_pred
        self.loss_fn = torch.nn.MSELoss(reduction='none')

        self.recog_net_1 = MLP_Encoder(64 * 64, 300, 2, nonlinearity='elu')
        self.recog_net_2 = MLP_Encoder(64 * 64, 300, 3, nonlinearity='elu')
        self.obs_net_1 = MLP_Decoder(1, 100, 64 * 64, nonlinearity='elu')
        self.obs_net_2 = MLP_Decoder(1, 100, 64 * 64, nonlinearity='elu')

        g_net = MatrixNet(3, 100, 4, shape=(2, 2))
        g_baseline_net = MLP(5, 400, 2)

        self.ode = Lag_Net_R1_T1(g_net=g_net,
                                 g_baseline=g_baseline_net,
                                 dyna_model='g_baseline')

        self.train_dataset = None
        self.non_ctrl_ind = 1

    def train_dataloader(self):
        if self.hparams.homo_u:
            # must set trainer flag reload_dataloaders_every_epoch=True
            if self.train_dataset is None:
                self.train_dataset = HomoImageDataset(self.data_path,
                                                      self.hparams.T_pred)
            if self.current_epoch < 1000:
                # feed zero ctrl dataset and ctrl dataset in turns
                if self.current_epoch % 2 == 0:
                    u_idx = 0
                else:
                    u_idx = self.non_ctrl_ind
                    self.non_ctrl_ind += 1
                    if self.non_ctrl_ind == 9:
                        self.non_ctrl_ind = 1
            else:
                u_idx = self.current_epoch % 9
            self.train_dataset.u_idx = u_idx
            self.t_eval = torch.from_numpy(self.train_dataset.t_eval)
            return DataLoader(self.train_dataset,
                              batch_size=self.hparams.batch_size,
                              shuffle=True,
                              collate_fn=my_collate)
        else:
            train_dataset = ImageDataset(self.data_path,
                                         self.hparams.T_pred,
                                         ctrl=True)
            self.t_eval = torch.from_numpy(train_dataset.t_eval)
            return DataLoader(train_dataset,
                              batch_size=self.hparams.batch_size,
                              shuffle=True,
                              collate_fn=my_collate)

    def angle_vel_est(self, q0_m_n, q1_m_n, delta_t):
        delta_cos = q1_m_n[:, 0:1] - q0_m_n[:, 0:1]
        delta_sin = q1_m_n[:, 1:2] - q0_m_n[:, 1:2]
        q_dot0 = -delta_cos * q0_m_n[:, 1:
                                     2] / delta_t + delta_sin * q0_m_n[:, 0:
                                                                       1] / delta_t
        return q_dot0

    def encode(self, batch_image):
        r_m_logv = self.recog_net_1(batch_image[:, 0].reshape(
            self.bs, self.d * self.d))
        r_m, r_logv = r_m_logv.split([1, 1], dim=1)
        r_m = torch.tanh(r_m)
        r_v = torch.exp(r_logv) + 0.0001
        theta = self.get_theta(1, 0, r_m[:, 0], 0)
        grid = F.affine_grid(theta, torch.Size((self.bs, 1, self.d, self.d)))
        pole_att_win = F.grid_sample(batch_image[:, 1:2], grid)
        phi_m_logv = self.recog_net_2(
            pole_att_win.reshape(self.bs, self.d * self.d))
        phi_m, phi_logv = phi_m_logv.split([2, 1], dim=1)
        phi_m_n = phi_m / phi_m.norm(dim=-1, keepdim=True)
        phi_v = F.softplus(phi_logv) + 1
        return r_m, r_v, phi_m, phi_v, phi_m_n

    def get_theta(self, cos, sin, x, y, bs=None):
        # x, y should have shape (bs, )
        bs = self.bs if bs is None else bs
        theta = torch.zeros([bs, 2, 3], dtype=self.dtype, device=self.device)
        theta[:, 0, 0] += cos
        theta[:, 0, 1] += sin
        theta[:, 0, 2] += x
        theta[:, 1, 0] += -sin
        theta[:, 1, 1] += cos
        theta[:, 1, 2] += y
        return theta

    def get_theta_inv(self, cos, sin, x, y, bs=None):
        bs = self.bs if bs is None else bs
        theta = torch.zeros([bs, 2, 3], dtype=self.dtype, device=self.device)
        theta[:, 0, 0] += cos
        theta[:, 0, 1] += -sin
        theta[:, 0, 2] += -x * cos + y * sin
        theta[:, 1, 0] += sin
        theta[:, 1, 1] += cos
        theta[:, 1, 2] += -x * sin - y * cos
        return theta

    def forward(self, X, u):
        [_, self.bs, c, self.d, self.d] = X.shape
        T = len(self.t_eval)
        # encode
        self.r0_m, self.r0_v, self.phi0_m, self.phi0_v, self.phi0_m_n = self.encode(
            X[0])
        self.r1_m, self.r1_v, self.phi1_m, self.phi1_v, self.phi1_m_n = self.encode(
            X[1])

        # reparametrize
        self.Q_r0 = Normal(self.r0_m, self.r0_v)
        self.P_normal = Normal(torch.zeros_like(self.r0_m),
                               torch.ones_like(self.r0_v))
        self.r0 = self.Q_r0.rsample()

        self.Q_phi0 = VonMisesFisher(self.phi0_m_n, self.phi0_v)
        self.P_hyper_uni = HypersphericalUniform(1, device=self.device)
        self.phi0 = self.Q_phi0.rsample()
        while torch.isnan(self.phi0).any():
            self.phi0 = self.Q_phi0.rsample()

        # estimate velocity
        self.r_dot0 = (self.r1_m - self.r0_m) / (self.t_eval[1] -
                                                 self.t_eval[0])
        self.phi_dot0 = self.angle_vel_est(self.phi0_m_n, self.phi1_m_n,
                                           self.t_eval[1] - self.t_eval[0])

        # predict
        z0_u = torch.cat([self.r0, self.phi0, self.r_dot0, self.phi_dot0, u],
                         dim=1)
        zT_u = odeint(self.ode, z0_u, self.t_eval,
                      method=self.hparams.solver)  # T, bs, 4
        self.qT, self.q_dotT, _ = zT_u.split([3, 2, 2], dim=-1)
        self.qT = self.qT.view(T * self.bs, 3)

        # decode
        ones = torch.ones_like(self.qT[:, 0:1])
        self.cart = self.obs_net_1(ones)
        self.pole = self.obs_net_2(ones)

        theta1 = self.get_theta_inv(1, 0, self.qT[:, 0], 0, bs=T * self.bs)
        theta2 = self.get_theta_inv(self.qT[:, 1],
                                    self.qT[:, 2],
                                    self.qT[:, 0],
                                    0,
                                    bs=T * self.bs)

        grid1 = F.affine_grid(theta1,
                              torch.Size((T * self.bs, 1, self.d, self.d)))
        grid2 = F.affine_grid(theta2,
                              torch.Size((T * self.bs, 1, self.d, self.d)))

        transf_cart = F.grid_sample(
            self.cart.view(T * self.bs, 1, self.d, self.d), grid1)
        transf_pole = F.grid_sample(
            self.pole.view(T * self.bs, 1, self.d, self.d), grid2)
        self.Xrec = torch.cat(
            [transf_cart, transf_pole,
             torch.zeros_like(transf_cart)], dim=1)
        self.Xrec = self.Xrec.view(T, self.bs, 3, self.d, self.d)
        return None

    def training_step(self, train_batch, batch_idx):
        X, u = train_batch
        self.forward(X, u)

        lhood = -self.loss_fn(self.Xrec, X)
        lhood = lhood.sum([0, 2, 3, 4]).mean()
        kl_q = torch.distributions.kl.kl_divergence(self.Q_r0, self.P_normal).mean() + \
                torch.distributions.kl.kl_divergence(self.Q_phi0, self.P_hyper_uni).mean()
        norm_penalty = (self.phi0_m.norm(dim=-1).mean() - 1)**2

        loss = -lhood + kl_q + 1 / 100 * norm_penalty

        logs = {'recon_loss': -lhood, 'kl_q_loss': kl_q, 'train_loss': loss}
        return {'loss': loss, 'log': logs, 'progress_bar': logs}

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), self.hparams.learning_rate)

    @staticmethod
    def add_model_specific_args(parent_parser):
        """
        Specify the hyperparams for this LightningModule
        """
        # MODEL specific
        parser = ArgumentParser(parents=[parent_parser], add_help=False)
        parser.add_argument('--learning_rate', default=1e-4, type=float)
        parser.add_argument('--batch_size', default=1024, type=int)

        return parser
コード例 #3
0
class Model(pl.LightningModule):
    def __init__(self, hparams, data_path=None):
        super(Model, self).__init__()
        self.hparams = hparams
        self.data_path = data_path
        self.T_pred = self.hparams.T_pred
        self.loss_fn = torch.nn.MSELoss(reduction='none')

        self.recog_q_net = MLP_Encoder(32 * 32, 300, 3, nonlinearity='elu')
        self.obs_net = MLP_Encoder(1, 100, 32 * 32, nonlinearity='elu')
        V_net = MLP(2, 50, 1)
        g_net = MLP(2, 50, 1)
        M_net = PSD(2, 50, 1)
        self.ode = Lag_Net(q_dim=1,
                           u_dim=1,
                           g_net=g_net,
                           M_net=M_net,
                           V_net=V_net)

        self.train_dataset = None
        self.non_ctrl_ind = 1

    def train_dataloader(self):
        if self.hparams.homo_u:
            # must set trainer flag reload_dataloaders_every_epoch=True
            if self.train_dataset is None:
                self.train_dataset = HomoImageDataset(self.data_path,
                                                      self.hparams.T_pred)
            if self.current_epoch < 1000:
                # feed zero ctrl dataset and ctrl dataset in turns
                if self.current_epoch % 2 == 0:
                    u_idx = 0
                else:
                    u_idx = self.non_ctrl_ind
                    self.non_ctrl_ind += 1
                    if self.non_ctrl_ind == 9:
                        self.non_ctrl_ind = 1
            else:
                u_idx = self.current_epoch % 9
            self.train_dataset.u_idx = u_idx
            self.t_eval = torch.from_numpy(self.train_dataset.t_eval)
            return DataLoader(self.train_dataset,
                              batch_size=self.hparams.batch_size,
                              shuffle=True,
                              collate_fn=my_collate)
        else:
            train_dataset = ImageDataset(self.data_path, self.hparams.T_pred)
            self.t_eval = torch.from_numpy(train_dataset.t_eval)
            return DataLoader(train_dataset,
                              batch_size=self.hparams.batch_size,
                              shuffle=True,
                              collate_fn=my_collate)

    def angle_vel_est(self, q0_m_n, q1_m_n, delta_t):
        delta_cos = q1_m_n[:, 0:1] - q0_m_n[:, 0:1]
        delta_sin = q1_m_n[:, 1:2] - q0_m_n[:, 1:2]
        q_dot0 = -delta_cos * q0_m_n[:, 1:
                                     2] / delta_t + delta_sin * q0_m_n[:, 0:
                                                                       1] / delta_t
        return q_dot0

    def encode(self, batch_image):
        q_m_logv = self.recog_q_net(batch_image)
        q_m, q_logv = q_m_logv.split([2, 1], dim=1)
        q_m_n = q_m / q_m.norm(dim=-1, keepdim=True)
        q_v = F.softplus(q_logv) + 1
        return q_m, q_v, q_m_n

    def get_theta_inv(self, cos, sin, x, y, bs=None):
        bs = self.bs if bs is None else bs
        theta = torch.zeros([bs, 2, 3], dtype=self.dtype, device=self.device)
        theta[:, 0, 0] += cos
        theta[:, 0, 1] += -sin
        theta[:, 0, 2] += -x * cos + y * sin
        theta[:, 1, 0] += sin
        theta[:, 1, 1] += cos
        theta[:, 1, 2] += -x * sin - y * cos
        return theta

    def forward(self, X, u):
        [_, self.bs, d, d] = X.shape
        T = len(self.t_eval)
        # encode
        self.q0_m, self.q0_v, self.q0_m_n = self.encode(X[0].reshape(
            self.bs, d * d))
        self.q1_m, self.q1_v, self.q1_m_n = self.encode(X[1].reshape(
            self.bs, d * d))

        # reparametrize
        self.Q_q = VonMisesFisher(self.q0_m_n, self.q0_v)
        self.P_q = HypersphericalUniform(1, device=self.device)
        self.q0 = self.Q_q.rsample()  # bs, 2
        while torch.isnan(self.q0).any():
            self.q0 = self.Q_q.rsample()  # a bad way to avoid nan

        # estimate velocity
        self.q_dot0 = self.angle_vel_est(self.q0_m_n, self.q1_m_n,
                                         self.t_eval[1] - self.t_eval[0])

        # predict
        z0_u = torch.cat((self.q0, self.q_dot0, u), dim=1)
        zT_u = odeint(self.ode, z0_u, self.t_eval,
                      method=self.hparams.solver)  # T, bs, 4
        self.qT, self.q_dotT, _ = zT_u.split([2, 1, 1], dim=-1)
        self.qT = self.qT.view(T * self.bs, 2)

        # decode
        ones = torch.ones_like(self.qT[:, 0:1])
        self.content = self.obs_net(ones)

        theta = self.get_theta_inv(self.qT[:, 0],
                                   self.qT[:, 1],
                                   0,
                                   0,
                                   bs=T * self.bs)  # cos , sin

        grid = F.affine_grid(theta, torch.Size((T * self.bs, 1, d, d)))
        self.Xrec = F.grid_sample(self.content.view(T * self.bs, 1, d, d),
                                  grid)
        self.Xrec = self.Xrec.view([T, self.bs, d, d])
        return None

    def training_step(self, train_batch, batch_idx):
        X, u = train_batch
        self.forward(X, u)

        lhood = -self.loss_fn(self.Xrec, X)
        lhood = lhood.sum([0, 2, 3]).mean()
        kl_q = torch.distributions.kl.kl_divergence(self.Q_q, self.P_q).mean()
        norm_penalty = (self.q0_m.norm(dim=-1).mean() - 1)**2

        lambda_ = self.current_epoch / 8000 if self.hparams.annealing else 1 / 100
        loss = -lhood + kl_q + lambda_ * norm_penalty

        logs = {
            'recon_loss': -lhood,
            'kl_q_loss': kl_q,
            'train_loss': loss,
            'monitor': -lhood + kl_q
        }
        return {'loss': loss, 'log': logs, 'progress_bar': logs}

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), self.hparams.learning_rate)

    @staticmethod
    def add_model_specific_args(parent_parser):
        """
        Specify the hyperparams for this LightningModule
        """
        # MODEL specific
        parser = ArgumentParser(parents=[parent_parser], add_help=False)
        parser.add_argument('--learning_rate', default=1e-3, type=float)
        parser.add_argument('--batch_size', default=512, type=int)

        return parser
コード例 #4
0
class Model(pl.LightningModule):
    def __init__(self, hparams, data_path=None):
        super(Model, self).__init__()
        self.hparams = hparams
        self.data_path = data_path
        self.T_pred = self.hparams.T_pred
        self.loss_fn = torch.nn.MSELoss(reduction='none')

        self.recog_net_1 = MLP_Encoder(64 * 64, 300, 3, nonlinearity='elu')
        self.recog_net_2 = MLP_Encoder(64 * 64, 300, 3, nonlinearity='elu')
        self.obs_net = MLP_Decoder(4, 100, 3 * 64 * 64, nonlinearity='elu')

        V_net = MLP(4, 100, 1)
        M_net = PSD(4, 300, 2)
        g_net = MatrixNet(4, 100, 4, shape=(2, 2))

        self.ode = Lag_Net(q_dim=2,
                           u_dim=2,
                           g_net=g_net,
                           M_net=M_net,
                           V_net=V_net)

        self.link1_para = torch.nn.Parameter(
            torch.tensor(0.0, dtype=self.dtype))

        self.train_dataset = None
        self.non_ctrl_ind = 1

    def train_dataloader(self):
        if self.hparams.homo_u:
            # must set trainer flag reload_dataloaders_every_epoch=True
            if self.train_dataset is None:
                self.train_dataset = HomoImageDataset(self.data_path,
                                                      self.hparams.T_pred)
            if self.current_epoch < 1000:
                # feed zero ctrl dataset and ctrl dataset in turns
                if self.current_epoch % 2 == 0:
                    u_idx = 0
                else:
                    u_idx = self.non_ctrl_ind
                    self.non_ctrl_ind += 1
                    if self.non_ctrl_ind == 9:
                        self.non_ctrl_ind = 1
            else:
                u_idx = self.current_epoch % 9
            self.train_dataset.u_idx = u_idx
            self.t_eval = torch.from_numpy(self.train_dataset.t_eval)
            return DataLoader(self.train_dataset,
                              batch_size=self.hparams.batch_size,
                              shuffle=True,
                              collate_fn=my_collate)
        else:
            train_dataset = ImageDataset(self.data_path,
                                         self.hparams.T_pred,
                                         ctrl=True)
            self.t_eval = torch.from_numpy(train_dataset.t_eval)
            return DataLoader(train_dataset,
                              batch_size=self.hparams.batch_size,
                              shuffle=True,
                              collate_fn=my_collate)

    def angle_vel_est(self, q0_m_n, q1_m_n, delta_t):
        delta_cos = q1_m_n[:, 0:1] - q0_m_n[:, 0:1]
        delta_sin = q1_m_n[:, 1:2] - q0_m_n[:, 1:2]
        q_dot0 = -delta_cos * q0_m_n[:, 1:
                                     2] / delta_t + delta_sin * q0_m_n[:, 0:
                                                                       1] / delta_t
        return q_dot0

    def encode(self, batch_image):
        phi1_m_logv = self.recog_net_1(batch_image[:, 0:1].reshape(
            self.bs, self.d * self.d))
        phi1_m, phi1_logv = phi1_m_logv.split([2, 1], dim=1)
        phi1_m_n = phi1_m / phi1_m.norm(dim=-1, keepdim=True)
        phi1_v = F.softplus(phi1_logv) + 1

        phi2_m_logv = self.recog_net_2(batch_image[:, 1:2].reshape(
            self.bs, self.d * self.d))
        phi2_m, phi2_logv = phi2_m_logv.split([2, 1], dim=1)
        phi2_m_n = phi2_m / phi2_m.norm(dim=-1, keepdim=True)
        phi2_v = F.softplus(phi2_logv) + 1
        return phi1_m, phi1_v, phi1_m_n, phi2_m, phi2_v, phi2_m_n

    def get_theta(self, cos, sin, x, y, bs=None):
        # x, y should have shape (bs, )
        bs = self.bs if bs is None else bs
        theta = torch.zeros([bs, 2, 3], dtype=self.dtype, device=self.device)
        theta[:, 0, 0] += cos
        theta[:, 0, 1] += sin
        theta[:, 0, 2] += x
        theta[:, 1, 0] += -sin
        theta[:, 1, 1] += cos
        theta[:, 1, 2] += y
        return theta

    def get_theta_inv(self, cos, sin, x, y, bs=None):
        bs = self.bs if bs is None else bs
        theta = torch.zeros([bs, 2, 3], dtype=self.dtype, device=self.device)
        theta[:, 0, 0] += cos
        theta[:, 0, 1] += -sin
        theta[:, 0, 2] += -x * cos + y * sin
        theta[:, 1, 0] += sin
        theta[:, 1, 1] += cos
        theta[:, 1, 2] += -x * sin - y * cos
        return theta

    def forward(self, X, u):
        [_, self.bs, c, self.d, self.d] = X.shape
        T = len(self.t_eval)
        self.link1_l = torch.sigmoid(self.link1_para)
        # encode
        self.phi1_m_t0, self.phi1_v_t0, self.phi1_m_n_t0, self.phi2_m_t0, self.phi2_v_t0, self.phi2_m_n_t0 = self.encode(
            X[0])
        self.phi1_m_t1, self.phi1_v_t1, self.phi1_m_n_t1, self.phi2_m_t1, self.phi2_v_t1, self.phi2_m_n_t1 = self.encode(
            X[1])
        # reparametrize
        self.Q_phi1 = VonMisesFisher(self.phi1_m_n_t0, self.phi1_v_t0)
        self.Q_phi2 = VonMisesFisher(self.phi2_m_n_t0, self.phi2_v_t0)
        self.P_hyper_uni = HypersphericalUniform(1, device=self.device)
        self.phi1_t0 = self.Q_phi1.rsample()
        while torch.isnan(self.phi1_t0).any():
            self.phi1_t0 = self.Q_phi1.rsample()
        self.phi2_t0 = self.Q_phi2.rsample()
        while torch.isnan(self.phi2_t0).any():
            self.phi2_t0 = self.Q_phi2.rsample()

        # estimate velocity
        self.phi1_dot_t0 = self.angle_vel_est(self.phi1_m_n_t0,
                                              self.phi1_m_n_t1,
                                              self.t_eval[1] - self.t_eval[0])
        self.phi2_dot_t0 = self.angle_vel_est(self.phi2_m_n_t0,
                                              self.phi2_m_n_t1,
                                              self.t_eval[1] - self.t_eval[0])

        # predict
        z0_u = torch.cat([
            self.phi1_t0[:, 0:1], self.phi2_t0[:, 0:1], self.phi1_t0[:, 1:2],
            self.phi2_t0[:, 1:2], self.phi1_dot_t0, self.phi2_dot_t0, u
        ],
                         dim=1)
        zT_u = odeint(self.ode, z0_u, self.t_eval,
                      method=self.hparams.solver)  # T, bs, 4
        self.qT, self.q_dotT, _ = zT_u.split([4, 2, 2], dim=-1)
        self.qT = self.qT.view(T * self.bs, 4)

        # decode
        self.Xrec = self.obs_net(self.qT).view(T, self.bs, 3, self.d, self.d)
        return None

    def training_step(self, train_batch, batch_idx):
        X, u = train_batch
        self.forward(X, u)

        lhood = -self.loss_fn(self.Xrec, X)
        lhood = lhood.sum([0, 2, 3, 4]).mean()
        kl_q = torch.distributions.kl.kl_divergence(self.Q_phi1, self.P_hyper_uni).mean() + \
               torch.distributions.kl.kl_divergence(self.Q_phi2, self.P_hyper_uni).mean()
        norm_penalty = (self.phi1_m_t0.norm(dim=-1).mean() - 1) ** 2 + \
                       (self.phi2_m_t0.norm(dim=-1).mean() - 1) ** 2

        loss = -lhood + kl_q + 1 / 100 * norm_penalty

        logs = {'recon_loss': -lhood, 'kl_q_loss': kl_q, 'train_loss': loss}
        return {'loss': loss, 'log': logs, 'progress_bar': logs}

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), self.hparams.learning_rate)

    @staticmethod
    def add_model_specific_args(parent_parser):
        """
        Specify the hyperparams for this LightningModule
        """
        # MODEL specific
        parser = ArgumentParser(parents=[parent_parser], add_help=False)
        parser.add_argument('--learning_rate', default=1e-3, type=float)
        parser.add_argument('--batch_size', default=1024, type=int)

        return parser