Пример #1
0
    def generate(self, number_samples):
        load_path, load_iters = get_load_path(self.num_sampling_iters,
                                              self.num_argmax_iters,
                                              self.cp_save_dir)
        all_init_node_properties, all_init_edge_properties, all_node_masks, all_edge_masks = \
            self.get_all_init_variables(load_path, number_samples)

        if self.set_seed_at_load_iter is True:
            set_seed_if(load_iters)

        retrieve_train_graphs = self.retrieve_train_graphs
        for j in range(load_iters, self.num_iters):
            if j > 0:
                retrieve_train_graphs = False
                if self.generation_algorithm == 'Gibbs':
                    self.train_data.do_not_corrupt = True
            loader = self.get_dataloader(all_init_node_properties,
                                         all_node_masks,
                                         all_init_edge_properties,
                                         number_samples, retrieve_train_graphs)

            use_argmax = (j >= self.num_sampling_iters)
            all_init_node_properties, all_init_edge_properties, all_node_masks, \
                smiles_list = self.carry_out_iteration(loader, use_argmax)

        return smiles_list
Пример #2
0
def main(params):
    model_cls = MODELS_DICT[params.model_name]
    model = model_cls(params)
    params, model, opt, scheduler, train_data, train_loader, val_dataset, val_loader, perturbation_loader, generator,\
    index_method, exp_path, training_smiles, pp, logger, writer, best_loss,\
    total_iter, grad_accum_iters = setup_data_and_model(params, model)

    for epoch in range(1, params.num_epochs+1):
        print('Starting epoch {}'.format(epoch), flush=True)
        for train_batch in train_loader:
            if total_iter % 100 == 0: print(total_iter, flush=True)
            if total_iter == params.max_steps:
                logger.info('Done training')
                break
            model.train()
            if hasattr(model, 'seq_model'): model.seq_model.eval()

            # Training step
            batch_init_graph, batch_orig_graph, batch_target_type_graph, _, \
            graph_properties, binary_graph_properties = train_batch
            # init is what goes into model
            # original are uncorrupted data (used for comparison with prediction in loss calculation)
            # masks are 1 in places corresponding to nodes that exist, 0 in other places (which are empty/padded)
            # target_types are 1 in places to mask, 2 in places to replace, 3 in places to reconstruct, 0 in places not to predict

            if params.local_cpu is False:
                batch_init_graph = batch_init_graph.to(torch.device('cuda:0'))
                batch_orig_graph = batch_orig_graph.to(torch.device('cuda:0'))
                dct_to_cuda_inplace(graph_properties)
                if binary_graph_properties: binary_graph_properties = binary_graph_properties.cuda()

            if grad_accum_iters % params.grad_accum_iters == 0:
                opt.zero_grad()

            _, batch_scores_graph, graph_property_scores = model(batch_init_graph, graph_properties,
                binary_graph_properties)

            num_components = sum([(v != 0).sum().numpy() for _, v in batch_target_type_graph.ndata.items()]) + \
                             sum([(v != 0).sum().numpy() for _, v in batch_target_type_graph.edata.items()])

            # calculate score
            losses = []
            results_dict = {}
            metrics_to_print = []
            if params.target_data_structs in ['nodes', 'both', 'random'] and \
                    batch_target_type_graph.ndata['node_type'].sum() > 0:
                node_losses = {}
                for name, target_type in batch_target_type_graph.ndata.items():
                    node_losses[name] = get_loss(
                        batch_orig_graph.ndata[name][target_type.numpy() != 0],
                        batch_scores_graph.ndata[name][target_type.numpy() != 0],
                        params.equalise, params.loss_normalisation_type, num_components, params.local_cpu)
                losses.extend(node_losses.values())
                results_dict.update(node_losses)
                metrics_to_print.extend(node_losses.keys())

            if params.target_data_structs in ['edges', 'both', 'random'] and \
                    batch_target_type_graph.edata['edge_type'].sum() > 0:
                edge_losses = {}
                for name, target_type in batch_target_type_graph.edata.items():
                    edge_losses[name] = get_loss(
                        batch_orig_graph.edata[name][target_type.numpy() != 0],
                        batch_scores_graph.edata[name][target_type.numpy() != 0],
                        params.equalise, params.loss_normalisation_type, num_components, params.local_cpu)
                losses.extend(edge_losses.values())
                results_dict.update(edge_losses)
                metrics_to_print.extend(edge_losses.keys())

            if params.predict_graph_properties is True:
                graph_property_losses = {}
                for name, scores in graph_property_scores.items():
                    graph_property_loss = normalise_loss(F.mse_loss(scores, graph_properties[name], reduction='sum'),
                                                         len(scores), num_components, params.loss_normalisation_type)
                    graph_property_losses[name] = graph_property_loss
                losses.extend(graph_property_losses.values())
                results_dict.update(graph_property_losses)
                metrics_to_print.extend(graph_property_losses.keys())

            loss = sum(losses)
            if params.no_update is False:
                (loss/params.grad_accum_iters).backward()
                grad_accum_iters += 1
                if grad_accum_iters % params.grad_accum_iters == 0:
                    # clip grad norm
                    if params.clip_grad_norm > -1:
                        torch.nn.utils.clip_grad_norm_(model.parameters(), params.clip_grad_norm)
                    opt.step()
                    grad_accum_iters = 0
                if (total_iter+1) <= params.warm_up_iters or (total_iter+1) % params.lr_decay_interval == 0:
                    scheduler.step(total_iter+1)

            if params.suppress_train_log is False:
                log_string = ''
                for name, value in results_dict.items():
                    if name in metrics_to_print:
                        log_string += ', {} = {:.2f}'.format(name, value)
                log_string = 'total_iter = {0:d}, loss = {1:.2f}'.format(total_iter, loss.cpu().item()) + log_string
                logger.info(log_string)

            if params.target_frac_inc_after is not None and total_iter > 0 and total_iter % params.target_frac_inc_after == 0:
                train_data.node_target_frac = min(train_data.node_target_frac + params.target_frac_inc_amount,
                                                  params.max_target_frac)
                train_data.edge_target_frac = min(train_data.edge_target_frac + params.target_frac_inc_amount,
                                                  params.max_target_frac)
            results_dict.update({'node_target_frac': train_data.node_target_frac})
            results_dict.update({'edge_target_frac': train_data.edge_target_frac})

            if params.tensorboard and total_iter % int(params.log_train_steps) == 0:
                results_dict.update({'loss': loss, 'lr': opt.param_groups[0]['lr']})
                write_tensorboard(writer, 'train', results_dict, total_iter)


            if total_iter > 0 and total_iter % params.val_after == 0:
                logger.info('Validating')
                val_loss, num_data_points = 0, 0
                node_property_losses = {name: 0 for name in train_data.node_property_names}
                edge_property_losses = {name: 0 for name in train_data.edge_property_names}
                node_property_num_components = {name: 0 for name in train_data.node_property_names}
                edge_property_num_components = {name: 0 for name in train_data.edge_property_names}
                graph_property_losses = {name: 0 for name in params.graph_property_names}
                model.eval()
                set_seed_if(params.seed)
                for batch_init_graph, batch_orig_graph, batch_target_type_graph, _, \
                    graph_properties, binary_graph_properties in val_loader:

                    if params.local_cpu is False:
                        batch_init_graph = batch_init_graph.to(torch.device('cuda:0'))
                        batch_orig_graph = batch_orig_graph.to(torch.device('cuda:0'))
                        dct_to_cuda_inplace(graph_properties)
                        if binary_graph_properties: binary_graph_properties = binary_graph_properties.cuda()

                    with torch.no_grad():
                        _, batch_scores_graph, graph_property_scores = model(batch_init_graph, graph_properties,
                                                                             binary_graph_properties)

                    num_data_points += float(batch_orig_graph.batch_size)
                    losses = []
                    if params.target_data_structs in ['nodes', 'both', 'random'] and \
                            batch_target_type_graph.ndata['node_type'].sum() > 0:
                        for name, target_type in batch_target_type_graph.ndata.items():
                            iter_node_property_loss = F.cross_entropy(
                                batch_scores_graph.ndata[name][target_type.numpy() != 0],
                                batch_orig_graph.ndata[name][target_type.numpy() != 0], reduction='sum').cpu().item()
                            node_property_losses[name] += iter_node_property_loss
                            losses.append(iter_node_property_loss)
                            node_property_num_components[name] += float((target_type != 0).sum())

                    if params.target_data_structs in ['edges', 'both', 'random'] and \
                            batch_target_type_graph.edata['edge_type'].sum() > 0:
                        for name, target_type in batch_target_type_graph.edata.items():
                            iter_edge_property_loss = F.cross_entropy(
                                batch_scores_graph.edata[name][target_type.numpy() != 0],
                                batch_orig_graph.edata[name][target_type.numpy() != 0], reduction='sum').cpu().item()
                            edge_property_losses[name] += iter_edge_property_loss
                            losses.append(iter_edge_property_loss)
                            edge_property_num_components[name] += float((target_type != 0).sum())

                    if params.predict_graph_properties is True:
                        for name, scores in graph_property_scores.items():
                            iter_graph_property_loss = F.mse_loss(
                                scores, graph_properties[name], reduction='sum').cpu().item()
                            graph_property_losses[name] += iter_graph_property_loss
                            losses.append(iter_graph_property_loss)

                    val_loss += sum(losses)

                avg_node_property_losses, avg_edge_property_losses, avg_graph_property_losses = {}, {}, {}
                if params.loss_normalisation_type == 'by_total':
                    total_num_components = float(sum(node_property_num_components.values()) +
                                                 sum(edge_property_num_components.values()))
                    avg_val_loss = val_loss/total_num_components
                    if params.target_data_structs in ['nodes', 'both', 'random']:
                        for name, loss in node_property_losses.items():
                            avg_node_property_losses[name] = loss/total_num_components
                    if params.target_data_structs in ['edges', 'both', 'random']:
                        for name, loss in edge_property_losses.items():
                            avg_edge_property_losses[name] = loss/total_num_components
                    if params.predict_graph_properties is True:
                        for name, loss in graph_property_losses.items():
                            avg_graph_property_losses[name] = loss/total_num_components
                elif params.loss_normalisation_type == 'by_component':
                    avg_val_loss = 0
                    if params.target_data_structs in ['nodes', 'both', 'random']:
                        for name, loss in node_property_losses.items():
                            avg_node_property_losses[name] = loss/node_property_num_components[name]
                        avg_val_loss += sum(avg_node_property_losses.values())
                    if params.target_data_structs in ['edges', 'both', 'random']:
                        for name, loss in edge_property_losses.items():
                            avg_edge_property_losses[name] = loss/edge_property_num_components[name]
                        avg_val_loss += sum(avg_edge_property_losses.values())
                    if params.predict_graph_properties is True:
                        for name, loss in graph_property_losses.items():
                            avg_graph_property_losses[name] = loss/num_data_points
                        avg_val_loss += sum(avg_graph_property_losses.values())
                val_iter = total_iter // params.val_after

                results_dict = {'Validation_loss': avg_val_loss}
                for name, loss in avg_node_property_losses.items():
                    results_dict['{}_loss'.format(name)] = loss
                for name, loss in avg_edge_property_losses.items():
                    results_dict['{}_loss'.format(name)] = loss
                for name, loss in avg_graph_property_losses.items():
                    results_dict['{}_loss'.format(name)] = loss

                for name, loss in results_dict.items():
                    logger.info('{}: {:.2f}'.format(name, loss))

                if params.tensorboard:
                    write_tensorboard(writer, 'Dev', results_dict, val_iter)
                if params.gen_num_samples > 0:
                    calculate_gen_benchmarks(generator, params.gen_num_samples, training_smiles, logger)

                logger.info("----------------------------------")
                model_state_dict = model.module.state_dict() if torch.cuda.device_count() > 1 else model.state_dict()
                best_loss = save_checkpoints(total_iter, avg_val_loss, best_loss, model_state_dict, opt.state_dict(),
                                             exp_path, logger, params.no_save, params.save_all)
                # Reset random seed
                set_seed_if(params.seed)
                logger.info('Validation complete')
            total_iter += 1
Пример #3
0
def setup_data_and_model(params, model):
    # Variables that may not otherwise be assigned
    writer = perturbation_loader = generator = training_smiles = None

    # setup random seeds
    if params.val_seed is None: params.val_seed = params.seed
    set_seed_if(params.seed)

    exp_path = os.path.join(params.dump_path, params.exp_name)
    # create exp path if it doesn't exist
    if not os.path.exists(exp_path):
        os.makedirs(exp_path)
    # create logger
    logger = create_logger(os.path.join(exp_path, 'train.log'), 0)
    pp = pprint.PrettyPrinter()
    logger.info("============ Initialized logger ============")
    logger.info("Random seed is {}".format(params.seed))
    if params.suppress_params is False:
        logger.info("\n".join("%s: %s" % (k, str(v))
                          for k, v in sorted(dict(vars(params)).items())))
        logger.info("Running command: %s" % 'python ' + ' '.join(sys.argv))
    logger.info("The experiment will be stored in %s\n" % exp_path)
    logger.info("")
    # load data
    train_data, val_dataset, train_loader, val_loader = load_graph_data(params)

    logger.info ('train_loader len is {}'.format(len(train_loader)))
    logger.info ('val_loader len is {}'.format(len(val_loader)))

    if params.num_binary_graph_properties > 0 and params.pretrained_property_embeddings_path:
        model.binary_graph_property_embedding_layer.weight.data = \
            torch.Tensor(np.load(params.pretrained_property_embeddings_path).T)
    if params.load_latest is True:
        load_prefix = 'latest'
    elif params.load_best is True:
        load_prefix = 'best'
    else:
        load_prefix = None

    if load_prefix is not None:
        if params.local_cpu is True:
            model.load_state_dict(torch.load(os.path.join(exp_path, '{}_model'.format(load_prefix)), map_location='cpu'))
        else:
            model.load_state_dict(torch.load(os.path.join(exp_path, '{}_model'.format(load_prefix))))
    if params.local_cpu is False:
        model = model.cuda()
    if params.gen_num_samples > 0:
        generator = GraphGenerator(train_data, model, params.gen_random_init, params.gen_num_iters, params.gen_predict_deterministically, params.local_cpu)
        with open(params.smiles_path) as f:
            smiles = f.read().split('\n')
            training_smiles = smiles[:int(params.smiles_train_split * len(smiles))]
            del smiles
    opt = get_optimizer(model.parameters(), params.optimizer)
    if load_prefix is not None:
        opt.load_state_dict(torch.load(os.path.join(exp_path, '{}_opt_sd'.format(load_prefix))))

    lr = opt.param_groups[0]['lr']
    lr_lambda = lambda iteration: lr_decay_multiplier(iteration, params.warm_up_iters, params.decay_start_iter,
                                                      params.lr_decay_amount, params.lr_decay_frac,
                                                      params.lr_decay_interval, params.min_lr, lr)
    scheduler = LambdaLR(opt, lr_lambda)
    index_method = get_index_method()

    best_loss = 9999
    if params.tensorboard:
        from tensorboardX import SummaryWriter
        writer = SummaryWriter(exp_path)

    total_iter, grad_accum_iters = params.first_iter, 0

    return params, model, opt, scheduler, train_data, train_loader, val_dataset, val_loader, perturbation_loader,\
           generator, index_method, exp_path, training_smiles, pp, logger, writer, best_loss, total_iter,\
           grad_accum_iters
Пример #4
0
    def __getitem__(self, index):
        set_seed_if(self.seed)

        # *** Initialise nodes ***
        unpadded_node_properties, init_node_properties, orig_node_properties, node_property_target_types = {}, {}, {}, {}
        for property_name, property_info in self.node_properties.items():
            unpadded_data = property_info['data'][index]
            if property_name == 'charge':
                unpadded_data = unpadded_data + abs(self.min_charge)
            unpadded_node_properties[property_name] = unpadded_data
            num_nodes = len(unpadded_data)
            max_nodes = self.max_nodes if self.pad is True else num_nodes
            init_node_properties[property_name], orig_node_properties[property_name] = \
                self.get_orig_and_init_node_property(
                    unpadded_data, max_nodes, len(unpadded_data), property_info['empty_index'])
            node_property_target_types[property_name] = np.zeros(max_nodes)

        # Create masks with 0 where node does not exist or edge would connect non-existent node, 1 everywhere else
        node_mask = torch.zeros(max_nodes)
        node_mask[:num_nodes] = 1
        edge_mask = torch.zeros((max_nodes, max_nodes))
        edge_mask[:num_nodes, :num_nodes] = 1
        edge_mask[np.arange(num_nodes), np.arange(num_nodes)] = 0

        # *** Initialise edges ***
        edge_coords = [(i, j) for i in range(num_nodes)
                       for j in range(i + 1, num_nodes)]
        num_edges = len(edge_coords)
        unpadded_edge_properties, unpadded_edge_property_inds, init_edge_properties, orig_edge_properties,\
            edge_property_target_types = {}, {}, {}, {}, {}
        for property_name, property_info in self.edge_properties.items():
            unpadded_data = property_info['data'][index]
            unpadded_edge_properties[property_name] = unpadded_data
            assert (check_symmetric(unpadded_data.astype(int))
                    )  # make sure bond matrix is symmetric
            unpadded_edge_property_inds[property_name] = np.array(
                [unpadded_data[i, j] for (i, j) in edge_coords])
            init_edge_properties[property_name], orig_edge_properties[property_name] =\
                self.get_orig_and_init_edge_property(
                unpadded_data, max_nodes, len(unpadded_data), property_info['empty_index'])
            edge_property_target_types[property_name] = np.zeros(
                edge_mask.shape)
        edge_coords = np.array(edge_coords)

        if self.do_not_corrupt is False:
            init_node_properties, node_property_target_types, init_edge_properties, edge_property_target_types = \
                self.corrupt_graph(unpadded_node_properties, init_node_properties, node_property_target_types,
                                   num_nodes, unpadded_edge_property_inds, init_edge_properties,
                                   edge_property_target_types, edge_coords, num_edges)

        if self.no_edge_present_type == 'zeros':
            edge_mask[np.where(init_edge_properties['edge_type'] == 0)] = 0

        # Cast to suitable type

        for property_name in init_node_properties.keys():
            init_node_properties[property_name] = torch.LongTensor(
                init_node_properties[property_name])
            orig_node_properties[property_name] = torch.LongTensor(
                orig_node_properties[property_name])
            node_property_target_types[property_name] = torch.IntTensor(
                node_property_target_types[property_name])

        for property_name in init_edge_properties.keys():
            init_edge_properties[property_name] = torch.LongTensor(
                init_edge_properties[property_name])
            orig_edge_properties[property_name] = torch.LongTensor(
                orig_edge_properties[property_name])
            edge_property_target_types[property_name] = torch.IntTensor(
                edge_property_target_types[property_name])

        graph_properties = {}
        for k, v in self.graph_properties.items():
            if self.normalise_graph_properties is True:
                graph_properties[k] = torch.Tensor([ (v[index] - self.graph_property_stats[k]['mean']) / \
                                                     self.graph_property_stats[k]['std'] ])
            else:
                graph_properties[k] = torch.Tensor([v[index]])

        if self.num_binary_graph_properties > 0:
            positive_binary_properties = self.graph2binary_properties[index]
            binary_graph_properties = torch.zeros(
                self.num_binary_graph_properties)
            binary_graph_properties[np.array(positive_binary_properties)] = 1
        else:
            binary_graph_properties = []

        ret_list = [
            init_node_properties, orig_node_properties,
            node_property_target_types, node_mask, init_edge_properties,
            orig_edge_properties, edge_property_target_types, edge_mask,
            graph_properties, binary_graph_properties
        ]

        return ret_list
Пример #5
0
    def generate_with_evaluation(self,
                                 num_samples_to_generate,
                                 smiles_dataset_path,
                                 output_dir,
                                 num_samples_to_evaluate,
                                 evaluate_connected_only=False):

        load_path, load_iters = get_load_path(self.num_sampling_iters,
                                              self.num_argmax_iters,
                                              self.cp_save_dir)
        all_init_node_properties, all_init_edge_properties, all_node_masks, all_edge_masks = \
            self.get_all_init_variables(load_path, num_samples_to_generate)

        if self.save_init is True and self.random_init is True and load_iters == 0:
            # Save smiles representations of initialised molecules
            smiles_list = []
            num_nodes = all_node_masks.sum(-1)
            for i in range(len(all_init_node_properties['node_type'])):
                mol = graph_to_mol({k: v[i][:int(num_nodes[i])].astype(int) \
                                    for k, v in all_init_node_properties.items()},
                                   {k: v[i][:int(num_nodes[i]), :int(num_nodes[i])].astype(int) \
                                    for k, v in all_init_edge_properties.items()},
                                   min_charge=self.train_data.min_charge, symbol_list=self.symbol_list)
                smiles_list.append(Chem.MolToSmiles(mol))
            save_smiles_list(smiles_list,
                             os.path.join(output_dir, 'smiles_0_0.txt'))
            del smiles_list, mol, num_nodes

        if self.set_seed_at_load_iter is True:
            set_seed_if(load_iters)

        retrieve_train_graphs = self.retrieve_train_graphs
        for j in tqdm(range(load_iters, self.num_iters)):
            if j > 0:
                retrieve_train_graphs = False
                if self.generation_algorithm == 'Gibbs':
                    self.train_data.do_not_corrupt = True
            loader = self.get_dataloader(all_init_node_properties,
                                         all_node_masks,
                                         all_init_edge_properties,
                                         num_samples_to_generate,
                                         retrieve_train_graphs)

            use_argmax = (j >= self.num_sampling_iters)
            all_init_node_properties, all_init_edge_properties, all_node_masks,\
                smiles_list = self.carry_out_iteration(loader, use_argmax)

            sampling_iters_completed = min(j + 1, self.num_sampling_iters)
            argmax_iters_completed = max(0, j + 1 - self.num_sampling_iters)
            if (j + 1 - load_iters) % self.checkpointing_period == 0:
                self.save_checkpoints(all_init_node_properties,
                                      all_init_edge_properties,
                                      sampling_iters_completed,
                                      argmax_iters_completed)

            if (j + 1 - load_iters) % self.save_period == 0 or (
                    self.save_finegrained is True and (j + 1) <= 10):
                smiles_output_path = os.path.join(
                    output_dir,
                    'smiles_{}_{}.txt'.format(sampling_iters_completed,
                                              argmax_iters_completed))
                save_smiles_list(smiles_list, smiles_output_path)

            if (j + 1 - load_iters) % self.evaluation_period == 0 or \
                (self.evaluate_finegrained is True and (j + 1) <= 10):
                json_output_path = os.path.join(
                    output_dir, 'distribution_results_{}_{}.json'.format(
                        sampling_iters_completed, argmax_iters_completed))
                evaluate_uncond_generation(
                    MockGenerator(smiles_list, num_samples_to_generate),
                    smiles_dataset_path, json_output_path,
                    num_samples_to_evaluate, evaluate_connected_only)
                if self.cond_property_values:
                    cond_json_output_path = os.path.join(
                        output_dir, 'cond_results_{}_{}.json'.format(
                            sampling_iters_completed, argmax_iters_completed))
                    self.evaluate_cond_generation(
                        smiles_list[:num_samples_to_evaluate],
                        cond_json_output_path)
Пример #6
0
def main(params):
    model_cls = MODELS_DICT[params.model_name]
    model = model_cls(params)
    params, model, opt, scheduler, train_data, train_loader, val_dataset, val_loader, perturbation_loader, generator,\
    index_method, exp_path, training_smiles, pp, logger, writer, best_loss,\
    total_iter, grad_accum_iters = setup_data_and_model(params, model)

    for epoch in range(1, params.num_epochs + 1):
        print('Starting epoch {}'.format(epoch), flush=True)
        for train_batch in train_loader:
            if total_iter % 100 == 0: print(total_iter, flush=True)
            if total_iter == params.max_steps:
                logger.info('Done training')
                break
            model.train()

            # Training step
            init_nodes, init_edges, original_node_inds, original_adj_mats, node_masks, edge_masks,\
            node_target_types, edge_target_types, init_hydrogens, original_hydrogens, init_charge,\
            orig_charge, init_is_in_ring, orig_is_in_ring, init_is_aromatic, orig_is_aromatic, init_chirality,\
            orig_chirality, hydrogen_target_types, charge_target_types, is_in_ring_target_types,\
            is_aromatic_target_types, chirality_target_types = train_batch
            # init is what goes into model
            # target_inds and target_coords are 1 at locations to be predicted, 0 elsewhere
            # target_inds and target_coords are now calculated here rather than in dataloader
            # original are uncorrupted data (used for comparison with prediction in loss calculation)
            # masks are 1 in places corresponding to nodes that exist, 0 in other places (which are empty/padded)
            # target_types are 1 in places to mask, 2 in places to replace, 3 in places to reconstruct, 0 in places not to predict

            node_target_inds_vector = getattr(node_target_types != 0,
                                              index_method)()
            edge_target_coords_matrix = getattr(edge_target_types != 0,
                                                index_method)()
            hydrogen_target_inds_vector = getattr(hydrogen_target_types != 0,
                                                  index_method)()
            charge_target_inds_vector = getattr(charge_target_types != 0,
                                                index_method)()
            is_in_ring_target_inds_vector = getattr(
                is_in_ring_target_types != 0, index_method)()
            is_aromatic_target_inds_vector = getattr(
                is_aromatic_target_types != 0, index_method)()
            chirality_target_inds_vector = getattr(chirality_target_types != 0,
                                                   index_method)()

            if params.local_cpu is False:
                init_nodes = init_nodes.cuda()
                init_edges = init_edges.cuda()
                original_node_inds = original_node_inds.cuda()
                original_adj_mats = original_adj_mats.cuda()
                node_masks = node_masks.cuda()
                edge_masks = edge_masks.cuda()
                node_target_types = node_target_types.cuda()
                edge_target_types = edge_target_types.cuda()
                init_hydrogens = init_hydrogens.cuda()
                original_hydrogens = original_hydrogens.cuda()
                init_charge = init_charge.cuda()
                orig_charge = orig_charge.cuda()
                init_is_in_ring = init_is_in_ring.cuda()
                orig_is_in_ring = orig_is_in_ring.cuda()
                init_is_aromatic = init_is_aromatic.cuda()
                orig_is_aromatic = orig_is_aromatic.cuda()
                init_chirality = init_chirality.cuda()
                orig_chirality = orig_chirality.cuda()
                hydrogen_target_types = hydrogen_target_types.cuda()
                charge_target_types = charge_target_types.cuda()
                is_in_ring_target_types = is_in_ring_target_types.cuda()
                is_aromatic_target_types = is_aromatic_target_types.cuda()
                chirality_target_types = chirality_target_types.cuda()
                if params.property_type is not None:
                    properties = properties.cuda()

            if grad_accum_iters % params.grad_accum_iters == 0:
                opt.zero_grad()

            out = model(init_nodes, init_edges, node_masks, edge_masks,
                        init_hydrogens, init_charge, init_is_in_ring,
                        init_is_aromatic, init_chirality)
            node_scores, edge_scores, hydrogen_scores, charge_scores, is_in_ring_scores, is_aromatic_scores,\
                chirality_scores = out
            node_num_classes = node_scores.shape[-1]
            edge_num_classes = edge_scores.shape[-1]
            hydrogen_num_classes = hydrogen_scores.shape[-1]
            charge_num_classes = charge_scores.shape[-1]
            is_in_ring_num_classes = is_in_ring_scores.shape[-1]
            is_aromatic_num_classes = is_aromatic_scores.shape[-1]
            chirality_num_classes = chirality_scores.shape[-1]

            if model.property_type is not None:
                property_scores = out[-1]

            # slice out target data structures
            node_scores, target_nodes, node_target_types = get_only_target_info(
                node_scores, original_node_inds, node_target_inds_vector,
                node_num_classes, node_target_types)
            edge_scores, target_adj_mats, edge_target_types = get_only_target_info(
                edge_scores, original_adj_mats, edge_target_coords_matrix,
                edge_num_classes, edge_target_types)
            if params.embed_hs is True:
                hydrogen_scores, target_hydrogens, hydrogen_target_types = get_only_target_info(
                    hydrogen_scores, original_hydrogens,
                    hydrogen_target_inds_vector, hydrogen_num_classes,
                    hydrogen_target_types)
                charge_scores, target_charge, charge_target_types = get_only_target_info(
                    charge_scores, orig_charge, charge_target_inds_vector,
                    charge_num_classes, charge_target_types)
                is_in_ring_scores, target_is_in_ring, is_in_ring_target_types = get_only_target_info(
                    is_in_ring_scores, orig_is_in_ring,
                    is_in_ring_target_inds_vector, is_in_ring_num_classes,
                    is_in_ring_target_types)
                is_aromatic_scores, target_is_aromatic, is_aromatic_target_types = get_only_target_info(
                    is_aromatic_scores, orig_is_aromatic,
                    is_aromatic_target_inds_vector, is_aromatic_num_classes,
                    is_aromatic_target_types)
                chirality_scores, target_chirality, chirality_target_types = get_only_target_info(
                    chirality_scores, orig_chirality,
                    chirality_target_inds_vector, chirality_num_classes,
                    chirality_target_types)

            num_components = len(target_nodes) + len(target_adj_mats)
            if params.embed_hs is True: num_components += len(target_hydrogens)
            # calculate score
            losses = []
            results_dict = {}
            metrics_to_print = []
            if params.target_data_structs in ['nodes', 'both', 'random'
                                              ] and len(target_nodes) > 0:
                node_loss = get_loss(target_nodes, node_scores,
                                     params.equalise,
                                     params.loss_normalisation_type,
                                     num_components, params.local_cpu)
                losses.append(node_loss)
                node_preds = torch.argmax(F.softmax(node_scores, -1), dim=-1)
                nodes_acc, mask_node_acc, replace_node_acc, recon_node_acc = get_ds_stats(
                    node_scores, target_nodes, node_target_types)
                carbon_nodes_correct, noncarbon_nodes_correct, carbon_nodes_acc, noncarbon_nodes_acc = \
                    get_majority_and_minority_stats(node_preds, target_nodes, 1)
                results_dict.update({
                    'nodes_acc': nodes_acc,
                    'carbon_nodes_acc': carbon_nodes_acc,
                    'noncarbon_nodes_acc': noncarbon_nodes_acc,
                    'mask_node_acc': mask_node_acc,
                    'replace_node_acc': replace_node_acc,
                    'recon_node_acc': recon_node_acc,
                    'node_loss': node_loss
                })
                metrics_to_print.extend([
                    'node_loss', 'nodes_acc', 'carbon_nodes_acc',
                    'noncarbon_nodes_acc'
                ])

                def node_property_computations(name,
                                               scores,
                                               targets,
                                               target_types,
                                               binary=False):
                    loss = get_loss(targets,
                                    scores,
                                    params.equalise,
                                    params.loss_normalisation_type,
                                    num_components,
                                    params.local_cpu,
                                    binary=binary)
                    losses.append(loss)
                    acc, mask_acc, replace_acc, recon_acc = get_ds_stats(
                        scores, targets, target_types)
                    results_dict.update({
                        '{}_acc'.format(name): acc,
                        'mask_{}_acc'.format(name): mask_acc,
                        'replace_{}_acc'.format(name): replace_acc,
                        'recon_{}_acc'.format(name): recon_acc,
                        '{}_loss'.format(name): loss
                    })
                    metrics_to_print.extend(['{}_loss'.format(name)])

                if params.embed_hs is True:
                    node_property_computations('hydrogen', hydrogen_scores,
                                               target_hydrogens,
                                               hydrogen_target_types)
                    node_property_computations('charge', charge_scores,
                                               target_charge,
                                               charge_target_types)
                    node_property_computations('is_in_ring', is_in_ring_scores,
                                               target_is_in_ring,
                                               is_in_ring_target_types)
                    node_property_computations('is_aromatic',
                                               is_aromatic_scores,
                                               target_is_aromatic,
                                               is_aromatic_target_types)
                    node_property_computations('chirality', chirality_scores,
                                               target_chirality,
                                               chirality_target_types)

            if params.target_data_structs in ['edges', 'both', 'random'
                                              ] and len(target_adj_mats) > 0:
                edge_loss = get_loss(target_adj_mats, edge_scores,
                                     params.equalise,
                                     params.loss_normalisation_type,
                                     num_components, params.local_cpu)
                losses.append(edge_loss)
                edge_preds = torch.argmax(F.softmax(edge_scores, -1), dim=-1)
                edges_acc, mask_edge_acc, replace_edge_acc, recon_edge_acc = get_ds_stats(
                    edge_scores, target_adj_mats, edge_target_types)
                no_edge_correct, edge_present_correct, no_edge_acc, edge_present_acc = \
                    get_majority_and_minority_stats(edge_preds, target_adj_mats, 0)
                results_dict.update({
                    'edges_acc': edges_acc,
                    'edge_present_acc': edge_present_acc,
                    'no_edge_acc': no_edge_acc,
                    'mask_edge_acc': mask_edge_acc,
                    'replace_edge_acc': replace_edge_acc,
                    'recon_edge_acc': recon_edge_acc,
                    'edge_loss': edge_loss
                })
                metrics_to_print.extend([
                    'edge_loss', 'edges_acc', 'edge_present_acc', 'no_edge_acc'
                ])

            if params.property_type is not None:
                property_loss = model.property_loss(
                    property_scores, properties) / params.batch_size
                losses.append(property_loss)
                results_dict.update({'property_loss': property_loss})
                metrics_to_print.extend(['property_loss'])

            loss = sum(losses)
            if params.no_update is False:
                (loss / params.grad_accum_iters).backward()
                grad_accum_iters += 1
                if grad_accum_iters % params.grad_accum_iters == 0:
                    # clip grad norm
                    if params.clip_grad_norm > -1:
                        torch.nn.utils.clip_grad_norm_(model.parameters(),
                                                       params.clip_grad_norm)
                    opt.step()
                    grad_accum_iters = 0
                if (total_iter + 1) <= params.warm_up_iters or (
                        total_iter + 1) % params.lr_decay_interval == 0:
                    scheduler.step(total_iter + 1)

            if params.check_pred_validity is True:
                if params.target_data_structs in ['nodes', 'both', 'random'
                                                  ] and len(target_nodes) > 0:
                    init_nodes[node_target_inds_vector] = node_preds
                if params.target_data_structs in [
                        'edges', 'both', 'random'
                ] and len(target_adj_mats) > 0:
                    init_edges[edge_target_coords_matrix] = edge_preds
                is_valid_list, is_connected_list = check_validity(
                    init_nodes, init_edges)
                percent_valid = np.mean(is_valid_list) * 100
                valid_percent_connected = np.mean(is_connected_list) * 100
                results_dict.update({
                    'percent_valid':
                    percent_valid,
                    'percent_connected':
                    valid_percent_connected
                })

            if params.suppress_train_log is False:
                log_string = ''
                for name, value in results_dict.items():
                    if name in metrics_to_print:
                        log_string += ', {} = {:.2f}'.format(name, value)
                log_string = 'total_iter = {0:d}, loss = {1:.2f}'.format(
                    total_iter,
                    loss.cpu().item()) + log_string
                logger.info(log_string)

            if params.target_frac_inc_after is not None and total_iter > 0 and total_iter % params.target_frac_inc_after == 0:
                train_data.node_target_frac = min(
                    train_data.node_target_frac +
                    params.target_frac_inc_amount, params.max_target_frac)
                train_data.edge_target_frac = min(
                    train_data.edge_target_frac +
                    params.target_frac_inc_amount, params.max_target_frac)
            results_dict.update(
                {'node_target_frac': train_data.node_target_frac})
            results_dict.update(
                {'edge_target_frac': train_data.edge_target_frac})

            if params.tensorboard and total_iter % int(
                    params.log_train_steps) == 0:
                results_dict.update({
                    'loss': loss,
                    'lr': opt.param_groups[0]['lr']
                })
                write_tensorboard(writer, 'train', results_dict, total_iter)

            dist_names = [[
                'mask_edge_correct_dist', 'mask_edge_incorrect_pred_dist',
                'mask_edge_incorrect_true_dist'
            ],
                          [
                              'replace_edge_correct_dist',
                              'replace_edge_incorrect_pred_dist',
                              'replace_edge_incorrect_true_dist'
                          ],
                          [
                              'recon_edge_correct_dist',
                              'recon_edge_incorrect_pred_dist',
                              'recon_edge_incorrect_true_dist'
                          ]]
            distributions = np.zeros((3, 3, params.num_edge_types))
            if total_iter > 0 and total_iter % params.val_after == 0:
                logger.info('Validating')
                val_loss, property_loss, node_loss, edge_loss, hydrogen_loss, num_data_points = 0, 0, 0, 0, 0, 0
                charge_loss, is_in_ring_loss, is_aromatic_loss, chirality_loss = 0, 0, 0, 0
                model.eval()
                set_seed_if(params.seed)
                nodes_correct, edges_correct, total_nodes, total_edges = 0, 0, 0, 0
                carbon_nodes_correct, noncarbon_nodes_correct, total_carbon_nodes, total_noncarbon_nodes = 0, 0, 0, 0
                mask_nodes_correct, replace_nodes_correct, recon_nodes_correct = 0, 0, 0
                total_mask_nodes, total_replace_nodes, total_recon_nodes = 0, 0, 0
                no_edge_correct, edge_present_correct, total_no_edges, total_edges_present = 0, 0, 0, 0
                mask_edges_correct, replace_edges_correct, recon_edges_correct = 0, 0, 0
                total_mask_edges, total_replace_edges, total_recon_edges = 0, 0, 0
                hydrogens_correct, mask_hydrogens_correct, replace_hydrogens_correct, recon_hydrogens_correct, \
                    total_mask_hydrogens, total_replace_hydrogens, total_recon_hydrogens,\
                    total_hydrogens = 0, 0, 0, 0, 0, 0, 0, 0
                is_valid_list, is_connected_list = [], []
                for init_nodes, init_edges, original_node_inds, original_adj_mats, node_masks, edge_masks,\
                    node_target_types, edge_target_types, init_hydrogens, original_hydrogens,\
                    init_charge, orig_charge, init_is_in_ring, orig_is_in_ring, init_is_aromatic, orig_is_aromatic,\
                    init_chirality, orig_chirality, hydrogen_target_types, charge_target_types, is_in_ring_target_types,\
                    is_aromatic_target_types, chirality_target_types in val_loader:

                    node_target_inds_vector = getattr(node_target_types != 0,
                                                      index_method)()
                    edge_target_coords_matrix = getattr(
                        edge_target_types != 0, index_method)()
                    hydrogen_target_inds_vector = getattr(
                        hydrogen_target_types != 0, index_method)()
                    charge_target_inds_vector = getattr(
                        charge_target_types != 0, index_method)()
                    is_in_ring_target_inds_vector = getattr(
                        is_in_ring_target_types != 0, index_method)()
                    is_aromatic_target_inds_vector = getattr(
                        is_aromatic_target_types != 0, index_method)()
                    chirality_target_inds_vector = getattr(
                        chirality_target_types != 0, index_method)()

                    if params.local_cpu is False:
                        init_nodes = init_nodes.cuda()
                        init_edges = init_edges.cuda()
                        original_node_inds = original_node_inds.cuda()
                        original_adj_mats = original_adj_mats.cuda()
                        node_masks = node_masks.cuda()
                        edge_masks = edge_masks.cuda()
                        node_target_types = node_target_types.cuda()
                        edge_target_types = edge_target_types.cuda()
                        init_hydrogens = init_hydrogens.cuda()
                        original_hydrogens = original_hydrogens.cuda()
                        init_charge = init_charge.cuda()
                        orig_charge = orig_charge.cuda()
                        init_is_in_ring = init_is_in_ring.cuda()
                        orig_is_in_ring = orig_is_in_ring.cuda()
                        init_is_aromatic = init_is_aromatic.cuda()
                        orig_is_aromatic = orig_is_aromatic.cuda()
                        init_chirality = init_chirality.cuda()
                        orig_chirality = orig_chirality.cuda()
                        hydrogen_target_types = hydrogen_target_types.cuda()
                        charge_target_types = charge_target_types.cuda()
                        is_in_ring_target_types = is_in_ring_target_types.cuda(
                        )
                        is_aromatic_target_types = is_aromatic_target_types.cuda(
                        )
                        chirality_target_types = chirality_target_types.cuda()
                        if params.embed_hs is True:
                            original_hydrogens = original_hydrogens.cuda()
                        if params.property_type is not None:
                            properties = properties.cuda()

                    batch_size = init_nodes.shape[0]

                    with torch.no_grad():
                        out = model(init_nodes, init_edges, node_masks,
                                    edge_masks, init_hydrogens, init_charge,
                                    init_is_in_ring, init_is_aromatic,
                                    init_chirality)
                        node_scores, edge_scores, hydrogen_scores, charge_scores, is_in_ring_scores,\
                        is_aromatic_scores, chirality_scores = out
                        if model.property_type is not None:
                            property_scores = out[-1]

                    node_num_classes = node_scores.shape[-1]
                    edge_num_classes = edge_scores.shape[-1]
                    hydrogen_num_classes = hydrogen_scores.shape[-1]
                    charge_num_classes = charge_scores.shape[-1]
                    is_in_ring_num_classes = is_in_ring_scores.shape[-1]
                    is_aromatic_num_classes = is_aromatic_scores.shape[-1]
                    chirality_num_classes = chirality_scores.shape[-1]

                    node_scores, target_nodes, node_target_types = get_only_target_info(
                        node_scores, original_node_inds,
                        node_target_inds_vector, node_num_classes,
                        node_target_types)
                    edge_scores, target_adj_mats, edge_target_types = get_only_target_info(
                        edge_scores, original_adj_mats,
                        edge_target_coords_matrix, edge_num_classes,
                        edge_target_types)

                    if params.embed_hs is True:
                        hydrogen_scores, target_hydrogens, hydrogen_target_types = get_only_target_info(
                            hydrogen_scores, original_hydrogens,
                            hydrogen_target_inds_vector, hydrogen_num_classes,
                            hydrogen_target_types)
                        charge_scores, target_charge, charge_target_types = get_only_target_info(
                            charge_scores, orig_charge,
                            charge_target_inds_vector, charge_num_classes,
                            charge_target_types)
                        is_in_ring_scores, target_is_in_ring, is_in_ring_target_types = get_only_target_info(
                            is_in_ring_scores, orig_is_in_ring,
                            is_in_ring_target_inds_vector,
                            is_in_ring_num_classes, is_in_ring_target_types)
                        is_aromatic_scores, target_is_aromatic, is_aromatic_target_types = get_only_target_info(
                            is_aromatic_scores, orig_is_aromatic,
                            is_aromatic_target_inds_vector,
                            is_aromatic_num_classes, is_aromatic_target_types)
                        chirality_scores, target_chirality, chirality_target_types = get_only_target_info(
                            chirality_scores, orig_chirality,
                            chirality_target_inds_vector,
                            chirality_num_classes, chirality_target_types)
                    num_data_points += batch_size

                    losses = []
                    if params.target_data_structs in [
                            'nodes', 'both', 'random'
                    ] and len(target_nodes) > 0:
                        weight = get_loss_weights(target_nodes, node_scores,
                                                  params.equalise,
                                                  params.local_cpu)
                        iter_node_loss = F.cross_entropy(node_scores,
                                                         target_nodes,
                                                         weight=weight,
                                                         reduction='sum')
                        losses.append(iter_node_loss)
                        node_loss += iter_node_loss
                        nodes_correct, mask_nodes_correct, replace_nodes_correct, recon_nodes_correct, total_mask_nodes, \
                        total_replace_nodes, total_recon_nodes, total_nodes = update_ds_stats(node_scores, target_nodes,
                            node_target_types, nodes_correct, mask_nodes_correct, replace_nodes_correct,
                            recon_nodes_correct, total_mask_nodes, total_replace_nodes, total_recon_nodes, total_nodes)
                        total_noncarbon_nodes += (target_nodes != 1).sum()
                        total_carbon_nodes += (target_nodes == 1).sum()
                        node_preds = torch.argmax(F.softmax(node_scores, -1),
                                                  dim=-1)
                        noncarbon_nodes_correct += torch.mul(
                            (node_preds == target_nodes),
                            (target_nodes != 1)).sum()
                        carbon_nodes_correct += torch.mul(
                            (node_preds == target_nodes),
                            (target_nodes == 1)).sum()
                        if params.check_pred_validity is True:
                            init_nodes[node_target_inds_vector] = node_preds

                        def val_node_property_loss_computation(
                                targets, scores, loss, binary=False):
                            weight = get_loss_weights(targets, scores,
                                                      params.equalise,
                                                      params.local_cpu)
                            if binary is True:
                                iter_loss = F.binary_cross_entropy_with_logits(
                                    scores,
                                    targets,
                                    weight=weight,
                                    reduction='sum')
                            else:
                                iter_loss = F.cross_entropy(scores,
                                                            targets,
                                                            weight=weight,
                                                            reduction='sum')
                            losses.append(iter_loss)
                            loss += iter_loss
                            return loss

                        if params.embed_hs is True:
                            hydrogen_loss = val_node_property_loss_computation(
                                target_hydrogens, hydrogen_scores,
                                hydrogen_loss)
                            hydrogens_correct, mask_hydrogens_correct, replace_hydrogens_correct, recon_hydrogens_correct,\
                            total_mask_hydrogens, total_replace_hydrogens, total_recon_hydrogens, total_hydrogens =\
                                update_ds_stats(hydrogen_scores, target_hydrogens, hydrogen_target_types, hydrogens_correct,
                                mask_hydrogens_correct, replace_hydrogens_correct, recon_hydrogens_correct,
                                total_mask_hydrogens, total_replace_hydrogens, total_recon_hydrogens, total_hydrogens)
                            charge_loss = val_node_property_loss_computation(
                                target_charge, charge_scores, charge_loss)
                            is_in_ring_loss = val_node_property_loss_computation(
                                target_is_in_ring, is_in_ring_scores,
                                is_in_ring_loss)
                            is_aromatic_loss = val_node_property_loss_computation(
                                target_is_aromatic, is_aromatic_scores,
                                is_aromatic_loss)
                            chirality_loss = val_node_property_loss_computation(
                                target_chirality, chirality_scores,
                                chirality_loss)

                    if params.target_data_structs in [
                            'edges', 'both', 'random'
                    ] and len(target_adj_mats) > 0:
                        weight = get_loss_weights(target_adj_mats, edge_scores,
                                                  params.equalise,
                                                  params.local_cpu)
                        iter_edge_loss = F.cross_entropy(edge_scores,
                                                         target_adj_mats,
                                                         weight=weight,
                                                         reduction='sum')
                        losses.append(iter_edge_loss)
                        edge_loss += iter_edge_loss
                        edges_correct, mask_edges_correct, replace_edges_correct, recon_edges_correct, total_mask_edges, \
                        total_replace_edges, total_recon_edges, total_edges = update_ds_stats(edge_scores, target_adj_mats,
                            edge_target_types, edges_correct, mask_edges_correct, replace_edges_correct,
                            recon_edges_correct, total_mask_edges, total_replace_edges, total_recon_edges, total_edges)
                        total_edges_present += (target_adj_mats != 0).sum()
                        total_no_edges += (target_adj_mats == 0).sum()
                        edge_preds = torch.argmax(F.softmax(edge_scores, -1),
                                                  dim=-1)
                        edge_present_correct += torch.mul(
                            (edge_preds == target_adj_mats),
                            (target_adj_mats != 0)).sum()
                        no_edge_correct += torch.mul(
                            (edge_preds == target_adj_mats),
                            (target_adj_mats == 0)).sum()

                        distributions += get_all_result_distributions(
                            edge_preds, target_adj_mats, edge_target_types,
                            [1, 2, 3], params.num_edge_types)
                        if params.check_pred_validity is True:
                            init_edges[edge_target_coords_matrix] = edge_preds

                    if params.property_type is not None:
                        iter_property_loss = model.property_loss(
                            property_scores, properties)
                        losses.append(iter_property_loss)
                        property_loss += iter_property_loss

                    loss = sum(losses).cpu().item()
                    val_loss += loss

                    if params.check_pred_validity is True:
                        is_valid_list, is_connected_list = check_validity(
                            init_nodes, init_edges, is_valid_list,
                            is_connected_list)

                if params.property_type is not None:
                    avg_property_loss = float(property_loss) / float(
                        num_data_points)
                if params.loss_normalisation_type == 'by_total':
                    if params.embed_hs is True:
                        num_components += (total_nodes * 5)
                    num_components = float(total_nodes) + float(total_edges)
                    avg_val_loss = float(val_loss) / float(num_components)
                    if params.target_data_structs in [
                            'nodes', 'both', 'random'
                    ] and total_nodes > 0:
                        avg_node_loss = float(node_loss) / float(
                            num_components)
                        if params.embed_hs is True:
                            avg_hydrogen_loss = float(
                                hydrogen_loss) / num_components
                            avg_charge_loss = float(
                                charge_loss) / num_components
                            avg_is_in_ring_loss = float(
                                is_in_ring_loss) / num_components
                            avg_is_aromatic_loss = float(
                                is_aromatic_loss) / num_components
                            avg_chirality_loss = float(
                                chirality_loss) / num_components
                    if params.target_data_structs in [
                            'edges', 'both', 'random'
                    ] and total_edges > 0:
                        avg_edge_loss = float(edge_loss) / float(
                            num_components)
                elif params.loss_normalisation_type == 'by_component':
                    avg_val_loss = 0
                    if params.target_data_structs in [
                            'nodes', 'both', 'random'
                    ] and total_nodes > 0:
                        avg_node_loss = float(node_loss) / float(total_nodes)
                        avg_val_loss += avg_node_loss
                        if params.embed_hs is True:
                            avg_hydrogen_loss = float(hydrogen_loss) / float(
                                total_hydrogens)
                            avg_charge_loss = float(charge_loss) / float(
                                total_nodes)
                            avg_is_in_ring_loss = float(
                                is_in_ring_loss) / float(total_nodes)
                            avg_is_aromatic_loss = float(
                                is_aromatic_loss) / float(total_nodes)
                            avg_chirality_loss = float(chirality_loss) / float(
                                total_nodes)
                            avg_val_loss += avg_hydrogen_loss + avg_charge_loss + avg_is_in_ring_loss + \
                                            avg_is_aromatic_loss + avg_chirality_loss
                    if params.target_data_structs in [
                            'edges', 'both', 'random'
                    ] and total_edges > 0:
                        avg_edge_loss = float(edge_loss) / float(total_edges)
                        avg_val_loss += avg_edge_loss
                    if params.property_type is not None:
                        avg_val_loss += avg_property_loss
                logger.info(
                    'Average validation loss: {0:.2f}'.format(avg_val_loss))
                val_iter = total_iter // params.val_after

                if params.check_pred_validity is True:
                    percent_valid = np.mean(is_valid_list) * 100
                    valid_percent_connected = np.mean(is_connected_list) * 100
                    logger.info('Percent valid: {}%'.format(percent_valid))
                    logger.info(
                        'Percent of valid molecules connected: {}%'.format(
                            valid_percent_connected))

                results_dict = {'loss': avg_val_loss}

                if params.target_data_structs in ['nodes', 'both', 'random'
                                                  ] and total_nodes > 0:
                    nodes_acc, noncarbon_nodes_acc, carbon_nodes_acc, mask_node_acc, replace_node_acc, recon_node_acc =\
                        accuracies_from_totals({nodes_correct: total_nodes, noncarbon_nodes_correct: total_noncarbon_nodes,
                        carbon_nodes_correct: total_carbon_nodes, mask_nodes_correct: total_mask_nodes,
                        replace_nodes_correct: total_replace_nodes, recon_nodes_correct: total_recon_nodes})
                    results_dict.update({
                        'nodes_acc': nodes_acc,
                        'carbon_nodes_acc': carbon_nodes_acc,
                        'noncarbon_nodes_acc': noncarbon_nodes_acc,
                        'mask_node_acc': mask_node_acc,
                        'replace_node_acc': replace_node_acc,
                        'recon_node_acc': recon_node_acc,
                        'node_loss': avg_node_loss
                    })
                    logger.info('Node loss: {0:.2f}'.format(avg_node_loss))
                    logger.info('Node accuracy: {0:.2f}%'.format(nodes_acc))
                    logger.info('Non-Carbon Node accuracy: {0:.2f}%'.format(
                        noncarbon_nodes_acc))
                    logger.info('Carbon Node accuracy: {0:.2f}%'.format(
                        carbon_nodes_acc))
                    logger.info(
                        'mask_node_acc {:.2f}, replace_node_acc {:.2f}, recon_node_acc {:.2f}'
                        .format(mask_node_acc, replace_node_acc,
                                recon_node_acc))
                    if params.embed_hs is True:
                        hydrogen_acc, mask_hydrogen_acc, replace_hydrogen_acc, recon_hydrogen_acc = accuracies_from_totals(
                            {
                                hydrogens_correct: total_hydrogens,
                                mask_hydrogens_correct: total_mask_hydrogens,
                                replace_hydrogens_correct:
                                total_replace_hydrogens,
                                recon_hydrogens_correct: total_recon_hydrogens
                            })
                        results_dict.update({
                            'hydrogen_acc': hydrogen_acc,
                            'mask_hydrogen_acc': mask_hydrogen_acc,
                            'replace_hydrogen_acc': replace_hydrogen_acc,
                            'recon_hydrogen_acc': recon_hydrogen_acc,
                            'hydrogen_loss': avg_hydrogen_loss
                        })
                        logger.info(
                            'Hydrogen loss: {0:.2f}'.format(avg_hydrogen_loss))
                        logger.info(
                            'Hydrogen accuracy: {0:.2f}%'.format(hydrogen_acc))
                        logger.info(
                            'mask_hydrogen_acc {:.2f}, replace_hydrogen_acc {:.2f}, recon_hydrogen_acc {:.2f}'
                            .format(mask_hydrogen_acc, replace_hydrogen_acc,
                                    recon_hydrogen_acc))

                        results_dict.update({
                            'charge_loss':
                            avg_charge_loss,
                            'is_in_ring_loss':
                            avg_is_in_ring_loss,
                            'is_aromatic_loss':
                            avg_is_aromatic_loss,
                            'chirality_loss':
                            avg_chirality_loss
                        })
                        logger.info(
                            'Charge loss: {0:.2f}'.format(avg_charge_loss))
                        logger.info('Is in ring loss: {0:.2f}'.format(
                            avg_is_in_ring_loss))
                        logger.info('Is aromatic loss: {0:.2f}'.format(
                            avg_is_aromatic_loss))
                        logger.info('Chirality loss: {0:.2f}'.format(
                            avg_chirality_loss))

                if params.target_data_structs in ['edges', 'both', 'random'
                                                  ] and total_edges > 0:
                    edges_acc, no_edge_acc, edge_present_acc, mask_edge_acc, replace_edge_acc, recon_edge_acc =\
                        accuracies_from_totals({edges_correct: total_edges, no_edge_correct: total_no_edges,
                        edge_present_correct: total_edges_present, mask_edges_correct: total_mask_edges,
                        replace_edges_correct: total_replace_edges, recon_edges_correct: total_recon_edges})
                    results_dict.update({
                        'edges_acc': edges_acc,
                        'edge_present_acc': edge_present_acc,
                        'no_edge_acc': no_edge_acc,
                        'mask_edge_acc': mask_edge_acc,
                        'replace_edge_acc': replace_edge_acc,
                        'recon_edge_acc': recon_edge_acc,
                        'edge_loss': avg_edge_loss
                    })
                    logger.info('Edge loss: {0:.2f}'.format(avg_edge_loss))
                    logger.info('Edge accuracy: {0:.2f}%'.format(edges_acc))
                    logger.info('Edge present accuracy: {0:.2f}%'.format(
                        edge_present_acc))
                    logger.info(
                        'No edge accuracy: {0:.2f}%'.format(no_edge_acc))
                    logger.info(
                        " mask_edge_acc {:.2f}, replace_edge_acc {:.2f}, recon_edge_acc {:.2f}"
                        .format(mask_edge_acc, replace_edge_acc,
                                recon_edge_acc))
                    logger.info('\n')
                    for i in range(distributions.shape[0]):
                        for j in range(distributions.shape[1]):
                            logger.info(dist_names[i][j] + ':\t' +
                                        str(distributions[i, j, :]))
                        logger.info('\n')

                if params.property_type is not None:
                    results_dict.update({'property_loss': avg_property_loss})
                    logger.info(
                        'Property loss: {0:.2f}%'.format(avg_property_loss))

                if params.run_perturbation_analysis is True:
                    preds = run_perturbations(perturbation_loader, model,
                                              params.embed_hs, params.max_hs,
                                              params.perturbation_batch_size,
                                              params.local_cpu)
                    stats, percentages = aggregate_stats(preds)
                    logger.info('Percentages: {}'.format(
                        pp.pformat(percentages)))

                if params.tensorboard:
                    if params.run_perturbation_analysis is True:
                        writer.add_scalars('dev/perturbation_stability', {
                            str(key): val
                            for key, val in percentages.items()
                        }, val_iter)
                    write_tensorboard(writer, 'dev', results_dict, val_iter)

                if params.gen_num_samples > 0:
                    calculate_gen_benchmarks(generator, params.gen_num_samples,
                                             training_smiles, logger)

                logger.info("----------------------------------")

                model_state_dict = model.module.state_dict(
                ) if torch.cuda.device_count() > 1 else model.state_dict()

                best_loss = save_checkpoints(total_iter, avg_val_loss,
                                             best_loss, model_state_dict,
                                             opt.state_dict(), exp_path,
                                             logger, params.no_save,
                                             params.save_all)

                # Reset random seed
                set_seed_if(params.seed)
                logger.info('Validation complete')
            total_iter += 1
Пример #7
0
    def __getitem__(self, index):
        set_seed_if(self.seed)
        # Create arrays with node empty index where node does not exist, node index everywhere else
        unpadded_node_inds = self.mol_nodeinds[index]

        max_nodes = self.max_nodes if self.pad is True else len(
            unpadded_node_inds)
        if self.binary_classification is True:
            unpadded_node_inds[np.where(unpadded_node_inds != 1)] = 0

        node_inds = np.ones(max_nodes) * self.node_empty_index
        num_nodes = len(unpadded_node_inds)
        node_inds[:num_nodes] = unpadded_node_inds

        # Create matrices with edge empty index where edge would connect non-existent node, edge index everywhere else
        unpadded_adj_mat = self.adj_mats[index]
        if self.binary_classification is True:
            unpadded_adj_mat[np.where(unpadded_adj_mat != 0)] = 1

        assert (check_symmetric(unpadded_adj_mat.astype(int))
                )  # make sure bond matrix is symmetric
        adj_mat = np.ones((max_nodes, max_nodes)) * self.edge_empty_index
        adj_mat[:num_nodes, :num_nodes] = unpadded_adj_mat

        # Create masks with 0 where node does not exist or edge would connect non-existent node, 1 everywhere else
        node_mask = torch.zeros(max_nodes)
        node_mask[:num_nodes] = 1
        edge_mask = torch.zeros((max_nodes, max_nodes))
        edge_mask[:num_nodes, :num_nodes] = 1
        edge_mask[np.arange(num_nodes), np.arange(num_nodes)] = 0

        # *** Initialise nodes ***
        unpadded_num_hs = self.num_hs[index]
        unpadded_charge = self.charge[index] + abs(self.min_charge)
        unpadded_is_in_ring = self.is_in_ring[index]
        unpadded_is_aromatic = self.is_aromatic[index]
        unpadded_chirality = self.chirality[index]
        init_hydrogens, orig_hydrogens = self.get_orig_and_init(
            unpadded_num_hs, max_nodes, num_nodes, self.h_empty_index)
        init_charge, orig_charge = self.get_orig_and_init(
            unpadded_charge, max_nodes, num_nodes, abs(self.min_charge))
        init_is_in_ring, orig_is_in_ring = self.get_orig_and_init(
            unpadded_is_in_ring, max_nodes, num_nodes, 0)
        init_is_aromatic, orig_is_aromatic = self.get_orig_and_init(
            unpadded_is_aromatic, max_nodes, num_nodes, 0)
        init_chirality, orig_chirality = self.get_orig_and_init(
            unpadded_chirality, max_nodes, num_nodes, 0)

        init_nodes = np.copy(node_inds)
        node_target_types = np.zeros(node_mask.shape)
        hydrogen_target_types = np.zeros(node_mask.shape)
        charge_target_types = np.zeros(node_mask.shape)
        is_in_ring_target_types = np.zeros(node_mask.shape)
        is_aromatic_target_types = np.zeros(node_mask.shape)
        chirality_target_types = np.zeros(node_mask.shape)

        if self.mask_all_ring_properties is True:
            init_is_in_ring[:] = self.is_in_ring_mask_index
            init_is_aromatic[:] = self.is_aromatic_mask_index

        # *** Initialise edges ***

        # Get (row, column) coordinates of upper triangular part of adjacency matrix excluding diagonal. These are
        # potential target edges. Also get values of the matrix at these indices.
        init_edges = np.copy(adj_mat)
        edge_coords, edge_vals = [], []
        for i in range(num_nodes):
            for j in range(i + 1, num_nodes):
                edge_coords.append((i, j))
                edge_vals.append(unpadded_adj_mat[i, j])
        edge_coords = np.array(edge_coords)
        unpadded_edge_inds = np.array(edge_vals)
        num_edges = len(edge_coords)

        edge_target_types = np.zeros(edge_mask.shape)
        # *** return values ***

        if self.do_not_corrupt is False:
            init_nodes, node_target_types, num_nodes, init_hydrogens, hydrogen_target_types, init_charge, \
            charge_target_types, init_is_in_ring, is_in_ring_target_types, init_is_aromatic, \
            is_aromatic_target_types, init_chirality, chirality_target_types, init_edges, \
            edge_target_types, num_edges = self.corrupt_graph(init_nodes, unpadded_node_inds, node_target_types,
                      num_nodes, init_hydrogens, unpadded_num_hs, hydrogen_target_types, init_charge, unpadded_charge,
                      charge_target_types, init_is_in_ring, unpadded_is_in_ring, is_in_ring_target_types,
                      init_is_aromatic, unpadded_is_aromatic, is_aromatic_target_types,
                      init_chirality, unpadded_chirality, chirality_target_types,
                      init_edges, edge_coords, unpadded_edge_inds,
                      edge_target_types, num_edges)

        if self.no_edge_present_type == 'zeros':
            edge_mask[np.where(init_edges == 0)] = 0

        ret_list = [
            torch.LongTensor(init_nodes),
            torch.LongTensor(init_edges),
            torch.LongTensor(node_inds),
            torch.LongTensor(adj_mat), node_mask, edge_mask,
            node_target_types.astype(np.int8),
            edge_target_types.astype(np.int8),
            torch.LongTensor(init_hydrogens),
            torch.LongTensor(orig_hydrogens),
            torch.LongTensor(init_charge),
            torch.LongTensor(orig_charge),
            torch.LongTensor(init_is_in_ring),
            torch.LongTensor(orig_is_in_ring),
            torch.LongTensor(init_is_aromatic),
            torch.LongTensor(orig_is_aromatic),
            torch.LongTensor(init_chirality),
            torch.LongTensor(orig_chirality),
            hydrogen_target_types.astype(np.int8),
            charge_target_types.astype(np.int8),
            is_in_ring_target_types.astype(np.int8),
            is_aromatic_target_types.astype(np.int8),
            chirality_target_types.astype(np.int8)
        ]

        return ret_list