Beispiel #1
0
    def __init__(self, config):
        """Initialize configurations."""

        self.config = config
        # Data loader.
        self.data = SparseMolecularDataset()
        self.data.load(config.mol_data_dir)

        # Model configurations.
        self.z_dim = config.z_dim
        self.m_dim = self.data.atom_num_types
        self.b_dim = self.data.bond_num_types
        self.g_conv_dim = config.g_conv_dim
        self.d_conv_dim = config.d_conv_dim
        self.g_repeat_num = config.g_repeat_num
        self.d_repeat_num = config.d_repeat_num
        self.lambda_cls = config.lambda_cls
        self.lambda_rec = config.lambda_rec
        self.lambda_gp = config.lambda_gp
        self.post_method = config.post_method

        self.metric = 'validity,sas'

        # Training configurations.
        self.batch_size = config.batch_size
        self.num_iters = config.num_iters
        self.num_iters_decay = config.num_iters_decay
        self.g_lr = config.g_lr
        self.d_lr = config.d_lr
        self.dropout = config.dropout
        self.n_critic = config.n_critic
        self.beta1 = config.beta1
        self.beta2 = config.beta2
        self.resume_iters = config.resume_iters

        # Test configurations.
        self.test_iters = config.test_iters

        # Miscellaneous.
        self.use_tensorboard = config.use_tensorboard
        self.device = torch.device(
            'cuda' if torch.cuda.is_available() else 'cpu')

        # Directories.
        self.log_dir = config.log_dir
        self.sample_dir = config.sample_dir
        self.model_save_dir = config.model_save_dir
        self.result_dir = config.result_dir

        # Step size.
        self.log_step = config.log_step
        self.sample_step = config.sample_step
        self.model_save_step = config.model_save_step
        self.lr_update_step = config.lr_update_step

        # Build the model and tensorboard.
        self.build_model()
        if self.use_tensorboard:
            self.build_tensorboard()
Beispiel #2
0
    def __init__(self, config, log=None):
        """Initialize configurations."""

        # Log
        self.log = log

        # Data loader.
        self.data = SparseMolecularDataset()
        self.data.load(config.mol_data_dir)

        # Model configurations.
        self.z_dim = config.z_dim
        self.m_dim = self.data.atom_num_types
        self.b_dim = self.data.bond_num_types
        self.f_dim = self.data.features
        self.g_conv_dim = config.g_conv_dim
        self.d_conv_dim = config.d_conv_dim
        self.lambda_wgan = config.lambda_wgan
        self.lambda_rec = config.lambda_rec
        self.post_method = config.post_method

        self.metric = 'validity,qed'

        # Training configurations.
        self.batch_size = config.batch_size
        self.num_epochs = config.num_epochs
        self.num_steps = (len(self.data) // self.batch_size)
        self.g_lr = config.g_lr
        self.d_lr = config.d_lr
        self.dropout_rate = config.dropout
        self.n_critic = config.n_critic
        self.resume_epoch = config.resume_epoch

        # Training or testing.
        self.mode = config.mode

        # Miscellaneous.
        self.device = torch.device(
            'cuda' if torch.cuda.is_available() else 'cpu')
        print('Device: ', self.device)

        # Directories.
        self.log_dir_path = config.log_dir_path
        self.model_dir_path = config.model_dir_path
        self.img_dir_path = config.img_dir_path

        # Step size.
        self.model_save_step = config.model_save_step

        # VAE KL weight.
        self.kl_la = 1.

        # Build the model.
        self.build_model()
class SparseMolecular(data.Dataset):
    """Dataset class for the CelebA dataset."""
    def __init__(self, data_dir):
        """Initialize and preprocess the CelebA dataset."""
        self.data = SparseMolecularDataset()
        self.data.load(data_dir)

    def __getitem__(self, index):
        """Return one image and its corresponding attribute label."""

        return index, self.data.data[index], self.data.smiles[index], \
               self.data.data_S[index], self.data.data_A[index], \
               self.data.data_X[index], self.data.data_D[index], \
               self.data.data_F[index], self.data.data_Le[index], \
               self.data.data_Lv[index]

    def __len__(self):
        """Return the number of images."""
        return len(self.data)
Beispiel #4
0
class Solver(object):
    """Solver for training and testing StarGAN."""
    def __init__(self, config):
        """Initialize configurations."""

        # Data loader.
        self.data = SparseMolecularDataset()
        self.data.load(config.mol_data_dir)

        # Model configurations.
        self.z_dim = config.qubits
        self.m_dim = self.data.atom_num_types
        self.b_dim = self.data.bond_num_types
        self.g_conv_dim = config.g_conv_dim
        self.d_conv_dim = config.d_conv_dim
        self.g_repeat_num = config.g_repeat_num
        self.d_repeat_num = config.d_repeat_num
        self.lambda_cls = config.lambda_cls
        self.lambda_rec = config.lambda_rec
        self.lambda_gp = config.lambda_gp
        self.post_method = config.post_method

        self.metric = 'validity,sas'

        # Training configurations.
        self.batch_size = config.batch_size
        self.num_iters = config.num_iters
        self.num_iters_decay = config.num_iters_decay
        self.g_lr = config.g_lr
        self.d_lr = config.d_lr
        self.dropout = config.dropout
        self.n_critic = config.n_critic
        self.beta1 = config.beta1
        self.beta2 = config.beta2
        self.resume_iters = config.resume_iters

        # Test configurations.
        self.test_iters = config.test_iters

        # Miscellaneous.
        self.use_tensorboard = config.use_tensorboard
        self.device = torch.device(
            'cuda' if torch.cuda.is_available() else 'cpu')

        # Directories.
        self.log_dir = config.log_dir
        self.sample_dir = config.sample_dir
        self.model_save_dir = config.model_save_dir
        self.result_dir = config.result_dir

        # Step size.
        self.log_step = config.log_step
        self.sample_step = config.sample_step
        self.model_save_step = config.model_save_step
        self.lr_update_step = config.lr_update_step

        # Build the model and tensorboard.
        self.build_model()
        if self.use_tensorboard:
            self.build_tensorboard()

    def build_model(self):
        """Create a generator and a discriminator."""
        self.G = Generator(self.g_conv_dim, self.z_dim, self.data.vertexes,
                           self.data.bond_num_types, self.data.atom_num_types,
                           self.dropout)
        self.D = Discriminator(self.d_conv_dim, self.m_dim, self.b_dim,
                               self.dropout)
        self.V = Discriminator(self.d_conv_dim, self.m_dim, self.b_dim,
                               self.dropout)

        self.g_optimizer = torch.optim.Adam(
            list(self.G.parameters()) + list(self.V.parameters()), self.g_lr,
            [self.beta1, self.beta2])
        self.d_optimizer = torch.optim.Adam(self.D.parameters(), self.d_lr,
                                            [self.beta1, self.beta2])
        self.print_network(self.G, 'G')
        self.print_network(self.D, 'D')

        self.G.to(self.device)
        self.D.to(self.device)
        self.V.to(self.device)

    def print_network(self, model, name):
        """Print out the network information."""
        num_params = 0
        for p in model.parameters():
            num_params += p.numel()
        print(model)
        print(name)
        print("The number of parameters: {}".format(num_params))

    def restore_model(self, resume_iters):
        """Restore the trained generator and discriminator."""
        print(
            'Loading the trained models from step {}...'.format(resume_iters))
        G_path = os.path.join(self.model_save_dir,
                              '{}-G.ckpt'.format(resume_iters))
        D_path = os.path.join(self.model_save_dir,
                              '{}-D.ckpt'.format(resume_iters))
        V_path = os.path.join(self.model_save_dir,
                              '{}-V.ckpt'.format(resume_iters))
        self.G.load_state_dict(
            torch.load(G_path, map_location=lambda storage, loc: storage))
        self.D.load_state_dict(
            torch.load(D_path, map_location=lambda storage, loc: storage))
        self.V.load_state_dict(
            torch.load(V_path, map_location=lambda storage, loc: storage))

    def build_tensorboard(self):
        """Build a tensorboard logger."""
        from torch.utils.tensorboard.logger import Logger  #from logger import Logger
        self.logger = Logger(self.log_dir)

    def update_lr(self, g_lr, d_lr):
        """Decay learning rates of the generator and discriminator."""
        for param_group in self.g_optimizer.param_groups:
            param_group['lr'] = g_lr
        for param_group in self.d_optimizer.param_groups:
            param_group['lr'] = d_lr

    def reset_grad(self):
        """Reset the gradient buffers."""
        self.g_optimizer.zero_grad()
        self.d_optimizer.zero_grad()

    def denorm(self, x):
        """Convert the range from [-1, 1] to [0, 1]."""
        out = (x + 1) / 2
        return out.clamp_(0, 1)

    def gradient_penalty(self, y, x):
        """Compute gradient penalty: (L2_norm(dy/dx) - 1)**2."""
        weight = torch.ones(y.size()).to(self.device)
        dydx = torch.autograd.grad(outputs=y,
                                   inputs=x,
                                   grad_outputs=weight,
                                   retain_graph=True,
                                   create_graph=True,
                                   only_inputs=True)[0]

        dydx = dydx.view(dydx.size(0), -1)
        dydx_l2norm = torch.sqrt(torch.sum(dydx**2, dim=1))
        return torch.mean((dydx_l2norm - 1)**2)

    def label2onehot(self, labels, dim):
        """Convert label indices to one-hot vectors."""
        out = torch.zeros(list(labels.size()) + [dim]).to(self.device)
        out.scatter_(len(out.size()) - 1, labels.unsqueeze(-1), 1.)
        return out

    def classification_loss(self, logit, target, dataset='CelebA'):
        """Compute binary or softmax cross entropy loss."""
        if dataset == 'CelebA':
            return F.binary_cross_entropy_with_logits(
                logit, target, size_average=False) / logit.size(0)
        elif dataset == 'RaFD':
            return F.cross_entropy(logit, target)

    def sample_z(self, batch_size):
        return np.random.normal(0, 1, size=(batch_size, self.z_dim))

    def postprocess(self, inputs, method, temperature=1.):
        def listify(x):
            return x if type(x) == list or type(x) == tuple else [x]

        def delistify(x):
            return x if len(x) > 1 else x[0]

        if method == 'soft_gumbel':
            softmax = [
                F.gumbel_softmax(
                    e_logits.contiguous().view(-1, e_logits.size(-1)) /
                    temperature,
                    hard=False).view(e_logits.size())
                for e_logits in listify(inputs)
            ]
        elif method == 'hard_gumbel':
            softmax = [
                F.gumbel_softmax(
                    e_logits.contiguous().view(-1, e_logits.size(-1)) /
                    temperature,
                    hard=True).view(e_logits.size())
                for e_logits in listify(inputs)
            ]
        else:
            softmax = [
                F.softmax(e_logits / temperature, -1)
                for e_logits in listify(inputs)
            ]

        return [delistify(e) for e in (softmax)]

    def reward(self, mols):
        rr = 1.
        for m in ('logp,sas,qed,unique'
                  if self.metric == 'all' else self.metric).split(','):

            if m == 'np':
                rr *= MolecularMetrics.natural_product_scores(mols, norm=True)
            elif m == 'logp':
                rr *= MolecularMetrics.water_octanol_partition_coefficient_scores(
                    mols, norm=True)
            elif m == 'sas':
                rr *= MolecularMetrics.synthetic_accessibility_score_scores(
                    mols, norm=True)
            elif m == 'qed':
                rr *= MolecularMetrics.quantitative_estimation_druglikeness_scores(
                    mols, norm=True)
            elif m == 'novelty':
                rr *= MolecularMetrics.novel_scores(mols, data)
            elif m == 'dc':
                rr *= MolecularMetrics.drugcandidate_scores(mols, data)
            elif m == 'unique':
                rr *= MolecularMetrics.unique_scores(mols)
            elif m == 'diversity':
                rr *= MolecularMetrics.diversity_scores(mols, data)
            elif m == 'validity':
                rr *= MolecularMetrics.valid_scores(mols)
            else:
                raise RuntimeError('{} is not defined as a metric'.format(m))

        return rr.reshape(-1, 1)

    def train(self):

        # Learning rate cache for decaying.
        g_lr = self.g_lr
        d_lr = self.d_lr

        # Start training from scratch or resume training.
        start_iters = 0
        if self.resume_iters:
            start_iters = self.resume_iters
            self.restore_model(self.resume_iters)

        # Start training.
        print('Start training...')
        start_time = time.time()
        for i in range(start_iters, self.num_iters):
            if (i + 1) % self.log_step == 0:
                mols, _, _, a, x, _, _, _, _ = self.data.next_validation_batch(
                )
                z = self.sample_z(a.shape[0])
                print('[Valid]', '')
            else:
                mols, _, _, a, x, _, _, _, _ = self.data.next_train_batch(
                    self.batch_size)
                z = self.sample_z(self.batch_size)

            # =================================================================================== #
            #                             1. Preprocess input data                                #
            # =================================================================================== #

            a = torch.from_numpy(a).to(self.device).long()  # Adjacency.
            x = torch.from_numpy(x).to(self.device).long()  # Nodes.
            a_tensor = self.label2onehot(a, self.b_dim)
            x_tensor = self.label2onehot(x, self.m_dim)
            z = torch.from_numpy(z).to(self.device).float()

            # =================================================================================== #
            #                             2. Train the discriminator                              #
            # =================================================================================== #

            # Compute loss with real images.
            logits_real, features_real = self.D(a_tensor, None, x_tensor)
            d_loss_real = -torch.mean(logits_real)

            # Compute loss with fake images.
            edges_logits, nodes_logits = self.G(z)
            # Postprocess with Gumbel softmax
            (edges_hat, nodes_hat) = self.postprocess(
                (edges_logits, nodes_logits), self.post_method)
            logits_fake, features_fake = self.D(edges_hat, None, nodes_hat)
            d_loss_fake = torch.mean(logits_fake)

            # Compute loss for gradient penalty.
            eps = torch.rand(logits_real.size(0), 1, 1, 1).to(self.device)
            x_int0 = (eps * a_tensor +
                      (1. - eps) * edges_hat).requires_grad_(True)
            x_int1 = (eps.squeeze(-1) * x_tensor +
                      (1. - eps.squeeze(-1)) * nodes_hat).requires_grad_(True)
            grad0, grad1 = self.D(x_int0, None, x_int1)
            d_loss_gp = self.gradient_penalty(
                grad0, x_int0) + self.gradient_penalty(grad1, x_int1)

            # Backward and optimize.
            d_loss = d_loss_fake + d_loss_real + self.lambda_gp * d_loss_gp
            self.reset_grad()
            d_loss.backward()
            self.d_optimizer.step()

            # Logging.
            loss = {}
            loss['D/loss_real'] = d_loss_real.item()
            loss['D/loss_fake'] = d_loss_fake.item()
            loss['D/loss_gp'] = d_loss_gp.item()

            # =================================================================================== #
            #                               3. Train the generator                                #
            # =================================================================================== #

            if (i + 1) % self.n_critic == 0:
                # Z-to-target
                edges_logits, nodes_logits = self.G(z)
                # Postprocess with Gumbel softmax
                (edges_hat, nodes_hat) = self.postprocess(
                    (edges_logits, nodes_logits), self.post_method)
                logits_fake, features_fake = self.D(edges_hat, None, nodes_hat)
                g_loss_fake = -torch.mean(logits_fake)

                # Real Reward
                rewardR = torch.from_numpy(self.reward(mols)).to(self.device)
                # Fake Reward
                (edges_hard, nodes_hard) = self.postprocess(
                    (edges_logits, nodes_logits), 'hard_gumbel')
                edges_hard, nodes_hard = torch.max(edges_hard,
                                                   -1)[1], torch.max(
                                                       nodes_hard, -1)[1]
                mols = [
                    self.data.matrices2mol(n_.data.cpu().numpy(),
                                           e_.data.cpu().numpy(),
                                           strict=True)
                    for e_, n_ in zip(edges_hard, nodes_hard)
                ]
                rewardF = torch.from_numpy(self.reward(mols)).to(self.device)

                # Value loss
                value_logit_real, _ = self.V(a_tensor, None, x_tensor,
                                             torch.sigmoid)
                value_logit_fake, _ = self.V(edges_hat, None, nodes_hat,
                                             torch.sigmoid)
                g_loss_value = torch.mean((value_logit_real - rewardR)**2 +
                                          (value_logit_fake - rewardF)**2)
                #rl_loss= -value_logit_fake
                #f_loss = (torch.mean(features_real, 0) - torch.mean(features_fake, 0)) ** 2

                # Backward and optimize.
                g_loss = g_loss_fake + g_loss_value
                self.reset_grad()
                g_loss.backward()
                self.g_optimizer.step()

                # Logging.
                loss['G/loss_fake'] = g_loss_fake.item()
                loss['G/loss_value'] = g_loss_value.item()

            # =================================================================================== #
            #                                 4. Miscellaneous                                    #
            # =================================================================================== #

            # Print out training information.
            if (i + 1) % self.log_step == 0:
                et = time.time() - start_time
                et = str(datetime.timedelta(seconds=et))[:-7]
                log = "Elapsed [{}], Iteration [{}/{}]".format(
                    et, i + 1, self.num_iters)

                # Log update
                m0, m1 = all_scores(
                    mols, self.data,
                    norm=True)  # 'mols' is output of Fake Reward
                m0 = {
                    k: np.array(v)[np.nonzero(v)].mean()
                    for k, v in m0.items()
                }
                m0.update(m1)
                loss.update(m0)
                for tag, value in loss.items():
                    log += ", {}: {:.4f}".format(tag, value)
                print(log)

                if self.use_tensorboard:
                    for tag, value in loss.items():
                        self.logger.scalar_summary(tag, value, i + 1)

            # Save model checkpoints.
            if (i + 1) % self.model_save_step == 0:
                G_path = os.path.join(self.model_save_dir,
                                      '{}-G.ckpt'.format(i + 1))
                D_path = os.path.join(self.model_save_dir,
                                      '{}-D.ckpt'.format(i + 1))
                V_path = os.path.join(self.model_save_dir,
                                      '{}-V.ckpt'.format(i + 1))
                torch.save(self.G.state_dict(), G_path)
                torch.save(self.D.state_dict(), D_path)
                torch.save(self.V.state_dict(), V_path)
                print('Saved model checkpoints into {}...'.format(
                    self.model_save_dir))

            # Decay learning rates.
            if (i + 1) % self.lr_update_step == 0 and (i + 1) > (
                    self.num_iters - self.num_iters_decay):
                g_lr -= (self.g_lr / float(self.num_iters_decay))
                d_lr -= (self.d_lr / float(self.num_iters_decay))
                self.update_lr(g_lr, d_lr)
                print('Decayed learning rates, g_lr: {}, d_lr: {}.'.format(
                    g_lr, d_lr))

    def test(self):
        # Load the trained generator.
        self.restore_model(self.test_iters)

        with torch.no_grad():
            mols, _, _, a, x, _, _, _, _ = self.data.next_test_batch()
            z = self.sample_z(a.shape[0])

            # Z-to-target
            edges_logits, nodes_logits = self.G(z)
            # Postprocess with Gumbel softmax
            (edges_hat, nodes_hat) = self.postprocess(
                (edges_logits, nodes_logits), self.post_method)
            logits_fake, features_fake = self.D(edges_hat, None, nodes_hat)
            g_loss_fake = -torch.mean(logits_fake)

            # Fake Reward
            (edges_hard, nodes_hard) = self.postprocess(
                (edges_logits, nodes_logits), 'hard_gumbel')
            edges_hard, nodes_hard = torch.max(edges_hard, -1)[1], torch.max(
                nodes_hard, -1)[1]
            mols = [
                self.data.matrices2mol(n_.data.cpu().numpy(),
                                       e_.data.cpu().numpy(),
                                       strict=True)
                for e_, n_ in zip(edges_hard, nodes_hard)
            ]

            # Log update
            m0, m1 = all_scores(mols, self.data,
                                norm=True)  # 'mols' is output of Fake Reward
            m0 = {k: np.array(v)[np.nonzero(v)].mean() for k, v in m0.items()}
            m0.update(m1)
            for tag, value in m0.items():
                log += ", {}: {:.4f}".format(tag, value)
 def __init__(self, data_dir):
     """Initialize and preprocess the CelebA dataset."""
     self.data = SparseMolecularDataset()
     self.data.load(data_dir)
Beispiel #6
0
class Solver(object):
    """Solver for training and testing StarGAN."""

    def __init__(self, config, log=None):
        """Initialize configurations."""

        # Log
        self.log = log

        # Data loader.
        self.data = SparseMolecularDataset()
        self.data.load(config.mol_data_dir)

        # Model configurations.
        self.z_dim = config.z_dim
        self.m_dim = self.data.atom_num_types
        self.b_dim = self.data.bond_num_types
        self.g_conv_dim = config.g_conv_dim
        self.d_conv_dim = config.d_conv_dim
        self.la = config.lambda_wgan
        self.lambda_rec = config.lambda_rec
        self.la_gp = config.lambda_gp
        self.post_method = config.post_method

        self.metric = 'validity,qed'

        # Training configurations.
        self.batch_size = config.batch_size
        self.num_epochs = config.num_epochs
        self.num_steps = (len(self.data) // self.batch_size)
        self.g_lr = config.g_lr
        self.d_lr = config.d_lr
        self.dropout = config.dropout
        if self.la > 0:
            self.n_critic = config.n_critic
        else:
            self.n_critic = 1
        self.resume_epoch = config.resume_epoch

        # Training or testing.
        self.mode = config.mode

        # Miscellaneous.
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        print('Device: ', self.device)

        # Directories.
        self.log_dir_path = config.log_dir_path
        self.model_dir_path = config.model_dir_path
        self.img_dir_path = config.img_dir_path

        # Step size.
        self.model_save_step = config.model_save_step

        # Build the model.
        self.build_model()

    def build_model(self):
        """Create a generator and a discriminator."""
        self.G = Generator(self.g_conv_dim, self.z_dim,
                           self.data.vertexes,
                           self.data.bond_num_types,
                           self.data.atom_num_types,
                           self.dropout)
        self.D = Discriminator(self.d_conv_dim, self.m_dim, self.b_dim - 1, self.dropout)
        self.V = Discriminator(self.d_conv_dim, self.m_dim, self.b_dim - 1, self.dropout)

        self.g_optimizer = torch.optim.RMSprop(self.G.parameters(), self.g_lr)
        self.d_optimizer = torch.optim.RMSprop(self.D.parameters(), self.d_lr)
        self.v_optimizer = torch.optim.RMSprop(self.V.parameters(), self.g_lr)
        self.print_network(self.G, 'G', self.log)
        self.print_network(self.D, 'D', self.log)
        self.print_network(self.V, 'V', self.log)

        self.G.to(self.device)
        self.D.to(self.device)
        self.V.to(self.device)

    @staticmethod
    def print_network(model, name, log=None):
        """Print out the network information."""
        num_params = 0
        for p in model.parameters():
            num_params += p.numel()
        print(model)
        print(name)
        print("The number of parameters: {}".format(num_params))
        if log is not None:
            log.info(model)
            log.info(name)
            log.info("The number of parameters: {}".format(num_params))

    def restore_model(self, resume_iters):
        """Restore the trained generator and discriminator."""
        print('Loading the trained models from step {}...'.format(resume_iters))
        G_path = os.path.join(self.model_dir_path, '{}-G.ckpt'.format(resume_iters))
        D_path = os.path.join(self.model_dir_path, '{}-D.ckpt'.format(resume_iters))
        V_path = os.path.join(self.model_dir_path, '{}-V.ckpt'.format(resume_iters))
        self.G.load_state_dict(torch.load(G_path, map_location=lambda storage, loc: storage))
        self.D.load_state_dict(torch.load(D_path, map_location=lambda storage, loc: storage))
        self.V.load_state_dict(torch.load(V_path, map_location=lambda storage, loc: storage))

    def update_lr(self, g_lr, d_lr):
        """Decay learning rates of the generator and discriminator."""
        for param_group in self.g_optimizer.param_groups:
            param_group['lr'] = g_lr
        for param_group in self.d_optimizer.param_groups:
            param_group['lr'] = d_lr

    def reset_grad(self):
        """Reset the gradient buffers."""
        self.g_optimizer.zero_grad()
        self.d_optimizer.zero_grad()
        self.v_optimizer.zero_grad()

    def gradient_penalty(self, y, x):
        """Compute gradient penalty: (L2_norm(dy/dx) - 1)**2."""
        weight = torch.ones(y.size()).to(self.device)
        dydx = torch.autograd.grad(outputs=y,
                                   inputs=x,
                                   grad_outputs=weight,
                                   retain_graph=True,
                                   create_graph=True,
                                   only_inputs=True)[0]

        dydx = dydx.view(dydx.size(0), -1)
        dydx_l2norm = torch.sqrt(torch.sum(dydx ** 2, dim=1))
        return torch.mean((dydx_l2norm - 1) ** 2)

    def label2onehot(self, labels, dim):
        """Convert label indices to one-hot vectors."""
        out = torch.zeros(list(labels.size()) + [dim]).to(self.device)
        out.scatter_(len(out.size()) - 1, labels.unsqueeze(-1), 1.)
        return out

    def sample_z(self, batch_size):
        return np.random.normal(0, 1, size=(batch_size, self.z_dim))

    @staticmethod
    def postprocess(inputs, method, temperature=1.):
        def listify(x):
            return x if type(x) == list or type(x) == tuple else [x]

        def delistify(x):
            return x if len(x) > 1 else x[0]

        if method == 'soft_gumbel':
            softmax = [F.gumbel_softmax(e_logits.contiguous().view(-1, e_logits.size(-1))
                                        / temperature, hard=False).view(e_logits.size())
                       for e_logits in listify(inputs)]
        elif method == 'hard_gumbel':
            softmax = [F.gumbel_softmax(e_logits.contiguous().view(-1, e_logits.size(-1))
                                        / temperature, hard=True).view(e_logits.size())
                       for e_logits in listify(inputs)]
        else:
            softmax = [F.softmax(e_logits / temperature, -1)
                       for e_logits in listify(inputs)]

        return [delistify(e) for e in (softmax)]

    def reward(self, mols):
        rr = 1.
        for m in ('logp,sas,qed,unique' if self.metric == 'all' else self.metric).split(','):

            if m == 'np':
                rr *= MolecularMetrics.natural_product_scores(mols, norm=True)
            elif m == 'logp':
                rr *= MolecularMetrics.water_octanol_partition_coefficient_scores(mols, norm=True)
            elif m == 'sas':
                rr *= MolecularMetrics.synthetic_accessibility_score_scores(mols, norm=True)
            elif m == 'qed':
                rr *= MolecularMetrics.quantitative_estimation_druglikeness_scores(mols, norm=True)
            elif m == 'novelty':
                rr *= MolecularMetrics.novel_scores(mols, self.data)
            elif m == 'dc':
                rr *= MolecularMetrics.drugcandidate_scores(mols, self.data)
            elif m == 'unique':
                rr *= MolecularMetrics.unique_scores(mols)
            elif m == 'diversity':
                rr *= MolecularMetrics.diversity_scores(mols, self.data)
            elif m == 'validity':
                rr *= MolecularMetrics.valid_scores(mols)
            else:
                raise RuntimeError('{} is not defined as a metric'.format(m))

        return rr.reshape(-1, 1)

    def train_and_validate(self):
        self.start_time = time.time()

        # Start training from scratch or resume training.
        start_epoch = 0
        if self.resume_epoch is not None:
            start_epoch = self.resume_epoch
            self.restore_model(self.resume_epoch)

        # Start training.
        if self.mode == 'train':
            print('Start training...')
            for i in range(start_epoch, self.num_epochs):
                self.train_or_valid(epoch_i=i, train_val_test='train')
                self.train_or_valid(epoch_i=i, train_val_test='val')
        elif self.mode == 'test':
            assert self.resume_epoch is not None
            self.train_or_valid(epoch_i=start_epoch, train_val_test='val')
        else:
            raise NotImplementedError

    def get_gen_mols(self, n_hat, e_hat, method):
        (edges_hard, nodes_hard) = self.postprocess((e_hat, n_hat), method)
        edges_hard, nodes_hard = torch.max(edges_hard, -1)[1], torch.max(nodes_hard, -1)[1]
        mols = [self.data.matrices2mol(n_.data.cpu().numpy(), e_.data.cpu().numpy(), strict=True)
                for e_, n_ in zip(edges_hard, nodes_hard)]
        return mols

    def get_reward(self, n_hat, e_hat, method):
        (edges_hard, nodes_hard) = self.postprocess((e_hat, n_hat), method)
        edges_hard, nodes_hard = torch.max(edges_hard, -1)[1], torch.max(nodes_hard, -1)[1]
        mols = [self.data.matrices2mol(n_.data.cpu().numpy(), e_.data.cpu().numpy(), strict=True)
                for e_, n_ in zip(edges_hard, nodes_hard)]
        reward = torch.from_numpy(self.reward(mols)).to(self.device)
        return reward

    def save_checkpoints(self, epoch_i):
        G_path = os.path.join(self.model_dir_path, '{}-G.ckpt'.format(epoch_i + 1))
        D_path = os.path.join(self.model_dir_path, '{}-D.ckpt'.format(epoch_i + 1))
        V_path = os.path.join(self.model_dir_path, '{}-V.ckpt'.format(epoch_i + 1))
        torch.save(self.G.state_dict(), G_path)
        torch.save(self.D.state_dict(), D_path)
        torch.save(self.V.state_dict(), V_path)
        print('Saved model checkpoints into {}...'.format(self.model_dir_path))
        if self.log is not None:
            self.log.info('Saved model checkpoints into {}...'.format(self.model_dir_path))

    def train_or_valid(self, epoch_i, train_val_test='val'):
        # The first several epochs using RL to purse stability (not used).
        if epoch_i < 0:
            cur_la = 0
        else:
            cur_la = self.la

        # Recordings
        losses = defaultdict(list)
        scores = defaultdict(list)

        # Iterations
        the_step = self.num_steps
        if train_val_test == 'val':
            if self.mode == 'train':
                the_step = 1
            print('[Validating]')

        for a_step in range(the_step):
            if train_val_test == 'val':
                mols, _, _, a, x, _, _, _, _ = self.data.next_validation_batch()
                z = self.sample_z(a.shape[0])
            elif train_val_test == 'train':
                mols, _, _, a, x, _, _, _, _ = self.data.next_train_batch(self.batch_size)
                z = self.sample_z(self.batch_size)
            else:
                raise NotImplementedError

            # =================================================================================== #
            #                             1. Preprocess input data                                #
            # =================================================================================== #

            a = torch.from_numpy(a).to(self.device).long()  # Adjacency.
            x = torch.from_numpy(x).to(self.device).long()  # Nodes.
            a_tensor = self.label2onehot(a, self.b_dim)
            x_tensor = self.label2onehot(x, self.m_dim)
            z = torch.from_numpy(z).to(self.device).float()

            # Current steps
            cur_step = self.num_steps * epoch_i + a_step
            # =================================================================================== #
            #                             2. Train the discriminator                              #
            # =================================================================================== #

            # Compute losses with real inputs.
            logits_real, features_real = self.D(a_tensor, None, x_tensor)

            # Z-to-target
            edges_logits, nodes_logits = self.G(z)
            # Postprocess with Gumbel softmax
            (edges_hat, nodes_hat) = self.postprocess((edges_logits, nodes_logits), self.post_method)
            logits_fake, features_fake = self.D(edges_hat, None, nodes_hat)

            # Compute losses for gradient penalty.
            eps = torch.rand(logits_real.size(0), 1, 1, 1).to(self.device)
            x_int0 = (eps * a_tensor + (1. - eps) * edges_hat).requires_grad_(True)
            x_int1 = (eps.squeeze(-1) * x_tensor + (1. - eps.squeeze(-1)) * nodes_hat).requires_grad_(True)
            grad0, grad1 = self.D(x_int0, None, x_int1)
            grad_penalty = self.gradient_penalty(grad0, x_int0) + self.gradient_penalty(grad1, x_int1)

            d_loss_real = torch.mean(logits_real)
            d_loss_fake = torch.mean(logits_fake)
            loss_D = -d_loss_real + d_loss_fake + self.la_gp * grad_penalty

            if cur_la > 0:
                losses['l_D/R'].append(d_loss_real.item())
                losses['l_D/F'].append(d_loss_fake.item())
                losses['l_D'].append(loss_D.item())

            # Optimise discriminator.
            if train_val_test == 'train' and cur_step % self.n_critic != 0 and cur_la > 0:
                self.reset_grad()
                loss_D.backward()
                self.d_optimizer.step()

            # =================================================================================== #
            #                               3. Train the generator                                #
            # =================================================================================== #

            # Z-to-target
            edges_logits, nodes_logits = self.G(z)
            # Postprocess with Gumbel softmax
            (edges_hat, nodes_hat) = self.postprocess((edges_logits, nodes_logits), self.post_method)
            logits_fake, features_fake = self.D(edges_hat, None, nodes_hat)

            # Value losses
            value_logit_real, _ = self.V(a_tensor, None, x_tensor, torch.sigmoid)
            value_logit_fake, _ = self.V(edges_hat, None, nodes_hat, torch.sigmoid)

            # Feature mapping losses. Not used anywhere in the PyTorch version.
            # I include it here for the consistency with the TF code.
            f_loss = (torch.mean(features_real, 0) - torch.mean(features_fake, 0)) ** 2

            # Real Reward
            reward_r = torch.from_numpy(self.reward(mols)).to(self.device)
            # Fake Reward
            reward_f = self.get_reward(nodes_hat, edges_hat, self.post_method)

            # Losses Update
            loss_G = -logits_fake
            # Original TF loss_V. Here we use absolute values instead of the squared one.
            # loss_V = (value_logit_real - reward_r) ** 2 + (value_logit_fake - reward_f) ** 2
            loss_V = torch.abs(value_logit_real - reward_r) + torch.abs(value_logit_fake - reward_f)
            loss_RL = -value_logit_fake

            loss_G = torch.mean(loss_G)
            loss_V = torch.mean(loss_V)
            loss_RL = torch.mean(loss_RL)
            losses['l_G'].append(loss_G.item())
            losses['l_RL'].append(loss_RL.item())
            losses['l_V'].append(loss_V.item())

            alpha = torch.abs(loss_G.detach() / loss_RL.detach()).detach()
            train_step_G = cur_la * loss_G + (1 - cur_la) * alpha * loss_RL

            train_step_V = loss_V

            if train_val_test == 'train':
                self.reset_grad()

                # Optimise generator.
                if cur_step % self.n_critic == 0:
                    train_step_G.backward(retain_graph=True)
                    self.g_optimizer.step()

                # Optimise value network.
                if cur_step % self.n_critic == 0:
                    train_step_V.backward()
                    self.v_optimizer.step()

            # =================================================================================== #
            #                                 4. Miscellaneous                                    #
            # =================================================================================== #

            # Get scores.
            if train_val_test == 'val':
                mols = self.get_gen_mols(nodes_logits, edges_logits, self.post_method)
                m0, m1 = all_scores(mols, self.data, norm=True)  # 'mols' is output of Fake Reward
                for k, v in m1.items():
                    scores[k].append(v)
                for k, v in m0.items():
                    scores[k].append(np.array(v)[np.nonzero(v)].mean())

                # Save checkpoints.
                if self.mode == 'train':
                    if (epoch_i + 1) % self.model_save_step == 0:
                        self.save_checkpoints(epoch_i=epoch_i)

                # Saving molecule images.
                mol_f_name = os.path.join(self.img_dir_path, 'mol-{}.png'.format(epoch_i))
                save_mol_img(mols, mol_f_name, is_test=self.mode == 'test')

                # Print out training information.
                et = time.time() - self.start_time
                et = str(datetime.timedelta(seconds=et))[:-7]
                log = "Elapsed [{}], Iteration [{}/{}]:".format(et, epoch_i + 1, self.num_epochs)

                is_first = True
                for tag, value in losses.items():
                    if is_first:
                        log += "\n{}: {:.2f}".format(tag, np.mean(value))
                        is_first = False
                    else:
                        log += ", {}: {:.2f}".format(tag, np.mean(value))
                is_first = True
                for tag, value in scores.items():
                    if is_first:
                        log += "\n{}: {:.2f}".format(tag, np.mean(value))
                        is_first = False
                    else:
                        log += ", {}: {:.2f}".format(tag, np.mean(value))
                print(log)

                if self.log is not None:
                    self.log.info(log)
Beispiel #7
0
class Solver(object):
    """Solver for training and testing StarGAN."""
    def __init__(self, config, log=None):
        """Initialize configurations."""

        # Log
        self.log = log

        # Data loader.
        self.data = SparseMolecularDataset()
        self.data.load(config.mol_data_dir)

        # Model configurations.
        self.z_dim = config.z_dim
        self.m_dim = self.data.atom_num_types
        self.b_dim = self.data.bond_num_types
        self.f_dim = self.data.features
        self.g_conv_dim = config.g_conv_dim
        self.d_conv_dim = config.d_conv_dim
        self.lambda_wgan = config.lambda_wgan
        self.lambda_rec = config.lambda_rec
        self.post_method = config.post_method

        self.metric = 'validity,qed'

        # Training configurations.
        self.batch_size = config.batch_size
        self.num_epochs = config.num_epochs
        self.num_steps = (len(self.data) // self.batch_size)
        self.g_lr = config.g_lr
        self.d_lr = config.d_lr
        self.dropout_rate = config.dropout
        self.n_critic = config.n_critic
        self.resume_epoch = config.resume_epoch

        # Training or testing.
        self.mode = config.mode

        # Miscellaneous.
        self.device = torch.device(
            'cuda' if torch.cuda.is_available() else 'cpu')
        print('Device: ', self.device)

        # Directories.
        self.log_dir_path = config.log_dir_path
        self.model_dir_path = config.model_dir_path
        self.img_dir_path = config.img_dir_path

        # Step size.
        self.model_save_step = config.model_save_step

        # VAE KL weight.
        self.kl_la = 1.

        # Build the model.
        self.build_model()

    def build_model(self):
        """Create an encoder and a decoder."""
        self.encoder = EncoderVAE(self.d_conv_dim,
                                  self.m_dim,
                                  self.b_dim - 1,
                                  self.z_dim,
                                  with_features=True,
                                  f_dim=self.f_dim,
                                  dropout_rate=self.dropout_rate).to(
                                      self.device)
        self.decoder = Generator(self.g_conv_dim, self.z_dim,
                                 self.data.vertexes, self.data.bond_num_types,
                                 self.data.atom_num_types,
                                 self.dropout_rate).to(self.device)
        self.V = Discriminator(self.d_conv_dim, self.m_dim, self.b_dim - 1,
                               self.dropout_rate).to(self.device)

        self.vae_optimizer = torch.optim.RMSprop(
            list(self.encoder.parameters()) + list(self.decoder.parameters()),
            self.g_lr)
        self.v_optimizer = torch.optim.RMSprop(self.V.parameters(), self.d_lr)

        self.print_network(self.encoder, 'Encoder', self.log)
        self.print_network(self.decoder, 'Decoder', self.log)
        self.print_network(self.V, 'Value', self.log)

    @staticmethod
    def print_network(model, name, log=None):
        """Print out the network information."""
        num_params = 0
        for p in model.parameters():
            num_params += p.numel()
        print(model)
        print(name)
        print("The number of parameters: {}".format(num_params))
        if log is not None:
            log.info(model)
            log.info(name)
            log.info("The number of parameters: {}".format(num_params))

    def restore_model(self, resume_iters):
        """Restore the trained generator and discriminator."""
        print(
            'Loading the trained models from step {}...'.format(resume_iters))
        enc_path = os.path.join(self.model_dir_path,
                                '{}-encoder.ckpt'.format(resume_iters))
        dec_path = os.path.join(self.model_dir_path,
                                '{}-decoder.ckpt'.format(resume_iters))
        V_path = os.path.join(self.model_dir_path,
                              '{}-V.ckpt'.format(resume_iters))
        self.encoder.load_state_dict(
            torch.load(enc_path, map_location=lambda storage, loc: storage))
        self.decoder.load_state_dict(
            torch.load(dec_path, map_location=lambda storage, loc: storage))
        self.V.load_state_dict(
            torch.load(V_path, map_location=lambda storage, loc: storage))

    def update_lr(self, g_lr, d_lr):
        """Decay learning rates of the generator and discriminator."""
        for param_group in self.g_optimizer.param_groups:
            param_group['lr'] = g_lr
        for param_group in self.d_optimizer.param_groups:
            param_group['lr'] = d_lr

    def reset_grad(self):
        """Reset the gradient buffers."""
        self.vae_optimizer.zero_grad()
        self.v_optimizer.zero_grad()

    def label2onehot(self, labels, dim):
        """Convert label indices to one-hot vectors."""
        out = torch.zeros(list(labels.size()) + [dim]).to(self.device)
        out.scatter_(len(out.size()) - 1, labels.unsqueeze(-1), 1.)
        return out

    def sample_z(self, batch_size):
        return np.random.normal(0, 1, size=(batch_size, self.z_dim))

    @staticmethod
    def postprocess_logits(inputs, method, temperature=1.):
        def listify(x):
            return x if type(x) == list or type(x) == tuple else [x]

        def delistify(x):
            return x if len(x) > 1 else x[0]

        if method == 'soft_gumbel':
            softmax = [
                F.gumbel_softmax(
                    e_logits.contiguous().view(-1, e_logits.size(-1)) /
                    temperature,
                    hard=False).view(e_logits.size())
                for e_logits in listify(inputs)
            ]
        elif method == 'hard_gumbel':
            softmax = [
                F.gumbel_softmax(
                    e_logits.contiguous().view(-1, e_logits.size(-1)) /
                    temperature,
                    hard=True).view(e_logits.size())
                for e_logits in listify(inputs)
            ]
        else:
            softmax = [
                F.softmax(e_logits / temperature, -1)
                for e_logits in listify(inputs)
            ]

        return [delistify(e) for e in (softmax)]

    def reward(self, mols):
        rr = 1.
        for m in ('logp,sas,qed,unique'
                  if self.metric == 'all' else self.metric).split(','):

            if m == 'np':
                rr *= MolecularMetrics.natural_product_scores(mols, norm=True)
            elif m == 'logp':
                rr *= MolecularMetrics.water_octanol_partition_coefficient_scores(
                    mols, norm=True)
            elif m == 'sas':
                rr *= MolecularMetrics.synthetic_accessibility_score_scores(
                    mols, norm=True)
            elif m == 'qed':
                rr *= MolecularMetrics.quantitative_estimation_druglikeness_scores(
                    mols, norm=True)
            elif m == 'novelty':
                rr *= MolecularMetrics.novel_scores(mols, self.data)
            elif m == 'dc':
                rr *= MolecularMetrics.drugcandidate_scores(mols, self.data)
            elif m == 'unique':
                rr *= MolecularMetrics.unique_scores(mols)
            elif m == 'diversity':
                rr *= MolecularMetrics.diversity_scores(mols, self.data)
            elif m == 'validity':
                rr *= MolecularMetrics.valid_scores(mols)
            else:
                raise RuntimeError('{} is not defined as a metric'.format(m))

        return rr.reshape(-1, 1)

    def train_and_validate(self):
        self.start_time = time.time()

        # Start training from scratch or resume training.
        start_epoch = 0
        if self.resume_epoch:
            start_epoch = self.resume_epoch
            self.restore_model(self.resume_epoch)

        # Start training.
        if self.mode == 'train':
            print('Start training...')
            for i in range(start_epoch, self.num_epochs):
                self.train_or_valid(epoch_i=i, train_val_test='train')
                self.train_or_valid(epoch_i=i, train_val_test='val')
                self.train_or_valid(epoch_i=i, train_val_test='sample')
        elif self.mode == 'test':
            assert self.resume_epoch is not None
            self.train_or_valid(epoch_i=start_epoch, train_val_test='sample')
            self.train_or_valid(epoch_i=start_epoch, train_val_test='val')
        else:
            raise NotImplementedError

    def get_reconstruction_loss(self, n_hat, n, e_hat, e):
        # This loss cares about the imbalance between nodes and edges.
        # However, in practice, they don't work well.
        # n_loss = torch.nn.CrossEntropyLoss(reduction='none')(n_hat.view(-1, self.m_dim), n.view(-1))
        # n_loss_ = n_loss.view(n.shape)
        # e_loss = torch.nn.CrossEntropyLoss(reduction='none')(e_hat.reshape((-1, self.b_dim)), e.view(-1))
        # e_loss_ = e_loss.view(e.shape)
        # loss_ = e_loss_ + n_loss_.unsqueeze(-1)
        # reconstruction_loss = torch.mean(loss_)
        # return reconstruction_loss

        n_loss = torch.nn.CrossEntropyLoss(reduction='mean')(n_hat.view(
            -1, self.m_dim), n.view(-1))
        e_loss = torch.nn.CrossEntropyLoss(reduction='mean')(e_hat.reshape(
            (-1, self.b_dim)), e.view(-1))
        reconstruction_loss = n_loss + e_loss
        return reconstruction_loss

    @staticmethod
    def get_kl_loss(mu, logvar):
        kld_loss = torch.mean(
            -0.5 * torch.sum(1 + logvar - mu**2 - logvar.exp(), dim=1), dim=0)
        return kld_loss

    def get_gen_mols(self, n_hat, e_hat, method):
        (edges_hard, nodes_hard) = self.postprocess_logits((e_hat, n_hat),
                                                           method)
        edges_hard, nodes_hard = torch.max(edges_hard,
                                           -1)[1], torch.max(nodes_hard, -1)[1]
        mols = [
            self.data.matrices2mol(n_.data.cpu().numpy(),
                                   e_.data.cpu().numpy(),
                                   strict=True)
            for e_, n_ in zip(edges_hard, nodes_hard)
        ]
        return mols

    def get_reward(self, n_hat, e_hat, method):
        (edges_hard, nodes_hard) = self.postprocess_logits((e_hat, n_hat),
                                                           method)
        edges_hard, nodes_hard = torch.max(edges_hard,
                                           -1)[1], torch.max(nodes_hard, -1)[1]
        mols = [
            self.data.matrices2mol(n_.data.cpu().numpy(),
                                   e_.data.cpu().numpy(),
                                   strict=True)
            for e_, n_ in zip(edges_hard, nodes_hard)
        ]
        reward = torch.from_numpy(self.reward(mols)).to(self.device)
        return reward

    def save_checkpoints(self, epoch_i):
        enc_path = os.path.join(self.model_dir_path,
                                '{}-encoder.ckpt'.format(epoch_i + 1))
        dec_path = os.path.join(self.model_dir_path,
                                '{}-decoder.ckpt'.format(epoch_i + 1))
        V_path = os.path.join(self.model_dir_path,
                              '{}-V.ckpt'.format(epoch_i + 1))
        torch.save(self.encoder.state_dict(), enc_path)
        torch.save(self.decoder.state_dict(), dec_path)
        torch.save(self.V.state_dict(), V_path)
        print('Saved model checkpoints into {}...'.format(self.model_dir_path))
        if self.log is not None:
            self.log.info('Saved model checkpoints into {}...'.format(
                self.model_dir_path))

    def get_scores(self, mols, to_print=False):
        scores = defaultdict(list)
        m0, m1 = all_scores(mols, self.data,
                            norm=True)  # 'mols' is output of Fake Reward
        for k, v in m1.items():
            scores[k].append(v)
        for k, v in m0.items():
            scores[k].append(np.array(v)[np.nonzero(v)].mean())

        if to_print:
            log = ""
            is_first = True
            for tag, value in scores.items():
                if is_first:
                    log += "{}: {:.2f}".format(tag, np.mean(value))
                    is_first = False
                else:
                    log += ", {}: {:.2f}".format(tag, np.mean(value))
            print(log)
            return scores, log

        return scores

    def train_or_valid(self, epoch_i, train_val_test='val'):
        # Recordings
        losses = defaultdict(list)

        the_step = self.num_steps
        if train_val_test == 'val':
            if self.mode == 'train':
                the_step = 1
            print('[Validating]')

        if train_val_test == 'sample':
            if self.mode == 'train':
                the_step = 1
            print('[Sampling]')

        for a_step in range(the_step):
            z = None
            if train_val_test == 'val':
                mols, _, _, a, x, _, f, _, _ = self.data.next_validation_batch(
                )
            elif train_val_test == 'train':
                mols, _, _, a, x, _, f, _, _ = self.data.next_train_batch(
                    self.batch_size)
            elif train_val_test == 'sample':
                z = self.sample_z(self.batch_size)
                z = torch.from_numpy(z).to(self.device).float()
            else:
                raise NotImplementedError

            if train_val_test == 'train' or train_val_test == 'val':
                a = torch.from_numpy(a).to(self.device).long()  # Adjacency.
                x = torch.from_numpy(x).to(self.device).long()  # Nodes.
                a_tensor = self.label2onehot(a, self.b_dim)
                x_tensor = self.label2onehot(x, self.m_dim)
                f = torch.from_numpy(f).to(self.device).float()

            if train_val_test == 'train' or train_val_test == 'val':
                z, z_mu, z_logvar = self.encoder(a_tensor, f, x_tensor)
            edges_logits, nodes_logits = self.decoder(z)
            (edges_hat, nodes_hat) = self.postprocess_logits(
                (edges_logits, nodes_logits), self.post_method)

            if train_val_test == 'train' or train_val_test == 'val':
                recon_loss = self.get_reconstruction_loss(
                    nodes_logits, x, edges_logits, a)
                kl_loss = self.get_kl_loss(z_mu, z_logvar)
                loss_vae = recon_loss + self.kl_la * kl_loss

                # Real Reward
                reward_r = torch.from_numpy(self.reward(mols)).to(self.device)
                # Fake Reward
                reward_f = self.get_reward(nodes_logits, edges_logits,
                                           'hard_gumbel')

                # Value loss
                value_logit_real, _ = self.V(a_tensor, None, x_tensor,
                                             torch.sigmoid)
                value_logit_fake, _ = self.V(edges_hat, None, nodes_hat,
                                             torch.sigmoid)
                loss_v = torch.mean((value_logit_real - reward_r)**2 +
                                    (value_logit_fake - reward_f)**2)
                loss_rl = torch.mean(-value_logit_fake)
                alpha = torch.abs(loss_vae.detach() / loss_rl.detach())
                loss_rl *= alpha

                vae_loss_train = self.lambda_wgan * loss_vae + (
                    1 - self.lambda_wgan) * loss_rl
                # vae_loss_train = loss_vae
                losses['l_Rec'].append(recon_loss.item())
                losses['l_KL'].append(kl_loss.item())
                losses['l_VAE'].append(loss_vae.item())
                losses['l_RL'].append(loss_rl.item())
                losses['l_V'].append(loss_v.item())

                if train_val_test == 'train':
                    self.reset_grad()
                    vae_loss_train.backward(retain_graph=True)
                    loss_v.backward()
                    self.vae_optimizer.step()
                    self.v_optimizer.step()

            if train_val_test == 'sample':
                mols = self.get_gen_mols(nodes_logits, edges_logits,
                                         'hard_gumbel')
                scores, mol_log = self.get_scores(mols, to_print=True)

                # Saving molecule images.
                mol_f_name = os.path.join(self.img_dir_path,
                                          'sample-mol-{}.png'.format(epoch_i))
                save_mol_img(mols, mol_f_name, is_test=self.mode == 'test')

                if self.log is not None:
                    self.log.info(mol_log)

            if train_val_test == 'val':
                mols = self.get_gen_mols(nodes_logits, edges_logits,
                                         'hard_gumbel')
                scores = self.get_scores(mols)

                # Save checkpoints.
                if self.mode == 'train':
                    if (epoch_i + 1) % self.model_save_step == 0:
                        self.save_checkpoints(epoch_i=epoch_i)

                # Saving molecule images.
                mol_f_name = os.path.join(self.img_dir_path,
                                          'mol-{}.png'.format(epoch_i))
                save_mol_img(mols, mol_f_name, is_test=self.mode == 'test')

                # Print out training information.
                et = time.time() - self.start_time
                et = str(datetime.timedelta(seconds=et))[:-7]
                log = "Elapsed [{}], Iteration [{}/{}]:".format(
                    et, epoch_i + 1, self.num_epochs)

                is_first = True
                for tag, value in losses.items():
                    if is_first:
                        log += "\n{}: {:.2f}".format(tag, np.mean(value))
                        is_first = False
                    else:
                        log += ", {}: {:.2f}".format(tag, np.mean(value))
                is_first = True
                for tag, value in scores.items():
                    if is_first:
                        log += "\n{}: {:.2f}".format(tag, np.mean(value))
                        is_first = False
                    else:
                        log += ", {}: {:.2f}".format(tag, np.mean(value))
                print(log)

                if self.log is not None:
                    self.log.info(log)