Ejemplo n.º 1
0
    def __init__(self, dataset='paired_mnist'):
        super(VCCA, self).__init__()
        self.dataset = dataset

        if self.dataset == 'paired_mnist':
            self.z_dim = 16
            self.joint_dim = 4

            self.enc_x = MnistCNNEncoder(out_dim=2 *
                                         (self.z_dim - self.joint_dim))
            self.enc_y = MnistCNNEncoder(out_dim=2 * self.z_dim)

            self.loss_rec_lambda_x = 10
            self.loss_rec_lambda_y = 10

            self.beta = 0.01

            self.dec_x = MnistCNNDecoder(in_dim=self.z_dim)
            self.dec_y = MnistCNNDecoder(in_dim=self.z_dim)
        elif self.dataset == 'lincs_rnn':
            self.z_dim = 20
            self.joint_dim = 10

            rnn_1 = RNNEncoder(out_dim=88)
            rnn_1.load_state_dict(
                torch.load('../saved_models/rnn_enc.ckpt',
                           map_location='cuda:1'))
            self.enc_x = FinetunedEncoder(rnn_1,
                                          out_dim=2 *
                                          (self.z_dim - self.joint_dim))

            self.enc_y = ExprDiffEncoder(out_dim=2 * self.z_dim)

            self.loss_rec_lambda_x = 5
            self.loss_rec_lambda_y = 1

            self.beta = 1

            rnn_2 = RNNDecoder(in_dim=44)
            rnn_2.load_state_dict(
                torch.load('../saved_models/rnn_dec.ckpt',
                           map_location='cuda:1'))
            self.dec_x = FinetunedDecoder(rnn_2, in_dim=self.z_dim)
            self.dec_y = ExprDiffDecoder(in_dim=self.z_dim)
        elif self.dataset == 'lincs_rnn_reverse':
            self.z_dim = 20
            self.joint_dim = 10

            rnn_1 = RNNEncoder(out_dim=88)
            rnn_1.load_state_dict(
                torch.load('../saved_models/rnn_enc.ckpt',
                           map_location='cuda:1'))
            self.enc_y = FinetunedEncoder(rnn_1, out_dim=2 * self.z_dim)

            self.enc_x = ExprDiffEncoder(out_dim=2 *
                                         (self.z_dim - self.joint_dim))

            self.loss_rec_lambda_x = 1
            self.loss_rec_lambda_y = 0.2

            self.beta = 1

            rnn_2 = RNNDecoder(in_dim=44)
            rnn_2.load_state_dict(
                torch.load('../saved_models/rnn_dec.ckpt',
                           map_location='cuda:1'))
            self.dec_y = FinetunedDecoder(rnn_2, in_dim=self.z_dim)
            self.dec_x = ExprDiffDecoder(in_dim=self.z_dim)
Ejemplo n.º 2
0
    def __init__(self, dataset='paired_mnist'):
        super(BiAAE, self).__init__()
        self.dataset = dataset

        if self.dataset == 'paired_mnist':
            self.z_dim = 16
            self.joint_dim = 4

            self.loss_rec_lambda_x = 10
            self.loss_rec_lambda_y = 10

            self.loss_normal_lambda = 0.3
            self.loss_indep_lambda_x = 1
            self.loss_indep_lambda_y = 1
            self.loss_mse_lambda = 0.1

            self.discr_steps = 1
            self.gen_steps = 1

            self.enc_x = MnistCNNEncoder(out_dim=self.z_dim)
            self.enc_y = MnistCNNEncoder(out_dim=self.z_dim)

            self.dec_x = MnistCNNDecoder(in_dim=self.z_dim)
            self.dec_y = MnistCNNDecoder(in_dim=self.z_dim)

            self.discr = FCDiscriminator(in_dim=2 * self.z_dim -
                                         self.joint_dim,
                                         use_sigmoid=False)

            self.discr_indep_x = FCDiscriminator(in_dim=2 * self.z_dim -
                                                 self.joint_dim,
                                                 use_sigmoid=False)

            self.discr_indep_y = FCDiscriminator(in_dim=2 * self.z_dim -
                                                 self.joint_dim,
                                                 use_sigmoid=False)

        elif self.dataset == 'lincs_rnn':
            self.z_dim = 20
            self.joint_dim = 10

            self.loss_rec_lambda_x = 5
            self.loss_rec_lambda_y = 1

            self.loss_normal_lambda = 0.5
            self.loss_indep_lambda_x = 0.5
            self.loss_indep_lambda_y = 0.5
            self.loss_mse_lambda = 0.5

            self.discr_steps = 3
            self.gen_steps = 1

            rnn_1 = RNNEncoder(out_dim=88)
            rnn_1.load_state_dict(
                torch.load('../saved_models/rnn_enc.ckpt',
                           map_location='cuda:1'))
            self.enc_x = FinetunedEncoder(rnn_1, out_dim=self.z_dim)
            self.enc_y = ExprDiffEncoder(out_dim=self.z_dim)

            rnn_2 = RNNDecoder(in_dim=44)
            rnn_2.load_state_dict(
                torch.load('../saved_models/rnn_dec.ckpt',
                           map_location='cuda:1'))
            self.dec_x = FinetunedDecoder(rnn_2, in_dim=self.z_dim)

            self.dec_y = ExprDiffDecoder(in_dim=self.z_dim)

            self.discr = FCDiscriminator(in_dim=2 * self.z_dim -
                                         self.joint_dim,
                                         use_sigmoid=False)

            self.discr_indep_x = FCDiscriminator(in_dim=2 * self.z_dim -
                                                 self.joint_dim,
                                                 use_sigmoid=False)

            self.discr_indep_y = FCDiscriminator(in_dim=2 * self.z_dim -
                                                 self.joint_dim,
                                                 use_sigmoid=False)
        elif self.dataset == 'lincs_rnn_reverse':
            self.z_dim = 20
            self.joint_dim = 10

            self.loss_rec_lambda_x = 1
            self.loss_rec_lambda_y = 0.2

            self.loss_normal_lambda = 0.5
            self.loss_indep_lambda_x = 0.5
            self.loss_indep_lambda_y = 0.5
            self.loss_mse_lambda = 0.5

            self.discr_steps = 3
            self.gen_steps = 1

            rnn_1 = RNNEncoder(out_dim=88)
            rnn_1.load_state_dict(
                torch.load('../saved_models/rnn_enc.ckpt',
                           map_location='cuda:1'))
            self.enc_y = FinetunedEncoder(rnn_1, out_dim=self.z_dim)
            self.enc_x = ExprDiffEncoder(out_dim=self.z_dim)

            rnn_2 = RNNDecoder(in_dim=44)
            rnn_2.load_state_dict(
                torch.load('../saved_models/rnn_dec.ckpt',
                           map_location='cuda:1'))
            self.dec_y = FinetunedDecoder(rnn_2, in_dim=self.z_dim)

            self.dec_x = ExprDiffDecoder(in_dim=self.z_dim)

            self.discr = FCDiscriminator(in_dim=2 * self.z_dim -
                                         self.joint_dim,
                                         use_sigmoid=False)

            self.discr_indep_x = FCDiscriminator(in_dim=2 * self.z_dim -
                                                 self.joint_dim,
                                                 use_sigmoid=False)

            self.discr_indep_y = FCDiscriminator(in_dim=2 * self.z_dim -
                                                 self.joint_dim,
                                                 use_sigmoid=False)
Ejemplo n.º 3
0
    def __init__(self, dataset='paired_mnist'):
        super(CVAE, self).__init__()
        self.dataset = dataset

        if self.dataset == 'paired_mnist':
            self.z_dim = 8

            self.loss_rec_lambda_x = 10
            self.loss_rec_lambda_y = 10

            self.beta = 0.01

            self.enc_y = MnistCNNEncoder(out_dim=2 * self.z_dim)

            self.enc_xy = JointEncoder(MnistCNNEncoder(out_dim=2 * self.z_dim,
                                                       short_tail=True),
                                       MnistCNNEncoder(out_dim=2 * self.z_dim,
                                                       short_tail=True),
                                       out_dim=2 * self.z_dim)

            self.dec_x_cond = ConditionedDecoder(
                dec=MnistCNNDecoder(in_dim=self.z_dim + self.z_dim),
                cond=MnistCNNEncoder(out_dim=self.z_dim, short_tail=True))
        elif self.dataset == 'lincs_rnn':
            self.z_dim = 10

            self.loss_rec_lambda_x = 5
            self.loss_rec_lambda_y = 1

            self.beta = 1

            rnn_1 = RNNEncoder(out_dim=88)
            rnn_1.load_state_dict(
                torch.load('../saved_models/rnn_enc.ckpt',
                           map_location='cuda:1'))
            enc_x = FinetunedEncoder(rnn_1, out_dim=2 * self.z_dim)

            self.enc_xy = JointEncoder(enc_x,
                                       ExprDiffEncoder(out_dim=2 * self.z_dim),
                                       out_dim=2 * self.z_dim)

            self.enc_y = ExprDiffEncoder(out_dim=2 * self.z_dim)

            rnn_2 = RNNDecoder(in_dim=44)
            rnn_2.load_state_dict(
                torch.load('../saved_models/rnn_dec.ckpt',
                           map_location='cuda:1'))
            dec_x = FinetunedDecoder(rnn_2, in_dim=2 * self.z_dim)
            self.dec_x_cond = ConditionedDecoder(
                dec=dec_x, cond=ExprDiffEncoder(out_dim=self.z_dim))

        elif self.dataset == 'lincs_rnn_reverse':
            self.z_dim = 10

            self.loss_rec_lambda_x = 1
            self.loss_rec_lambda_y = 0.2

            self.beta = 1

            rnn_1 = RNNEncoder(out_dim=88)
            rnn_1.load_state_dict(
                torch.load('../saved_models/rnn_enc.ckpt',
                           map_location='cuda:1'))
            self.enc_y = FinetunedEncoder(rnn_1, out_dim=2 * self.z_dim)

            self.enc_xy = JointEncoder(ExprDiffEncoder(out_dim=2 * self.z_dim),
                                       FinetunedEncoder(rnn_1,
                                                        out_dim=2 *
                                                        self.z_dim),
                                       out_dim=2 * self.z_dim)

            rnn_2 = RNNEncoder(out_dim=88)
            rnn_2.load_state_dict(
                torch.load('../saved_models/rnn_enc.ckpt',
                           map_location='cuda:1'))

            self.dec_x_cond = ConditionedDecoder(
                dec=ExprDiffDecoder(in_dim=2 * self.z_dim),
                cond=FinetunedEncoder(rnn_2, out_dim=self.z_dim))