def generate(self, number_samples): load_path, load_iters = get_load_path(self.num_sampling_iters, self.num_argmax_iters, self.cp_save_dir) all_init_node_properties, all_init_edge_properties, all_node_masks, all_edge_masks = \ self.get_all_init_variables(load_path, number_samples) if self.set_seed_at_load_iter is True: set_seed_if(load_iters) retrieve_train_graphs = self.retrieve_train_graphs for j in range(load_iters, self.num_iters): if j > 0: retrieve_train_graphs = False if self.generation_algorithm == 'Gibbs': self.train_data.do_not_corrupt = True loader = self.get_dataloader(all_init_node_properties, all_node_masks, all_init_edge_properties, number_samples, retrieve_train_graphs) use_argmax = (j >= self.num_sampling_iters) all_init_node_properties, all_init_edge_properties, all_node_masks, \ smiles_list = self.carry_out_iteration(loader, use_argmax) return smiles_list
def main(params): model_cls = MODELS_DICT[params.model_name] model = model_cls(params) params, model, opt, scheduler, train_data, train_loader, val_dataset, val_loader, perturbation_loader, generator,\ index_method, exp_path, training_smiles, pp, logger, writer, best_loss,\ total_iter, grad_accum_iters = setup_data_and_model(params, model) for epoch in range(1, params.num_epochs+1): print('Starting epoch {}'.format(epoch), flush=True) for train_batch in train_loader: if total_iter % 100 == 0: print(total_iter, flush=True) if total_iter == params.max_steps: logger.info('Done training') break model.train() if hasattr(model, 'seq_model'): model.seq_model.eval() # Training step batch_init_graph, batch_orig_graph, batch_target_type_graph, _, \ graph_properties, binary_graph_properties = train_batch # init is what goes into model # original are uncorrupted data (used for comparison with prediction in loss calculation) # masks are 1 in places corresponding to nodes that exist, 0 in other places (which are empty/padded) # target_types are 1 in places to mask, 2 in places to replace, 3 in places to reconstruct, 0 in places not to predict if params.local_cpu is False: batch_init_graph = batch_init_graph.to(torch.device('cuda:0')) batch_orig_graph = batch_orig_graph.to(torch.device('cuda:0')) dct_to_cuda_inplace(graph_properties) if binary_graph_properties: binary_graph_properties = binary_graph_properties.cuda() if grad_accum_iters % params.grad_accum_iters == 0: opt.zero_grad() _, batch_scores_graph, graph_property_scores = model(batch_init_graph, graph_properties, binary_graph_properties) num_components = sum([(v != 0).sum().numpy() for _, v in batch_target_type_graph.ndata.items()]) + \ sum([(v != 0).sum().numpy() for _, v in batch_target_type_graph.edata.items()]) # calculate score losses = [] results_dict = {} metrics_to_print = [] if params.target_data_structs in ['nodes', 'both', 'random'] and \ batch_target_type_graph.ndata['node_type'].sum() > 0: node_losses = {} for name, target_type in batch_target_type_graph.ndata.items(): node_losses[name] = get_loss( batch_orig_graph.ndata[name][target_type.numpy() != 0], batch_scores_graph.ndata[name][target_type.numpy() != 0], params.equalise, params.loss_normalisation_type, num_components, params.local_cpu) losses.extend(node_losses.values()) results_dict.update(node_losses) metrics_to_print.extend(node_losses.keys()) if params.target_data_structs in ['edges', 'both', 'random'] and \ batch_target_type_graph.edata['edge_type'].sum() > 0: edge_losses = {} for name, target_type in batch_target_type_graph.edata.items(): edge_losses[name] = get_loss( batch_orig_graph.edata[name][target_type.numpy() != 0], batch_scores_graph.edata[name][target_type.numpy() != 0], params.equalise, params.loss_normalisation_type, num_components, params.local_cpu) losses.extend(edge_losses.values()) results_dict.update(edge_losses) metrics_to_print.extend(edge_losses.keys()) if params.predict_graph_properties is True: graph_property_losses = {} for name, scores in graph_property_scores.items(): graph_property_loss = normalise_loss(F.mse_loss(scores, graph_properties[name], reduction='sum'), len(scores), num_components, params.loss_normalisation_type) graph_property_losses[name] = graph_property_loss losses.extend(graph_property_losses.values()) results_dict.update(graph_property_losses) metrics_to_print.extend(graph_property_losses.keys()) loss = sum(losses) if params.no_update is False: (loss/params.grad_accum_iters).backward() grad_accum_iters += 1 if grad_accum_iters % params.grad_accum_iters == 0: # clip grad norm if params.clip_grad_norm > -1: torch.nn.utils.clip_grad_norm_(model.parameters(), params.clip_grad_norm) opt.step() grad_accum_iters = 0 if (total_iter+1) <= params.warm_up_iters or (total_iter+1) % params.lr_decay_interval == 0: scheduler.step(total_iter+1) if params.suppress_train_log is False: log_string = '' for name, value in results_dict.items(): if name in metrics_to_print: log_string += ', {} = {:.2f}'.format(name, value) log_string = 'total_iter = {0:d}, loss = {1:.2f}'.format(total_iter, loss.cpu().item()) + log_string logger.info(log_string) if params.target_frac_inc_after is not None and total_iter > 0 and total_iter % params.target_frac_inc_after == 0: train_data.node_target_frac = min(train_data.node_target_frac + params.target_frac_inc_amount, params.max_target_frac) train_data.edge_target_frac = min(train_data.edge_target_frac + params.target_frac_inc_amount, params.max_target_frac) results_dict.update({'node_target_frac': train_data.node_target_frac}) results_dict.update({'edge_target_frac': train_data.edge_target_frac}) if params.tensorboard and total_iter % int(params.log_train_steps) == 0: results_dict.update({'loss': loss, 'lr': opt.param_groups[0]['lr']}) write_tensorboard(writer, 'train', results_dict, total_iter) if total_iter > 0 and total_iter % params.val_after == 0: logger.info('Validating') val_loss, num_data_points = 0, 0 node_property_losses = {name: 0 for name in train_data.node_property_names} edge_property_losses = {name: 0 for name in train_data.edge_property_names} node_property_num_components = {name: 0 for name in train_data.node_property_names} edge_property_num_components = {name: 0 for name in train_data.edge_property_names} graph_property_losses = {name: 0 for name in params.graph_property_names} model.eval() set_seed_if(params.seed) for batch_init_graph, batch_orig_graph, batch_target_type_graph, _, \ graph_properties, binary_graph_properties in val_loader: if params.local_cpu is False: batch_init_graph = batch_init_graph.to(torch.device('cuda:0')) batch_orig_graph = batch_orig_graph.to(torch.device('cuda:0')) dct_to_cuda_inplace(graph_properties) if binary_graph_properties: binary_graph_properties = binary_graph_properties.cuda() with torch.no_grad(): _, batch_scores_graph, graph_property_scores = model(batch_init_graph, graph_properties, binary_graph_properties) num_data_points += float(batch_orig_graph.batch_size) losses = [] if params.target_data_structs in ['nodes', 'both', 'random'] and \ batch_target_type_graph.ndata['node_type'].sum() > 0: for name, target_type in batch_target_type_graph.ndata.items(): iter_node_property_loss = F.cross_entropy( batch_scores_graph.ndata[name][target_type.numpy() != 0], batch_orig_graph.ndata[name][target_type.numpy() != 0], reduction='sum').cpu().item() node_property_losses[name] += iter_node_property_loss losses.append(iter_node_property_loss) node_property_num_components[name] += float((target_type != 0).sum()) if params.target_data_structs in ['edges', 'both', 'random'] and \ batch_target_type_graph.edata['edge_type'].sum() > 0: for name, target_type in batch_target_type_graph.edata.items(): iter_edge_property_loss = F.cross_entropy( batch_scores_graph.edata[name][target_type.numpy() != 0], batch_orig_graph.edata[name][target_type.numpy() != 0], reduction='sum').cpu().item() edge_property_losses[name] += iter_edge_property_loss losses.append(iter_edge_property_loss) edge_property_num_components[name] += float((target_type != 0).sum()) if params.predict_graph_properties is True: for name, scores in graph_property_scores.items(): iter_graph_property_loss = F.mse_loss( scores, graph_properties[name], reduction='sum').cpu().item() graph_property_losses[name] += iter_graph_property_loss losses.append(iter_graph_property_loss) val_loss += sum(losses) avg_node_property_losses, avg_edge_property_losses, avg_graph_property_losses = {}, {}, {} if params.loss_normalisation_type == 'by_total': total_num_components = float(sum(node_property_num_components.values()) + sum(edge_property_num_components.values())) avg_val_loss = val_loss/total_num_components if params.target_data_structs in ['nodes', 'both', 'random']: for name, loss in node_property_losses.items(): avg_node_property_losses[name] = loss/total_num_components if params.target_data_structs in ['edges', 'both', 'random']: for name, loss in edge_property_losses.items(): avg_edge_property_losses[name] = loss/total_num_components if params.predict_graph_properties is True: for name, loss in graph_property_losses.items(): avg_graph_property_losses[name] = loss/total_num_components elif params.loss_normalisation_type == 'by_component': avg_val_loss = 0 if params.target_data_structs in ['nodes', 'both', 'random']: for name, loss in node_property_losses.items(): avg_node_property_losses[name] = loss/node_property_num_components[name] avg_val_loss += sum(avg_node_property_losses.values()) if params.target_data_structs in ['edges', 'both', 'random']: for name, loss in edge_property_losses.items(): avg_edge_property_losses[name] = loss/edge_property_num_components[name] avg_val_loss += sum(avg_edge_property_losses.values()) if params.predict_graph_properties is True: for name, loss in graph_property_losses.items(): avg_graph_property_losses[name] = loss/num_data_points avg_val_loss += sum(avg_graph_property_losses.values()) val_iter = total_iter // params.val_after results_dict = {'Validation_loss': avg_val_loss} for name, loss in avg_node_property_losses.items(): results_dict['{}_loss'.format(name)] = loss for name, loss in avg_edge_property_losses.items(): results_dict['{}_loss'.format(name)] = loss for name, loss in avg_graph_property_losses.items(): results_dict['{}_loss'.format(name)] = loss for name, loss in results_dict.items(): logger.info('{}: {:.2f}'.format(name, loss)) if params.tensorboard: write_tensorboard(writer, 'Dev', results_dict, val_iter) if params.gen_num_samples > 0: calculate_gen_benchmarks(generator, params.gen_num_samples, training_smiles, logger) logger.info("----------------------------------") model_state_dict = model.module.state_dict() if torch.cuda.device_count() > 1 else model.state_dict() best_loss = save_checkpoints(total_iter, avg_val_loss, best_loss, model_state_dict, opt.state_dict(), exp_path, logger, params.no_save, params.save_all) # Reset random seed set_seed_if(params.seed) logger.info('Validation complete') total_iter += 1
def setup_data_and_model(params, model): # Variables that may not otherwise be assigned writer = perturbation_loader = generator = training_smiles = None # setup random seeds if params.val_seed is None: params.val_seed = params.seed set_seed_if(params.seed) exp_path = os.path.join(params.dump_path, params.exp_name) # create exp path if it doesn't exist if not os.path.exists(exp_path): os.makedirs(exp_path) # create logger logger = create_logger(os.path.join(exp_path, 'train.log'), 0) pp = pprint.PrettyPrinter() logger.info("============ Initialized logger ============") logger.info("Random seed is {}".format(params.seed)) if params.suppress_params is False: logger.info("\n".join("%s: %s" % (k, str(v)) for k, v in sorted(dict(vars(params)).items()))) logger.info("Running command: %s" % 'python ' + ' '.join(sys.argv)) logger.info("The experiment will be stored in %s\n" % exp_path) logger.info("") # load data train_data, val_dataset, train_loader, val_loader = load_graph_data(params) logger.info ('train_loader len is {}'.format(len(train_loader))) logger.info ('val_loader len is {}'.format(len(val_loader))) if params.num_binary_graph_properties > 0 and params.pretrained_property_embeddings_path: model.binary_graph_property_embedding_layer.weight.data = \ torch.Tensor(np.load(params.pretrained_property_embeddings_path).T) if params.load_latest is True: load_prefix = 'latest' elif params.load_best is True: load_prefix = 'best' else: load_prefix = None if load_prefix is not None: if params.local_cpu is True: model.load_state_dict(torch.load(os.path.join(exp_path, '{}_model'.format(load_prefix)), map_location='cpu')) else: model.load_state_dict(torch.load(os.path.join(exp_path, '{}_model'.format(load_prefix)))) if params.local_cpu is False: model = model.cuda() if params.gen_num_samples > 0: generator = GraphGenerator(train_data, model, params.gen_random_init, params.gen_num_iters, params.gen_predict_deterministically, params.local_cpu) with open(params.smiles_path) as f: smiles = f.read().split('\n') training_smiles = smiles[:int(params.smiles_train_split * len(smiles))] del smiles opt = get_optimizer(model.parameters(), params.optimizer) if load_prefix is not None: opt.load_state_dict(torch.load(os.path.join(exp_path, '{}_opt_sd'.format(load_prefix)))) lr = opt.param_groups[0]['lr'] lr_lambda = lambda iteration: lr_decay_multiplier(iteration, params.warm_up_iters, params.decay_start_iter, params.lr_decay_amount, params.lr_decay_frac, params.lr_decay_interval, params.min_lr, lr) scheduler = LambdaLR(opt, lr_lambda) index_method = get_index_method() best_loss = 9999 if params.tensorboard: from tensorboardX import SummaryWriter writer = SummaryWriter(exp_path) total_iter, grad_accum_iters = params.first_iter, 0 return params, model, opt, scheduler, train_data, train_loader, val_dataset, val_loader, perturbation_loader,\ generator, index_method, exp_path, training_smiles, pp, logger, writer, best_loss, total_iter,\ grad_accum_iters
def __getitem__(self, index): set_seed_if(self.seed) # *** Initialise nodes *** unpadded_node_properties, init_node_properties, orig_node_properties, node_property_target_types = {}, {}, {}, {} for property_name, property_info in self.node_properties.items(): unpadded_data = property_info['data'][index] if property_name == 'charge': unpadded_data = unpadded_data + abs(self.min_charge) unpadded_node_properties[property_name] = unpadded_data num_nodes = len(unpadded_data) max_nodes = self.max_nodes if self.pad is True else num_nodes init_node_properties[property_name], orig_node_properties[property_name] = \ self.get_orig_and_init_node_property( unpadded_data, max_nodes, len(unpadded_data), property_info['empty_index']) node_property_target_types[property_name] = np.zeros(max_nodes) # Create masks with 0 where node does not exist or edge would connect non-existent node, 1 everywhere else node_mask = torch.zeros(max_nodes) node_mask[:num_nodes] = 1 edge_mask = torch.zeros((max_nodes, max_nodes)) edge_mask[:num_nodes, :num_nodes] = 1 edge_mask[np.arange(num_nodes), np.arange(num_nodes)] = 0 # *** Initialise edges *** edge_coords = [(i, j) for i in range(num_nodes) for j in range(i + 1, num_nodes)] num_edges = len(edge_coords) unpadded_edge_properties, unpadded_edge_property_inds, init_edge_properties, orig_edge_properties,\ edge_property_target_types = {}, {}, {}, {}, {} for property_name, property_info in self.edge_properties.items(): unpadded_data = property_info['data'][index] unpadded_edge_properties[property_name] = unpadded_data assert (check_symmetric(unpadded_data.astype(int)) ) # make sure bond matrix is symmetric unpadded_edge_property_inds[property_name] = np.array( [unpadded_data[i, j] for (i, j) in edge_coords]) init_edge_properties[property_name], orig_edge_properties[property_name] =\ self.get_orig_and_init_edge_property( unpadded_data, max_nodes, len(unpadded_data), property_info['empty_index']) edge_property_target_types[property_name] = np.zeros( edge_mask.shape) edge_coords = np.array(edge_coords) if self.do_not_corrupt is False: init_node_properties, node_property_target_types, init_edge_properties, edge_property_target_types = \ self.corrupt_graph(unpadded_node_properties, init_node_properties, node_property_target_types, num_nodes, unpadded_edge_property_inds, init_edge_properties, edge_property_target_types, edge_coords, num_edges) if self.no_edge_present_type == 'zeros': edge_mask[np.where(init_edge_properties['edge_type'] == 0)] = 0 # Cast to suitable type for property_name in init_node_properties.keys(): init_node_properties[property_name] = torch.LongTensor( init_node_properties[property_name]) orig_node_properties[property_name] = torch.LongTensor( orig_node_properties[property_name]) node_property_target_types[property_name] = torch.IntTensor( node_property_target_types[property_name]) for property_name in init_edge_properties.keys(): init_edge_properties[property_name] = torch.LongTensor( init_edge_properties[property_name]) orig_edge_properties[property_name] = torch.LongTensor( orig_edge_properties[property_name]) edge_property_target_types[property_name] = torch.IntTensor( edge_property_target_types[property_name]) graph_properties = {} for k, v in self.graph_properties.items(): if self.normalise_graph_properties is True: graph_properties[k] = torch.Tensor([ (v[index] - self.graph_property_stats[k]['mean']) / \ self.graph_property_stats[k]['std'] ]) else: graph_properties[k] = torch.Tensor([v[index]]) if self.num_binary_graph_properties > 0: positive_binary_properties = self.graph2binary_properties[index] binary_graph_properties = torch.zeros( self.num_binary_graph_properties) binary_graph_properties[np.array(positive_binary_properties)] = 1 else: binary_graph_properties = [] ret_list = [ init_node_properties, orig_node_properties, node_property_target_types, node_mask, init_edge_properties, orig_edge_properties, edge_property_target_types, edge_mask, graph_properties, binary_graph_properties ] return ret_list
def generate_with_evaluation(self, num_samples_to_generate, smiles_dataset_path, output_dir, num_samples_to_evaluate, evaluate_connected_only=False): load_path, load_iters = get_load_path(self.num_sampling_iters, self.num_argmax_iters, self.cp_save_dir) all_init_node_properties, all_init_edge_properties, all_node_masks, all_edge_masks = \ self.get_all_init_variables(load_path, num_samples_to_generate) if self.save_init is True and self.random_init is True and load_iters == 0: # Save smiles representations of initialised molecules smiles_list = [] num_nodes = all_node_masks.sum(-1) for i in range(len(all_init_node_properties['node_type'])): mol = graph_to_mol({k: v[i][:int(num_nodes[i])].astype(int) \ for k, v in all_init_node_properties.items()}, {k: v[i][:int(num_nodes[i]), :int(num_nodes[i])].astype(int) \ for k, v in all_init_edge_properties.items()}, min_charge=self.train_data.min_charge, symbol_list=self.symbol_list) smiles_list.append(Chem.MolToSmiles(mol)) save_smiles_list(smiles_list, os.path.join(output_dir, 'smiles_0_0.txt')) del smiles_list, mol, num_nodes if self.set_seed_at_load_iter is True: set_seed_if(load_iters) retrieve_train_graphs = self.retrieve_train_graphs for j in tqdm(range(load_iters, self.num_iters)): if j > 0: retrieve_train_graphs = False if self.generation_algorithm == 'Gibbs': self.train_data.do_not_corrupt = True loader = self.get_dataloader(all_init_node_properties, all_node_masks, all_init_edge_properties, num_samples_to_generate, retrieve_train_graphs) use_argmax = (j >= self.num_sampling_iters) all_init_node_properties, all_init_edge_properties, all_node_masks,\ smiles_list = self.carry_out_iteration(loader, use_argmax) sampling_iters_completed = min(j + 1, self.num_sampling_iters) argmax_iters_completed = max(0, j + 1 - self.num_sampling_iters) if (j + 1 - load_iters) % self.checkpointing_period == 0: self.save_checkpoints(all_init_node_properties, all_init_edge_properties, sampling_iters_completed, argmax_iters_completed) if (j + 1 - load_iters) % self.save_period == 0 or ( self.save_finegrained is True and (j + 1) <= 10): smiles_output_path = os.path.join( output_dir, 'smiles_{}_{}.txt'.format(sampling_iters_completed, argmax_iters_completed)) save_smiles_list(smiles_list, smiles_output_path) if (j + 1 - load_iters) % self.evaluation_period == 0 or \ (self.evaluate_finegrained is True and (j + 1) <= 10): json_output_path = os.path.join( output_dir, 'distribution_results_{}_{}.json'.format( sampling_iters_completed, argmax_iters_completed)) evaluate_uncond_generation( MockGenerator(smiles_list, num_samples_to_generate), smiles_dataset_path, json_output_path, num_samples_to_evaluate, evaluate_connected_only) if self.cond_property_values: cond_json_output_path = os.path.join( output_dir, 'cond_results_{}_{}.json'.format( sampling_iters_completed, argmax_iters_completed)) self.evaluate_cond_generation( smiles_list[:num_samples_to_evaluate], cond_json_output_path)
def main(params): model_cls = MODELS_DICT[params.model_name] model = model_cls(params) params, model, opt, scheduler, train_data, train_loader, val_dataset, val_loader, perturbation_loader, generator,\ index_method, exp_path, training_smiles, pp, logger, writer, best_loss,\ total_iter, grad_accum_iters = setup_data_and_model(params, model) for epoch in range(1, params.num_epochs + 1): print('Starting epoch {}'.format(epoch), flush=True) for train_batch in train_loader: if total_iter % 100 == 0: print(total_iter, flush=True) if total_iter == params.max_steps: logger.info('Done training') break model.train() # Training step init_nodes, init_edges, original_node_inds, original_adj_mats, node_masks, edge_masks,\ node_target_types, edge_target_types, init_hydrogens, original_hydrogens, init_charge,\ orig_charge, init_is_in_ring, orig_is_in_ring, init_is_aromatic, orig_is_aromatic, init_chirality,\ orig_chirality, hydrogen_target_types, charge_target_types, is_in_ring_target_types,\ is_aromatic_target_types, chirality_target_types = train_batch # init is what goes into model # target_inds and target_coords are 1 at locations to be predicted, 0 elsewhere # target_inds and target_coords are now calculated here rather than in dataloader # original are uncorrupted data (used for comparison with prediction in loss calculation) # masks are 1 in places corresponding to nodes that exist, 0 in other places (which are empty/padded) # target_types are 1 in places to mask, 2 in places to replace, 3 in places to reconstruct, 0 in places not to predict node_target_inds_vector = getattr(node_target_types != 0, index_method)() edge_target_coords_matrix = getattr(edge_target_types != 0, index_method)() hydrogen_target_inds_vector = getattr(hydrogen_target_types != 0, index_method)() charge_target_inds_vector = getattr(charge_target_types != 0, index_method)() is_in_ring_target_inds_vector = getattr( is_in_ring_target_types != 0, index_method)() is_aromatic_target_inds_vector = getattr( is_aromatic_target_types != 0, index_method)() chirality_target_inds_vector = getattr(chirality_target_types != 0, index_method)() if params.local_cpu is False: init_nodes = init_nodes.cuda() init_edges = init_edges.cuda() original_node_inds = original_node_inds.cuda() original_adj_mats = original_adj_mats.cuda() node_masks = node_masks.cuda() edge_masks = edge_masks.cuda() node_target_types = node_target_types.cuda() edge_target_types = edge_target_types.cuda() init_hydrogens = init_hydrogens.cuda() original_hydrogens = original_hydrogens.cuda() init_charge = init_charge.cuda() orig_charge = orig_charge.cuda() init_is_in_ring = init_is_in_ring.cuda() orig_is_in_ring = orig_is_in_ring.cuda() init_is_aromatic = init_is_aromatic.cuda() orig_is_aromatic = orig_is_aromatic.cuda() init_chirality = init_chirality.cuda() orig_chirality = orig_chirality.cuda() hydrogen_target_types = hydrogen_target_types.cuda() charge_target_types = charge_target_types.cuda() is_in_ring_target_types = is_in_ring_target_types.cuda() is_aromatic_target_types = is_aromatic_target_types.cuda() chirality_target_types = chirality_target_types.cuda() if params.property_type is not None: properties = properties.cuda() if grad_accum_iters % params.grad_accum_iters == 0: opt.zero_grad() out = model(init_nodes, init_edges, node_masks, edge_masks, init_hydrogens, init_charge, init_is_in_ring, init_is_aromatic, init_chirality) node_scores, edge_scores, hydrogen_scores, charge_scores, is_in_ring_scores, is_aromatic_scores,\ chirality_scores = out node_num_classes = node_scores.shape[-1] edge_num_classes = edge_scores.shape[-1] hydrogen_num_classes = hydrogen_scores.shape[-1] charge_num_classes = charge_scores.shape[-1] is_in_ring_num_classes = is_in_ring_scores.shape[-1] is_aromatic_num_classes = is_aromatic_scores.shape[-1] chirality_num_classes = chirality_scores.shape[-1] if model.property_type is not None: property_scores = out[-1] # slice out target data structures node_scores, target_nodes, node_target_types = get_only_target_info( node_scores, original_node_inds, node_target_inds_vector, node_num_classes, node_target_types) edge_scores, target_adj_mats, edge_target_types = get_only_target_info( edge_scores, original_adj_mats, edge_target_coords_matrix, edge_num_classes, edge_target_types) if params.embed_hs is True: hydrogen_scores, target_hydrogens, hydrogen_target_types = get_only_target_info( hydrogen_scores, original_hydrogens, hydrogen_target_inds_vector, hydrogen_num_classes, hydrogen_target_types) charge_scores, target_charge, charge_target_types = get_only_target_info( charge_scores, orig_charge, charge_target_inds_vector, charge_num_classes, charge_target_types) is_in_ring_scores, target_is_in_ring, is_in_ring_target_types = get_only_target_info( is_in_ring_scores, orig_is_in_ring, is_in_ring_target_inds_vector, is_in_ring_num_classes, is_in_ring_target_types) is_aromatic_scores, target_is_aromatic, is_aromatic_target_types = get_only_target_info( is_aromatic_scores, orig_is_aromatic, is_aromatic_target_inds_vector, is_aromatic_num_classes, is_aromatic_target_types) chirality_scores, target_chirality, chirality_target_types = get_only_target_info( chirality_scores, orig_chirality, chirality_target_inds_vector, chirality_num_classes, chirality_target_types) num_components = len(target_nodes) + len(target_adj_mats) if params.embed_hs is True: num_components += len(target_hydrogens) # calculate score losses = [] results_dict = {} metrics_to_print = [] if params.target_data_structs in ['nodes', 'both', 'random' ] and len(target_nodes) > 0: node_loss = get_loss(target_nodes, node_scores, params.equalise, params.loss_normalisation_type, num_components, params.local_cpu) losses.append(node_loss) node_preds = torch.argmax(F.softmax(node_scores, -1), dim=-1) nodes_acc, mask_node_acc, replace_node_acc, recon_node_acc = get_ds_stats( node_scores, target_nodes, node_target_types) carbon_nodes_correct, noncarbon_nodes_correct, carbon_nodes_acc, noncarbon_nodes_acc = \ get_majority_and_minority_stats(node_preds, target_nodes, 1) results_dict.update({ 'nodes_acc': nodes_acc, 'carbon_nodes_acc': carbon_nodes_acc, 'noncarbon_nodes_acc': noncarbon_nodes_acc, 'mask_node_acc': mask_node_acc, 'replace_node_acc': replace_node_acc, 'recon_node_acc': recon_node_acc, 'node_loss': node_loss }) metrics_to_print.extend([ 'node_loss', 'nodes_acc', 'carbon_nodes_acc', 'noncarbon_nodes_acc' ]) def node_property_computations(name, scores, targets, target_types, binary=False): loss = get_loss(targets, scores, params.equalise, params.loss_normalisation_type, num_components, params.local_cpu, binary=binary) losses.append(loss) acc, mask_acc, replace_acc, recon_acc = get_ds_stats( scores, targets, target_types) results_dict.update({ '{}_acc'.format(name): acc, 'mask_{}_acc'.format(name): mask_acc, 'replace_{}_acc'.format(name): replace_acc, 'recon_{}_acc'.format(name): recon_acc, '{}_loss'.format(name): loss }) metrics_to_print.extend(['{}_loss'.format(name)]) if params.embed_hs is True: node_property_computations('hydrogen', hydrogen_scores, target_hydrogens, hydrogen_target_types) node_property_computations('charge', charge_scores, target_charge, charge_target_types) node_property_computations('is_in_ring', is_in_ring_scores, target_is_in_ring, is_in_ring_target_types) node_property_computations('is_aromatic', is_aromatic_scores, target_is_aromatic, is_aromatic_target_types) node_property_computations('chirality', chirality_scores, target_chirality, chirality_target_types) if params.target_data_structs in ['edges', 'both', 'random' ] and len(target_adj_mats) > 0: edge_loss = get_loss(target_adj_mats, edge_scores, params.equalise, params.loss_normalisation_type, num_components, params.local_cpu) losses.append(edge_loss) edge_preds = torch.argmax(F.softmax(edge_scores, -1), dim=-1) edges_acc, mask_edge_acc, replace_edge_acc, recon_edge_acc = get_ds_stats( edge_scores, target_adj_mats, edge_target_types) no_edge_correct, edge_present_correct, no_edge_acc, edge_present_acc = \ get_majority_and_minority_stats(edge_preds, target_adj_mats, 0) results_dict.update({ 'edges_acc': edges_acc, 'edge_present_acc': edge_present_acc, 'no_edge_acc': no_edge_acc, 'mask_edge_acc': mask_edge_acc, 'replace_edge_acc': replace_edge_acc, 'recon_edge_acc': recon_edge_acc, 'edge_loss': edge_loss }) metrics_to_print.extend([ 'edge_loss', 'edges_acc', 'edge_present_acc', 'no_edge_acc' ]) if params.property_type is not None: property_loss = model.property_loss( property_scores, properties) / params.batch_size losses.append(property_loss) results_dict.update({'property_loss': property_loss}) metrics_to_print.extend(['property_loss']) loss = sum(losses) if params.no_update is False: (loss / params.grad_accum_iters).backward() grad_accum_iters += 1 if grad_accum_iters % params.grad_accum_iters == 0: # clip grad norm if params.clip_grad_norm > -1: torch.nn.utils.clip_grad_norm_(model.parameters(), params.clip_grad_norm) opt.step() grad_accum_iters = 0 if (total_iter + 1) <= params.warm_up_iters or ( total_iter + 1) % params.lr_decay_interval == 0: scheduler.step(total_iter + 1) if params.check_pred_validity is True: if params.target_data_structs in ['nodes', 'both', 'random' ] and len(target_nodes) > 0: init_nodes[node_target_inds_vector] = node_preds if params.target_data_structs in [ 'edges', 'both', 'random' ] and len(target_adj_mats) > 0: init_edges[edge_target_coords_matrix] = edge_preds is_valid_list, is_connected_list = check_validity( init_nodes, init_edges) percent_valid = np.mean(is_valid_list) * 100 valid_percent_connected = np.mean(is_connected_list) * 100 results_dict.update({ 'percent_valid': percent_valid, 'percent_connected': valid_percent_connected }) if params.suppress_train_log is False: log_string = '' for name, value in results_dict.items(): if name in metrics_to_print: log_string += ', {} = {:.2f}'.format(name, value) log_string = 'total_iter = {0:d}, loss = {1:.2f}'.format( total_iter, loss.cpu().item()) + log_string logger.info(log_string) if params.target_frac_inc_after is not None and total_iter > 0 and total_iter % params.target_frac_inc_after == 0: train_data.node_target_frac = min( train_data.node_target_frac + params.target_frac_inc_amount, params.max_target_frac) train_data.edge_target_frac = min( train_data.edge_target_frac + params.target_frac_inc_amount, params.max_target_frac) results_dict.update( {'node_target_frac': train_data.node_target_frac}) results_dict.update( {'edge_target_frac': train_data.edge_target_frac}) if params.tensorboard and total_iter % int( params.log_train_steps) == 0: results_dict.update({ 'loss': loss, 'lr': opt.param_groups[0]['lr'] }) write_tensorboard(writer, 'train', results_dict, total_iter) dist_names = [[ 'mask_edge_correct_dist', 'mask_edge_incorrect_pred_dist', 'mask_edge_incorrect_true_dist' ], [ 'replace_edge_correct_dist', 'replace_edge_incorrect_pred_dist', 'replace_edge_incorrect_true_dist' ], [ 'recon_edge_correct_dist', 'recon_edge_incorrect_pred_dist', 'recon_edge_incorrect_true_dist' ]] distributions = np.zeros((3, 3, params.num_edge_types)) if total_iter > 0 and total_iter % params.val_after == 0: logger.info('Validating') val_loss, property_loss, node_loss, edge_loss, hydrogen_loss, num_data_points = 0, 0, 0, 0, 0, 0 charge_loss, is_in_ring_loss, is_aromatic_loss, chirality_loss = 0, 0, 0, 0 model.eval() set_seed_if(params.seed) nodes_correct, edges_correct, total_nodes, total_edges = 0, 0, 0, 0 carbon_nodes_correct, noncarbon_nodes_correct, total_carbon_nodes, total_noncarbon_nodes = 0, 0, 0, 0 mask_nodes_correct, replace_nodes_correct, recon_nodes_correct = 0, 0, 0 total_mask_nodes, total_replace_nodes, total_recon_nodes = 0, 0, 0 no_edge_correct, edge_present_correct, total_no_edges, total_edges_present = 0, 0, 0, 0 mask_edges_correct, replace_edges_correct, recon_edges_correct = 0, 0, 0 total_mask_edges, total_replace_edges, total_recon_edges = 0, 0, 0 hydrogens_correct, mask_hydrogens_correct, replace_hydrogens_correct, recon_hydrogens_correct, \ total_mask_hydrogens, total_replace_hydrogens, total_recon_hydrogens,\ total_hydrogens = 0, 0, 0, 0, 0, 0, 0, 0 is_valid_list, is_connected_list = [], [] for init_nodes, init_edges, original_node_inds, original_adj_mats, node_masks, edge_masks,\ node_target_types, edge_target_types, init_hydrogens, original_hydrogens,\ init_charge, orig_charge, init_is_in_ring, orig_is_in_ring, init_is_aromatic, orig_is_aromatic,\ init_chirality, orig_chirality, hydrogen_target_types, charge_target_types, is_in_ring_target_types,\ is_aromatic_target_types, chirality_target_types in val_loader: node_target_inds_vector = getattr(node_target_types != 0, index_method)() edge_target_coords_matrix = getattr( edge_target_types != 0, index_method)() hydrogen_target_inds_vector = getattr( hydrogen_target_types != 0, index_method)() charge_target_inds_vector = getattr( charge_target_types != 0, index_method)() is_in_ring_target_inds_vector = getattr( is_in_ring_target_types != 0, index_method)() is_aromatic_target_inds_vector = getattr( is_aromatic_target_types != 0, index_method)() chirality_target_inds_vector = getattr( chirality_target_types != 0, index_method)() if params.local_cpu is False: init_nodes = init_nodes.cuda() init_edges = init_edges.cuda() original_node_inds = original_node_inds.cuda() original_adj_mats = original_adj_mats.cuda() node_masks = node_masks.cuda() edge_masks = edge_masks.cuda() node_target_types = node_target_types.cuda() edge_target_types = edge_target_types.cuda() init_hydrogens = init_hydrogens.cuda() original_hydrogens = original_hydrogens.cuda() init_charge = init_charge.cuda() orig_charge = orig_charge.cuda() init_is_in_ring = init_is_in_ring.cuda() orig_is_in_ring = orig_is_in_ring.cuda() init_is_aromatic = init_is_aromatic.cuda() orig_is_aromatic = orig_is_aromatic.cuda() init_chirality = init_chirality.cuda() orig_chirality = orig_chirality.cuda() hydrogen_target_types = hydrogen_target_types.cuda() charge_target_types = charge_target_types.cuda() is_in_ring_target_types = is_in_ring_target_types.cuda( ) is_aromatic_target_types = is_aromatic_target_types.cuda( ) chirality_target_types = chirality_target_types.cuda() if params.embed_hs is True: original_hydrogens = original_hydrogens.cuda() if params.property_type is not None: properties = properties.cuda() batch_size = init_nodes.shape[0] with torch.no_grad(): out = model(init_nodes, init_edges, node_masks, edge_masks, init_hydrogens, init_charge, init_is_in_ring, init_is_aromatic, init_chirality) node_scores, edge_scores, hydrogen_scores, charge_scores, is_in_ring_scores,\ is_aromatic_scores, chirality_scores = out if model.property_type is not None: property_scores = out[-1] node_num_classes = node_scores.shape[-1] edge_num_classes = edge_scores.shape[-1] hydrogen_num_classes = hydrogen_scores.shape[-1] charge_num_classes = charge_scores.shape[-1] is_in_ring_num_classes = is_in_ring_scores.shape[-1] is_aromatic_num_classes = is_aromatic_scores.shape[-1] chirality_num_classes = chirality_scores.shape[-1] node_scores, target_nodes, node_target_types = get_only_target_info( node_scores, original_node_inds, node_target_inds_vector, node_num_classes, node_target_types) edge_scores, target_adj_mats, edge_target_types = get_only_target_info( edge_scores, original_adj_mats, edge_target_coords_matrix, edge_num_classes, edge_target_types) if params.embed_hs is True: hydrogen_scores, target_hydrogens, hydrogen_target_types = get_only_target_info( hydrogen_scores, original_hydrogens, hydrogen_target_inds_vector, hydrogen_num_classes, hydrogen_target_types) charge_scores, target_charge, charge_target_types = get_only_target_info( charge_scores, orig_charge, charge_target_inds_vector, charge_num_classes, charge_target_types) is_in_ring_scores, target_is_in_ring, is_in_ring_target_types = get_only_target_info( is_in_ring_scores, orig_is_in_ring, is_in_ring_target_inds_vector, is_in_ring_num_classes, is_in_ring_target_types) is_aromatic_scores, target_is_aromatic, is_aromatic_target_types = get_only_target_info( is_aromatic_scores, orig_is_aromatic, is_aromatic_target_inds_vector, is_aromatic_num_classes, is_aromatic_target_types) chirality_scores, target_chirality, chirality_target_types = get_only_target_info( chirality_scores, orig_chirality, chirality_target_inds_vector, chirality_num_classes, chirality_target_types) num_data_points += batch_size losses = [] if params.target_data_structs in [ 'nodes', 'both', 'random' ] and len(target_nodes) > 0: weight = get_loss_weights(target_nodes, node_scores, params.equalise, params.local_cpu) iter_node_loss = F.cross_entropy(node_scores, target_nodes, weight=weight, reduction='sum') losses.append(iter_node_loss) node_loss += iter_node_loss nodes_correct, mask_nodes_correct, replace_nodes_correct, recon_nodes_correct, total_mask_nodes, \ total_replace_nodes, total_recon_nodes, total_nodes = update_ds_stats(node_scores, target_nodes, node_target_types, nodes_correct, mask_nodes_correct, replace_nodes_correct, recon_nodes_correct, total_mask_nodes, total_replace_nodes, total_recon_nodes, total_nodes) total_noncarbon_nodes += (target_nodes != 1).sum() total_carbon_nodes += (target_nodes == 1).sum() node_preds = torch.argmax(F.softmax(node_scores, -1), dim=-1) noncarbon_nodes_correct += torch.mul( (node_preds == target_nodes), (target_nodes != 1)).sum() carbon_nodes_correct += torch.mul( (node_preds == target_nodes), (target_nodes == 1)).sum() if params.check_pred_validity is True: init_nodes[node_target_inds_vector] = node_preds def val_node_property_loss_computation( targets, scores, loss, binary=False): weight = get_loss_weights(targets, scores, params.equalise, params.local_cpu) if binary is True: iter_loss = F.binary_cross_entropy_with_logits( scores, targets, weight=weight, reduction='sum') else: iter_loss = F.cross_entropy(scores, targets, weight=weight, reduction='sum') losses.append(iter_loss) loss += iter_loss return loss if params.embed_hs is True: hydrogen_loss = val_node_property_loss_computation( target_hydrogens, hydrogen_scores, hydrogen_loss) hydrogens_correct, mask_hydrogens_correct, replace_hydrogens_correct, recon_hydrogens_correct,\ total_mask_hydrogens, total_replace_hydrogens, total_recon_hydrogens, total_hydrogens =\ update_ds_stats(hydrogen_scores, target_hydrogens, hydrogen_target_types, hydrogens_correct, mask_hydrogens_correct, replace_hydrogens_correct, recon_hydrogens_correct, total_mask_hydrogens, total_replace_hydrogens, total_recon_hydrogens, total_hydrogens) charge_loss = val_node_property_loss_computation( target_charge, charge_scores, charge_loss) is_in_ring_loss = val_node_property_loss_computation( target_is_in_ring, is_in_ring_scores, is_in_ring_loss) is_aromatic_loss = val_node_property_loss_computation( target_is_aromatic, is_aromatic_scores, is_aromatic_loss) chirality_loss = val_node_property_loss_computation( target_chirality, chirality_scores, chirality_loss) if params.target_data_structs in [ 'edges', 'both', 'random' ] and len(target_adj_mats) > 0: weight = get_loss_weights(target_adj_mats, edge_scores, params.equalise, params.local_cpu) iter_edge_loss = F.cross_entropy(edge_scores, target_adj_mats, weight=weight, reduction='sum') losses.append(iter_edge_loss) edge_loss += iter_edge_loss edges_correct, mask_edges_correct, replace_edges_correct, recon_edges_correct, total_mask_edges, \ total_replace_edges, total_recon_edges, total_edges = update_ds_stats(edge_scores, target_adj_mats, edge_target_types, edges_correct, mask_edges_correct, replace_edges_correct, recon_edges_correct, total_mask_edges, total_replace_edges, total_recon_edges, total_edges) total_edges_present += (target_adj_mats != 0).sum() total_no_edges += (target_adj_mats == 0).sum() edge_preds = torch.argmax(F.softmax(edge_scores, -1), dim=-1) edge_present_correct += torch.mul( (edge_preds == target_adj_mats), (target_adj_mats != 0)).sum() no_edge_correct += torch.mul( (edge_preds == target_adj_mats), (target_adj_mats == 0)).sum() distributions += get_all_result_distributions( edge_preds, target_adj_mats, edge_target_types, [1, 2, 3], params.num_edge_types) if params.check_pred_validity is True: init_edges[edge_target_coords_matrix] = edge_preds if params.property_type is not None: iter_property_loss = model.property_loss( property_scores, properties) losses.append(iter_property_loss) property_loss += iter_property_loss loss = sum(losses).cpu().item() val_loss += loss if params.check_pred_validity is True: is_valid_list, is_connected_list = check_validity( init_nodes, init_edges, is_valid_list, is_connected_list) if params.property_type is not None: avg_property_loss = float(property_loss) / float( num_data_points) if params.loss_normalisation_type == 'by_total': if params.embed_hs is True: num_components += (total_nodes * 5) num_components = float(total_nodes) + float(total_edges) avg_val_loss = float(val_loss) / float(num_components) if params.target_data_structs in [ 'nodes', 'both', 'random' ] and total_nodes > 0: avg_node_loss = float(node_loss) / float( num_components) if params.embed_hs is True: avg_hydrogen_loss = float( hydrogen_loss) / num_components avg_charge_loss = float( charge_loss) / num_components avg_is_in_ring_loss = float( is_in_ring_loss) / num_components avg_is_aromatic_loss = float( is_aromatic_loss) / num_components avg_chirality_loss = float( chirality_loss) / num_components if params.target_data_structs in [ 'edges', 'both', 'random' ] and total_edges > 0: avg_edge_loss = float(edge_loss) / float( num_components) elif params.loss_normalisation_type == 'by_component': avg_val_loss = 0 if params.target_data_structs in [ 'nodes', 'both', 'random' ] and total_nodes > 0: avg_node_loss = float(node_loss) / float(total_nodes) avg_val_loss += avg_node_loss if params.embed_hs is True: avg_hydrogen_loss = float(hydrogen_loss) / float( total_hydrogens) avg_charge_loss = float(charge_loss) / float( total_nodes) avg_is_in_ring_loss = float( is_in_ring_loss) / float(total_nodes) avg_is_aromatic_loss = float( is_aromatic_loss) / float(total_nodes) avg_chirality_loss = float(chirality_loss) / float( total_nodes) avg_val_loss += avg_hydrogen_loss + avg_charge_loss + avg_is_in_ring_loss + \ avg_is_aromatic_loss + avg_chirality_loss if params.target_data_structs in [ 'edges', 'both', 'random' ] and total_edges > 0: avg_edge_loss = float(edge_loss) / float(total_edges) avg_val_loss += avg_edge_loss if params.property_type is not None: avg_val_loss += avg_property_loss logger.info( 'Average validation loss: {0:.2f}'.format(avg_val_loss)) val_iter = total_iter // params.val_after if params.check_pred_validity is True: percent_valid = np.mean(is_valid_list) * 100 valid_percent_connected = np.mean(is_connected_list) * 100 logger.info('Percent valid: {}%'.format(percent_valid)) logger.info( 'Percent of valid molecules connected: {}%'.format( valid_percent_connected)) results_dict = {'loss': avg_val_loss} if params.target_data_structs in ['nodes', 'both', 'random' ] and total_nodes > 0: nodes_acc, noncarbon_nodes_acc, carbon_nodes_acc, mask_node_acc, replace_node_acc, recon_node_acc =\ accuracies_from_totals({nodes_correct: total_nodes, noncarbon_nodes_correct: total_noncarbon_nodes, carbon_nodes_correct: total_carbon_nodes, mask_nodes_correct: total_mask_nodes, replace_nodes_correct: total_replace_nodes, recon_nodes_correct: total_recon_nodes}) results_dict.update({ 'nodes_acc': nodes_acc, 'carbon_nodes_acc': carbon_nodes_acc, 'noncarbon_nodes_acc': noncarbon_nodes_acc, 'mask_node_acc': mask_node_acc, 'replace_node_acc': replace_node_acc, 'recon_node_acc': recon_node_acc, 'node_loss': avg_node_loss }) logger.info('Node loss: {0:.2f}'.format(avg_node_loss)) logger.info('Node accuracy: {0:.2f}%'.format(nodes_acc)) logger.info('Non-Carbon Node accuracy: {0:.2f}%'.format( noncarbon_nodes_acc)) logger.info('Carbon Node accuracy: {0:.2f}%'.format( carbon_nodes_acc)) logger.info( 'mask_node_acc {:.2f}, replace_node_acc {:.2f}, recon_node_acc {:.2f}' .format(mask_node_acc, replace_node_acc, recon_node_acc)) if params.embed_hs is True: hydrogen_acc, mask_hydrogen_acc, replace_hydrogen_acc, recon_hydrogen_acc = accuracies_from_totals( { hydrogens_correct: total_hydrogens, mask_hydrogens_correct: total_mask_hydrogens, replace_hydrogens_correct: total_replace_hydrogens, recon_hydrogens_correct: total_recon_hydrogens }) results_dict.update({ 'hydrogen_acc': hydrogen_acc, 'mask_hydrogen_acc': mask_hydrogen_acc, 'replace_hydrogen_acc': replace_hydrogen_acc, 'recon_hydrogen_acc': recon_hydrogen_acc, 'hydrogen_loss': avg_hydrogen_loss }) logger.info( 'Hydrogen loss: {0:.2f}'.format(avg_hydrogen_loss)) logger.info( 'Hydrogen accuracy: {0:.2f}%'.format(hydrogen_acc)) logger.info( 'mask_hydrogen_acc {:.2f}, replace_hydrogen_acc {:.2f}, recon_hydrogen_acc {:.2f}' .format(mask_hydrogen_acc, replace_hydrogen_acc, recon_hydrogen_acc)) results_dict.update({ 'charge_loss': avg_charge_loss, 'is_in_ring_loss': avg_is_in_ring_loss, 'is_aromatic_loss': avg_is_aromatic_loss, 'chirality_loss': avg_chirality_loss }) logger.info( 'Charge loss: {0:.2f}'.format(avg_charge_loss)) logger.info('Is in ring loss: {0:.2f}'.format( avg_is_in_ring_loss)) logger.info('Is aromatic loss: {0:.2f}'.format( avg_is_aromatic_loss)) logger.info('Chirality loss: {0:.2f}'.format( avg_chirality_loss)) if params.target_data_structs in ['edges', 'both', 'random' ] and total_edges > 0: edges_acc, no_edge_acc, edge_present_acc, mask_edge_acc, replace_edge_acc, recon_edge_acc =\ accuracies_from_totals({edges_correct: total_edges, no_edge_correct: total_no_edges, edge_present_correct: total_edges_present, mask_edges_correct: total_mask_edges, replace_edges_correct: total_replace_edges, recon_edges_correct: total_recon_edges}) results_dict.update({ 'edges_acc': edges_acc, 'edge_present_acc': edge_present_acc, 'no_edge_acc': no_edge_acc, 'mask_edge_acc': mask_edge_acc, 'replace_edge_acc': replace_edge_acc, 'recon_edge_acc': recon_edge_acc, 'edge_loss': avg_edge_loss }) logger.info('Edge loss: {0:.2f}'.format(avg_edge_loss)) logger.info('Edge accuracy: {0:.2f}%'.format(edges_acc)) logger.info('Edge present accuracy: {0:.2f}%'.format( edge_present_acc)) logger.info( 'No edge accuracy: {0:.2f}%'.format(no_edge_acc)) logger.info( " mask_edge_acc {:.2f}, replace_edge_acc {:.2f}, recon_edge_acc {:.2f}" .format(mask_edge_acc, replace_edge_acc, recon_edge_acc)) logger.info('\n') for i in range(distributions.shape[0]): for j in range(distributions.shape[1]): logger.info(dist_names[i][j] + ':\t' + str(distributions[i, j, :])) logger.info('\n') if params.property_type is not None: results_dict.update({'property_loss': avg_property_loss}) logger.info( 'Property loss: {0:.2f}%'.format(avg_property_loss)) if params.run_perturbation_analysis is True: preds = run_perturbations(perturbation_loader, model, params.embed_hs, params.max_hs, params.perturbation_batch_size, params.local_cpu) stats, percentages = aggregate_stats(preds) logger.info('Percentages: {}'.format( pp.pformat(percentages))) if params.tensorboard: if params.run_perturbation_analysis is True: writer.add_scalars('dev/perturbation_stability', { str(key): val for key, val in percentages.items() }, val_iter) write_tensorboard(writer, 'dev', results_dict, val_iter) if params.gen_num_samples > 0: calculate_gen_benchmarks(generator, params.gen_num_samples, training_smiles, logger) logger.info("----------------------------------") model_state_dict = model.module.state_dict( ) if torch.cuda.device_count() > 1 else model.state_dict() best_loss = save_checkpoints(total_iter, avg_val_loss, best_loss, model_state_dict, opt.state_dict(), exp_path, logger, params.no_save, params.save_all) # Reset random seed set_seed_if(params.seed) logger.info('Validation complete') total_iter += 1
def __getitem__(self, index): set_seed_if(self.seed) # Create arrays with node empty index where node does not exist, node index everywhere else unpadded_node_inds = self.mol_nodeinds[index] max_nodes = self.max_nodes if self.pad is True else len( unpadded_node_inds) if self.binary_classification is True: unpadded_node_inds[np.where(unpadded_node_inds != 1)] = 0 node_inds = np.ones(max_nodes) * self.node_empty_index num_nodes = len(unpadded_node_inds) node_inds[:num_nodes] = unpadded_node_inds # Create matrices with edge empty index where edge would connect non-existent node, edge index everywhere else unpadded_adj_mat = self.adj_mats[index] if self.binary_classification is True: unpadded_adj_mat[np.where(unpadded_adj_mat != 0)] = 1 assert (check_symmetric(unpadded_adj_mat.astype(int)) ) # make sure bond matrix is symmetric adj_mat = np.ones((max_nodes, max_nodes)) * self.edge_empty_index adj_mat[:num_nodes, :num_nodes] = unpadded_adj_mat # Create masks with 0 where node does not exist or edge would connect non-existent node, 1 everywhere else node_mask = torch.zeros(max_nodes) node_mask[:num_nodes] = 1 edge_mask = torch.zeros((max_nodes, max_nodes)) edge_mask[:num_nodes, :num_nodes] = 1 edge_mask[np.arange(num_nodes), np.arange(num_nodes)] = 0 # *** Initialise nodes *** unpadded_num_hs = self.num_hs[index] unpadded_charge = self.charge[index] + abs(self.min_charge) unpadded_is_in_ring = self.is_in_ring[index] unpadded_is_aromatic = self.is_aromatic[index] unpadded_chirality = self.chirality[index] init_hydrogens, orig_hydrogens = self.get_orig_and_init( unpadded_num_hs, max_nodes, num_nodes, self.h_empty_index) init_charge, orig_charge = self.get_orig_and_init( unpadded_charge, max_nodes, num_nodes, abs(self.min_charge)) init_is_in_ring, orig_is_in_ring = self.get_orig_and_init( unpadded_is_in_ring, max_nodes, num_nodes, 0) init_is_aromatic, orig_is_aromatic = self.get_orig_and_init( unpadded_is_aromatic, max_nodes, num_nodes, 0) init_chirality, orig_chirality = self.get_orig_and_init( unpadded_chirality, max_nodes, num_nodes, 0) init_nodes = np.copy(node_inds) node_target_types = np.zeros(node_mask.shape) hydrogen_target_types = np.zeros(node_mask.shape) charge_target_types = np.zeros(node_mask.shape) is_in_ring_target_types = np.zeros(node_mask.shape) is_aromatic_target_types = np.zeros(node_mask.shape) chirality_target_types = np.zeros(node_mask.shape) if self.mask_all_ring_properties is True: init_is_in_ring[:] = self.is_in_ring_mask_index init_is_aromatic[:] = self.is_aromatic_mask_index # *** Initialise edges *** # Get (row, column) coordinates of upper triangular part of adjacency matrix excluding diagonal. These are # potential target edges. Also get values of the matrix at these indices. init_edges = np.copy(adj_mat) edge_coords, edge_vals = [], [] for i in range(num_nodes): for j in range(i + 1, num_nodes): edge_coords.append((i, j)) edge_vals.append(unpadded_adj_mat[i, j]) edge_coords = np.array(edge_coords) unpadded_edge_inds = np.array(edge_vals) num_edges = len(edge_coords) edge_target_types = np.zeros(edge_mask.shape) # *** return values *** if self.do_not_corrupt is False: init_nodes, node_target_types, num_nodes, init_hydrogens, hydrogen_target_types, init_charge, \ charge_target_types, init_is_in_ring, is_in_ring_target_types, init_is_aromatic, \ is_aromatic_target_types, init_chirality, chirality_target_types, init_edges, \ edge_target_types, num_edges = self.corrupt_graph(init_nodes, unpadded_node_inds, node_target_types, num_nodes, init_hydrogens, unpadded_num_hs, hydrogen_target_types, init_charge, unpadded_charge, charge_target_types, init_is_in_ring, unpadded_is_in_ring, is_in_ring_target_types, init_is_aromatic, unpadded_is_aromatic, is_aromatic_target_types, init_chirality, unpadded_chirality, chirality_target_types, init_edges, edge_coords, unpadded_edge_inds, edge_target_types, num_edges) if self.no_edge_present_type == 'zeros': edge_mask[np.where(init_edges == 0)] = 0 ret_list = [ torch.LongTensor(init_nodes), torch.LongTensor(init_edges), torch.LongTensor(node_inds), torch.LongTensor(adj_mat), node_mask, edge_mask, node_target_types.astype(np.int8), edge_target_types.astype(np.int8), torch.LongTensor(init_hydrogens), torch.LongTensor(orig_hydrogens), torch.LongTensor(init_charge), torch.LongTensor(orig_charge), torch.LongTensor(init_is_in_ring), torch.LongTensor(orig_is_in_ring), torch.LongTensor(init_is_aromatic), torch.LongTensor(orig_is_aromatic), torch.LongTensor(init_chirality), torch.LongTensor(orig_chirality), hydrogen_target_types.astype(np.int8), charge_target_types.astype(np.int8), is_in_ring_target_types.astype(np.int8), is_aromatic_target_types.astype(np.int8), chirality_target_types.astype(np.int8) ] return ret_list