class SparseMolecular(data.Dataset): """Dataset class for the CelebA dataset.""" def __init__(self, data_dir): """Initialize and preprocess the CelebA dataset.""" self.data = SparseMolecularDataset() self.data.load(data_dir) def __getitem__(self, index): """Return one image and its corresponding attribute label.""" return index, self.data.data[index], self.data.smiles[index], \ self.data.data_S[index], self.data.data_A[index], \ self.data.data_X[index], self.data.data_D[index], \ self.data.data_F[index], self.data.data_Le[index], \ self.data.data_Lv[index] def __len__(self): """Return the number of images.""" return len(self.data)
class Solver(object): """Solver for training and testing StarGAN.""" def __init__(self, config): """Initialize configurations.""" # Data loader. self.data = SparseMolecularDataset() self.data.load(config.mol_data_dir) # Model configurations. self.z_dim = config.qubits self.m_dim = self.data.atom_num_types self.b_dim = self.data.bond_num_types self.g_conv_dim = config.g_conv_dim self.d_conv_dim = config.d_conv_dim self.g_repeat_num = config.g_repeat_num self.d_repeat_num = config.d_repeat_num self.lambda_cls = config.lambda_cls self.lambda_rec = config.lambda_rec self.lambda_gp = config.lambda_gp self.post_method = config.post_method self.metric = 'validity,sas' # Training configurations. self.batch_size = config.batch_size self.num_iters = config.num_iters self.num_iters_decay = config.num_iters_decay self.g_lr = config.g_lr self.d_lr = config.d_lr self.dropout = config.dropout self.n_critic = config.n_critic self.beta1 = config.beta1 self.beta2 = config.beta2 self.resume_iters = config.resume_iters # Test configurations. self.test_iters = config.test_iters # Miscellaneous. self.use_tensorboard = config.use_tensorboard self.device = torch.device( 'cuda' if torch.cuda.is_available() else 'cpu') # Directories. self.log_dir = config.log_dir self.sample_dir = config.sample_dir self.model_save_dir = config.model_save_dir self.result_dir = config.result_dir # Step size. self.log_step = config.log_step self.sample_step = config.sample_step self.model_save_step = config.model_save_step self.lr_update_step = config.lr_update_step # Build the model and tensorboard. self.build_model() if self.use_tensorboard: self.build_tensorboard() def build_model(self): """Create a generator and a discriminator.""" self.G = Generator(self.g_conv_dim, self.z_dim, self.data.vertexes, self.data.bond_num_types, self.data.atom_num_types, self.dropout) self.D = Discriminator(self.d_conv_dim, self.m_dim, self.b_dim, self.dropout) self.V = Discriminator(self.d_conv_dim, self.m_dim, self.b_dim, self.dropout) self.g_optimizer = torch.optim.Adam( list(self.G.parameters()) + list(self.V.parameters()), self.g_lr, [self.beta1, self.beta2]) self.d_optimizer = torch.optim.Adam(self.D.parameters(), self.d_lr, [self.beta1, self.beta2]) self.print_network(self.G, 'G') self.print_network(self.D, 'D') self.G.to(self.device) self.D.to(self.device) self.V.to(self.device) def print_network(self, model, name): """Print out the network information.""" num_params = 0 for p in model.parameters(): num_params += p.numel() print(model) print(name) print("The number of parameters: {}".format(num_params)) def restore_model(self, resume_iters): """Restore the trained generator and discriminator.""" print( 'Loading the trained models from step {}...'.format(resume_iters)) G_path = os.path.join(self.model_save_dir, '{}-G.ckpt'.format(resume_iters)) D_path = os.path.join(self.model_save_dir, '{}-D.ckpt'.format(resume_iters)) V_path = os.path.join(self.model_save_dir, '{}-V.ckpt'.format(resume_iters)) self.G.load_state_dict( torch.load(G_path, map_location=lambda storage, loc: storage)) self.D.load_state_dict( torch.load(D_path, map_location=lambda storage, loc: storage)) self.V.load_state_dict( torch.load(V_path, map_location=lambda storage, loc: storage)) def build_tensorboard(self): """Build a tensorboard logger.""" from torch.utils.tensorboard.logger import Logger #from logger import Logger self.logger = Logger(self.log_dir) def update_lr(self, g_lr, d_lr): """Decay learning rates of the generator and discriminator.""" for param_group in self.g_optimizer.param_groups: param_group['lr'] = g_lr for param_group in self.d_optimizer.param_groups: param_group['lr'] = d_lr def reset_grad(self): """Reset the gradient buffers.""" self.g_optimizer.zero_grad() self.d_optimizer.zero_grad() def denorm(self, x): """Convert the range from [-1, 1] to [0, 1].""" out = (x + 1) / 2 return out.clamp_(0, 1) def gradient_penalty(self, y, x): """Compute gradient penalty: (L2_norm(dy/dx) - 1)**2.""" weight = torch.ones(y.size()).to(self.device) dydx = torch.autograd.grad(outputs=y, inputs=x, grad_outputs=weight, retain_graph=True, create_graph=True, only_inputs=True)[0] dydx = dydx.view(dydx.size(0), -1) dydx_l2norm = torch.sqrt(torch.sum(dydx**2, dim=1)) return torch.mean((dydx_l2norm - 1)**2) def label2onehot(self, labels, dim): """Convert label indices to one-hot vectors.""" out = torch.zeros(list(labels.size()) + [dim]).to(self.device) out.scatter_(len(out.size()) - 1, labels.unsqueeze(-1), 1.) return out def classification_loss(self, logit, target, dataset='CelebA'): """Compute binary or softmax cross entropy loss.""" if dataset == 'CelebA': return F.binary_cross_entropy_with_logits( logit, target, size_average=False) / logit.size(0) elif dataset == 'RaFD': return F.cross_entropy(logit, target) def sample_z(self, batch_size): return np.random.normal(0, 1, size=(batch_size, self.z_dim)) def postprocess(self, inputs, method, temperature=1.): def listify(x): return x if type(x) == list or type(x) == tuple else [x] def delistify(x): return x if len(x) > 1 else x[0] if method == 'soft_gumbel': softmax = [ F.gumbel_softmax( e_logits.contiguous().view(-1, e_logits.size(-1)) / temperature, hard=False).view(e_logits.size()) for e_logits in listify(inputs) ] elif method == 'hard_gumbel': softmax = [ F.gumbel_softmax( e_logits.contiguous().view(-1, e_logits.size(-1)) / temperature, hard=True).view(e_logits.size()) for e_logits in listify(inputs) ] else: softmax = [ F.softmax(e_logits / temperature, -1) for e_logits in listify(inputs) ] return [delistify(e) for e in (softmax)] def reward(self, mols): rr = 1. for m in ('logp,sas,qed,unique' if self.metric == 'all' else self.metric).split(','): if m == 'np': rr *= MolecularMetrics.natural_product_scores(mols, norm=True) elif m == 'logp': rr *= MolecularMetrics.water_octanol_partition_coefficient_scores( mols, norm=True) elif m == 'sas': rr *= MolecularMetrics.synthetic_accessibility_score_scores( mols, norm=True) elif m == 'qed': rr *= MolecularMetrics.quantitative_estimation_druglikeness_scores( mols, norm=True) elif m == 'novelty': rr *= MolecularMetrics.novel_scores(mols, data) elif m == 'dc': rr *= MolecularMetrics.drugcandidate_scores(mols, data) elif m == 'unique': rr *= MolecularMetrics.unique_scores(mols) elif m == 'diversity': rr *= MolecularMetrics.diversity_scores(mols, data) elif m == 'validity': rr *= MolecularMetrics.valid_scores(mols) else: raise RuntimeError('{} is not defined as a metric'.format(m)) return rr.reshape(-1, 1) def train(self): # Learning rate cache for decaying. g_lr = self.g_lr d_lr = self.d_lr # Start training from scratch or resume training. start_iters = 0 if self.resume_iters: start_iters = self.resume_iters self.restore_model(self.resume_iters) # Start training. print('Start training...') start_time = time.time() for i in range(start_iters, self.num_iters): if (i + 1) % self.log_step == 0: mols, _, _, a, x, _, _, _, _ = self.data.next_validation_batch( ) z = self.sample_z(a.shape[0]) print('[Valid]', '') else: mols, _, _, a, x, _, _, _, _ = self.data.next_train_batch( self.batch_size) z = self.sample_z(self.batch_size) # =================================================================================== # # 1. Preprocess input data # # =================================================================================== # a = torch.from_numpy(a).to(self.device).long() # Adjacency. x = torch.from_numpy(x).to(self.device).long() # Nodes. a_tensor = self.label2onehot(a, self.b_dim) x_tensor = self.label2onehot(x, self.m_dim) z = torch.from_numpy(z).to(self.device).float() # =================================================================================== # # 2. Train the discriminator # # =================================================================================== # # Compute loss with real images. logits_real, features_real = self.D(a_tensor, None, x_tensor) d_loss_real = -torch.mean(logits_real) # Compute loss with fake images. edges_logits, nodes_logits = self.G(z) # Postprocess with Gumbel softmax (edges_hat, nodes_hat) = self.postprocess( (edges_logits, nodes_logits), self.post_method) logits_fake, features_fake = self.D(edges_hat, None, nodes_hat) d_loss_fake = torch.mean(logits_fake) # Compute loss for gradient penalty. eps = torch.rand(logits_real.size(0), 1, 1, 1).to(self.device) x_int0 = (eps * a_tensor + (1. - eps) * edges_hat).requires_grad_(True) x_int1 = (eps.squeeze(-1) * x_tensor + (1. - eps.squeeze(-1)) * nodes_hat).requires_grad_(True) grad0, grad1 = self.D(x_int0, None, x_int1) d_loss_gp = self.gradient_penalty( grad0, x_int0) + self.gradient_penalty(grad1, x_int1) # Backward and optimize. d_loss = d_loss_fake + d_loss_real + self.lambda_gp * d_loss_gp self.reset_grad() d_loss.backward() self.d_optimizer.step() # Logging. loss = {} loss['D/loss_real'] = d_loss_real.item() loss['D/loss_fake'] = d_loss_fake.item() loss['D/loss_gp'] = d_loss_gp.item() # =================================================================================== # # 3. Train the generator # # =================================================================================== # if (i + 1) % self.n_critic == 0: # Z-to-target edges_logits, nodes_logits = self.G(z) # Postprocess with Gumbel softmax (edges_hat, nodes_hat) = self.postprocess( (edges_logits, nodes_logits), self.post_method) logits_fake, features_fake = self.D(edges_hat, None, nodes_hat) g_loss_fake = -torch.mean(logits_fake) # Real Reward rewardR = torch.from_numpy(self.reward(mols)).to(self.device) # Fake Reward (edges_hard, nodes_hard) = self.postprocess( (edges_logits, nodes_logits), 'hard_gumbel') edges_hard, nodes_hard = torch.max(edges_hard, -1)[1], torch.max( nodes_hard, -1)[1] mols = [ self.data.matrices2mol(n_.data.cpu().numpy(), e_.data.cpu().numpy(), strict=True) for e_, n_ in zip(edges_hard, nodes_hard) ] rewardF = torch.from_numpy(self.reward(mols)).to(self.device) # Value loss value_logit_real, _ = self.V(a_tensor, None, x_tensor, torch.sigmoid) value_logit_fake, _ = self.V(edges_hat, None, nodes_hat, torch.sigmoid) g_loss_value = torch.mean((value_logit_real - rewardR)**2 + (value_logit_fake - rewardF)**2) #rl_loss= -value_logit_fake #f_loss = (torch.mean(features_real, 0) - torch.mean(features_fake, 0)) ** 2 # Backward and optimize. g_loss = g_loss_fake + g_loss_value self.reset_grad() g_loss.backward() self.g_optimizer.step() # Logging. loss['G/loss_fake'] = g_loss_fake.item() loss['G/loss_value'] = g_loss_value.item() # =================================================================================== # # 4. Miscellaneous # # =================================================================================== # # Print out training information. if (i + 1) % self.log_step == 0: et = time.time() - start_time et = str(datetime.timedelta(seconds=et))[:-7] log = "Elapsed [{}], Iteration [{}/{}]".format( et, i + 1, self.num_iters) # Log update m0, m1 = all_scores( mols, self.data, norm=True) # 'mols' is output of Fake Reward m0 = { k: np.array(v)[np.nonzero(v)].mean() for k, v in m0.items() } m0.update(m1) loss.update(m0) for tag, value in loss.items(): log += ", {}: {:.4f}".format(tag, value) print(log) if self.use_tensorboard: for tag, value in loss.items(): self.logger.scalar_summary(tag, value, i + 1) # Save model checkpoints. if (i + 1) % self.model_save_step == 0: G_path = os.path.join(self.model_save_dir, '{}-G.ckpt'.format(i + 1)) D_path = os.path.join(self.model_save_dir, '{}-D.ckpt'.format(i + 1)) V_path = os.path.join(self.model_save_dir, '{}-V.ckpt'.format(i + 1)) torch.save(self.G.state_dict(), G_path) torch.save(self.D.state_dict(), D_path) torch.save(self.V.state_dict(), V_path) print('Saved model checkpoints into {}...'.format( self.model_save_dir)) # Decay learning rates. if (i + 1) % self.lr_update_step == 0 and (i + 1) > ( self.num_iters - self.num_iters_decay): g_lr -= (self.g_lr / float(self.num_iters_decay)) d_lr -= (self.d_lr / float(self.num_iters_decay)) self.update_lr(g_lr, d_lr) print('Decayed learning rates, g_lr: {}, d_lr: {}.'.format( g_lr, d_lr)) def test(self): # Load the trained generator. self.restore_model(self.test_iters) with torch.no_grad(): mols, _, _, a, x, _, _, _, _ = self.data.next_test_batch() z = self.sample_z(a.shape[0]) # Z-to-target edges_logits, nodes_logits = self.G(z) # Postprocess with Gumbel softmax (edges_hat, nodes_hat) = self.postprocess( (edges_logits, nodes_logits), self.post_method) logits_fake, features_fake = self.D(edges_hat, None, nodes_hat) g_loss_fake = -torch.mean(logits_fake) # Fake Reward (edges_hard, nodes_hard) = self.postprocess( (edges_logits, nodes_logits), 'hard_gumbel') edges_hard, nodes_hard = torch.max(edges_hard, -1)[1], torch.max( nodes_hard, -1)[1] mols = [ self.data.matrices2mol(n_.data.cpu().numpy(), e_.data.cpu().numpy(), strict=True) for e_, n_ in zip(edges_hard, nodes_hard) ] # Log update m0, m1 = all_scores(mols, self.data, norm=True) # 'mols' is output of Fake Reward m0 = {k: np.array(v)[np.nonzero(v)].mean() for k, v in m0.items()} m0.update(m1) for tag, value in m0.items(): log += ", {}: {:.4f}".format(tag, value)
class Solver(object): """Solver for training and testing StarGAN.""" def __init__(self, config, log=None): """Initialize configurations.""" # Log self.log = log # Data loader. self.data = SparseMolecularDataset() self.data.load(config.mol_data_dir) # Model configurations. self.z_dim = config.z_dim self.m_dim = self.data.atom_num_types self.b_dim = self.data.bond_num_types self.g_conv_dim = config.g_conv_dim self.d_conv_dim = config.d_conv_dim self.la = config.lambda_wgan self.lambda_rec = config.lambda_rec self.la_gp = config.lambda_gp self.post_method = config.post_method self.metric = 'validity,qed' # Training configurations. self.batch_size = config.batch_size self.num_epochs = config.num_epochs self.num_steps = (len(self.data) // self.batch_size) self.g_lr = config.g_lr self.d_lr = config.d_lr self.dropout = config.dropout if self.la > 0: self.n_critic = config.n_critic else: self.n_critic = 1 self.resume_epoch = config.resume_epoch # Training or testing. self.mode = config.mode # Miscellaneous. self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') print('Device: ', self.device) # Directories. self.log_dir_path = config.log_dir_path self.model_dir_path = config.model_dir_path self.img_dir_path = config.img_dir_path # Step size. self.model_save_step = config.model_save_step # Build the model. self.build_model() def build_model(self): """Create a generator and a discriminator.""" self.G = Generator(self.g_conv_dim, self.z_dim, self.data.vertexes, self.data.bond_num_types, self.data.atom_num_types, self.dropout) self.D = Discriminator(self.d_conv_dim, self.m_dim, self.b_dim - 1, self.dropout) self.V = Discriminator(self.d_conv_dim, self.m_dim, self.b_dim - 1, self.dropout) self.g_optimizer = torch.optim.RMSprop(self.G.parameters(), self.g_lr) self.d_optimizer = torch.optim.RMSprop(self.D.parameters(), self.d_lr) self.v_optimizer = torch.optim.RMSprop(self.V.parameters(), self.g_lr) self.print_network(self.G, 'G', self.log) self.print_network(self.D, 'D', self.log) self.print_network(self.V, 'V', self.log) self.G.to(self.device) self.D.to(self.device) self.V.to(self.device) @staticmethod def print_network(model, name, log=None): """Print out the network information.""" num_params = 0 for p in model.parameters(): num_params += p.numel() print(model) print(name) print("The number of parameters: {}".format(num_params)) if log is not None: log.info(model) log.info(name) log.info("The number of parameters: {}".format(num_params)) def restore_model(self, resume_iters): """Restore the trained generator and discriminator.""" print('Loading the trained models from step {}...'.format(resume_iters)) G_path = os.path.join(self.model_dir_path, '{}-G.ckpt'.format(resume_iters)) D_path = os.path.join(self.model_dir_path, '{}-D.ckpt'.format(resume_iters)) V_path = os.path.join(self.model_dir_path, '{}-V.ckpt'.format(resume_iters)) self.G.load_state_dict(torch.load(G_path, map_location=lambda storage, loc: storage)) self.D.load_state_dict(torch.load(D_path, map_location=lambda storage, loc: storage)) self.V.load_state_dict(torch.load(V_path, map_location=lambda storage, loc: storage)) def update_lr(self, g_lr, d_lr): """Decay learning rates of the generator and discriminator.""" for param_group in self.g_optimizer.param_groups: param_group['lr'] = g_lr for param_group in self.d_optimizer.param_groups: param_group['lr'] = d_lr def reset_grad(self): """Reset the gradient buffers.""" self.g_optimizer.zero_grad() self.d_optimizer.zero_grad() self.v_optimizer.zero_grad() def gradient_penalty(self, y, x): """Compute gradient penalty: (L2_norm(dy/dx) - 1)**2.""" weight = torch.ones(y.size()).to(self.device) dydx = torch.autograd.grad(outputs=y, inputs=x, grad_outputs=weight, retain_graph=True, create_graph=True, only_inputs=True)[0] dydx = dydx.view(dydx.size(0), -1) dydx_l2norm = torch.sqrt(torch.sum(dydx ** 2, dim=1)) return torch.mean((dydx_l2norm - 1) ** 2) def label2onehot(self, labels, dim): """Convert label indices to one-hot vectors.""" out = torch.zeros(list(labels.size()) + [dim]).to(self.device) out.scatter_(len(out.size()) - 1, labels.unsqueeze(-1), 1.) return out def sample_z(self, batch_size): return np.random.normal(0, 1, size=(batch_size, self.z_dim)) @staticmethod def postprocess(inputs, method, temperature=1.): def listify(x): return x if type(x) == list or type(x) == tuple else [x] def delistify(x): return x if len(x) > 1 else x[0] if method == 'soft_gumbel': softmax = [F.gumbel_softmax(e_logits.contiguous().view(-1, e_logits.size(-1)) / temperature, hard=False).view(e_logits.size()) for e_logits in listify(inputs)] elif method == 'hard_gumbel': softmax = [F.gumbel_softmax(e_logits.contiguous().view(-1, e_logits.size(-1)) / temperature, hard=True).view(e_logits.size()) for e_logits in listify(inputs)] else: softmax = [F.softmax(e_logits / temperature, -1) for e_logits in listify(inputs)] return [delistify(e) for e in (softmax)] def reward(self, mols): rr = 1. for m in ('logp,sas,qed,unique' if self.metric == 'all' else self.metric).split(','): if m == 'np': rr *= MolecularMetrics.natural_product_scores(mols, norm=True) elif m == 'logp': rr *= MolecularMetrics.water_octanol_partition_coefficient_scores(mols, norm=True) elif m == 'sas': rr *= MolecularMetrics.synthetic_accessibility_score_scores(mols, norm=True) elif m == 'qed': rr *= MolecularMetrics.quantitative_estimation_druglikeness_scores(mols, norm=True) elif m == 'novelty': rr *= MolecularMetrics.novel_scores(mols, self.data) elif m == 'dc': rr *= MolecularMetrics.drugcandidate_scores(mols, self.data) elif m == 'unique': rr *= MolecularMetrics.unique_scores(mols) elif m == 'diversity': rr *= MolecularMetrics.diversity_scores(mols, self.data) elif m == 'validity': rr *= MolecularMetrics.valid_scores(mols) else: raise RuntimeError('{} is not defined as a metric'.format(m)) return rr.reshape(-1, 1) def train_and_validate(self): self.start_time = time.time() # Start training from scratch or resume training. start_epoch = 0 if self.resume_epoch is not None: start_epoch = self.resume_epoch self.restore_model(self.resume_epoch) # Start training. if self.mode == 'train': print('Start training...') for i in range(start_epoch, self.num_epochs): self.train_or_valid(epoch_i=i, train_val_test='train') self.train_or_valid(epoch_i=i, train_val_test='val') elif self.mode == 'test': assert self.resume_epoch is not None self.train_or_valid(epoch_i=start_epoch, train_val_test='val') else: raise NotImplementedError def get_gen_mols(self, n_hat, e_hat, method): (edges_hard, nodes_hard) = self.postprocess((e_hat, n_hat), method) edges_hard, nodes_hard = torch.max(edges_hard, -1)[1], torch.max(nodes_hard, -1)[1] mols = [self.data.matrices2mol(n_.data.cpu().numpy(), e_.data.cpu().numpy(), strict=True) for e_, n_ in zip(edges_hard, nodes_hard)] return mols def get_reward(self, n_hat, e_hat, method): (edges_hard, nodes_hard) = self.postprocess((e_hat, n_hat), method) edges_hard, nodes_hard = torch.max(edges_hard, -1)[1], torch.max(nodes_hard, -1)[1] mols = [self.data.matrices2mol(n_.data.cpu().numpy(), e_.data.cpu().numpy(), strict=True) for e_, n_ in zip(edges_hard, nodes_hard)] reward = torch.from_numpy(self.reward(mols)).to(self.device) return reward def save_checkpoints(self, epoch_i): G_path = os.path.join(self.model_dir_path, '{}-G.ckpt'.format(epoch_i + 1)) D_path = os.path.join(self.model_dir_path, '{}-D.ckpt'.format(epoch_i + 1)) V_path = os.path.join(self.model_dir_path, '{}-V.ckpt'.format(epoch_i + 1)) torch.save(self.G.state_dict(), G_path) torch.save(self.D.state_dict(), D_path) torch.save(self.V.state_dict(), V_path) print('Saved model checkpoints into {}...'.format(self.model_dir_path)) if self.log is not None: self.log.info('Saved model checkpoints into {}...'.format(self.model_dir_path)) def train_or_valid(self, epoch_i, train_val_test='val'): # The first several epochs using RL to purse stability (not used). if epoch_i < 0: cur_la = 0 else: cur_la = self.la # Recordings losses = defaultdict(list) scores = defaultdict(list) # Iterations the_step = self.num_steps if train_val_test == 'val': if self.mode == 'train': the_step = 1 print('[Validating]') for a_step in range(the_step): if train_val_test == 'val': mols, _, _, a, x, _, _, _, _ = self.data.next_validation_batch() z = self.sample_z(a.shape[0]) elif train_val_test == 'train': mols, _, _, a, x, _, _, _, _ = self.data.next_train_batch(self.batch_size) z = self.sample_z(self.batch_size) else: raise NotImplementedError # =================================================================================== # # 1. Preprocess input data # # =================================================================================== # a = torch.from_numpy(a).to(self.device).long() # Adjacency. x = torch.from_numpy(x).to(self.device).long() # Nodes. a_tensor = self.label2onehot(a, self.b_dim) x_tensor = self.label2onehot(x, self.m_dim) z = torch.from_numpy(z).to(self.device).float() # Current steps cur_step = self.num_steps * epoch_i + a_step # =================================================================================== # # 2. Train the discriminator # # =================================================================================== # # Compute losses with real inputs. logits_real, features_real = self.D(a_tensor, None, x_tensor) # Z-to-target edges_logits, nodes_logits = self.G(z) # Postprocess with Gumbel softmax (edges_hat, nodes_hat) = self.postprocess((edges_logits, nodes_logits), self.post_method) logits_fake, features_fake = self.D(edges_hat, None, nodes_hat) # Compute losses for gradient penalty. eps = torch.rand(logits_real.size(0), 1, 1, 1).to(self.device) x_int0 = (eps * a_tensor + (1. - eps) * edges_hat).requires_grad_(True) x_int1 = (eps.squeeze(-1) * x_tensor + (1. - eps.squeeze(-1)) * nodes_hat).requires_grad_(True) grad0, grad1 = self.D(x_int0, None, x_int1) grad_penalty = self.gradient_penalty(grad0, x_int0) + self.gradient_penalty(grad1, x_int1) d_loss_real = torch.mean(logits_real) d_loss_fake = torch.mean(logits_fake) loss_D = -d_loss_real + d_loss_fake + self.la_gp * grad_penalty if cur_la > 0: losses['l_D/R'].append(d_loss_real.item()) losses['l_D/F'].append(d_loss_fake.item()) losses['l_D'].append(loss_D.item()) # Optimise discriminator. if train_val_test == 'train' and cur_step % self.n_critic != 0 and cur_la > 0: self.reset_grad() loss_D.backward() self.d_optimizer.step() # =================================================================================== # # 3. Train the generator # # =================================================================================== # # Z-to-target edges_logits, nodes_logits = self.G(z) # Postprocess with Gumbel softmax (edges_hat, nodes_hat) = self.postprocess((edges_logits, nodes_logits), self.post_method) logits_fake, features_fake = self.D(edges_hat, None, nodes_hat) # Value losses value_logit_real, _ = self.V(a_tensor, None, x_tensor, torch.sigmoid) value_logit_fake, _ = self.V(edges_hat, None, nodes_hat, torch.sigmoid) # Feature mapping losses. Not used anywhere in the PyTorch version. # I include it here for the consistency with the TF code. f_loss = (torch.mean(features_real, 0) - torch.mean(features_fake, 0)) ** 2 # Real Reward reward_r = torch.from_numpy(self.reward(mols)).to(self.device) # Fake Reward reward_f = self.get_reward(nodes_hat, edges_hat, self.post_method) # Losses Update loss_G = -logits_fake # Original TF loss_V. Here we use absolute values instead of the squared one. # loss_V = (value_logit_real - reward_r) ** 2 + (value_logit_fake - reward_f) ** 2 loss_V = torch.abs(value_logit_real - reward_r) + torch.abs(value_logit_fake - reward_f) loss_RL = -value_logit_fake loss_G = torch.mean(loss_G) loss_V = torch.mean(loss_V) loss_RL = torch.mean(loss_RL) losses['l_G'].append(loss_G.item()) losses['l_RL'].append(loss_RL.item()) losses['l_V'].append(loss_V.item()) alpha = torch.abs(loss_G.detach() / loss_RL.detach()).detach() train_step_G = cur_la * loss_G + (1 - cur_la) * alpha * loss_RL train_step_V = loss_V if train_val_test == 'train': self.reset_grad() # Optimise generator. if cur_step % self.n_critic == 0: train_step_G.backward(retain_graph=True) self.g_optimizer.step() # Optimise value network. if cur_step % self.n_critic == 0: train_step_V.backward() self.v_optimizer.step() # =================================================================================== # # 4. Miscellaneous # # =================================================================================== # # Get scores. if train_val_test == 'val': mols = self.get_gen_mols(nodes_logits, edges_logits, self.post_method) m0, m1 = all_scores(mols, self.data, norm=True) # 'mols' is output of Fake Reward for k, v in m1.items(): scores[k].append(v) for k, v in m0.items(): scores[k].append(np.array(v)[np.nonzero(v)].mean()) # Save checkpoints. if self.mode == 'train': if (epoch_i + 1) % self.model_save_step == 0: self.save_checkpoints(epoch_i=epoch_i) # Saving molecule images. mol_f_name = os.path.join(self.img_dir_path, 'mol-{}.png'.format(epoch_i)) save_mol_img(mols, mol_f_name, is_test=self.mode == 'test') # Print out training information. et = time.time() - self.start_time et = str(datetime.timedelta(seconds=et))[:-7] log = "Elapsed [{}], Iteration [{}/{}]:".format(et, epoch_i + 1, self.num_epochs) is_first = True for tag, value in losses.items(): if is_first: log += "\n{}: {:.2f}".format(tag, np.mean(value)) is_first = False else: log += ", {}: {:.2f}".format(tag, np.mean(value)) is_first = True for tag, value in scores.items(): if is_first: log += "\n{}: {:.2f}".format(tag, np.mean(value)) is_first = False else: log += ", {}: {:.2f}".format(tag, np.mean(value)) print(log) if self.log is not None: self.log.info(log)
class Solver(object): """Solver for training and testing StarGAN.""" def __init__(self, config, log=None): """Initialize configurations.""" # Log self.log = log # Data loader. self.data = SparseMolecularDataset() self.data.load(config.mol_data_dir) # Model configurations. self.z_dim = config.z_dim self.m_dim = self.data.atom_num_types self.b_dim = self.data.bond_num_types self.f_dim = self.data.features self.g_conv_dim = config.g_conv_dim self.d_conv_dim = config.d_conv_dim self.lambda_wgan = config.lambda_wgan self.lambda_rec = config.lambda_rec self.post_method = config.post_method self.metric = 'validity,qed' # Training configurations. self.batch_size = config.batch_size self.num_epochs = config.num_epochs self.num_steps = (len(self.data) // self.batch_size) self.g_lr = config.g_lr self.d_lr = config.d_lr self.dropout_rate = config.dropout self.n_critic = config.n_critic self.resume_epoch = config.resume_epoch # Training or testing. self.mode = config.mode # Miscellaneous. self.device = torch.device( 'cuda' if torch.cuda.is_available() else 'cpu') print('Device: ', self.device) # Directories. self.log_dir_path = config.log_dir_path self.model_dir_path = config.model_dir_path self.img_dir_path = config.img_dir_path # Step size. self.model_save_step = config.model_save_step # VAE KL weight. self.kl_la = 1. # Build the model. self.build_model() def build_model(self): """Create an encoder and a decoder.""" self.encoder = EncoderVAE(self.d_conv_dim, self.m_dim, self.b_dim - 1, self.z_dim, with_features=True, f_dim=self.f_dim, dropout_rate=self.dropout_rate).to( self.device) self.decoder = Generator(self.g_conv_dim, self.z_dim, self.data.vertexes, self.data.bond_num_types, self.data.atom_num_types, self.dropout_rate).to(self.device) self.V = Discriminator(self.d_conv_dim, self.m_dim, self.b_dim - 1, self.dropout_rate).to(self.device) self.vae_optimizer = torch.optim.RMSprop( list(self.encoder.parameters()) + list(self.decoder.parameters()), self.g_lr) self.v_optimizer = torch.optim.RMSprop(self.V.parameters(), self.d_lr) self.print_network(self.encoder, 'Encoder', self.log) self.print_network(self.decoder, 'Decoder', self.log) self.print_network(self.V, 'Value', self.log) @staticmethod def print_network(model, name, log=None): """Print out the network information.""" num_params = 0 for p in model.parameters(): num_params += p.numel() print(model) print(name) print("The number of parameters: {}".format(num_params)) if log is not None: log.info(model) log.info(name) log.info("The number of parameters: {}".format(num_params)) def restore_model(self, resume_iters): """Restore the trained generator and discriminator.""" print( 'Loading the trained models from step {}...'.format(resume_iters)) enc_path = os.path.join(self.model_dir_path, '{}-encoder.ckpt'.format(resume_iters)) dec_path = os.path.join(self.model_dir_path, '{}-decoder.ckpt'.format(resume_iters)) V_path = os.path.join(self.model_dir_path, '{}-V.ckpt'.format(resume_iters)) self.encoder.load_state_dict( torch.load(enc_path, map_location=lambda storage, loc: storage)) self.decoder.load_state_dict( torch.load(dec_path, map_location=lambda storage, loc: storage)) self.V.load_state_dict( torch.load(V_path, map_location=lambda storage, loc: storage)) def update_lr(self, g_lr, d_lr): """Decay learning rates of the generator and discriminator.""" for param_group in self.g_optimizer.param_groups: param_group['lr'] = g_lr for param_group in self.d_optimizer.param_groups: param_group['lr'] = d_lr def reset_grad(self): """Reset the gradient buffers.""" self.vae_optimizer.zero_grad() self.v_optimizer.zero_grad() def label2onehot(self, labels, dim): """Convert label indices to one-hot vectors.""" out = torch.zeros(list(labels.size()) + [dim]).to(self.device) out.scatter_(len(out.size()) - 1, labels.unsqueeze(-1), 1.) return out def sample_z(self, batch_size): return np.random.normal(0, 1, size=(batch_size, self.z_dim)) @staticmethod def postprocess_logits(inputs, method, temperature=1.): def listify(x): return x if type(x) == list or type(x) == tuple else [x] def delistify(x): return x if len(x) > 1 else x[0] if method == 'soft_gumbel': softmax = [ F.gumbel_softmax( e_logits.contiguous().view(-1, e_logits.size(-1)) / temperature, hard=False).view(e_logits.size()) for e_logits in listify(inputs) ] elif method == 'hard_gumbel': softmax = [ F.gumbel_softmax( e_logits.contiguous().view(-1, e_logits.size(-1)) / temperature, hard=True).view(e_logits.size()) for e_logits in listify(inputs) ] else: softmax = [ F.softmax(e_logits / temperature, -1) for e_logits in listify(inputs) ] return [delistify(e) for e in (softmax)] def reward(self, mols): rr = 1. for m in ('logp,sas,qed,unique' if self.metric == 'all' else self.metric).split(','): if m == 'np': rr *= MolecularMetrics.natural_product_scores(mols, norm=True) elif m == 'logp': rr *= MolecularMetrics.water_octanol_partition_coefficient_scores( mols, norm=True) elif m == 'sas': rr *= MolecularMetrics.synthetic_accessibility_score_scores( mols, norm=True) elif m == 'qed': rr *= MolecularMetrics.quantitative_estimation_druglikeness_scores( mols, norm=True) elif m == 'novelty': rr *= MolecularMetrics.novel_scores(mols, self.data) elif m == 'dc': rr *= MolecularMetrics.drugcandidate_scores(mols, self.data) elif m == 'unique': rr *= MolecularMetrics.unique_scores(mols) elif m == 'diversity': rr *= MolecularMetrics.diversity_scores(mols, self.data) elif m == 'validity': rr *= MolecularMetrics.valid_scores(mols) else: raise RuntimeError('{} is not defined as a metric'.format(m)) return rr.reshape(-1, 1) def train_and_validate(self): self.start_time = time.time() # Start training from scratch or resume training. start_epoch = 0 if self.resume_epoch: start_epoch = self.resume_epoch self.restore_model(self.resume_epoch) # Start training. if self.mode == 'train': print('Start training...') for i in range(start_epoch, self.num_epochs): self.train_or_valid(epoch_i=i, train_val_test='train') self.train_or_valid(epoch_i=i, train_val_test='val') self.train_or_valid(epoch_i=i, train_val_test='sample') elif self.mode == 'test': assert self.resume_epoch is not None self.train_or_valid(epoch_i=start_epoch, train_val_test='sample') self.train_or_valid(epoch_i=start_epoch, train_val_test='val') else: raise NotImplementedError def get_reconstruction_loss(self, n_hat, n, e_hat, e): # This loss cares about the imbalance between nodes and edges. # However, in practice, they don't work well. # n_loss = torch.nn.CrossEntropyLoss(reduction='none')(n_hat.view(-1, self.m_dim), n.view(-1)) # n_loss_ = n_loss.view(n.shape) # e_loss = torch.nn.CrossEntropyLoss(reduction='none')(e_hat.reshape((-1, self.b_dim)), e.view(-1)) # e_loss_ = e_loss.view(e.shape) # loss_ = e_loss_ + n_loss_.unsqueeze(-1) # reconstruction_loss = torch.mean(loss_) # return reconstruction_loss n_loss = torch.nn.CrossEntropyLoss(reduction='mean')(n_hat.view( -1, self.m_dim), n.view(-1)) e_loss = torch.nn.CrossEntropyLoss(reduction='mean')(e_hat.reshape( (-1, self.b_dim)), e.view(-1)) reconstruction_loss = n_loss + e_loss return reconstruction_loss @staticmethod def get_kl_loss(mu, logvar): kld_loss = torch.mean( -0.5 * torch.sum(1 + logvar - mu**2 - logvar.exp(), dim=1), dim=0) return kld_loss def get_gen_mols(self, n_hat, e_hat, method): (edges_hard, nodes_hard) = self.postprocess_logits((e_hat, n_hat), method) edges_hard, nodes_hard = torch.max(edges_hard, -1)[1], torch.max(nodes_hard, -1)[1] mols = [ self.data.matrices2mol(n_.data.cpu().numpy(), e_.data.cpu().numpy(), strict=True) for e_, n_ in zip(edges_hard, nodes_hard) ] return mols def get_reward(self, n_hat, e_hat, method): (edges_hard, nodes_hard) = self.postprocess_logits((e_hat, n_hat), method) edges_hard, nodes_hard = torch.max(edges_hard, -1)[1], torch.max(nodes_hard, -1)[1] mols = [ self.data.matrices2mol(n_.data.cpu().numpy(), e_.data.cpu().numpy(), strict=True) for e_, n_ in zip(edges_hard, nodes_hard) ] reward = torch.from_numpy(self.reward(mols)).to(self.device) return reward def save_checkpoints(self, epoch_i): enc_path = os.path.join(self.model_dir_path, '{}-encoder.ckpt'.format(epoch_i + 1)) dec_path = os.path.join(self.model_dir_path, '{}-decoder.ckpt'.format(epoch_i + 1)) V_path = os.path.join(self.model_dir_path, '{}-V.ckpt'.format(epoch_i + 1)) torch.save(self.encoder.state_dict(), enc_path) torch.save(self.decoder.state_dict(), dec_path) torch.save(self.V.state_dict(), V_path) print('Saved model checkpoints into {}...'.format(self.model_dir_path)) if self.log is not None: self.log.info('Saved model checkpoints into {}...'.format( self.model_dir_path)) def get_scores(self, mols, to_print=False): scores = defaultdict(list) m0, m1 = all_scores(mols, self.data, norm=True) # 'mols' is output of Fake Reward for k, v in m1.items(): scores[k].append(v) for k, v in m0.items(): scores[k].append(np.array(v)[np.nonzero(v)].mean()) if to_print: log = "" is_first = True for tag, value in scores.items(): if is_first: log += "{}: {:.2f}".format(tag, np.mean(value)) is_first = False else: log += ", {}: {:.2f}".format(tag, np.mean(value)) print(log) return scores, log return scores def train_or_valid(self, epoch_i, train_val_test='val'): # Recordings losses = defaultdict(list) the_step = self.num_steps if train_val_test == 'val': if self.mode == 'train': the_step = 1 print('[Validating]') if train_val_test == 'sample': if self.mode == 'train': the_step = 1 print('[Sampling]') for a_step in range(the_step): z = None if train_val_test == 'val': mols, _, _, a, x, _, f, _, _ = self.data.next_validation_batch( ) elif train_val_test == 'train': mols, _, _, a, x, _, f, _, _ = self.data.next_train_batch( self.batch_size) elif train_val_test == 'sample': z = self.sample_z(self.batch_size) z = torch.from_numpy(z).to(self.device).float() else: raise NotImplementedError if train_val_test == 'train' or train_val_test == 'val': a = torch.from_numpy(a).to(self.device).long() # Adjacency. x = torch.from_numpy(x).to(self.device).long() # Nodes. a_tensor = self.label2onehot(a, self.b_dim) x_tensor = self.label2onehot(x, self.m_dim) f = torch.from_numpy(f).to(self.device).float() if train_val_test == 'train' or train_val_test == 'val': z, z_mu, z_logvar = self.encoder(a_tensor, f, x_tensor) edges_logits, nodes_logits = self.decoder(z) (edges_hat, nodes_hat) = self.postprocess_logits( (edges_logits, nodes_logits), self.post_method) if train_val_test == 'train' or train_val_test == 'val': recon_loss = self.get_reconstruction_loss( nodes_logits, x, edges_logits, a) kl_loss = self.get_kl_loss(z_mu, z_logvar) loss_vae = recon_loss + self.kl_la * kl_loss # Real Reward reward_r = torch.from_numpy(self.reward(mols)).to(self.device) # Fake Reward reward_f = self.get_reward(nodes_logits, edges_logits, 'hard_gumbel') # Value loss value_logit_real, _ = self.V(a_tensor, None, x_tensor, torch.sigmoid) value_logit_fake, _ = self.V(edges_hat, None, nodes_hat, torch.sigmoid) loss_v = torch.mean((value_logit_real - reward_r)**2 + (value_logit_fake - reward_f)**2) loss_rl = torch.mean(-value_logit_fake) alpha = torch.abs(loss_vae.detach() / loss_rl.detach()) loss_rl *= alpha vae_loss_train = self.lambda_wgan * loss_vae + ( 1 - self.lambda_wgan) * loss_rl # vae_loss_train = loss_vae losses['l_Rec'].append(recon_loss.item()) losses['l_KL'].append(kl_loss.item()) losses['l_VAE'].append(loss_vae.item()) losses['l_RL'].append(loss_rl.item()) losses['l_V'].append(loss_v.item()) if train_val_test == 'train': self.reset_grad() vae_loss_train.backward(retain_graph=True) loss_v.backward() self.vae_optimizer.step() self.v_optimizer.step() if train_val_test == 'sample': mols = self.get_gen_mols(nodes_logits, edges_logits, 'hard_gumbel') scores, mol_log = self.get_scores(mols, to_print=True) # Saving molecule images. mol_f_name = os.path.join(self.img_dir_path, 'sample-mol-{}.png'.format(epoch_i)) save_mol_img(mols, mol_f_name, is_test=self.mode == 'test') if self.log is not None: self.log.info(mol_log) if train_val_test == 'val': mols = self.get_gen_mols(nodes_logits, edges_logits, 'hard_gumbel') scores = self.get_scores(mols) # Save checkpoints. if self.mode == 'train': if (epoch_i + 1) % self.model_save_step == 0: self.save_checkpoints(epoch_i=epoch_i) # Saving molecule images. mol_f_name = os.path.join(self.img_dir_path, 'mol-{}.png'.format(epoch_i)) save_mol_img(mols, mol_f_name, is_test=self.mode == 'test') # Print out training information. et = time.time() - self.start_time et = str(datetime.timedelta(seconds=et))[:-7] log = "Elapsed [{}], Iteration [{}/{}]:".format( et, epoch_i + 1, self.num_epochs) is_first = True for tag, value in losses.items(): if is_first: log += "\n{}: {:.2f}".format(tag, np.mean(value)) is_first = False else: log += ", {}: {:.2f}".format(tag, np.mean(value)) is_first = True for tag, value in scores.items(): if is_first: log += "\n{}: {:.2f}".format(tag, np.mean(value)) is_first = False else: log += ", {}: {:.2f}".format(tag, np.mean(value)) print(log) if self.log is not None: self.log.info(log)