Пример #1
0
def train_model(batches, test_batches, device, args, params):
    """
    Train and evaluate the model
    """
    losses = []
    # We use Binary-Cross-Entropy loss
    loss_function = nn.BCELoss()
    # Set model with relevant parameters
    model = AutoencoderLSTMClassifier(params['emb_dim'], device,
                                      params['max_len'], 21, params['enc_dim'],
                                      params['batch_size'], args['ae_file'],
                                      params['train_ae'])
    # Move to GPU
    model.to(device)
    # We use Adam optimizer
    optimizer = optim.Adam(model.parameters(),
                           lr=params['lr'],
                           weight_decay=params['wd'])
    # Train several epochs
    best_auc = 0
    best_roc = None
    for epoch in range(params['epochs']):
        print('epoch:', epoch + 1)
        epoch_time = time.time()
        # Train model and get loss
        loss = train_epoch(batches, model, loss_function, optimizer, device)
        losses.append(loss)
        # Compute auc
        train_auc = evaluate(model, batches, device)[0]
        print('train auc:', train_auc)
        with open(args['train_auc_file'], 'a+') as file:
            file.write(str(train_auc) + '\n')
        test_auc, roc = evaluate(model, test_batches, device)

        # nni.report_intermediate_result(test_auc)

        if test_auc > best_auc:
            best_auc = test_auc
            best_roc = roc
        # print(roc)
        # plt.plot(roc[0], roc[1])
        # plt.show()
        print('test auc:', test_auc)
        with open(args['test_auc_file'], 'a+') as file:
            file.write(str(test_auc) + '\n')
        print('one epoch time:', time.time() - epoch_time)
    return model, best_auc, best_roc
Пример #2
0
def predict(args):
    # Word to index dictionary
    amino_acids = [letter for letter in 'ARNDCEQGHILKMFPSTWYV']
    if args.model_type == 'lstm':
        amino_to_ix = {
            amino: index
            for index, amino in enumerate(['PAD'] + amino_acids)
        }
    if args.model_type == 'ae':
        pep_atox = {
            amino: index
            for index, amino in enumerate(['PAD'] + amino_acids)
        }
        tcr_atox = {
            amino: index
            for index, amino in enumerate(amino_acids + ['X'])
        }

    if args.ae_file == 'auto':
        args.ae_file = 'TCR_Autoencoder/tcr_autoencoder.pt'
    if args.model_file == 'auto':
        dir = 'models'
        p_key = 'protein' if args.protein else ''
        args.model_file = dir + '/' + '_'.join(
            [args.model_type, args.dataset, args.sampling, p_key, 'model.pt'])
    if args.test_data_file == 'auto':
        args.test_data_file = 'pairs_example.csv'

    # Read test data
    tcrs = []
    peps = []
    signs = []
    max_len = 28
    with open(args.test_data_file, 'r') as csv_file:
        reader = csv.reader(csv_file)
        for line in reader:
            tcr, pep = line
            if args.model_type == 'ae' and len(tcr) >= max_len:
                continue
            tcrs.append(tcr)
            peps.append(pep)
            signs.append(0.0)
    tcrs_copy = tcrs.copy()
    peps_copy = peps.copy()

    # Load model
    device = args.device
    if args.model_type == 'ae':
        model = AutoencoderLSTMClassifier(10, device, 28, 21, 30, 50,
                                          args.ae_file, False)
        checkpoint = torch.load(args.model_file, map_location=device)
        model.load_state_dict(checkpoint['model_state_dict'])
        model.to(device)
        model.eval()
    if args.model_type == 'lstm':
        model = DoubleLSTMClassifier(10, 30, 0.1, device)
        checkpoint = torch.load(args.model_file, map_location=device)
        model.load_state_dict(checkpoint['model_state_dict'])
        model.to(device)
        model.eval()
        pass

    # Predict
    batch_size = 50
    if args.model_type == 'ae':
        test_batches = ae.get_full_batches(tcrs, peps, signs, tcr_atox,
                                           pep_atox, batch_size, max_len)
        preds = ae.predict(model, test_batches, device)
    if args.model_type == 'lstm':
        lstm.convert_data(tcrs, peps, amino_to_ix)
        test_batches = lstm.get_full_batches(tcrs, peps, signs, batch_size,
                                             amino_to_ix)
        preds = lstm.predict(model, test_batches, device)

    # Print predictions
    for tcr, pep, pred in zip(tcrs_copy, peps_copy, preds):
        print('\t'.join([tcr, pep, str(pred)]))
Пример #3
0
def protein_test(args):
    assert args.protein
    # Word to index dictionary
    amino_acids = [letter for letter in 'ARNDCEQGHILKMFPSTWYV']
    if args.model_type == 'lstm':
        amino_to_ix = {
            amino: index
            for index, amino in enumerate(['PAD'] + amino_acids)
        }
    if args.model_type == 'ae':
        pep_atox = {
            amino: index
            for index, amino in enumerate(['PAD'] + amino_acids)
        }
        tcr_atox = {
            amino: index
            for index, amino in enumerate(amino_acids + ['X'])
        }

    if args.ae_file == 'auto':
        args.ae_file = 'TCR_Autoencoder/tcr_autoencoder.pt'
    if args.test_data_file == 'auto':
        dir = 'memory_and_protein'
        p_key = 'protein' if args.protein else ''
        args.test_data_file = dir + '/' + '_'.join([
            args.model_type, args.dataset, args.sampling, p_key, 'test.pickle'
        ])
    if args.model_file == 'auto':
        dir = 'memory_and_protein'
        p_key = 'protein' if args.protein else ''
        args.model_file = dir + '/' + '_'.join(
            [args.model_type, args.dataset, args.sampling, p_key, 'model.pt'])

    # Read test data
    with open(args.test_data_file, 'rb') as handle:
        test = pickle.load(handle)

    device = args.device
    if args.model_type == 'ae':
        test_tcrs, test_peps, test_signs = ae_get_lists_from_pairs(test, 28)
        model = AutoencoderLSTMClassifier(10, device, 28, 21, 30, 50,
                                          args.ae_file, False)
        checkpoint = torch.load(args.model_file)
        model.load_state_dict(checkpoint['model_state_dict'])
        model.to(device)
        model.eval()
    if args.model_type == 'lstm':
        test_tcrs, test_peps, test_signs = lstm_get_lists_from_pairs(test)
        model = DoubleLSTMClassifier(10, 30, 0.1, device)
        checkpoint = torch.load(args.model_file)
        model.load_state_dict(checkpoint['model_state_dict'])
        model.to(device)
        model.eval()
        pass

    # Get frequent peps list
    if args.dataset == 'mcpas':
        datafile = 'McPAS-TCR.csv'
    p = []
    protein_peps = {}
    with open(datafile, 'r', encoding='unicode_escape') as file:
        file.readline()
        reader = csv.reader(file)
        for line in reader:
            pep, protein = line[11], line[9]
            if protein == 'NA' or pep == 'NA':
                continue
            p.append(protein)
            try:
                protein_peps[protein].append(pep)
            except KeyError:
                protein_peps[protein] = [pep]

    d = {t: p.count(t) for t in set(p)}
    sorted_d = sorted(d.items(), key=lambda k: k[1], reverse=True)
    proteins = [t[0] for t in sorted_d]
    """
    McPAS most frequent proteins
    NP177   Influenza
    Matrix protein (M1) Influenza
    pp65    Cytomegalovirus (CMV)
    BMLF-1  Epstein Barr virus (EBV)
    PB1 Influenza
    """
    rocs = []
    for protein in proteins[:50]:
        protein_shows = [
            i for i in range(len(test_peps))
            if test_peps[i] in protein_peps[protein]
        ]
        test_tcrs_protein = [test_tcrs[i] for i in protein_shows]
        test_peps_protein = [test_peps[i] for i in protein_shows]
        test_signs_protein = [test_signs[i] for i in protein_shows]
        if args.model_type == 'ae':
            test_batches_protein = ae.get_full_batches(test_tcrs_protein,
                                                       test_peps_protein,
                                                       test_signs_protein,
                                                       tcr_atox, pep_atox, 50,
                                                       28)
        if args.model_type == 'lstm':
            lstm.convert_data(test_tcrs_protein, test_peps_protein,
                              amino_to_ix)
            test_batches_protein = lstm.get_full_batches(
                test_tcrs_protein, test_peps_protein, test_signs_protein, 50,
                amino_to_ix)
        if len(protein_shows):
            try:
                if args.model_type == 'ae':
                    test_auc, roc = ae.evaluate_full(model,
                                                     test_batches_protein,
                                                     device)
                if args.model_type == 'lstm':
                    test_auc, roc = lstm.evaluate_full(model,
                                                       test_batches_protein,
                                                       device)
                rocs.append((pep, roc))
                # print(protein)
                print(str(test_auc))
                # print(protein + ', ' + str(test_auc))
            except ValueError:
                # print(protein)
                print('NA')
                # print(protein + ', ' 'NA')
                pass
    return rocs
Пример #4
0
def pep_test(args):
    # Word to index dictionary
    amino_acids = [letter for letter in 'ARNDCEQGHILKMFPSTWYV']
    if args.model_type == 'lstm':
        amino_to_ix = {
            amino: index
            for index, amino in enumerate(['PAD'] + amino_acids)
        }
    if args.model_type == 'ae':
        pep_atox = {
            amino: index
            for index, amino in enumerate(['PAD'] + amino_acids)
        }
        tcr_atox = {
            amino: index
            for index, amino in enumerate(amino_acids + ['X'])
        }

    if args.ae_file == 'auto':
        args.ae_file = 'TCR_Autoencoder/tcr_ae_dim_100.pt'
    if args.test_data_file == 'auto':
        dir = 'final_results'
        p_key = 'protein' if args.protein else ''
        args.test_data_file = dir + '/' + '_'.join([
            args.model_type, args.dataset, args.sampling, p_key, 'test.pickle'
        ])
    if args.model_file == 'auto':
        dir = 'final_results'
        p_key = 'protein' if args.protein else ''
        args.model_file = dir + '/' + '_'.join(
            [args.model_type, args.dataset, args.sampling, p_key, 'model.pt'])

    # Read test data
    with open(args.test_data_file, 'rb') as handle:
        test = pickle.load(handle)

    device = args.device
    if args.model_type == 'ae':
        test_tcrs, test_peps, test_signs = ae_get_lists_from_pairs(test, 28)
        model = AutoencoderLSTMClassifier(10, device, 28, 21, 100, 50,
                                          args.ae_file, False)
        checkpoint = torch.load(args.model_file)
        model.load_state_dict(checkpoint['model_state_dict'])
        model.to(device)
        model.eval()
    if args.model_type == 'lstm':
        test_tcrs, test_peps, test_signs = lstm_get_lists_from_pairs(test)
        model = DoubleLSTMClassifier(10, 500, 0.1, device)
        checkpoint = torch.load(args.model_file)
        model.load_state_dict(checkpoint['model_state_dict'])
        model.to(device)
        model.eval()
        pass

    # Get frequent peps list
    if args.dataset == 'mcpas':
        datafile = 'nine_class_testdata.csv'
    p = []
    with open(datafile, 'r', encoding='unicode_escape') as file:
        file.readline()
        reader = csv.reader(file)
        for line in reader:
            pep = line[1]
            if pep == 'NA':
                continue
            p.append(pep)
    d = {t: p.count(t) for t in set(p)}
    sorted_d = sorted(d.items(), key=lambda k: k[1], reverse=True)
    peps = [t[0] for t in sorted_d]
    """
    McPAS most frequent peps
    LPRRSGAAGA  Influenza
    GILGFVFTL   Influenza
    GLCTLVAML   Epstein Barr virus (EBV)	
    NLVPMVATV   Cytomegalovirus (CMV)	
    SSYRRPVGI   Influenza
    """
    rocs = []
    auc = []
    for pep in peps[:249]:
        pep_shows = [i for i in range(len(test_peps)) if pep == test_peps[i]]
        test_tcrs_pep = [test_tcrs[i] for i in pep_shows]
        test_peps_pep = [test_peps[i] for i in pep_shows]
        test_signs_pep = [test_signs[i] for i in pep_shows]
        if args.model_type == 'ae':
            test_batches_pep = ae.get_full_batches(test_tcrs_pep,
                                                   test_peps_pep,
                                                   test_signs_pep, tcr_atox,
                                                   pep_atox, 50, 28)
        if args.model_type == 'lstm':
            lstm.convert_data(test_tcrs_pep, test_peps_pep, amino_to_ix)
            test_batches_pep = lstm.get_full_batches(test_tcrs_pep,
                                                     test_peps_pep,
                                                     test_signs_pep, 50,
                                                     amino_to_ix)
        if len(pep_shows):
            try:
                if args.model_type == 'ae':
                    test_auc, roc = ae.evaluate_full(model, test_batches_pep,
                                                     device)
                    filename = 'pep_test_result_ae2'
                if args.model_type == 'lstm':
                    test_auc, roc = lstm.evaluate_full(model, test_batches_pep,
                                                       device)
                    filename = 'pep_test_result_lstm2'
                rocs.append((pep, roc))
                auc.append([pep, roc, test_auc])
                print(str(test_auc))
                # print(pep + ', ' + str(test_auc))
            except ValueError:
                print('NA')
                # print(pep + ', ' 'NA')
                pass
    pickle.dump(auc, open(filename, 'wb'))
Пример #5
0
def pep_test(args):
    # Word to index dictionary
    amino_acids = [letter for letter in 'ARNDCEQGHILKMFPSTWYV']
    if args.model_type == 'lstm':
        amino_to_ix = {
            amino: index
            for index, amino in enumerate(['PAD'] + amino_acids)
        }
    if args.model_type == 'ae':
        pep_atox = {
            amino: index
            for index, amino in enumerate(['PAD'] + amino_acids)
        }
        tcr_atox = {
            amino: index
            for index, amino in enumerate(amino_acids + ['X'])
        }

    if args.ae_file == 'auto':
        args.ae_file = 'TCR_Autoencoder/tcr_ae_dim_100.pt'
    if args.test_data_file == 'auto':
        dir = 'final_results'
        p_key = 'protein' if args.protein else ''
        args.test_data_file = dir + '/' + '_'.join([
            args.model_type, args.dataset, args.sampling, p_key, 'test.pickle'
        ])
    if args.model_file == 'auto':
        dir = 'final_results'
        p_key = 'protein' if args.protein else ''
        args.model_file = dir + '/' + '_'.join(
            [args.model_type, args.dataset, args.sampling, p_key, 'model.pt'])

    # Read test data
    with open(args.test_data_file, 'rb') as handle:
        test = pickle.load(handle)

    device = args.device
    if args.model_type == 'ae':
        test_tcrs, test_peps, test_signs = ae_get_lists_from_pairs(test, 28)
        model = AutoencoderLSTMClassifier(10, device, 28, 21, 100, 50,
                                          args.ae_file, False)
        checkpoint = torch.load(args.model_file)
        model.load_state_dict(checkpoint['model_state_dict'])
        model.to(device)
        model.eval()
    if args.model_type == 'lstm':
        test_tcrs, test_peps, test_signs = lstm_get_lists_from_pairs(test)
        model = DoubleLSTMClassifier(10, 500, 0.1, device)
        checkpoint = torch.load(args.model_file)
        model.load_state_dict(checkpoint['model_state_dict'])
        model.to(device)
        model.eval()
        pass

    rocs = []
    auc = []

    if args.model_type == 'ae':
        test_batches_pep = ae.get_full_batches(test_tcrs, test_peps,
                                               test_signs, tcr_atox, pep_atox,
                                               50, 28)
        test_auc, roc = ae.evaluate_full(model, test_batches_pep, device)
        filename = 'tpp_test_result_ae2'
    if args.model_type == 'lstm':
        lstm.convert_data(test_tcrs, test_peps, amino_to_ix)
        test_batches_pep = lstm.get_full_batches(test_tcrs, test_peps,
                                                 test_signs, 50, amino_to_ix)
        test_auc, roc = lstm.evaluate_full(model, test_batches_pep, device)
        filename = 'tpp_test_result_lstm2'

    rocs.append(roc)
    auc.append([roc, test_auc])
    print(str(test_auc))
    pickle.dump(auc, open(filename, 'wb'))
Пример #6
0
def load_model_and_data(args):
    # train
    if args.train_data_file == 'auto':
        dir = 'save_results'
        p_key = 'protein' if args.protein else ''
        args.train_data_file = dir + '/' + '_'.join([
            args.model_type, args.dataset, args.sampling, p_key, 'train.pickle'
        ])
    # test
    if args.test_data_file == 'auto':
        dir = 'save_results'
        p_key = 'protein' if args.protein else ''
        args.test_data_file = dir + '/' + '_'.join([
            args.model_type, args.dataset, args.sampling, p_key, 'test.pickle'
        ])

    # Read train data
    with open(args.train_data_file, "rb") as file:
        train_data = pickle.load(file)
    # Read test data
    with open(args.test_data_file, "rb") as file:
        test_data = pickle.load(file)

    # trained model
    if args.model_file == 'auto':
        dir = 'save_results'
        p_key = 'protein' if args.protein else ''
        args.model_file = dir + '/' + '_'.join(
            [args.model_type, args.dataset, args.sampling, p_key, 'model.pt'])

    # enc_dim = 30
    # Load model
    device = args.device
    if args.model_type == 'ae':
        checkpoint = torch.load(args.model_file, map_location=device)
        params = checkpoint['params']
        args.ae_file = 'TCR_Autoencoder/tcr_ae_dim_' + str(
            params['enc_dim']) + '.pt'
        model = AutoencoderLSTMClassifier(params['emb_dim'], device, 28, 21,
                                          params['enc_dim'],
                                          params['batch_size'], args.ae_file,
                                          False)
        model.load_state_dict(checkpoint['model_state_dict'])
        model.to(device)
        model.eval()
    if args.model_type == 'lstm':
        checkpoint = torch.load(args.model_file, map_location=device)
        params = checkpoint['params']
        model = DoubleLSTMClassifier(params['emb_dim'], params['lstm_dim'],
                                     params['dropout'], device)
        model.load_state_dict(checkpoint['model_state_dict'])
        model.to(device)
        model.eval()

    data = [train_data, test_data]
    return model, data