コード例 #1
0
def test(rank, model):
    idxs = test_idxs[rank]
    print("pid: {}".format(os.getpid()))
    result_summary = None

    molecules = MoleculeDatasetCSV(
        csv_file=args.D,
        corrupt_path=args.c, target=args.target, scaling=args.scale)

    loss_fn = get_loss(args)

    start_time = time.clock()
    for idx in idxs:
        molecule_loader = DataLoader(molecules, batch_size=1, sampler=SubsetRandomSampler([idx]),
                                     collate_fn=collate_fn, num_workers=0)
        for batch in molecule_loader:
            val_dict = model.validation_step(batch=batch, loss_fn=loss_fn)

            result_summary = pd.concat([result_summary, pd.DataFrame(({"idx": idx,
                                                                "loss": val_dict["batch_dict"][key]["loss"][0],
                                                                "pred": val_dict["batch_dict"][key]["pred"][0],
                                                                "true": val_dict["batch_dict"][key]["true"][0]}
                                                                for key in val_dict["batch_dict"].keys()), index=[0])],
                                                                axis=0)
        if debug == True:
            break
    end_time = time.clock()

    print("evaluation finished in {} cpu seconds. writing results...".format((end_time-start_time)))

    # convert the pred and true columns to numpy objects...have some messy shapes/etc so clean this up here
    result_summary.pred = result_summary.pred.apply(lambda x: x.data.numpy())
    result_summary.true = result_summary.true.apply(lambda x: x.data.numpy())

    result_summary = result_summary.reset_index()

    if debug == True:
        print(result_summary.head())

    else:
        result_summary.to_csv(output_path+"/test_results_{}.csv".format(rank))
コード例 #2
0
 def node_property_computations(name,
                                scores,
                                targets,
                                target_types,
                                binary=False):
     loss = get_loss(targets,
                     scores,
                     params.equalise,
                     params.loss_normalisation_type,
                     num_components,
                     params.local_cpu,
                     binary=binary)
     losses.append(loss)
     acc, mask_acc, replace_acc, recon_acc = get_ds_stats(
         scores, targets, target_types)
     results_dict.update({
         '{}_acc'.format(name): acc,
         'mask_{}_acc'.format(name): mask_acc,
         'replace_{}_acc'.format(name): replace_acc,
         'recon_{}_acc'.format(name): recon_acc,
         '{}_loss'.format(name): loss
     })
     metrics_to_print.extend(['{}_loss'.format(name)])
コード例 #3
0
def pretrain(cfg):
    print(cfg.pretty())
    pretrain_config_validator(cfg)
    fix_seed(cfg.seed)

    controller = load_pretrained_weights(
        NAO(**cfg.controller).to(0), cfg.pretrained_model_path)
    models = {'trunk': controller}
    dataset = get_dataset(seed=cfg.seed, **cfg.dataset)
    optimizers = {
        'trunk_optimizer':
        get_optimizer(parameters=models['trunk'].parameters(), **cfg.optimizer)
    }
    lr_schedulers = {
        'trunk_scheduler_by_iteration':
        get_scheduler(optimizer=optimizers['trunk_optimizer'], **cfg.scheduler)
    }
    loss_funcs = {
        'reconstruction_loss': torch.nn.NLLLoss(),
        'metric_loss': get_loss(**cfg.loss)
    }
    mining_funcs = {"tuple_miner": get_miner(**cfg.miner)}
    visualizers = [umap.UMAP(**params) for params in cfg.visualizers]
    end_of_iteration_hook = TensorboardHook(visualizers).end_of_iteration_hook
    end_of_epoch_hook = ModelSaverHook().end_of_epoch_hook
    get_trainer(
        models=models,
        optimizers=optimizers,
        lr_schedulers=lr_schedulers,
        loss_funcs=loss_funcs,
        mining_funcs=mining_funcs,
        dataset=dataset,
        end_of_iteration_hook=end_of_iteration_hook,
        end_of_epoch_hook=end_of_epoch_hook,
        **cfg.trainer,
    ).train()
コード例 #4
0
ファイル: train.py プロジェクト: nyu-dl/dl4chem-mgm
def main(params):
    model_cls = MODELS_DICT[params.model_name]
    model = model_cls(params)
    params, model, opt, scheduler, train_data, train_loader, val_dataset, val_loader, perturbation_loader, generator,\
    index_method, exp_path, training_smiles, pp, logger, writer, best_loss,\
    total_iter, grad_accum_iters = setup_data_and_model(params, model)

    for epoch in range(1, params.num_epochs+1):
        print('Starting epoch {}'.format(epoch), flush=True)
        for train_batch in train_loader:
            if total_iter % 100 == 0: print(total_iter, flush=True)
            if total_iter == params.max_steps:
                logger.info('Done training')
                break
            model.train()
            if hasattr(model, 'seq_model'): model.seq_model.eval()

            # Training step
            batch_init_graph, batch_orig_graph, batch_target_type_graph, _, \
            graph_properties, binary_graph_properties = train_batch
            # init is what goes into model
            # original are uncorrupted data (used for comparison with prediction in loss calculation)
            # masks are 1 in places corresponding to nodes that exist, 0 in other places (which are empty/padded)
            # target_types are 1 in places to mask, 2 in places to replace, 3 in places to reconstruct, 0 in places not to predict

            if params.local_cpu is False:
                batch_init_graph = batch_init_graph.to(torch.device('cuda:0'))
                batch_orig_graph = batch_orig_graph.to(torch.device('cuda:0'))
                dct_to_cuda_inplace(graph_properties)
                if binary_graph_properties: binary_graph_properties = binary_graph_properties.cuda()

            if grad_accum_iters % params.grad_accum_iters == 0:
                opt.zero_grad()

            _, batch_scores_graph, graph_property_scores = model(batch_init_graph, graph_properties,
                binary_graph_properties)

            num_components = sum([(v != 0).sum().numpy() for _, v in batch_target_type_graph.ndata.items()]) + \
                             sum([(v != 0).sum().numpy() for _, v in batch_target_type_graph.edata.items()])

            # calculate score
            losses = []
            results_dict = {}
            metrics_to_print = []
            if params.target_data_structs in ['nodes', 'both', 'random'] and \
                    batch_target_type_graph.ndata['node_type'].sum() > 0:
                node_losses = {}
                for name, target_type in batch_target_type_graph.ndata.items():
                    node_losses[name] = get_loss(
                        batch_orig_graph.ndata[name][target_type.numpy() != 0],
                        batch_scores_graph.ndata[name][target_type.numpy() != 0],
                        params.equalise, params.loss_normalisation_type, num_components, params.local_cpu)
                losses.extend(node_losses.values())
                results_dict.update(node_losses)
                metrics_to_print.extend(node_losses.keys())

            if params.target_data_structs in ['edges', 'both', 'random'] and \
                    batch_target_type_graph.edata['edge_type'].sum() > 0:
                edge_losses = {}
                for name, target_type in batch_target_type_graph.edata.items():
                    edge_losses[name] = get_loss(
                        batch_orig_graph.edata[name][target_type.numpy() != 0],
                        batch_scores_graph.edata[name][target_type.numpy() != 0],
                        params.equalise, params.loss_normalisation_type, num_components, params.local_cpu)
                losses.extend(edge_losses.values())
                results_dict.update(edge_losses)
                metrics_to_print.extend(edge_losses.keys())

            if params.predict_graph_properties is True:
                graph_property_losses = {}
                for name, scores in graph_property_scores.items():
                    graph_property_loss = normalise_loss(F.mse_loss(scores, graph_properties[name], reduction='sum'),
                                                         len(scores), num_components, params.loss_normalisation_type)
                    graph_property_losses[name] = graph_property_loss
                losses.extend(graph_property_losses.values())
                results_dict.update(graph_property_losses)
                metrics_to_print.extend(graph_property_losses.keys())

            loss = sum(losses)
            if params.no_update is False:
                (loss/params.grad_accum_iters).backward()
                grad_accum_iters += 1
                if grad_accum_iters % params.grad_accum_iters == 0:
                    # clip grad norm
                    if params.clip_grad_norm > -1:
                        torch.nn.utils.clip_grad_norm_(model.parameters(), params.clip_grad_norm)
                    opt.step()
                    grad_accum_iters = 0
                if (total_iter+1) <= params.warm_up_iters or (total_iter+1) % params.lr_decay_interval == 0:
                    scheduler.step(total_iter+1)

            if params.suppress_train_log is False:
                log_string = ''
                for name, value in results_dict.items():
                    if name in metrics_to_print:
                        log_string += ', {} = {:.2f}'.format(name, value)
                log_string = 'total_iter = {0:d}, loss = {1:.2f}'.format(total_iter, loss.cpu().item()) + log_string
                logger.info(log_string)

            if params.target_frac_inc_after is not None and total_iter > 0 and total_iter % params.target_frac_inc_after == 0:
                train_data.node_target_frac = min(train_data.node_target_frac + params.target_frac_inc_amount,
                                                  params.max_target_frac)
                train_data.edge_target_frac = min(train_data.edge_target_frac + params.target_frac_inc_amount,
                                                  params.max_target_frac)
            results_dict.update({'node_target_frac': train_data.node_target_frac})
            results_dict.update({'edge_target_frac': train_data.edge_target_frac})

            if params.tensorboard and total_iter % int(params.log_train_steps) == 0:
                results_dict.update({'loss': loss, 'lr': opt.param_groups[0]['lr']})
                write_tensorboard(writer, 'train', results_dict, total_iter)


            if total_iter > 0 and total_iter % params.val_after == 0:
                logger.info('Validating')
                val_loss, num_data_points = 0, 0
                node_property_losses = {name: 0 for name in train_data.node_property_names}
                edge_property_losses = {name: 0 for name in train_data.edge_property_names}
                node_property_num_components = {name: 0 for name in train_data.node_property_names}
                edge_property_num_components = {name: 0 for name in train_data.edge_property_names}
                graph_property_losses = {name: 0 for name in params.graph_property_names}
                model.eval()
                set_seed_if(params.seed)
                for batch_init_graph, batch_orig_graph, batch_target_type_graph, _, \
                    graph_properties, binary_graph_properties in val_loader:

                    if params.local_cpu is False:
                        batch_init_graph = batch_init_graph.to(torch.device('cuda:0'))
                        batch_orig_graph = batch_orig_graph.to(torch.device('cuda:0'))
                        dct_to_cuda_inplace(graph_properties)
                        if binary_graph_properties: binary_graph_properties = binary_graph_properties.cuda()

                    with torch.no_grad():
                        _, batch_scores_graph, graph_property_scores = model(batch_init_graph, graph_properties,
                                                                             binary_graph_properties)

                    num_data_points += float(batch_orig_graph.batch_size)
                    losses = []
                    if params.target_data_structs in ['nodes', 'both', 'random'] and \
                            batch_target_type_graph.ndata['node_type'].sum() > 0:
                        for name, target_type in batch_target_type_graph.ndata.items():
                            iter_node_property_loss = F.cross_entropy(
                                batch_scores_graph.ndata[name][target_type.numpy() != 0],
                                batch_orig_graph.ndata[name][target_type.numpy() != 0], reduction='sum').cpu().item()
                            node_property_losses[name] += iter_node_property_loss
                            losses.append(iter_node_property_loss)
                            node_property_num_components[name] += float((target_type != 0).sum())

                    if params.target_data_structs in ['edges', 'both', 'random'] and \
                            batch_target_type_graph.edata['edge_type'].sum() > 0:
                        for name, target_type in batch_target_type_graph.edata.items():
                            iter_edge_property_loss = F.cross_entropy(
                                batch_scores_graph.edata[name][target_type.numpy() != 0],
                                batch_orig_graph.edata[name][target_type.numpy() != 0], reduction='sum').cpu().item()
                            edge_property_losses[name] += iter_edge_property_loss
                            losses.append(iter_edge_property_loss)
                            edge_property_num_components[name] += float((target_type != 0).sum())

                    if params.predict_graph_properties is True:
                        for name, scores in graph_property_scores.items():
                            iter_graph_property_loss = F.mse_loss(
                                scores, graph_properties[name], reduction='sum').cpu().item()
                            graph_property_losses[name] += iter_graph_property_loss
                            losses.append(iter_graph_property_loss)

                    val_loss += sum(losses)

                avg_node_property_losses, avg_edge_property_losses, avg_graph_property_losses = {}, {}, {}
                if params.loss_normalisation_type == 'by_total':
                    total_num_components = float(sum(node_property_num_components.values()) +
                                                 sum(edge_property_num_components.values()))
                    avg_val_loss = val_loss/total_num_components
                    if params.target_data_structs in ['nodes', 'both', 'random']:
                        for name, loss in node_property_losses.items():
                            avg_node_property_losses[name] = loss/total_num_components
                    if params.target_data_structs in ['edges', 'both', 'random']:
                        for name, loss in edge_property_losses.items():
                            avg_edge_property_losses[name] = loss/total_num_components
                    if params.predict_graph_properties is True:
                        for name, loss in graph_property_losses.items():
                            avg_graph_property_losses[name] = loss/total_num_components
                elif params.loss_normalisation_type == 'by_component':
                    avg_val_loss = 0
                    if params.target_data_structs in ['nodes', 'both', 'random']:
                        for name, loss in node_property_losses.items():
                            avg_node_property_losses[name] = loss/node_property_num_components[name]
                        avg_val_loss += sum(avg_node_property_losses.values())
                    if params.target_data_structs in ['edges', 'both', 'random']:
                        for name, loss in edge_property_losses.items():
                            avg_edge_property_losses[name] = loss/edge_property_num_components[name]
                        avg_val_loss += sum(avg_edge_property_losses.values())
                    if params.predict_graph_properties is True:
                        for name, loss in graph_property_losses.items():
                            avg_graph_property_losses[name] = loss/num_data_points
                        avg_val_loss += sum(avg_graph_property_losses.values())
                val_iter = total_iter // params.val_after

                results_dict = {'Validation_loss': avg_val_loss}
                for name, loss in avg_node_property_losses.items():
                    results_dict['{}_loss'.format(name)] = loss
                for name, loss in avg_edge_property_losses.items():
                    results_dict['{}_loss'.format(name)] = loss
                for name, loss in avg_graph_property_losses.items():
                    results_dict['{}_loss'.format(name)] = loss

                for name, loss in results_dict.items():
                    logger.info('{}: {:.2f}'.format(name, loss))

                if params.tensorboard:
                    write_tensorboard(writer, 'Dev', results_dict, val_iter)
                if params.gen_num_samples > 0:
                    calculate_gen_benchmarks(generator, params.gen_num_samples, training_smiles, logger)

                logger.info("----------------------------------")
                model_state_dict = model.module.state_dict() if torch.cuda.device_count() > 1 else model.state_dict()
                best_loss = save_checkpoints(total_iter, avg_val_loss, best_loss, model_state_dict, opt.state_dict(),
                                             exp_path, logger, params.no_save, params.save_all)
                # Reset random seed
                set_seed_if(params.seed)
                logger.info('Validation complete')
            total_iter += 1
コード例 #5
0
def train(zoom_size=4, model="mcnn", dataset="shtu_dataset"):
    """

    :type zoom_size: int
    :type model: str
    :type dataset: str

    """
    # load data
    if dataset == "shtu_dataset":
        print("train data loading..........")
        shanghaitech_dataset = shtu_dataset.ShanghaiTechDataset(
            mode="train", zoom_size=zoom_size)
        tech_loader = torch.utils.data.DataLoader(shanghaitech_dataset,
                                                  batch_size=1,
                                                  shuffle=True,
                                                  num_workers=8)
        print("test data loading............")
        test_data = shtu_dataset.ShanghaiTechDataset(mode="test")
        test_loader = torch.utils.data.DataLoader(test_data,
                                                  batch_size=1,
                                                  shuffle=False)
    elif dataset == "mall_dataset":
        print("train data loading..........")
        mall_data = mall_dataset.MallDataset(
            img_path="./mall_dataset/frames/",
            point_path="./mall_dataset/mall_gt.mat",
            zoom_size=zoom_size)
        tech_loader = torch.utils.data.DataLoader(mall_data,
                                                  batch_size=6,
                                                  shuffle=True,
                                                  num_workers=6)
        print("test data loading............")
        mall_test_data = mall_data
        test_loader = torch.utils.data.DataLoader(mall_test_data,
                                                  batch_size=6,
                                                  shuffle=False,
                                                  num_workers=6)
    number = len(tech_loader)
    print("init net...........")
    net = models[model]
    net = net.train().to(DEVICE)
    print("init optimizer..........")
    # optimizer = optim.Adam(filter(lambda p:p.requires_grad, net.parameters()), lr=learning_rate)
    # optimizer = optim.SGD(filter(lambda p:p.requires_grad, net.parameters()), lr=learning_rate, momentum=0.9)
    optimizer = optim.SGD(net.parameters(),
                          lr=1e-7,
                          momentum=0.95,
                          weight_decay=5 * 1e-4)
    print("start to train net.....")
    sum_loss = 0
    step = 0
    result = []
    epoch_index = -1
    min_mae = sys.maxsize
    # for each 2 epochs in 2000 get and results to test
    # and keep the best one
    for epoch in range(2000):
        for input, ground_truth in iter(tech_loader):
            input = input.float().to(DEVICE)
            ground_truth = ground_truth.float().to(DEVICE)
            output = net(input)
            loss = utils.get_loss(output, ground_truth)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            sum_loss += float(loss)
            step += 1
            if step % (number // 2) == 0:
                print("{0} patches are done, loss: ".format(step),
                      sum_loss / (number // 2))
                sum_loss = 0

        if epoch % 2 == 0:
            sum_mae = 0.0
            sum_mse = 0.0
            for input, ground_truth in iter(test_loader):
                input = input.float().to(DEVICE)
                ground_truth = ground_truth.float().to(DEVICE)
                output = net(input)
                mae, mse = utils.get_test_loss(output, ground_truth)
                sum_mae += float(mae)
                sum_mse += float(mse)
            if sum_mae / len(test_loader) < min_mae:
                min_mae = sum_mae / len(test_loader)
                min_mse = sum_mse / len(test_loader)
                result.append([min_mae, math.sqrt(min_mse)])
                torch.save(net.state_dict(), "./results/mall_result/mcnn.pkl")
            print("best_mae:%.1f, best_mse:%.1f" %
                  (min_mae, math.sqrt(min_mse)))
            epoch_index += 2
            print("{0} epoches / 2000 epoches are done".format(epoch_index))
        step = 0
    result = np.asarray(result)
    try:
        np.save("./results/mall_result/mcnn.npy", result)
    except IOError:
        os.mkdir("./results")
        np.save("./results/mall_result/mcnn.npy", result)
    print("save successful!")
コード例 #6
0
def main(params):
    model_cls = MODELS_DICT[params.model_name]
    model = model_cls(params)
    params, model, opt, scheduler, train_data, train_loader, val_dataset, val_loader, perturbation_loader, generator,\
    index_method, exp_path, training_smiles, pp, logger, writer, best_loss,\
    total_iter, grad_accum_iters = setup_data_and_model(params, model)

    for epoch in range(1, params.num_epochs + 1):
        print('Starting epoch {}'.format(epoch), flush=True)
        for train_batch in train_loader:
            if total_iter % 100 == 0: print(total_iter, flush=True)
            if total_iter == params.max_steps:
                logger.info('Done training')
                break
            model.train()

            # Training step
            init_nodes, init_edges, original_node_inds, original_adj_mats, node_masks, edge_masks,\
            node_target_types, edge_target_types, init_hydrogens, original_hydrogens, init_charge,\
            orig_charge, init_is_in_ring, orig_is_in_ring, init_is_aromatic, orig_is_aromatic, init_chirality,\
            orig_chirality, hydrogen_target_types, charge_target_types, is_in_ring_target_types,\
            is_aromatic_target_types, chirality_target_types = train_batch
            # init is what goes into model
            # target_inds and target_coords are 1 at locations to be predicted, 0 elsewhere
            # target_inds and target_coords are now calculated here rather than in dataloader
            # original are uncorrupted data (used for comparison with prediction in loss calculation)
            # masks are 1 in places corresponding to nodes that exist, 0 in other places (which are empty/padded)
            # target_types are 1 in places to mask, 2 in places to replace, 3 in places to reconstruct, 0 in places not to predict

            node_target_inds_vector = getattr(node_target_types != 0,
                                              index_method)()
            edge_target_coords_matrix = getattr(edge_target_types != 0,
                                                index_method)()
            hydrogen_target_inds_vector = getattr(hydrogen_target_types != 0,
                                                  index_method)()
            charge_target_inds_vector = getattr(charge_target_types != 0,
                                                index_method)()
            is_in_ring_target_inds_vector = getattr(
                is_in_ring_target_types != 0, index_method)()
            is_aromatic_target_inds_vector = getattr(
                is_aromatic_target_types != 0, index_method)()
            chirality_target_inds_vector = getattr(chirality_target_types != 0,
                                                   index_method)()

            if params.local_cpu is False:
                init_nodes = init_nodes.cuda()
                init_edges = init_edges.cuda()
                original_node_inds = original_node_inds.cuda()
                original_adj_mats = original_adj_mats.cuda()
                node_masks = node_masks.cuda()
                edge_masks = edge_masks.cuda()
                node_target_types = node_target_types.cuda()
                edge_target_types = edge_target_types.cuda()
                init_hydrogens = init_hydrogens.cuda()
                original_hydrogens = original_hydrogens.cuda()
                init_charge = init_charge.cuda()
                orig_charge = orig_charge.cuda()
                init_is_in_ring = init_is_in_ring.cuda()
                orig_is_in_ring = orig_is_in_ring.cuda()
                init_is_aromatic = init_is_aromatic.cuda()
                orig_is_aromatic = orig_is_aromatic.cuda()
                init_chirality = init_chirality.cuda()
                orig_chirality = orig_chirality.cuda()
                hydrogen_target_types = hydrogen_target_types.cuda()
                charge_target_types = charge_target_types.cuda()
                is_in_ring_target_types = is_in_ring_target_types.cuda()
                is_aromatic_target_types = is_aromatic_target_types.cuda()
                chirality_target_types = chirality_target_types.cuda()
                if params.property_type is not None:
                    properties = properties.cuda()

            if grad_accum_iters % params.grad_accum_iters == 0:
                opt.zero_grad()

            out = model(init_nodes, init_edges, node_masks, edge_masks,
                        init_hydrogens, init_charge, init_is_in_ring,
                        init_is_aromatic, init_chirality)
            node_scores, edge_scores, hydrogen_scores, charge_scores, is_in_ring_scores, is_aromatic_scores,\
                chirality_scores = out
            node_num_classes = node_scores.shape[-1]
            edge_num_classes = edge_scores.shape[-1]
            hydrogen_num_classes = hydrogen_scores.shape[-1]
            charge_num_classes = charge_scores.shape[-1]
            is_in_ring_num_classes = is_in_ring_scores.shape[-1]
            is_aromatic_num_classes = is_aromatic_scores.shape[-1]
            chirality_num_classes = chirality_scores.shape[-1]

            if model.property_type is not None:
                property_scores = out[-1]

            # slice out target data structures
            node_scores, target_nodes, node_target_types = get_only_target_info(
                node_scores, original_node_inds, node_target_inds_vector,
                node_num_classes, node_target_types)
            edge_scores, target_adj_mats, edge_target_types = get_only_target_info(
                edge_scores, original_adj_mats, edge_target_coords_matrix,
                edge_num_classes, edge_target_types)
            if params.embed_hs is True:
                hydrogen_scores, target_hydrogens, hydrogen_target_types = get_only_target_info(
                    hydrogen_scores, original_hydrogens,
                    hydrogen_target_inds_vector, hydrogen_num_classes,
                    hydrogen_target_types)
                charge_scores, target_charge, charge_target_types = get_only_target_info(
                    charge_scores, orig_charge, charge_target_inds_vector,
                    charge_num_classes, charge_target_types)
                is_in_ring_scores, target_is_in_ring, is_in_ring_target_types = get_only_target_info(
                    is_in_ring_scores, orig_is_in_ring,
                    is_in_ring_target_inds_vector, is_in_ring_num_classes,
                    is_in_ring_target_types)
                is_aromatic_scores, target_is_aromatic, is_aromatic_target_types = get_only_target_info(
                    is_aromatic_scores, orig_is_aromatic,
                    is_aromatic_target_inds_vector, is_aromatic_num_classes,
                    is_aromatic_target_types)
                chirality_scores, target_chirality, chirality_target_types = get_only_target_info(
                    chirality_scores, orig_chirality,
                    chirality_target_inds_vector, chirality_num_classes,
                    chirality_target_types)

            num_components = len(target_nodes) + len(target_adj_mats)
            if params.embed_hs is True: num_components += len(target_hydrogens)
            # calculate score
            losses = []
            results_dict = {}
            metrics_to_print = []
            if params.target_data_structs in ['nodes', 'both', 'random'
                                              ] and len(target_nodes) > 0:
                node_loss = get_loss(target_nodes, node_scores,
                                     params.equalise,
                                     params.loss_normalisation_type,
                                     num_components, params.local_cpu)
                losses.append(node_loss)
                node_preds = torch.argmax(F.softmax(node_scores, -1), dim=-1)
                nodes_acc, mask_node_acc, replace_node_acc, recon_node_acc = get_ds_stats(
                    node_scores, target_nodes, node_target_types)
                carbon_nodes_correct, noncarbon_nodes_correct, carbon_nodes_acc, noncarbon_nodes_acc = \
                    get_majority_and_minority_stats(node_preds, target_nodes, 1)
                results_dict.update({
                    'nodes_acc': nodes_acc,
                    'carbon_nodes_acc': carbon_nodes_acc,
                    'noncarbon_nodes_acc': noncarbon_nodes_acc,
                    'mask_node_acc': mask_node_acc,
                    'replace_node_acc': replace_node_acc,
                    'recon_node_acc': recon_node_acc,
                    'node_loss': node_loss
                })
                metrics_to_print.extend([
                    'node_loss', 'nodes_acc', 'carbon_nodes_acc',
                    'noncarbon_nodes_acc'
                ])

                def node_property_computations(name,
                                               scores,
                                               targets,
                                               target_types,
                                               binary=False):
                    loss = get_loss(targets,
                                    scores,
                                    params.equalise,
                                    params.loss_normalisation_type,
                                    num_components,
                                    params.local_cpu,
                                    binary=binary)
                    losses.append(loss)
                    acc, mask_acc, replace_acc, recon_acc = get_ds_stats(
                        scores, targets, target_types)
                    results_dict.update({
                        '{}_acc'.format(name): acc,
                        'mask_{}_acc'.format(name): mask_acc,
                        'replace_{}_acc'.format(name): replace_acc,
                        'recon_{}_acc'.format(name): recon_acc,
                        '{}_loss'.format(name): loss
                    })
                    metrics_to_print.extend(['{}_loss'.format(name)])

                if params.embed_hs is True:
                    node_property_computations('hydrogen', hydrogen_scores,
                                               target_hydrogens,
                                               hydrogen_target_types)
                    node_property_computations('charge', charge_scores,
                                               target_charge,
                                               charge_target_types)
                    node_property_computations('is_in_ring', is_in_ring_scores,
                                               target_is_in_ring,
                                               is_in_ring_target_types)
                    node_property_computations('is_aromatic',
                                               is_aromatic_scores,
                                               target_is_aromatic,
                                               is_aromatic_target_types)
                    node_property_computations('chirality', chirality_scores,
                                               target_chirality,
                                               chirality_target_types)

            if params.target_data_structs in ['edges', 'both', 'random'
                                              ] and len(target_adj_mats) > 0:
                edge_loss = get_loss(target_adj_mats, edge_scores,
                                     params.equalise,
                                     params.loss_normalisation_type,
                                     num_components, params.local_cpu)
                losses.append(edge_loss)
                edge_preds = torch.argmax(F.softmax(edge_scores, -1), dim=-1)
                edges_acc, mask_edge_acc, replace_edge_acc, recon_edge_acc = get_ds_stats(
                    edge_scores, target_adj_mats, edge_target_types)
                no_edge_correct, edge_present_correct, no_edge_acc, edge_present_acc = \
                    get_majority_and_minority_stats(edge_preds, target_adj_mats, 0)
                results_dict.update({
                    'edges_acc': edges_acc,
                    'edge_present_acc': edge_present_acc,
                    'no_edge_acc': no_edge_acc,
                    'mask_edge_acc': mask_edge_acc,
                    'replace_edge_acc': replace_edge_acc,
                    'recon_edge_acc': recon_edge_acc,
                    'edge_loss': edge_loss
                })
                metrics_to_print.extend([
                    'edge_loss', 'edges_acc', 'edge_present_acc', 'no_edge_acc'
                ])

            if params.property_type is not None:
                property_loss = model.property_loss(
                    property_scores, properties) / params.batch_size
                losses.append(property_loss)
                results_dict.update({'property_loss': property_loss})
                metrics_to_print.extend(['property_loss'])

            loss = sum(losses)
            if params.no_update is False:
                (loss / params.grad_accum_iters).backward()
                grad_accum_iters += 1
                if grad_accum_iters % params.grad_accum_iters == 0:
                    # clip grad norm
                    if params.clip_grad_norm > -1:
                        torch.nn.utils.clip_grad_norm_(model.parameters(),
                                                       params.clip_grad_norm)
                    opt.step()
                    grad_accum_iters = 0
                if (total_iter + 1) <= params.warm_up_iters or (
                        total_iter + 1) % params.lr_decay_interval == 0:
                    scheduler.step(total_iter + 1)

            if params.check_pred_validity is True:
                if params.target_data_structs in ['nodes', 'both', 'random'
                                                  ] and len(target_nodes) > 0:
                    init_nodes[node_target_inds_vector] = node_preds
                if params.target_data_structs in [
                        'edges', 'both', 'random'
                ] and len(target_adj_mats) > 0:
                    init_edges[edge_target_coords_matrix] = edge_preds
                is_valid_list, is_connected_list = check_validity(
                    init_nodes, init_edges)
                percent_valid = np.mean(is_valid_list) * 100
                valid_percent_connected = np.mean(is_connected_list) * 100
                results_dict.update({
                    'percent_valid':
                    percent_valid,
                    'percent_connected':
                    valid_percent_connected
                })

            if params.suppress_train_log is False:
                log_string = ''
                for name, value in results_dict.items():
                    if name in metrics_to_print:
                        log_string += ', {} = {:.2f}'.format(name, value)
                log_string = 'total_iter = {0:d}, loss = {1:.2f}'.format(
                    total_iter,
                    loss.cpu().item()) + log_string
                logger.info(log_string)

            if params.target_frac_inc_after is not None and total_iter > 0 and total_iter % params.target_frac_inc_after == 0:
                train_data.node_target_frac = min(
                    train_data.node_target_frac +
                    params.target_frac_inc_amount, params.max_target_frac)
                train_data.edge_target_frac = min(
                    train_data.edge_target_frac +
                    params.target_frac_inc_amount, params.max_target_frac)
            results_dict.update(
                {'node_target_frac': train_data.node_target_frac})
            results_dict.update(
                {'edge_target_frac': train_data.edge_target_frac})

            if params.tensorboard and total_iter % int(
                    params.log_train_steps) == 0:
                results_dict.update({
                    'loss': loss,
                    'lr': opt.param_groups[0]['lr']
                })
                write_tensorboard(writer, 'train', results_dict, total_iter)

            dist_names = [[
                'mask_edge_correct_dist', 'mask_edge_incorrect_pred_dist',
                'mask_edge_incorrect_true_dist'
            ],
                          [
                              'replace_edge_correct_dist',
                              'replace_edge_incorrect_pred_dist',
                              'replace_edge_incorrect_true_dist'
                          ],
                          [
                              'recon_edge_correct_dist',
                              'recon_edge_incorrect_pred_dist',
                              'recon_edge_incorrect_true_dist'
                          ]]
            distributions = np.zeros((3, 3, params.num_edge_types))
            if total_iter > 0 and total_iter % params.val_after == 0:
                logger.info('Validating')
                val_loss, property_loss, node_loss, edge_loss, hydrogen_loss, num_data_points = 0, 0, 0, 0, 0, 0
                charge_loss, is_in_ring_loss, is_aromatic_loss, chirality_loss = 0, 0, 0, 0
                model.eval()
                set_seed_if(params.seed)
                nodes_correct, edges_correct, total_nodes, total_edges = 0, 0, 0, 0
                carbon_nodes_correct, noncarbon_nodes_correct, total_carbon_nodes, total_noncarbon_nodes = 0, 0, 0, 0
                mask_nodes_correct, replace_nodes_correct, recon_nodes_correct = 0, 0, 0
                total_mask_nodes, total_replace_nodes, total_recon_nodes = 0, 0, 0
                no_edge_correct, edge_present_correct, total_no_edges, total_edges_present = 0, 0, 0, 0
                mask_edges_correct, replace_edges_correct, recon_edges_correct = 0, 0, 0
                total_mask_edges, total_replace_edges, total_recon_edges = 0, 0, 0
                hydrogens_correct, mask_hydrogens_correct, replace_hydrogens_correct, recon_hydrogens_correct, \
                    total_mask_hydrogens, total_replace_hydrogens, total_recon_hydrogens,\
                    total_hydrogens = 0, 0, 0, 0, 0, 0, 0, 0
                is_valid_list, is_connected_list = [], []
                for init_nodes, init_edges, original_node_inds, original_adj_mats, node_masks, edge_masks,\
                    node_target_types, edge_target_types, init_hydrogens, original_hydrogens,\
                    init_charge, orig_charge, init_is_in_ring, orig_is_in_ring, init_is_aromatic, orig_is_aromatic,\
                    init_chirality, orig_chirality, hydrogen_target_types, charge_target_types, is_in_ring_target_types,\
                    is_aromatic_target_types, chirality_target_types in val_loader:

                    node_target_inds_vector = getattr(node_target_types != 0,
                                                      index_method)()
                    edge_target_coords_matrix = getattr(
                        edge_target_types != 0, index_method)()
                    hydrogen_target_inds_vector = getattr(
                        hydrogen_target_types != 0, index_method)()
                    charge_target_inds_vector = getattr(
                        charge_target_types != 0, index_method)()
                    is_in_ring_target_inds_vector = getattr(
                        is_in_ring_target_types != 0, index_method)()
                    is_aromatic_target_inds_vector = getattr(
                        is_aromatic_target_types != 0, index_method)()
                    chirality_target_inds_vector = getattr(
                        chirality_target_types != 0, index_method)()

                    if params.local_cpu is False:
                        init_nodes = init_nodes.cuda()
                        init_edges = init_edges.cuda()
                        original_node_inds = original_node_inds.cuda()
                        original_adj_mats = original_adj_mats.cuda()
                        node_masks = node_masks.cuda()
                        edge_masks = edge_masks.cuda()
                        node_target_types = node_target_types.cuda()
                        edge_target_types = edge_target_types.cuda()
                        init_hydrogens = init_hydrogens.cuda()
                        original_hydrogens = original_hydrogens.cuda()
                        init_charge = init_charge.cuda()
                        orig_charge = orig_charge.cuda()
                        init_is_in_ring = init_is_in_ring.cuda()
                        orig_is_in_ring = orig_is_in_ring.cuda()
                        init_is_aromatic = init_is_aromatic.cuda()
                        orig_is_aromatic = orig_is_aromatic.cuda()
                        init_chirality = init_chirality.cuda()
                        orig_chirality = orig_chirality.cuda()
                        hydrogen_target_types = hydrogen_target_types.cuda()
                        charge_target_types = charge_target_types.cuda()
                        is_in_ring_target_types = is_in_ring_target_types.cuda(
                        )
                        is_aromatic_target_types = is_aromatic_target_types.cuda(
                        )
                        chirality_target_types = chirality_target_types.cuda()
                        if params.embed_hs is True:
                            original_hydrogens = original_hydrogens.cuda()
                        if params.property_type is not None:
                            properties = properties.cuda()

                    batch_size = init_nodes.shape[0]

                    with torch.no_grad():
                        out = model(init_nodes, init_edges, node_masks,
                                    edge_masks, init_hydrogens, init_charge,
                                    init_is_in_ring, init_is_aromatic,
                                    init_chirality)
                        node_scores, edge_scores, hydrogen_scores, charge_scores, is_in_ring_scores,\
                        is_aromatic_scores, chirality_scores = out
                        if model.property_type is not None:
                            property_scores = out[-1]

                    node_num_classes = node_scores.shape[-1]
                    edge_num_classes = edge_scores.shape[-1]
                    hydrogen_num_classes = hydrogen_scores.shape[-1]
                    charge_num_classes = charge_scores.shape[-1]
                    is_in_ring_num_classes = is_in_ring_scores.shape[-1]
                    is_aromatic_num_classes = is_aromatic_scores.shape[-1]
                    chirality_num_classes = chirality_scores.shape[-1]

                    node_scores, target_nodes, node_target_types = get_only_target_info(
                        node_scores, original_node_inds,
                        node_target_inds_vector, node_num_classes,
                        node_target_types)
                    edge_scores, target_adj_mats, edge_target_types = get_only_target_info(
                        edge_scores, original_adj_mats,
                        edge_target_coords_matrix, edge_num_classes,
                        edge_target_types)

                    if params.embed_hs is True:
                        hydrogen_scores, target_hydrogens, hydrogen_target_types = get_only_target_info(
                            hydrogen_scores, original_hydrogens,
                            hydrogen_target_inds_vector, hydrogen_num_classes,
                            hydrogen_target_types)
                        charge_scores, target_charge, charge_target_types = get_only_target_info(
                            charge_scores, orig_charge,
                            charge_target_inds_vector, charge_num_classes,
                            charge_target_types)
                        is_in_ring_scores, target_is_in_ring, is_in_ring_target_types = get_only_target_info(
                            is_in_ring_scores, orig_is_in_ring,
                            is_in_ring_target_inds_vector,
                            is_in_ring_num_classes, is_in_ring_target_types)
                        is_aromatic_scores, target_is_aromatic, is_aromatic_target_types = get_only_target_info(
                            is_aromatic_scores, orig_is_aromatic,
                            is_aromatic_target_inds_vector,
                            is_aromatic_num_classes, is_aromatic_target_types)
                        chirality_scores, target_chirality, chirality_target_types = get_only_target_info(
                            chirality_scores, orig_chirality,
                            chirality_target_inds_vector,
                            chirality_num_classes, chirality_target_types)
                    num_data_points += batch_size

                    losses = []
                    if params.target_data_structs in [
                            'nodes', 'both', 'random'
                    ] and len(target_nodes) > 0:
                        weight = get_loss_weights(target_nodes, node_scores,
                                                  params.equalise,
                                                  params.local_cpu)
                        iter_node_loss = F.cross_entropy(node_scores,
                                                         target_nodes,
                                                         weight=weight,
                                                         reduction='sum')
                        losses.append(iter_node_loss)
                        node_loss += iter_node_loss
                        nodes_correct, mask_nodes_correct, replace_nodes_correct, recon_nodes_correct, total_mask_nodes, \
                        total_replace_nodes, total_recon_nodes, total_nodes = update_ds_stats(node_scores, target_nodes,
                            node_target_types, nodes_correct, mask_nodes_correct, replace_nodes_correct,
                            recon_nodes_correct, total_mask_nodes, total_replace_nodes, total_recon_nodes, total_nodes)
                        total_noncarbon_nodes += (target_nodes != 1).sum()
                        total_carbon_nodes += (target_nodes == 1).sum()
                        node_preds = torch.argmax(F.softmax(node_scores, -1),
                                                  dim=-1)
                        noncarbon_nodes_correct += torch.mul(
                            (node_preds == target_nodes),
                            (target_nodes != 1)).sum()
                        carbon_nodes_correct += torch.mul(
                            (node_preds == target_nodes),
                            (target_nodes == 1)).sum()
                        if params.check_pred_validity is True:
                            init_nodes[node_target_inds_vector] = node_preds

                        def val_node_property_loss_computation(
                                targets, scores, loss, binary=False):
                            weight = get_loss_weights(targets, scores,
                                                      params.equalise,
                                                      params.local_cpu)
                            if binary is True:
                                iter_loss = F.binary_cross_entropy_with_logits(
                                    scores,
                                    targets,
                                    weight=weight,
                                    reduction='sum')
                            else:
                                iter_loss = F.cross_entropy(scores,
                                                            targets,
                                                            weight=weight,
                                                            reduction='sum')
                            losses.append(iter_loss)
                            loss += iter_loss
                            return loss

                        if params.embed_hs is True:
                            hydrogen_loss = val_node_property_loss_computation(
                                target_hydrogens, hydrogen_scores,
                                hydrogen_loss)
                            hydrogens_correct, mask_hydrogens_correct, replace_hydrogens_correct, recon_hydrogens_correct,\
                            total_mask_hydrogens, total_replace_hydrogens, total_recon_hydrogens, total_hydrogens =\
                                update_ds_stats(hydrogen_scores, target_hydrogens, hydrogen_target_types, hydrogens_correct,
                                mask_hydrogens_correct, replace_hydrogens_correct, recon_hydrogens_correct,
                                total_mask_hydrogens, total_replace_hydrogens, total_recon_hydrogens, total_hydrogens)
                            charge_loss = val_node_property_loss_computation(
                                target_charge, charge_scores, charge_loss)
                            is_in_ring_loss = val_node_property_loss_computation(
                                target_is_in_ring, is_in_ring_scores,
                                is_in_ring_loss)
                            is_aromatic_loss = val_node_property_loss_computation(
                                target_is_aromatic, is_aromatic_scores,
                                is_aromatic_loss)
                            chirality_loss = val_node_property_loss_computation(
                                target_chirality, chirality_scores,
                                chirality_loss)

                    if params.target_data_structs in [
                            'edges', 'both', 'random'
                    ] and len(target_adj_mats) > 0:
                        weight = get_loss_weights(target_adj_mats, edge_scores,
                                                  params.equalise,
                                                  params.local_cpu)
                        iter_edge_loss = F.cross_entropy(edge_scores,
                                                         target_adj_mats,
                                                         weight=weight,
                                                         reduction='sum')
                        losses.append(iter_edge_loss)
                        edge_loss += iter_edge_loss
                        edges_correct, mask_edges_correct, replace_edges_correct, recon_edges_correct, total_mask_edges, \
                        total_replace_edges, total_recon_edges, total_edges = update_ds_stats(edge_scores, target_adj_mats,
                            edge_target_types, edges_correct, mask_edges_correct, replace_edges_correct,
                            recon_edges_correct, total_mask_edges, total_replace_edges, total_recon_edges, total_edges)
                        total_edges_present += (target_adj_mats != 0).sum()
                        total_no_edges += (target_adj_mats == 0).sum()
                        edge_preds = torch.argmax(F.softmax(edge_scores, -1),
                                                  dim=-1)
                        edge_present_correct += torch.mul(
                            (edge_preds == target_adj_mats),
                            (target_adj_mats != 0)).sum()
                        no_edge_correct += torch.mul(
                            (edge_preds == target_adj_mats),
                            (target_adj_mats == 0)).sum()

                        distributions += get_all_result_distributions(
                            edge_preds, target_adj_mats, edge_target_types,
                            [1, 2, 3], params.num_edge_types)
                        if params.check_pred_validity is True:
                            init_edges[edge_target_coords_matrix] = edge_preds

                    if params.property_type is not None:
                        iter_property_loss = model.property_loss(
                            property_scores, properties)
                        losses.append(iter_property_loss)
                        property_loss += iter_property_loss

                    loss = sum(losses).cpu().item()
                    val_loss += loss

                    if params.check_pred_validity is True:
                        is_valid_list, is_connected_list = check_validity(
                            init_nodes, init_edges, is_valid_list,
                            is_connected_list)

                if params.property_type is not None:
                    avg_property_loss = float(property_loss) / float(
                        num_data_points)
                if params.loss_normalisation_type == 'by_total':
                    if params.embed_hs is True:
                        num_components += (total_nodes * 5)
                    num_components = float(total_nodes) + float(total_edges)
                    avg_val_loss = float(val_loss) / float(num_components)
                    if params.target_data_structs in [
                            'nodes', 'both', 'random'
                    ] and total_nodes > 0:
                        avg_node_loss = float(node_loss) / float(
                            num_components)
                        if params.embed_hs is True:
                            avg_hydrogen_loss = float(
                                hydrogen_loss) / num_components
                            avg_charge_loss = float(
                                charge_loss) / num_components
                            avg_is_in_ring_loss = float(
                                is_in_ring_loss) / num_components
                            avg_is_aromatic_loss = float(
                                is_aromatic_loss) / num_components
                            avg_chirality_loss = float(
                                chirality_loss) / num_components
                    if params.target_data_structs in [
                            'edges', 'both', 'random'
                    ] and total_edges > 0:
                        avg_edge_loss = float(edge_loss) / float(
                            num_components)
                elif params.loss_normalisation_type == 'by_component':
                    avg_val_loss = 0
                    if params.target_data_structs in [
                            'nodes', 'both', 'random'
                    ] and total_nodes > 0:
                        avg_node_loss = float(node_loss) / float(total_nodes)
                        avg_val_loss += avg_node_loss
                        if params.embed_hs is True:
                            avg_hydrogen_loss = float(hydrogen_loss) / float(
                                total_hydrogens)
                            avg_charge_loss = float(charge_loss) / float(
                                total_nodes)
                            avg_is_in_ring_loss = float(
                                is_in_ring_loss) / float(total_nodes)
                            avg_is_aromatic_loss = float(
                                is_aromatic_loss) / float(total_nodes)
                            avg_chirality_loss = float(chirality_loss) / float(
                                total_nodes)
                            avg_val_loss += avg_hydrogen_loss + avg_charge_loss + avg_is_in_ring_loss + \
                                            avg_is_aromatic_loss + avg_chirality_loss
                    if params.target_data_structs in [
                            'edges', 'both', 'random'
                    ] and total_edges > 0:
                        avg_edge_loss = float(edge_loss) / float(total_edges)
                        avg_val_loss += avg_edge_loss
                    if params.property_type is not None:
                        avg_val_loss += avg_property_loss
                logger.info(
                    'Average validation loss: {0:.2f}'.format(avg_val_loss))
                val_iter = total_iter // params.val_after

                if params.check_pred_validity is True:
                    percent_valid = np.mean(is_valid_list) * 100
                    valid_percent_connected = np.mean(is_connected_list) * 100
                    logger.info('Percent valid: {}%'.format(percent_valid))
                    logger.info(
                        'Percent of valid molecules connected: {}%'.format(
                            valid_percent_connected))

                results_dict = {'loss': avg_val_loss}

                if params.target_data_structs in ['nodes', 'both', 'random'
                                                  ] and total_nodes > 0:
                    nodes_acc, noncarbon_nodes_acc, carbon_nodes_acc, mask_node_acc, replace_node_acc, recon_node_acc =\
                        accuracies_from_totals({nodes_correct: total_nodes, noncarbon_nodes_correct: total_noncarbon_nodes,
                        carbon_nodes_correct: total_carbon_nodes, mask_nodes_correct: total_mask_nodes,
                        replace_nodes_correct: total_replace_nodes, recon_nodes_correct: total_recon_nodes})
                    results_dict.update({
                        'nodes_acc': nodes_acc,
                        'carbon_nodes_acc': carbon_nodes_acc,
                        'noncarbon_nodes_acc': noncarbon_nodes_acc,
                        'mask_node_acc': mask_node_acc,
                        'replace_node_acc': replace_node_acc,
                        'recon_node_acc': recon_node_acc,
                        'node_loss': avg_node_loss
                    })
                    logger.info('Node loss: {0:.2f}'.format(avg_node_loss))
                    logger.info('Node accuracy: {0:.2f}%'.format(nodes_acc))
                    logger.info('Non-Carbon Node accuracy: {0:.2f}%'.format(
                        noncarbon_nodes_acc))
                    logger.info('Carbon Node accuracy: {0:.2f}%'.format(
                        carbon_nodes_acc))
                    logger.info(
                        'mask_node_acc {:.2f}, replace_node_acc {:.2f}, recon_node_acc {:.2f}'
                        .format(mask_node_acc, replace_node_acc,
                                recon_node_acc))
                    if params.embed_hs is True:
                        hydrogen_acc, mask_hydrogen_acc, replace_hydrogen_acc, recon_hydrogen_acc = accuracies_from_totals(
                            {
                                hydrogens_correct: total_hydrogens,
                                mask_hydrogens_correct: total_mask_hydrogens,
                                replace_hydrogens_correct:
                                total_replace_hydrogens,
                                recon_hydrogens_correct: total_recon_hydrogens
                            })
                        results_dict.update({
                            'hydrogen_acc': hydrogen_acc,
                            'mask_hydrogen_acc': mask_hydrogen_acc,
                            'replace_hydrogen_acc': replace_hydrogen_acc,
                            'recon_hydrogen_acc': recon_hydrogen_acc,
                            'hydrogen_loss': avg_hydrogen_loss
                        })
                        logger.info(
                            'Hydrogen loss: {0:.2f}'.format(avg_hydrogen_loss))
                        logger.info(
                            'Hydrogen accuracy: {0:.2f}%'.format(hydrogen_acc))
                        logger.info(
                            'mask_hydrogen_acc {:.2f}, replace_hydrogen_acc {:.2f}, recon_hydrogen_acc {:.2f}'
                            .format(mask_hydrogen_acc, replace_hydrogen_acc,
                                    recon_hydrogen_acc))

                        results_dict.update({
                            'charge_loss':
                            avg_charge_loss,
                            'is_in_ring_loss':
                            avg_is_in_ring_loss,
                            'is_aromatic_loss':
                            avg_is_aromatic_loss,
                            'chirality_loss':
                            avg_chirality_loss
                        })
                        logger.info(
                            'Charge loss: {0:.2f}'.format(avg_charge_loss))
                        logger.info('Is in ring loss: {0:.2f}'.format(
                            avg_is_in_ring_loss))
                        logger.info('Is aromatic loss: {0:.2f}'.format(
                            avg_is_aromatic_loss))
                        logger.info('Chirality loss: {0:.2f}'.format(
                            avg_chirality_loss))

                if params.target_data_structs in ['edges', 'both', 'random'
                                                  ] and total_edges > 0:
                    edges_acc, no_edge_acc, edge_present_acc, mask_edge_acc, replace_edge_acc, recon_edge_acc =\
                        accuracies_from_totals({edges_correct: total_edges, no_edge_correct: total_no_edges,
                        edge_present_correct: total_edges_present, mask_edges_correct: total_mask_edges,
                        replace_edges_correct: total_replace_edges, recon_edges_correct: total_recon_edges})
                    results_dict.update({
                        'edges_acc': edges_acc,
                        'edge_present_acc': edge_present_acc,
                        'no_edge_acc': no_edge_acc,
                        'mask_edge_acc': mask_edge_acc,
                        'replace_edge_acc': replace_edge_acc,
                        'recon_edge_acc': recon_edge_acc,
                        'edge_loss': avg_edge_loss
                    })
                    logger.info('Edge loss: {0:.2f}'.format(avg_edge_loss))
                    logger.info('Edge accuracy: {0:.2f}%'.format(edges_acc))
                    logger.info('Edge present accuracy: {0:.2f}%'.format(
                        edge_present_acc))
                    logger.info(
                        'No edge accuracy: {0:.2f}%'.format(no_edge_acc))
                    logger.info(
                        " mask_edge_acc {:.2f}, replace_edge_acc {:.2f}, recon_edge_acc {:.2f}"
                        .format(mask_edge_acc, replace_edge_acc,
                                recon_edge_acc))
                    logger.info('\n')
                    for i in range(distributions.shape[0]):
                        for j in range(distributions.shape[1]):
                            logger.info(dist_names[i][j] + ':\t' +
                                        str(distributions[i, j, :]))
                        logger.info('\n')

                if params.property_type is not None:
                    results_dict.update({'property_loss': avg_property_loss})
                    logger.info(
                        'Property loss: {0:.2f}%'.format(avg_property_loss))

                if params.run_perturbation_analysis is True:
                    preds = run_perturbations(perturbation_loader, model,
                                              params.embed_hs, params.max_hs,
                                              params.perturbation_batch_size,
                                              params.local_cpu)
                    stats, percentages = aggregate_stats(preds)
                    logger.info('Percentages: {}'.format(
                        pp.pformat(percentages)))

                if params.tensorboard:
                    if params.run_perturbation_analysis is True:
                        writer.add_scalars('dev/perturbation_stability', {
                            str(key): val
                            for key, val in percentages.items()
                        }, val_iter)
                    write_tensorboard(writer, 'dev', results_dict, val_iter)

                if params.gen_num_samples > 0:
                    calculate_gen_benchmarks(generator, params.gen_num_samples,
                                             training_smiles, logger)

                logger.info("----------------------------------")

                model_state_dict = model.module.state_dict(
                ) if torch.cuda.device_count() > 1 else model.state_dict()

                best_loss = save_checkpoints(total_iter, avg_val_loss,
                                             best_loss, model_state_dict,
                                             opt.state_dict(), exp_path,
                                             logger, params.no_save,
                                             params.save_all)

                # Reset random seed
                set_seed_if(params.seed)
                logger.info('Validation complete')
            total_iter += 1
コード例 #7
0
def train(rank, args, model):

    torch.manual_seed(args.seed + rank)

    writer = SummaryWriter(log_path + "/" + "p{}".format(os.getpid()))

    writer.add_text('args', str(sys.argv))
    writer.add_text("target", str(args.target))
    writer.add_text("pid", str(os.getpid()))

    print("loading data...")

    molecules = MoleculeDatasetCSV(csv_file=args.D,
                                   corrupt_path=args.c,
                                   target=args.target,
                                   scaling=args.scale)

    batch_size = args.batch_size
    train_idxs = np.fromfile(args.train_idxs, dtype=np.int)
    val_idxs = np.fromfile(args.val_idxs, dtype=np.int)

    molecule_loader_train = DataLoader(molecules,
                                       batch_size=batch_size,
                                       num_workers=0,
                                       collate_fn=collate_fn,
                                       sampler=SubsetRandomSampler(train_idxs))
    molecule_loader_val = DataLoader(molecules,
                                     batch_size=val_idxs.shape[0],
                                     num_workers=0,
                                     collate_fn=collate_fn,
                                     sampler=SubsetRandomSampler(val_idxs))

    loss_fn = get_loss(args)
    optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)

    global_step = 0
    for epoch in range(0, args.n_epochs):
        global_step = train_epoch(rank=rank,
                                  epoch=epoch,
                                  global_step=global_step,
                                  model=model,
                                  molecule_loader_train=molecule_loader_train,
                                  optimizer=optimizer,
                                  loss_fn=loss_fn,
                                  writer=writer)

        test_epoch(rank=rank,
                   epoch=epoch,
                   model=model,
                   molecule_loader_val=molecule_loader_val,
                   loss_fn=loss_fn,
                   writer=writer)

        print("Saving model checkpoint...")
        torch.save(
            model.state_dict(),
            checkpoint_path + "/" + "p{}".format(os.getpid()) +
            "_epoch{}".format(epoch) + "_params.pth")

        # Output training metrics
        print("Saving training metrics")

        writer.export_scalars_to_json(scalar_path + "/" +
                                      "p{}".format(os.getpid()) +
                                      "_epoch{}".format(epoch) +
                                      "_scalars.json")

    writer.close()