Пример #1
0
def model_eval():
    contact_net.eval()
    contacts, seq_embeddings, matrix_reps, seq_lens = next(iter(val_generator))
    contacts_batch = torch.Tensor(contacts.float()).to(device)
    seq_embedding_batch = torch.Tensor(seq_embeddings.float()).to(device)
    matrix_reps_batch = torch.unsqueeze(
        torch.Tensor(matrix_reps.float()).to(device), -1)

    # padding the states for supervised training with all 0s
    state_pad = torch.zeros([matrix_reps_batch.shape[0], seq_len,
                             seq_len]).to(device)
    PE_batch = get_pe(seq_lens, seq_len).float().to(device)

    with torch.no_grad():
        pred_contacts = contact_net(PE_batch, seq_embedding_batch, state_pad)

    u_no_train = postprocess(pred_contacts, seq_embedding_batch, 0.01, 0.1, 50,
                             1.0, True)
    map_no_train = (u_no_train > 0.5).float()
    f1_no_train_tmp = list(
        map(
            lambda i: F1_low_tri(map_no_train.cpu()[i],
                                 contacts_batch.cpu()[i]),
            range(contacts_batch.shape[0])))
    print('Average val F1 score with pure post-processing: ',
          np.average(f1_no_train_tmp))
Пример #2
0
def model_eval_all_test():
    contact_net.eval()
    result_no_train = list()
    result_no_train_shift = list()
    batch_n = 0
    for contacts, seq_embeddings, matrix_reps, seq_lens in test_generator:
        if batch_n % 10 == 0:
            print('Batch number: ', batch_n)
        batch_n += 1
        contacts_batch = torch.Tensor(contacts.float()).to(device)
        seq_embedding_batch = torch.Tensor(seq_embeddings.float()).to(device)
        matrix_reps_batch = torch.unsqueeze(
            torch.Tensor(matrix_reps.float()).to(device), -1)

        state_pad = torch.zeros([matrix_reps_batch.shape[0], seq_len,
                                 seq_len]).to(device)

        PE_batch = get_pe(seq_lens, seq_len).float().to(device)
        with torch.no_grad():
            pred_contacts = contact_net(PE_batch, seq_embedding_batch,
                                        state_pad)

        # only post-processing without learning
        u_no_train = postprocess(pred_contacts, seq_embedding_batch, 0.01, 0.1,
                                 50, 1.0, True)
        map_no_train = (u_no_train > 0.5).float()
        result_no_train_tmp = list(
            map(
                lambda i: evaluate_exact(map_no_train.cpu()[i],
                                         contacts_batch.cpu()[i]),
                range(contacts_batch.shape[0])))
        result_no_train += result_no_train_tmp
        result_no_train_tmp_shift = list(
            map(
                lambda i: evaluate_shifted(map_no_train.cpu()[i],
                                           contacts_batch.cpu()[i]),
                range(contacts_batch.shape[0])))
        result_no_train_shift += result_no_train_tmp_shift

    nt_exact_p, nt_exact_r, nt_exact_f1 = zip(*result_no_train)
    nt_shift_p, nt_shift_r, nt_shift_f1 = zip(*result_no_train_shift)

    print('Average testing F1 score with pure post-processing: ',
          np.average(nt_exact_f1))

    print('Average testing F1 score with pure post-processing allow shift: ',
          np.average(nt_shift_f1))

    print('Average testing precision with pure post-processing: ',
          np.average(nt_exact_p))

    print('Average testing precision with pure post-processing allow shift: ',
          np.average(nt_shift_p))

    print('Average testing recall with pure post-processing: ',
          np.average(nt_exact_r))

    print('Average testing recall with pure post-processing allow shift: ',
          np.average(nt_shift_r))
Пример #3
0
def model_eval_all_test(test_generator, contact_net, lag_pp_net, device):
    contact_net.eval()
    lag_pp_net.eval()
    result_no_train = list()
    result_no_train_shift = list()
    result_pp = list()
    result_pp_shift = list()

    f1_no_train = list()
    f1_pp = list()
    seq_lens_list = list()

    batch_n = 0
    for contacts, seq_embeddings, matrix_reps, seq_lens in test_generator:
        if batch_n % 10 == 0:
            print('Batch number: ', batch_n)
        batch_n += 1
        contacts_batch = torch.Tensor(contacts.float()).to(device)
        seq_embedding_batch = torch.Tensor(seq_embeddings.float()).to(device)
        matrix_reps_batch = torch.unsqueeze(
            torch.Tensor(matrix_reps.float()).to(device), -1)

        state_pad = torch.zeros(contacts.shape).to(device)

        PE_batch = get_pe(seq_lens, contacts.shape[-1]).float().to(device)
        with torch.no_grad():
            pred_contacts = contact_net(PE_batch, seq_embedding_batch,
                                        state_pad)
            a_pred_list = lag_pp_net(pred_contacts, seq_embedding_batch)

        # only post-processing without learning
        u_no_train = postprocess(pred_contacts, seq_embedding_batch, 0.01, 0.1,
                                 50, 1.0, True)
        map_no_train = (u_no_train > 0.5).float()
        result_no_train_tmp = list(
            map(
                lambda i: evaluate_exact(map_no_train.cpu()[i],
                                         contacts_batch.cpu()[i]),
                range(contacts_batch.shape[0])))
        result_no_train += result_no_train_tmp
        result_no_train_tmp_shift = list(
            map(
                lambda i: evaluate_shifted(map_no_train.cpu()[i],
                                           contacts_batch.cpu()[i]),
                range(contacts_batch.shape[0])))
        result_no_train_shift += result_no_train_tmp_shift

        f1_no_train_tmp = list(
            map(
                lambda i: F1_low_tri(map_no_train.cpu()[i],
                                     contacts_batch.cpu()[i]),
                range(contacts_batch.shape[0])))
        f1_no_train += f1_no_train_tmp

        # the learning pp result
        final_pred = (a_pred_list[-1].cpu() > 0.5).float()
        result_tmp = list(
            map(
                lambda i: evaluate_exact(final_pred.cpu()[i],
                                         contacts_batch.cpu()[i]),
                range(contacts_batch.shape[0])))
        result_pp += result_tmp

        result_tmp_shift = list(
            map(
                lambda i: evaluate_shifted(final_pred.cpu()[i],
                                           contacts_batch.cpu()[i]),
                range(contacts_batch.shape[0])))
        result_pp_shift += result_tmp_shift

        f1_tmp = list(
            map(
                lambda i: F1_low_tri(final_pred.cpu()[i],
                                     contacts_batch.cpu()[i]),
                range(contacts_batch.shape[0])))
        f1_pp += f1_tmp
        seq_lens_list += list(seq_lens)

    nt_exact_p, nt_exact_r, nt_exact_f1 = zip(*result_no_train)
    nt_shift_p, nt_shift_r, nt_shift_f1 = zip(*result_no_train_shift)

    pp_exact_p, pp_exact_r, pp_exact_f1 = zip(*result_pp)
    pp_shift_p, pp_shift_r, pp_shift_f1 = zip(*result_pp_shift)
    print('Average testing F1 score with learning post-processing: ',
          np.average(pp_exact_f1))
    print('Average testing F1 score with zero parameter pp: ',
          np.average(nt_exact_f1))

    print(
        'Average testing F1 score with learning post-processing allow shift: ',
        np.average(pp_shift_f1))
    print('Average testing F1 score with zero parameter pp allow shift: ',
          np.average(nt_shift_f1))

    print('Average testing precision with learning post-processing: ',
          np.average(pp_exact_p))
    print('Average testing precision with zero parameter pp: ',
          np.average(nt_exact_p))

    print(
        'Average testing precision with learning post-processing allow shift: ',
        np.average(pp_shift_p))
    print('Average testing precision with zero parameter pp allow shift: ',
          np.average(nt_shift_p))

    print('Average testing recall with learning post-processing: ',
          np.average(pp_exact_r))
    print('Average testing recall with zero parameter pp : ',
          np.average(nt_exact_r))

    print('Average testing recall with learning post-processing allow shift: ',
          np.average(pp_shift_r))
    print('Average testing recall with zero parameter pp allow shift: ',
          np.average(nt_shift_r))

    result_dict = dict()
    result_dict['exact_p'] = pp_exact_p
    result_dict['exact_r'] = pp_exact_r
    result_dict['exact_f1'] = pp_exact_f1
    result_dict['shift_p'] = pp_shift_p
    result_dict['shift_r'] = pp_shift_r
    result_dict['shift_f1'] = pp_shift_f1
    result_dict['seq_lens'] = seq_lens_list
    result_dict['exact_weighted_f1'] = np.sum(
        np.array(pp_exact_f1) * np.array(seq_lens_list) /
        np.sum(seq_lens_list))
    result_dict['shift_weighted_f1'] = np.sum(
        np.array(pp_shift_f1) * np.array(seq_lens_list) /
        np.sum(seq_lens_list))
Пример #4
0
def model_eval_all_test():
    contact_net.eval()
    result_no_train = list()
    result_no_train_shift = list()
    seq_lens_list = list()
    batch_n = 0
    # for contacts, seq_embeddings, matrix_reps, seq_lens in test_generator:
    #     if batch_n%10==0:
    #         print('Batch number: ', batch_n)
    #     batch_n += 1
    #     contacts_batch = torch.Tensor(contacts.float()).to(device)
    #     seq_embedding_batch = torch.Tensor(seq_embeddings.float()).to(device)

    #     state_pad = torch.zeros(1,2,2).to(device)

    #     PE_batch = get_pe(seq_lens, 600).float().to(device)
    #     with torch.no_grad():
    #         pred_contacts = contact_net(PE_batch,
    #             seq_embedding_batch, state_pad)

    #     # only post-processing without learning
    #     u_no_train = postprocess(pred_contacts,
    #         seq_embedding_batch, 0.01, 0.1, 50, 1.0, True)
    #     map_no_train = (u_no_train > 0.5).float()
    #     result_no_train_tmp = list(map(lambda i: evaluate_exact(map_no_train.cpu()[i],
    #         contacts_batch.cpu()[i]), range(contacts_batch.shape[0])))
    #     result_no_train += result_no_train_tmp
    #     result_no_train_tmp_shift = list(map(lambda i: evaluate_shifted(map_no_train.cpu()[i],
    #         contacts_batch.cpu()[i]), range(contacts_batch.shape[0])))
    #     result_no_train_shift += result_no_train_tmp_shift

    for seq_embedding_batch, PE_batch, contacts_batch, _, _, _, seq_lens in test_generator_1800:
        if batch_n % 10 == 0:
            print('Batch number: ', batch_n)
        batch_n += 1
        seq_embedding_batch = seq_embedding_batch[0].to(device)
        PE_batch = PE_batch[0].to(device)
        contacts_batch = contacts_batch[0]
        # padding the states for supervised training with all 0s
        state_pad = torch.zeros(1, 2, 2).to(device)

        with torch.no_grad():
            pred_contacts = contact_net(PE_batch, seq_embedding_batch,
                                        state_pad)

        # only post-processing without learning
        u_no_train = postprocess(pred_contacts, seq_embedding_batch, 0.01, 0.1,
                                 50, 1.0, True)
        map_no_train = (u_no_train > 0.5).float()
        result_no_train_tmp = list(
            map(
                lambda i: evaluate_exact(map_no_train.cpu()[i],
                                         contacts_batch.cpu()[i]),
                range(contacts_batch.shape[0])))
        result_no_train += result_no_train_tmp
        result_no_train_tmp_shift = list(
            map(
                lambda i: evaluate_shifted(map_no_train.cpu()[i],
                                           contacts_batch.cpu()[i]),
                range(contacts_batch.shape[0])))
        result_no_train_shift += result_no_train_tmp_shift
        seq_lens_list += list(seq_lens)

    nt_exact_p, nt_exact_r, nt_exact_f1 = zip(*result_no_train)
    nt_shift_p, nt_shift_r, nt_shift_f1 = zip(*result_no_train_shift)

    nt_exact_p = np.nan_to_num(np.array(nt_exact_p))
    nt_exact_r = np.nan_to_num(np.array(nt_exact_r))
    nt_exact_f1 = np.nan_to_num(np.array(nt_exact_f1))

    nt_shift_p = np.nan_to_num(np.array(nt_shift_p))
    nt_shift_r = np.nan_to_num(np.array(nt_shift_r))
    nt_shift_f1 = np.nan_to_num(np.array(nt_shift_f1))

    print('Average testing F1 score with pure post-processing: ',
          np.average(nt_exact_f1))
    print('Average testing F1 score with pure post-processing allow shift: ',
          np.average(nt_shift_f1))
    print('Average testing precision with pure post-processing: ',
          np.average(nt_exact_p))
    print('Average testing precision with pure post-processing allow shift: ',
          np.average(nt_shift_p))
    print('Average testing recall with pure post-processing: ',
          np.average(nt_exact_r))
    print('Average testing recall with pure post-processing allow shift: ',
          np.average(nt_shift_r))
    nt_exact_f1_agg = list()
    nt_shift_f1_agg = list()
    for i in range(len(seq_lens_list)):
        nt_exact_f1_agg.append(np.average(nt_exact_f1[i * 15:(i + 1) * 15]))
        nt_shift_f1_agg.append(np.average(nt_shift_f1[i * 15:(i + 1) * 15]))
    result_dict = dict()
    result_dict['exact_p'] = nt_exact_p
    result_dict['exact_r'] = nt_exact_r
    result_dict['exact_f1'] = nt_exact_f1
    result_dict['shift_p'] = nt_shift_p
    result_dict['shift_r'] = nt_shift_r
    result_dict['shift_f1'] = nt_shift_f1
    result_dict['seq_lens'] = seq_lens_list
    result_dict['exact_weighted_f1'] = np.sum(
        np.array(nt_exact_f1_agg) * np.array(seq_lens_list) /
        np.sum(seq_lens_list))
    result_dict['shift_weighted_f1'] = np.sum(
        np.array(nt_shift_f1_agg) * np.array(seq_lens_list) /
        np.sum(seq_lens_list))
    import _pickle as pickle
    with open('../results/rnastralign_long_pure_pp_evaluation_dict.pickle',
              'wb') as f:
        pickle.dump(result_dict, f)
Пример #5
0
        state_pad = torch.zeros(1, 2, 2).to(device)

        PE_batch = get_pe(seq_lens, 600).float().to(device)
        contact_masks = torch.Tensor(contact_map_masks(seq_lens,
                                                       600)).to(device)
        pred_contacts = contact_net(PE_batch, seq_embedding_batch, state_pad)

        # Compute loss
        loss_u = criterion_bce_weighted(pred_contacts * contact_masks,
                                        contacts_batch)

        # print(steps_done)
        if steps_done % OUT_STEP == 0:
            print('Stage 1, epoch for 600: {}, step: {}, loss: {}'.format(
                epoch, steps_done, loss_u))
            u_no_train = postprocess(pred_contacts, seq_embedding_batch, 0.01,
                                     0.1, 50, 1.0, True)
            map_no_train = (u_no_train > 0.5).float()
            f1_no_train_tmp = list(
                map(
                    lambda i: F1_low_tri(map_no_train.cpu()[i],
                                         contacts_batch.cpu()[i]),
                    range(contacts_batch.shape[0])))
            print('Average train F1 score for 600 with pure post-processing: ',
                  np.average(f1_no_train_tmp))

        # Optimize the model
        u_optimizer.zero_grad()
        loss_u.backward()
        u_optimizer.step()
        steps_done = steps_done + 1
        if steps_done % 600 == 0:
def model_eval_all_test():
    contact_net.eval()
    result_no_train = list()
    result_no_train_shift = list()
    batch_n = 0
    for contacts, seq_embeddings, matrix_reps, seq_lens in test_generator:
        if batch_n%10==0:
            print('Batch number: ', batch_n)
        batch_n += 1
        contacts_batch = torch.Tensor(contacts.float()).to(device)
        seq_embedding_batch = torch.Tensor(seq_embeddings.float()).to(device)

        state_pad = torch.zeros(1,2,2).to(device)

        PE_batch = get_pe(seq_lens, 600).float().to(device)
        with torch.no_grad():
            pred_contacts = contact_net(PE_batch, 
                seq_embedding_batch, state_pad)

        # only post-processing without learning
        u_no_train = postprocess(pred_contacts,
            seq_embedding_batch, 0.01, 0.1, 50, 1.0, True)
        map_no_train = (u_no_train > 0.5).float()
        result_no_train_tmp = list(map(lambda i: evaluate_exact(map_no_train.cpu()[i],
            contacts_batch.cpu()[i]), range(contacts_batch.shape[0])))
        result_no_train += result_no_train_tmp
        result_no_train_tmp_shift = list(map(lambda i: evaluate_shifted(map_no_train.cpu()[i],
            contacts_batch.cpu()[i]), range(contacts_batch.shape[0])))
        result_no_train_shift += result_no_train_tmp_shift


    for seq_embedding_batch, PE_batch, contacts_batch, _, _, _, _ in test_generator_1800:
        if batch_n%10==0:
            print('Batch number: ', batch_n)
        batch_n += 1
        seq_embedding_batch = seq_embedding_batch[0].to(device)
        PE_batch = PE_batch[0].to(device)
        contacts_batch = contacts_batch[0]
        # padding the states for supervised training with all 0s
        state_pad = torch.zeros(1,2,2).to(device)

        with torch.no_grad():
            pred_contacts = contact_net(PE_batch, seq_embedding_batch, state_pad)

        # only post-processing without learning
        u_no_train = postprocess(pred_contacts,
            seq_embedding_batch, 0.01, 0.1, 50, 1.0, True)
        map_no_train = (u_no_train > 0.5).float()
        result_no_train_tmp = list(map(lambda i: evaluate_exact(map_no_train.cpu()[i],
            contacts_batch.cpu()[i]), range(contacts_batch.shape[0])))
        result_no_train += result_no_train_tmp
        result_no_train_tmp_shift = list(map(lambda i: evaluate_shifted(map_no_train.cpu()[i],
            contacts_batch.cpu()[i]), range(contacts_batch.shape[0])))
        result_no_train_shift += result_no_train_tmp_shift


    nt_exact_p,nt_exact_r,nt_exact_f1 = zip(*result_no_train)
    nt_shift_p,nt_shift_r,nt_shift_f1 = zip(*result_no_train_shift)

    nt_exact_p = np.nan_to_num(np.array(nt_exact_p))
    nt_exact_r = np.nan_to_num(np.array(nt_exact_r))
    nt_exact_f1 = np.nan_to_num(np.array(nt_exact_f1))

    nt_shift_p = np.nan_to_num(np.array(nt_shift_p))
    nt_shift_r = np.nan_to_num(np.array(nt_shift_r))
    nt_shift_f1 = np.nan_to_num(np.array(nt_shift_f1))
    
    print('Average testing F1 score with pure post-processing: ', np.average(nt_exact_f1[1780:]))
    print('Average testing F1 score with pure post-processing allow shift: ', np.average(nt_shift_f1[1780:]))
    print('Average testing precision with pure post-processing: ', np.average(nt_exact_p[1780:]))
    print('Average testing precision with pure post-processing allow shift: ', np.average(nt_shift_p[1780:]))
    print('Average testing recall with pure post-processing: ', np.average(nt_exact_r[1780:]))
    print('Average testing recall with pure post-processing allow shift: ', np.average(nt_shift_r[1780:]))
Пример #7
0
def model_eval_all_test():
    contact_net.eval()
    lag_pp_net.eval()
    result_no_train = list()
    result_no_train_shift = list()
    result_pp = list()
    result_pp_shift = list()

    f1_no_train = list()
    f1_pp = list()
    seq_lens_list = list()

    # for long sequences
    batch_n = 0
    for seq_embedding_batch, PE_batch, _, comb_index, seq_embeddings, contacts, seq_lens in test_generator_1800:
        if batch_n % 10==0:
            print('Batch number: ', batch_n)
        batch_n += 1

        state_pad = torch.zeros(1,2,2).to(device)
        seq_embedding_batch = seq_embedding_batch[0].to(device)
        PE_batch = PE_batch[0].to(device)
        seq_embedding = torch.Tensor(seq_embeddings.float()).to(device)
        contact_masks = torch.Tensor(contact_map_masks(seq_lens, 1800)).to(device)

    
        with torch.no_grad():
            pred_contacts = contact_net(PE_batch, seq_embedding_batch, state_pad)
            pred_u_map = combine_chunk_u_maps_no_replace(pred_contacts, comb_index, 6)
            pred_u_map = pred_u_map.unsqueeze(0)
            a_pred_list = lag_pp_net(pred_u_map, seq_embedding)

        #  ground truth 
        contacts_batch = torch.Tensor(contacts.float()[:,:1800, :1800])
        # only post-processing, with zero parameters
        u_no_train = postprocess(pred_u_map,
            seq_embedding, 0.01, 0.1, 50, 1.0, True)
        map_no_train = (u_no_train > 0.5).float()
        result_no_train_tmp = list(map(lambda i: evaluate_exact(map_no_train.cpu()[i],
            contacts_batch.cpu()[i]), range(contacts_batch.shape[0])))
        result_no_train += result_no_train_tmp
        result_no_train_tmp_shift = list(map(lambda i: evaluate_shifted(map_no_train.cpu()[i],
            contacts_batch.cpu()[i]), range(contacts_batch.shape[0])))
        result_no_train_shift += result_no_train_tmp_shift

        f1_no_train_tmp = list(map(lambda i: F1_low_tri(map_no_train.cpu()[i],
            contacts_batch.cpu()[i]), range(contacts_batch.shape[0])))
        f1_no_train += f1_no_train_tmp

        # the learning pp result
        final_pred = (a_pred_list[-1].cpu()>0.5).float()
        result_tmp = list(map(lambda i: evaluate_exact(final_pred.cpu()[i], 
            contacts_batch.cpu()[i]), range(contacts_batch.shape[0])))
        result_pp += result_tmp

        result_tmp_shift = list(map(lambda i: evaluate_shifted(final_pred.cpu()[i], 
            contacts_batch.cpu()[i]), range(contacts_batch.shape[0])))
        result_pp_shift += result_tmp_shift

        f1_tmp = list(map(lambda i: F1_low_tri(final_pred.cpu()[i], 
            contacts_batch.cpu()[i]), range(contacts_batch.shape[0])))
        f1_pp += f1_tmp
        seq_lens_list += list(seq_lens)



    nt_exact_p,nt_exact_r,nt_exact_f1 = zip(*result_no_train)
    nt_shift_p,nt_shift_r,nt_shift_f1 = zip(*result_no_train_shift)  

    pp_exact_p,pp_exact_r,pp_exact_f1 = zip(*result_pp)
    pp_shift_p,pp_shift_r,pp_shift_f1 = zip(*result_pp_shift)  
    print('Average testing F1 score with learning post-processing: ', np.average(pp_exact_f1))
    print('Average testing F1 score with zero parameter pp: ', np.average(nt_exact_f1))

    print('Average testing F1 score with learning post-processing allow shift: ', np.average(pp_shift_f1))
    print('Average testing F1 score with zero parameter pp allow shift: ', np.average(nt_shift_f1))

    print('Average testing precision with learning post-processing: ', np.average(pp_exact_p))
    print('Average testing precision with zero parameter pp: ', np.average(nt_exact_p))

    print('Average testing precision with learning post-processing allow shift: ', np.average(pp_shift_p))
    print('Average testing precision with zero parameter pp allow shift: ', np.average(nt_shift_p))

    print('Average testing recall with learning post-processing: ', np.average(pp_exact_r))
    print('Average testing recall with zero parameter pp : ', np.average(nt_exact_r))

    print('Average testing recall with learning post-processing allow shift: ', np.average(pp_shift_r))
    print('Average testing recall with zero parameter pp allow shift: ', np.average(nt_shift_r))