コード例 #1
0
def input2output_embedder(args,
                          input_batch,
                          model,
                          random_seed_list=None,
                          max_out_len=90,
                          recover_seed=True):
    # prepare input
    input_batch = tuple(data.to(model.device) for data in input_batch)

    random_seed_list = args.seed if random_seed_list is None else random_seed_list
    output_batch = []
    for seed in random_seed_list:
        # set seed
        set_seed(seed)

        # embedder encode
        input_batch_emb, _ = model.forward_encoder(input_batch)

        # embedder decode (decode test = input is <bos> and multi for next char + embedding)
        output_batch += model.decoder_test(max_len=max_out_len,
                                           embedding=input_batch_emb)

    if recover_seed is True:
        set_seed(args.seed)
    return output_batch
コード例 #2
0
def get_dataloader(model,
                   args,
                   data,
                   batch_size=batch_size,
                   collate_fn=None,
                   shuffle=True):
    if collate_fn is None:

        def collate_fn(train_data):
            if model.dataset.use_atom_tokenizer:
                train_data.sort(
                    key=lambda string: len(atomwise_tokenizer(string)),
                    reverse=True)
            else:
                train_data.sort(key=len, reverse=True)
            tensors = [
                model.dataset.string2tensor(string,
                                            model.dataset.c2i,
                                            device=model.device)
                for string in train_data
            ]
            return tensors

    return DataLoader(data,
                      batch_size=batch_size,
                      shuffle=shuffle,
                      collate_fn=collate_fn,
                      num_workers=0,
                      worker_init_fn=set_seed(args.seed))
コード例 #3
0
                                         model,
                                         random_seed_list,
                                         max_out_len=args.test_max_len,
                                         recover_seed=False)

        discover_approved_drugs(args, model, model, drug_input2output_func)


if __name__ == "__main__":

    with torch.no_grad():
        # parse arguments
        args = parse_arguments()

        # set seed
        set_seed(args.seed)

        # set device (CPU/GPU)
        device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        print('Using device:', device)

        # disable rdkit error messages
        rdkit_no_error_print()

        if args.is_CDN is True:
            handle_CDN(args, device)
            exit()

        # load checkpoint
        T_AB, T_BA, model_A, model_B, best_criterion = load_checkpoint(
            args, device)