Exemplo n.º 1
0
def validation(args, model_name, model_in, model_out, T, epoch, boundaries, random_seed_list, fig=None, ax=None):
    # evaluation mode
    model_in.eval()
    model_out.eval()
    T.eval()

    # dataset loader
    valid_loader = get_dataloader(model_in, args, model_in.dataset.validset, batch_size=args.validation_batch_size, shuffle=False)

    # number samples in validset
    validset_len = len(model_in.dataset.validset)

    # tensor to molecule smiles
    def input_tensor2string(input_tensor):
        return model_in.tensor2string(input_tensor)

    trainset = set(model_in.dataset.trainset).union(model_out.dataset.trainset)

    #generate output molecule from input molecule
    def local_input2output(input_batch):
        return input2output(args, input_batch, model_in, T, model_out, random_seed_list,
                            max_out_len=args.validation_max_len)

    # use general validation function
    avg_similarity, avg_property, avg_SR, avg_validity, avg_novelty, avg_diversity =\
        general_validation(args, local_input2output, input_tensor2string, boundaries, valid_loader, validset_len, model_name, trainset, epoch,
        fig=fig, ax=ax)

    # back to train mode
    model_in.train()
    model_out.train()
    T.train()

    return avg_similarity, avg_property, avg_SR, avg_validity, avg_novelty, avg_diversity
Exemplo n.º 2
0
def check_testset(args, model_in, model_out, input2output_func):
    testset_path = 'dataset/' + args.property + '/' + args.testset_filename
    results_file_path = args.plots_folder + '/' + args.property + '/UGMMT_' + args.property + '_test.txt'
    valid_results_file_path = args.plots_folder + '/' + args.property + '/valid_UGMMT_' + args.property + '_test.txt'
    print(' ')
    print('Loading testset from file   => ' + testset_path)

    # train set for novelty
    trainset = set(model_in.dataset.trainset).union(model_out.dataset.trainset)

    # build testset from filename
    testset_list = filname2testset(testset_path, model_in, model_out)
    test_loader = get_dataloader(model_in,
                                 args,
                                 testset_list,
                                 batch_size=args.batch_size,
                                 shuffle=False)

    # generate random seeds
    random_seed_list = get_random_list(args.num_retries)

    def input2output_func_aux(input_batch):
        return input2output_func(input_batch, random_seed_list)

    def input2smiles(input):
        return model_in.tensor2string(input)  # to smiles

    # generate results file
    generate_results_file(test_loader, input2output_func_aux, input2smiles,
                          results_file_path)

    # result file -> valid results + property and similarity for output molecules
    process_results_file(results_file_path, args, valid_results_file_path,
                         trainset)

    # calculate metrics
    validity_mean, validity_std, \
    diversity_mean, diversity_std, \
    novelty_mean, novelty_std, \
    property_mean, property_std, \
    similarity_mean, similarity_std, \
    SR_mean, SR_std = \
                        valid_results_file_to_metrics(valid_results_file_path, args, len(testset_list))

    # print results
    print(' ')
    print('Property => ' + args.property)
    print('property => mean: ' + str(round(property_mean, 3)) + '   std: ' +
          str(round(property_std, 3)))
    print('fingerprint Similarity => mean: ' + str(round(similarity_mean, 3)) +
          '   std: ' + str(round(similarity_std, 3)))
    print('SR => mean: ' + str(round(SR_mean, 3)) + '   std: ' +
          str(round(SR_std, 3)))
    print('validity => mean: ' + str(round(validity_mean, 3)) + '   std: ' +
          str(round(validity_std, 3)))
    print('novelty => mean: ' + str(round(novelty_mean, 3)) + '   std: ' +
          str(round(novelty_std, 3)))
    print('diversity => mean: ' + str(round(diversity_mean, 3)) + '   std: ' +
          str(round(diversity_std, 3)))
Exemplo n.º 3
0
def intermediate_analysis(args, model_in, input2all_func):
    assert args.use_EETN

    # get data
    data_path = 'dataset/' + args.property + '/A_train.txt'
    print(' ')
    print('Intermediate_analyse: Loading data from file   => ' + data_path)

    # build dataset from filename
    df = pd.read_csv(data_path, header=None)
    dataset_list = list(df.iloc[:, 0])

    # get random molecule
    random_seed_list = get_random_list(args.num_samples,
                                       last=len(dataset_list) - 1)

    for seed in random_seed_list:
        rand_mol = dataset_list.pop(seed)

        # get intermediate for the *** RAND *** molecule
        rand_mol_loader = get_dataloader(model_in,
                                         args, [rand_mol],
                                         batch_size=1,
                                         shuffle=False)
        for i, input_batch in enumerate(rand_mol_loader):
            rand_enc, rand_trans, [rand_out_mol] = input2all_func(input_batch)

        # if source output molecule is invalid - skip to next molecule
        if not is_valid_molecule(rand_out_mol, args.property):
            continue

        sim_rand_all_sort_list = [(mol, similarity_calc(rand_mol, mol))
                                  for mol in dataset_list]
        sim_rand_all_sort_list.sort(key=lambda x: x[1], reverse=True)
        neighbours_near_list = [
            mol for (mol, _) in sim_rand_all_sort_list[:args.num_neighbours]
        ]
        neighbours_far_list = [
            mol for (mol, _) in sim_rand_all_sort_list[args.num_neighbours:2 *
                                                       args.num_neighbours]
        ]
        # neighbours_far_list = [mol for (mol, _) in sim_rand_all_sort_list[-args.num_neighbours:]]

        # get intermediate for *** NEAR *** molecules
        near_loader = get_dataloader(model_in,
                                     args,
                                     neighbours_near_list,
                                     batch_size=args.num_neighbours,
                                     shuffle=False)
        for i, input_batch in enumerate(near_loader):
            near_enc, near_trans, near_out_mol = input2all_func(input_batch)

        # get intermediate for *** FAR *** molecules
        far_loader = get_dataloader(model_in,
                                    args,
                                    neighbours_far_list,
                                    batch_size=args.num_neighbours,
                                    shuffle=False)
        for i, input_batch in enumerate(far_loader):
            far_enc, far_trans, far_out_mol = input2all_func(input_batch)

        p = 3
        print(' ')
        print('seed   => ' + str(seed))
        # analyse *** FAR ***
        far_input_sim_mean, far_input_sim_std, far_enc_dis_mean, far_enc_dis_std, far_trans_dis_mean, far_trans_dis_std, far_out_sim_mean, far_out_sim_std =\
            analyze_intermediate_outputs(rand_mol, rand_enc, rand_trans, rand_out_mol, neighbours_far_list, far_enc, far_trans, far_out_mol)
        print('*** FAR (NOT similar) ***')
        print('Input mol similarity => mean: ' +
              str(round(far_input_sim_mean, p)) + '   std: ' +
              str(round(far_input_sim_std, p)))
        print('A_enc emb distance => mean: ' +
              str(round(far_enc_dis_mean, p)) + '   std: ' +
              str(round(far_enc_dis_std, p)))
        print('Translator emb distance => mean: ' +
              str(round(far_trans_dis_mean, p)) + '   std: ' +
              str(round(far_trans_dis_std, p)))
        print('Output emb distance => mean: ' +
              str(round(far_out_sim_mean, p)) + '   std: ' +
              str(round(far_out_sim_std, p)))

        # analyse *** NEAR ***
        near_input_sim_mean, near_input_sim_std, near_enc_dis_mean, near_enc_dis_std, near_trans_dis_mean, near_trans_dis_std, near_out_sim_mean, near_out_sim_std =\
            analyze_intermediate_outputs(rand_mol, rand_enc, rand_trans, rand_out_mol, neighbours_near_list, near_enc, near_trans, near_out_mol)
        print('*** NEAR (similar) ***')
        print('Input mol similarity => mean: ' +
              str(round(near_input_sim_mean, p)) + '   std: ' +
              str(round(near_input_sim_std, p)))
        print('A_enc emb distance => mean: ' +
              str(round(near_enc_dis_mean, p)) + '   std: ' +
              str(round(near_enc_dis_std, p)))
        print('Translator emb distance => mean: ' +
              str(round(near_trans_dis_mean, p)) + '   std: ' +
              str(round(near_trans_dis_std, p)))
        print('Output emb distance => mean: ' +
              str(round(near_out_sim_mean, p)) + '   std: ' +
              str(round(near_out_sim_std, p)))
Exemplo n.º 4
0
def discover_approved_drugs(args, model_in, model_out, input2output_func):
    print(' ')
    print('Loading approved drugs dataset from file   =>' +
          args.FDA_approved_filename)
    # build testset from filename
    approved_drugs_list, smiles_name_dict = filname2testset(
        args.FDA_approved_filename, model_in, model_out, drugs=True)
    approved_drugs_list = [
        drug for drug in approved_drugs_list
        if len(drug) <= args.approved_drugs_max_len
    ]
    approved_drugs_set = set(approved_drugs_list)
    print('Final number of drugs (after length removal): ' +
          str(len(approved_drugs_set)))
    print(' ')

    # testset loader
    approved_drugs_loader = get_dataloader(model_in,
                                           args,
                                           approved_drugs_list,
                                           batch_size=args.batch_size,
                                           shuffle=False)

    all_drug_matches = dict()
    num_valid_molecule_matches = 0
    unique_drugs_found = set()
    unique_valid_molecules_found = set()
    for i, input_batch in enumerate(approved_drugs_loader):
        current_batch_size = len(input_batch)

        # generate random seeds
        random_seed_list = get_random_list(args.approved_drugs_retries)

        # generate output batch
        output_batch = input2output_func(input_batch, random_seed_list)

        # for every molecule
        for j, input in enumerate(input_batch):
            output_unique_molecule_smiles_set = set(
                output_batch[j::current_batch_size])

            # num valid molecules matches
            output_valid_unique_molecule_smiles_list = \
                [output_molecule_smiles for output_molecule_smiles in output_unique_molecule_smiles_set if is_valid_molecule(output_molecule_smiles, args.property)]
            num_unique_valid_molecules_per_input = len(
                output_valid_unique_molecule_smiles_list)
            num_valid_molecule_matches += num_unique_valid_molecules_per_input
            unique_valid_molecules_found = unique_valid_molecules_found.union(
                output_valid_unique_molecule_smiles_list)

            intersection_set = approved_drugs_set.intersection(
                output_unique_molecule_smiles_set)
            if not not intersection_set:  # if intersection_set is not empty
                input_molecule_smiles = model_in.tensor2string(
                    input)  # to smiles
                all_drug_matches[input_molecule_smiles] = intersection_set
                unique_drugs_found = unique_drugs_found.union(intersection_set)

            drug_index = i * current_batch_size + j
            if drug_index % 30 == 0:
                print('Drug #' + str(drug_index + 1) +
                      ': Unique approved drugs found so far: ' +
                      str(len(unique_drugs_found)))

    # print detailed results
    similarity_list = []
    out_property_list = []
    property_improvement = []
    for input_molecule_smiles, output_molecule_smiles_set in all_drug_matches.items(
    ):
        print(' ')
        property_value_in = property_calc(input_molecule_smiles, args.property)
        print('Input: ' + str(input_molecule_smiles) + '. Property value: ' +
              str(round(property_value_in, 4)) + '  Drug name: ' +
              smiles_name_dict[input_molecule_smiles])
        for output_molecule_smiles in output_molecule_smiles_set:
            property_value_out = property_calc(output_molecule_smiles,
                                               args.property)
            similarity_in_out = similarity_calc(input_molecule_smiles,
                                                output_molecule_smiles)
            print('Output: ' + str(output_molecule_smiles) +
                  '. Property value: ' + str(round(property_value_out, 4)) +
                  '  Drug name: ' + smiles_name_dict[output_molecule_smiles])
            print('Similarity: ' + str(round(similarity_in_out, 4)))
            similarity_list.append(similarity_in_out)
            out_property_list.append(property_value_out)
            property_improvement.append(property_value_out - property_value_in)

    # print final results
    average_drug_matches_similarity = sum(similarity_list) / len(
        similarity_list) if len(similarity_list) > 0 else -1
    average_out_drugs_property_val = sum(out_property_list) / len(
        out_property_list) if len(out_property_list) > 0 else -1
    average_property_improvement = sum(property_improvement) / len(
        property_improvement) if len(property_improvement) > 0 else -1
    unique_drugs_per_unique_legal_molecule = 100 * len(
        unique_drugs_found) / len(unique_valid_molecules_found) if len(
            unique_valid_molecules_found) > 0 else -1
    drug_matches_per_valid_matches = 100 * len(
        out_property_list
    ) / num_valid_molecule_matches if num_valid_molecule_matches > 0 else -1

    print(' ')
    print('Property => ' + args.property)
    print('Translation direction   => ' + args.test_direction)

    print('Unique approved drugs found: ' + str(len(unique_drugs_found)))
    print('Unique legal molecules generated: ' +
          str(len(unique_valid_molecules_found)))
    print('Unique approved drugs per unique legal molecules generated: ' +
          str(round(unique_drugs_per_unique_legal_molecule, 5)) + '%')

    print('Drugs matches: ' + str(len(out_property_list)))
    print('All valid matches: ' + str(num_valid_molecule_matches))
    print('Unique drug-drug matches per drug-valid molecule matches: ' +
          str(round(drug_matches_per_valid_matches, 5)) + '%')

    print('Average generated drugs property value  => ' +
          str(round(average_out_drugs_property_val, 3)))
    print('Average drug matches fingerprint Similarity =>' +
          str(round(average_drug_matches_similarity, 3)))
    print('Average drug matches property value improvement  => ' +
          str(round(average_property_improvement, 3)))
Exemplo n.º 5
0
    # optimizers for ablation
    if args.gan_loss:
        optimizer_D_A = torch.optim.Adam(D_A.parameters(), lr=args.init_lr, betas=(0.5, 0.999))
        optimizer_D_B = torch.optim.Adam(D_B.parameters(), lr=args.init_lr, betas=(0.5, 0.999))


    # scheduler
    lr_scheduler_T = torch.optim.lr_scheduler.LambdaLR(optimizer_T, lr_lambda=LambdaLR(args.epochs, args.epoch_init, args.epoch_decay).step)
    # schedulers for ablation
    if args.gan_loss:
        lr_scheduler_D_A = torch.optim.lr_scheduler.LambdaLR(optimizer_D_A, lr_lambda=LambdaLR(args.epochs, args.epoch_init, args.epoch_decay).step)
        lr_scheduler_D_B = torch.optim.lr_scheduler.LambdaLR(optimizer_D_B, lr_lambda=LambdaLR(args.epochs, args.epoch_init, args.epoch_decay).step)


    # train dataloaders
    A_train_loader = get_dataloader(model_A, args, model_A.dataset.trainset, args.batch_size, collate_fn=None, shuffle=True)
    B_train_loader = get_dataloader(model_B, args, model_B.dataset.trainset, args.batch_size, collate_fn=None, shuffle=True)

    # buffer for max_size last fake samples for ablation
    if args.gan_loss:
        fake_A_buffer = ReplayBuffer(max_size=50)
        fake_B_buffer = ReplayBuffer(max_size=50)
    else:
        fake_A_buffer, fake_B_buffer = None, None


    # for early stopping
    best_criterion = None
    runs_without_improvement = 0

    # generate random seeds