def embed_sequence(x,
                   lm_embed,
                   lstm_stack,
                   proj,
                   include_lm=True,
                   final_only=False,
                   pool='none',
                   use_cuda=False):

    if len(x) == 0:
        return None

    alphabet = Uniprot21()
    x = x.upper()
    # convert to alphabet index
    x = alphabet.encode(x)
    x = torch.from_numpy(x)
    if use_cuda:
        x = x.cuda()

    # embed the sequence
    with torch.no_grad():
        x = x.long().unsqueeze(0)
        z = embed_stack(x,
                        lm_embed,
                        lstm_stack,
                        proj,
                        include_lm=include_lm,
                        final_only=final_only)
        # pool if needed
        z = z.squeeze(0)
        if pool == 'sum':
            z = z.sum(0)
        elif pool == 'max':
            z, _ = z.max(0)
        elif pool == 'avg':
            z = z.mean(0)
        z = z.cpu().numpy()

    return z
def load_data():
    alphabet = Uniprot21()

    path = 'data/transmembrane/TOPCONS2_datasets/TM.3line'
    x_tm, y_tm = load_3line(path, alphabet)

    path = 'data/transmembrane/TOPCONS2_datasets/SP+TM.3line'
    x_tm_sp, y_tm_sp = load_3line(path, alphabet)

    path = 'data/transmembrane/TOPCONS2_datasets/Globular.3line'
    x_glob, y_glob = load_3line(path, alphabet)

    path = 'data/transmembrane/TOPCONS2_datasets/Globular+SP.3line'
    x_glob_sp, y_glob_sp = load_3line(path, alphabet)

    datasets = {
        'TM': (x_tm, y_tm),
        'SP+TM': (x_tm_sp, y_tm_sp),
        'Globular': (x_glob, y_glob),
        'Globular+SP': (x_glob_sp, y_glob_sp)
    }

    return datasets
def main():
    import argparse
    parser = argparse.ArgumentParser('Script for evaluating contact map models.')
    parser.add_argument('model', help='path to saved model')
    parser.add_argument('--dataset', default='2.06 test', help='which dataset (default: 2.06 test)')
    parser.add_argument('--batch-size', default=10, type=int, help='number of sequences to process in each batch (default: 10)')
    parser.add_argument('-o', '--output', help='output file path (default: stdout)')
    parser.add_argument('-d', '--device', type=int, default=-2, help='compute device to use')
    args = parser.parse_args()

    # load the data
    if args.dataset == '2.06 test':
        fasta_path = 'data/SCOPe/astral-scopedom-seqres-gd-sel-gs-bib-95-2.06.test.fa'
        contact_paths = glob.glob('data/SCOPe/pdbstyle-2.06/*/*.png')
    elif args.dataset == '2.07 test' or args.dataset == '2.07 new test':
        fasta_path = 'data/SCOPe/astral-scopedom-seqres-gd-sel-gs-bib-95-2.07-new.fa'
        contact_paths = glob.glob('data/SCOPe/pdbstyle-2.07/*/*.png')
    else:
        raise Exception('Bad dataset argument ' + args.dataset)
    
    alphabet = Uniprot21()
    x,y,names = load_data(fasta_path, contact_paths, alphabet)

    ## set the device
    d = args.device
    use_cuda = (d != -1) and torch.cuda.is_available()
    if d >= 0:
        torch.cuda.set_device(d)

    if use_cuda:
        x = [x_.cuda() for x_ in x]
        y = [y_.cuda() for y_ in y]

    model = torch.load(args.model)
    model.eval()
    if use_cuda:
        model.cuda()

    # predict contact maps
    batch_size = args.batch_size
    dataset = ContactMapDataset(x, y)
    iterator = torch.utils.data.DataLoader(dataset, batch_size=batch_size, collate_fn=collate_lists)
    logits = []
    with torch.no_grad():
        for xmb,ymb in iterator:
            lmb = predict_minibatch(model, xmb, use_cuda)
            logits += lmb

    # calculate performance metrics
    lengths = np.array([len(x_) for x_ in x])
    logits = [logit.cpu().numpy() for logit in logits]
    y = [y_.cpu().numpy() for y_ in y]

    output = args.output
    if output is None:
        output = sys.stdout
    else:
        output = open(output, 'w')
    line = '\t'.join(['Distance', 'Precision', 'Recall', 'F1', 'AUPR', 'Precision@L', 'Precision@L/2', 'Precision@L/5'])
    print(line, file=output)
    output.flush()

    # for all contacts
    y_flat = []
    logits_flat = []
    for i in range(len(y)):
        yi = y[i]
        mask = (yi < 0)
        y_flat.append(yi[~mask])
        logits_flat.append(logits[i][~mask])

    # calculate precision, recall, F1, and area under the precision recall curve for all contacts
    precision = np.zeros(len(x))
    recall = np.zeros(len(x))
    F1 = np.zeros(len(x))
    AUPR = np.zeros(len(x))
    prL = np.zeros(len(x))
    prL2 = np.zeros(len(x))
    prL5 = np.zeros(len(x))
    for i in range(len(x)):
        pr,re,f1,aupr = calc_metrics(logits_flat[i], y_flat[i])
        precision[i] = pr
        recall[i] = re
        F1[i] = f1
        AUPR[i] = aupr

        order = np.argsort(logits_flat[i])[::-1]
        n = lengths[i]
        topL = order[:n]
        prL[i] = y_flat[i][topL].mean()
        topL2 = order[:n//2]
        prL2[i] = y_flat[i][topL2].mean()
        topL5 = order[:n//5]
        prL5[i] = y_flat[i][topL5].mean()

    template = 'All\t{:.5f}\t{:.5f}\t{:.5f}\t{:.5f}\t{:.5f}\t{:.5f}\t{:.5f}'
    line = template.format(precision.mean(), recall.mean(), F1.mean(), AUPR.mean(), prL.mean(), prL2.mean(), prL5.mean())
    print(line, file=output)
    output.flush()

    # for Medium/Long range contacts
    y_flat = []
    logits_flat = []
    for i in range(len(y)):
        yi = y[i]
        mask = (yi < 0)

        medlong = np.tril_indices(len(yi), k=11)
        medlong_mask = np.zeros((len(yi),len(yi)), dtype=np.uint8)
        medlong_mask[medlong] = 1
        mask = mask | (medlong_mask == 1)

        y_flat.append(yi[~mask])
        logits_flat.append(logits[i][~mask])

    # calculate precision, recall, F1, and area under the precision recall curve for all contacts
    precision = np.zeros(len(x))
    recall = np.zeros(len(x))
    F1 = np.zeros(len(x))
    AUPR = np.zeros(len(x))
    prL = np.zeros(len(x))
    prL2 = np.zeros(len(x))
    prL5 = np.zeros(len(x))
    for i in range(len(x)):
        pr,re,f1,aupr = calc_metrics(logits_flat[i], y_flat[i])
        precision[i] = pr
        recall[i] = re
        F1[i] = f1
        AUPR[i] = aupr

        order = np.argsort(logits_flat[i])[::-1]
        n = lengths[i]
        topL = order[:n]
        prL[i] = y_flat[i][topL].mean()
        topL2 = order[:n//2]
        prL2[i] = y_flat[i][topL2].mean()
        topL5 = order[:n//5]
        prL5[i] = y_flat[i][topL5].mean()

    template = 'Medium/Long\t{:.5f}\t{:.5f}\t{:.5f}\t{:.5f}\t{:.5f}\t{:.5f}\t{:.5f}'
    line = template.format(np.nanmean(precision), np.nanmean(recall), np.nanmean(F1), np.nanmean(AUPR), np.nanmean(prL)
                          , np.nanmean(prL2), np.nanmean(prL5))
    print(line, file=output)
    output.flush()
def main():
    import argparse
    parser = argparse.ArgumentParser(
        'Script for evaluating similarity model on SCOP test set.')

    parser.add_argument(
        'model',
        help=
        'path to saved model file or "nw-align" for Needleman-Wunsch alignment score baseline'
    )

    parser.add_argument('--dev',
                        action='store_true',
                        help='use train/dev split')

    parser.add_argument(
        '--batch-size',
        default=64,
        type=int,
        help='number of sequence pairs to process in each batch (default: 64)')

    parser.add_argument('-d',
                        '--device',
                        type=int,
                        default=-2,
                        help='compute device to use')

    parser.add_argument('--coarse',
                        action='store_true',
                        help='use coarse comparison rather than full SSA')

    args = parser.parse_args()

    scop_train_path = 'data/SCOPe/astral-scopedom-seqres-gd-sel-gs-bib-95-2.06.train.sampledpairs.txt'

    eval_paths = [
        ('2.06-test',
         'data/SCOPe/astral-scopedom-seqres-gd-sel-gs-bib-95-2.06.test.sampledpairs.txt'
         ),
        ('2.07-new',
         'data/SCOPe/astral-scopedom-seqres-gd-sel-gs-bib-95-2.07-new.allpairs.txt'
         )
    ]
    if args.dev:
        scop_train_path = 'data/SCOPe/astral-scopedom-seqres-gd-sel-gs-bib-95-2.06.train.train.sampledpairs.txt'

        eval_paths = [(
            '2.06-dev',
            'data/SCOPe/astral-scopedom-seqres-gd-sel-gs-bib-95-2.06.train.dev.sampledpairs.txt'
        )]

    ## load the data
    alphabet = Uniprot21()
    x0_train, x1_train, y_train = load_pairs(scop_train_path, alphabet)

    ## load the model
    if args.model == 'nw-align':
        model = NWAlign(alphabet)
    elif args.model in ['hhalign', 'phmmer', 'TMalign']:
        model = args.model
    else:
        model = torch.load(args.model)
        model.eval()

        ## set the device
        d = args.device
        use_cuda = (d != -1) and torch.cuda.is_available()
        if d >= 0:
            torch.cuda.set_device(d)

        if use_cuda:
            model.cuda()

        mode = 'align'
        if args.coarse:
            mode = 'coarse'
        model = TorchModel(model, use_cuda, mode=mode)

    batch_size = args.batch_size

    ## for calculating the classification accuracy, first find the best partitions using the training set
    if type(model) is str:
        path = 'data/SCOPe/astral-scopedom-seqres-gd-sel-gs-bib-95-2.06.train.sampledpairs.' \
                + model + '.npy'
        scores = np.load(path)
        scores = scores.mean(1)
    else:
        scores = score_pairs(model, x0_train, x1_train, batch_size)
    thresholds = find_best_thresholds(scores, y_train)

    print(
        'Dataset\tAccuracy\tPearson\'s r\tSpearman\'s rho\tClass\tFold\tSuperfamily\tFamily'
    )

    accuracy, r, rho, aupr = calculate_metrics(scores, y_train, thresholds)

    template = '{:.5f}\t{:.5f}\t{:.5f}\t{:.5f}\t{:.5f}\t{:.5f}\t{:.5f}'

    line = '2.06-train\t' + template.format(accuracy, r, rho, aupr[0], aupr[1],
                                            aupr[2], aupr[3])
    #line = '\t'.join(['2.06-train', str(accuracy), str(r), str(rho), str(aupr[0]), str(aupr[1]), str(aupr[2]), str(aupr[3])])
    print(line)

    for dset, path in eval_paths:
        x0_test, x1_test, y_test = load_pairs(path, alphabet)
        if type(model) is str:
            path = os.path.splitext(path)[0]
            path = path + '.' + model + '.npy'
            scores = np.load(path)
            scores = scores.mean(1)
        else:
            scores = score_pairs(model, x0_test, x1_test, batch_size)
        accuracy, r, rho, aupr = calculate_metrics(scores, y_test, thresholds)

        line = dset + '\t' + template.format(accuracy, r, rho, aupr[0],
                                             aupr[1], aupr[2], aupr[3])
        #line = '\t'.join([dset, str(accuracy), str(r), str(rho), str(aupr[0]), str(aupr[1]), str(aupr[2]), str(aupr[3])])
        print(line)
Example #5
0
def main():
    import argparse
    parser = argparse.ArgumentParser(
        'Script for evaluating contact map models.')
    parser.add_argument('model', help='path to saved model')
    parser.add_argument(
        '--batch-size',
        default=10,
        type=int,
        help='number of sequences to process in each batch (default: 10)')
    parser.add_argument('-o',
                        '--output',
                        help='output file path (default: stdout)')
    parser.add_argument('-d',
                        '--device',
                        type=int,
                        default=-2,
                        help='compute device to use')
    parser.add_argument('--individual', action='store_true')
    args = parser.parse_args()

    # load the data
    fasta_path = 'data/casp12/casp12.fm-domains.seq.fa'
    contact_paths = glob.glob('data/casp12/domains_T0/*.png')

    alphabet = Uniprot21()
    baselines = None
    if args.model == 'baselines':
        x, y, names, baselines = load_data(fasta_path,
                                           contact_paths,
                                           alphabet,
                                           baselines=True)
    else:
        x, y, names = load_data(fasta_path, contact_paths, alphabet)

    if baselines is not None:
        output = args.output
        if output is None:
            output = sys.stdout
        else:
            output = open(output, 'w')

        lengths = np.array([len(x_) for x_ in x])
        calc_baselines(baselines,
                       y,
                       lengths,
                       names,
                       output=output,
                       individual=args.individual)

        sys.exit(0)

    ## set the device
    d = args.device
    use_cuda = (d != -1) and torch.cuda.is_available()
    if d >= 0:
        torch.cuda.set_device(d)

    if use_cuda:
        x = [x_.cuda() for x_ in x]
        y = [y_.cuda() for y_ in y]

    model = torch.load(args.model)
    model.eval()
    if use_cuda:
        model.cuda()

    # predict contact maps
    batch_size = args.batch_size
    dataset = ContactMapDataset(x, y)
    iterator = torch.utils.data.DataLoader(dataset,
                                           batch_size=batch_size,
                                           collate_fn=collate_lists)
    logits = []
    with torch.no_grad():
        for xmb, ymb in iterator:
            lmb = predict_minibatch(model, xmb, use_cuda)
            logits += lmb

    # calculate performance metrics
    lengths = np.array([len(x_) for x_ in x])
    logits = [logit.cpu().numpy() for logit in logits]
    y = [y_.cpu().numpy() for y_ in y]

    output = args.output
    if output is None:
        output = sys.stdout
    else:
        output = open(output, 'w')
    if args.individual:
        line = '\t'.join([
            'Distance', 'Protein', 'Precision', 'Recall', 'F1', 'AUPR',
            'Precision@L', 'Precision@L/2', 'Precision@L/5'
        ])
    else:
        line = '\t'.join([
            'Distance', 'Precision', 'Recall', 'F1', 'AUPR', 'Precision@L',
            'Precision@L/2', 'Precision@L/5'
        ])
    print(line, file=output)
    output.flush()

    # for all contacts
    y_flat = []
    logits_flat = []
    for i in range(len(y)):
        yi = y[i]
        mask = (yi < 0)
        y_flat.append(yi[~mask])
        logits_flat.append(logits[i][~mask])

    # calculate precision, recall, F1, and area under the precision recall curve for all contacts
    precision = np.zeros(len(x))
    recall = np.zeros(len(x))
    F1 = np.zeros(len(x))
    AUPR = np.zeros(len(x))
    prL = np.zeros(len(x))
    prL2 = np.zeros(len(x))
    prL5 = np.zeros(len(x))
    for i in range(len(x)):
        pr, re, f1, aupr = calc_metrics(logits_flat[i], y_flat[i])
        precision[i] = pr
        recall[i] = re
        F1[i] = f1
        AUPR[i] = aupr

        order = np.argsort(logits_flat[i])[::-1]
        n = lengths[i]
        topL = order[:n]
        prL[i] = y_flat[i][topL].mean()
        topL2 = order[:n // 2]
        prL2[i] = y_flat[i][topL2].mean()
        topL5 = order[:n // 5]
        prL5[i] = y_flat[i][topL5].mean()

    if args.individual:
        template = 'All\t{}\t{:.5f}\t{:.5f}\t{:.5f}\t{:.5f}\t{:.5f}\t{:.5f}\t{:.5f}'
        for i in range(len(x)):
            name = names[i]
            line = template.format(name, precision[i], recall[i], F1[i],
                                   AUPR[i], prL[i], prL2[i], prL5[i])
            print(line, file=output)
    else:
        template = 'All\t{:.5f}\t{:.5f}\t{:.5f}\t{:.5f}\t{:.5f}\t{:.5f}\t{:.5f}'
        line = template.format(precision.mean(), recall.mean(), F1.mean(),
                               AUPR.mean(), prL.mean(), prL2.mean(),
                               prL5.mean())
        print(line, file=output)
    output.flush()

    # for Medium/Long range contacts
    y_flat = []
    logits_flat = []
    for i in range(len(y)):
        yi = y[i]
        mask = (yi < 0)

        medlong = np.tril_indices(len(yi), k=11)
        medlong_mask = np.zeros((len(yi), len(yi)), dtype=np.uint8)
        medlong_mask[medlong] = 1
        mask = mask | (medlong_mask == 1)

        y_flat.append(yi[~mask])
        logits_flat.append(logits[i][~mask])

    # calculate precision, recall, F1, and area under the precision recall curve for all contacts
    precision = np.zeros(len(x))
    recall = np.zeros(len(x))
    F1 = np.zeros(len(x))
    AUPR = np.zeros(len(x))
    prL = np.zeros(len(x))
    prL2 = np.zeros(len(x))
    prL5 = np.zeros(len(x))
    for i in range(len(x)):
        pr, re, f1, aupr = calc_metrics(logits_flat[i], y_flat[i])
        precision[i] = pr
        recall[i] = re
        F1[i] = f1
        AUPR[i] = aupr

        order = np.argsort(logits_flat[i])[::-1]
        n = lengths[i]
        topL = order[:n]
        prL[i] = y_flat[i][topL].mean()
        topL2 = order[:n // 2]
        prL2[i] = y_flat[i][topL2].mean()
        topL5 = order[:n // 5]
        prL5[i] = y_flat[i][topL5].mean()

    if args.individual:
        template = 'Medium/Long\t{}\t{:.5f}\t{:.5f}\t{:.5f}\t{:.5f}\t{:.5f}\t{:.5f}\t{:.5f}'
        for i in range(len(x)):
            name = names[i]
            line = template.format(name, precision[i], recall[i], F1[i],
                                   AUPR[i], prL[i], prL2[i], prL5[i])
            print(line, file=output)
    else:
        template = 'Medium/Long\t{:.5f}\t{:.5f}\t{:.5f}\t{:.5f}\t{:.5f}\t{:.5f}\t{:.5f}'
        line = template.format(precision.mean(), recall.mean(), F1.mean(),
                               AUPR.mean(), prL.mean(), prL2.mean(),
                               prL5.mean())
        print(line, file=output)
    output.flush()
def main():
    import argparse
    parser = argparse.ArgumentParser(
        'Script for training contact prediction model')

    parser.add_argument('--dev',
                        action='store_true',
                        help='use train/dev split')

    parser.add_argument('--rnn-type',
                        choices=['lstm', 'gru'],
                        default='lstm',
                        help='type of RNN block to use (default: lstm)')
    parser.add_argument('--embedding-dim',
                        type=int,
                        default=100,
                        help='embedding dimension (default: 40)')
    parser.add_argument('--input-dim',
                        type=int,
                        default=512,
                        help='dimension of input to RNN (default: 512)')
    parser.add_argument('--rnn-dim',
                        type=int,
                        default=512,
                        help='hidden units of RNNs (default: 128)')
    parser.add_argument('--num-layers',
                        type=int,
                        default=3,
                        help='number of RNN layers (default: 3)')
    parser.add_argument('--dropout',
                        type=float,
                        default=0,
                        help='dropout probability (default: 0)')

    parser.add_argument(
        '--hidden-dim',
        type=int,
        default=50,
        help=
        'number of hidden units for comparison layer in contact predictionn (default: 50)'
    )
    parser.add_argument(
        '--width',
        type=int,
        default=7,
        help='width of convolutional filter for contact prediction (default: 7)'
    )

    parser.add_argument('--epoch-size',
                        type=int,
                        default=100000,
                        help='number of examples per epoch (default: 100,000)')
    parser.add_argument(
        '--epoch-scale',
        type=int,
        default=5,
        help='report heldout performance every this many epochs (default: 5)')
    parser.add_argument('--num-epochs',
                        type=int,
                        default=100,
                        help='number of epochs (default: 100)')

    parser.add_argument(
        '--similarity-batch-size',
        type=int,
        default=64,
        help=
        'minibatch size for similarity prediction loss in pairs (default: 64)')
    parser.add_argument(
        '--contact-batch-size',
        type=int,
        default=10,
        help='minibatch size for contact predictionn loss (default: 10)')

    parser.add_argument('--weight-decay',
                        type=float,
                        default=0,
                        help='L2 regularization (default: 0)')
    parser.add_argument('--lr', type=float, default=0.001)
    parser.add_argument(
        '--lambda',
        dest='lambda_',
        type=float,
        default=0.5,
        help=
        'weight on the similarity objective, contact map objective weight is one minus this (default: 0.5)'
    )

    parser.add_argument('--tau',
                        type=float,
                        default=0.5,
                        help='sampling proportion exponent (default: 0.5)')
    parser.add_argument(
        '--augment',
        type=float,
        default=0,
        help=
        'probability of resampling amino acid for data augmentation (default: 0)'
    )

    parser.add_argument('--lm',
                        help='pretrained LM to use as initial embedding')

    parser.add_argument('-o',
                        '--output',
                        help='output file path (default: stdout)')
    parser.add_argument('--save-prefix', help='path prefix for saving models')
    parser.add_argument('-d',
                        '--device',
                        type=int,
                        default=-2,
                        help='compute device to use')

    args = parser.parse_args()

    prefix = args.output

    ## set the device
    d = args.device
    use_cuda = (d != -1) and torch.cuda.is_available()
    if d >= 0:
        torch.cuda.set_device(d)

    ## make the datasets
    alphabet = Uniprot21()

    astral_train_path = 'data/SCOPe/astral-scopedom-seqres-gd-sel-gs-bib-95-2.06.train.fa'
    astral_test_path = 'data/SCOPe/astral-scopedom-seqres-gd-sel-gs-bib-95-2.06.test.fa'
    astral_testpairs_path = 'data/SCOPe/astral-scopedom-seqres-gd-sel-gs-bib-95-2.06.test.sampledpairs.txt'
    if args.dev:
        astral_train_path = 'data/SCOPe/astral-scopedom-seqres-gd-sel-gs-bib-95-2.06.train.train.fa'
        astral_test_path = 'data/SCOPe/astral-scopedom-seqres-gd-sel-gs-bib-95-2.06.train.dev.fa'
        astral_testpairs_path = 'data/SCOPe/astral-scopedom-seqres-gd-sel-gs-bib-95-2.06.train.dev.sampledpairs.txt'

    print('# loading training sequences:', astral_train_path, file=sys.stderr)
    x_train, structs_train, contacts_train = load_data(astral_train_path,
                                                       alphabet)
    if use_cuda:
        x_train = [x.cuda() for x in x_train]
        #contacts_train = [c.cuda() for c in contacts_train]
    print('# loaded', len(x_train), 'training sequences', file=sys.stderr)

    print('# loading test sequences:', astral_test_path, file=sys.stderr)
    x_test, _, contacts_test = load_data(astral_test_path, alphabet)
    if use_cuda:
        x_test = [x.cuda() for x in x_test]
        #contacts_test = [c.cuda() for c in contacts_test]
    print('# loaded',
          len(x_test),
          'contact map test sequences',
          file=sys.stderr)

    x0_test, x1_test, y_scop_test = load_scop_testpairs(
        astral_testpairs_path, alphabet)
    if use_cuda:
        x0_test = [x.cuda() for x in x0_test]
        x1_test = [x.cuda() for x in x1_test]
    print('# loaded', len(x0_test), 'scop test pairs', file=sys.stderr)

    ## make the dataset iterators

    # data augmentation by resampling amino acids
    augment = None
    p = 0
    if args.augment > 0:
        p = args.augment
        trans = torch.ones(len(alphabet), len(alphabet))
        trans = trans / trans.sum(1, keepdim=True)
        if use_cuda:
            trans = trans.cuda()
        augment = MultinomialResample(trans, p)
    print('# resampling amino acids with p:', p, file=sys.stderr)

    # SCOP structural similarity datasets
    scop_levels = torch.cumprod(
        (structs_train.unsqueeze(1) == structs_train.unsqueeze(0)).long(), 2)
    scop_train = AllPairsDataset(x_train, scop_levels, augment=augment)
    scop_test = PairedDataset(x0_test, x1_test, y_scop_test)

    # contact map datasets
    cmap_train = ContactMapDataset(x_train, contacts_train, augment=augment)
    cmap_test = ContactMapDataset(x_test, contacts_test)

    # iterators for contacts data
    batch_size = args.contact_batch_size
    cmap_train_iterator = torch.utils.data.DataLoader(cmap_train,
                                                      batch_size=batch_size,
                                                      shuffle=True,
                                                      collate_fn=collate_lists)
    cmap_test_iterator = torch.utils.data.DataLoader(cmap_test,
                                                     batch_size=batch_size,
                                                     collate_fn=collate_lists)

    # make the SCOP training iterator have same number of minibatches
    num_steps = len(cmap_train_iterator)
    batch_size = args.similarity_batch_size
    epoch_size = num_steps * batch_size

    similarity = scop_levels.numpy().sum(2)
    levels, counts = np.unique(similarity, return_counts=True)
    order = np.argsort(levels)
    levels = levels[order]
    counts = counts[order]

    tau = args.tau
    print('# using tau:', tau, file=sys.stderr)
    print('#', counts**tau / np.sum(counts**tau), file=sys.stderr)
    weights = counts**tau / counts
    weights = weights[similarity].ravel()
    sampler = torch.utils.data.sampler.WeightedRandomSampler(
        weights, epoch_size)
    N = epoch_size

    # iterators for similarity data
    scop_train_iterator = torch.utils.data.DataLoader(
        scop_train,
        batch_size=batch_size,
        sampler=sampler,
        collate_fn=collate_paired_sequences)
    scop_test_iterator = torch.utils.data.DataLoader(
        scop_test, batch_size=batch_size, collate_fn=collate_paired_sequences)

    report_steps = args.epoch_scale

    ## initialize the model
    rnn_type = args.rnn_type
    rnn_dim = args.rnn_dim
    num_layers = args.num_layers

    embedding_size = args.embedding_dim
    input_dim = args.input_dim
    dropout = args.dropout

    print('# initializing embedding model with:', file=sys.stderr)
    print('# embedding_size:', embedding_size, file=sys.stderr)
    print('# input_dim:', input_dim, file=sys.stderr)
    print('# rnn_dim:', rnn_dim, file=sys.stderr)
    print('# num_layers:', num_layers, file=sys.stderr)
    print('# dropout:', dropout, file=sys.stderr)

    lm = None
    if args.lm is not None:
        print('# using pretrained LM:', args.lm, file=sys.stderr)
        lm = torch.load(args.lm)
        lm.eval()
        ## do not update the LM parameters
        for param in lm.parameters():
            param.requires_grad = False

    embedding = src.models.embedding.StackedRNN(len(alphabet),
                                                input_dim,
                                                rnn_dim,
                                                embedding_size,
                                                nlayers=num_layers,
                                                dropout=dropout,
                                                lm=lm)

    # similarity prediction parameters
    similarity_kwargs = {}

    # contact map prediction parameters
    hidden_dim = args.hidden_dim
    width = args.width
    cmap_kwargs = {'hidden_dim': hidden_dim, 'width': width}

    model = src.models.multitask.SCOPCM(embedding,
                                        similarity_kwargs=similarity_kwargs,
                                        cmap_kwargs=cmap_kwargs)
    if use_cuda:
        model.cuda()

    ## setup training parameters and optimizer
    num_epochs = args.num_epochs

    weight_decay = args.weight_decay
    lr = args.lr

    print('# training with Adam: lr={}, weight_decay={}'.format(
        lr, weight_decay),
          file=sys.stderr)
    params = [p for p in model.parameters() if p.requires_grad]
    optim = torch.optim.Adam(params, lr=lr, weight_decay=weight_decay)

    scop_weight = args.lambda_
    cmap_weight = 1 - scop_weight

    print('# weighting tasks with SIMILARITY: {:.3f}, CONTACTS: {:.3f}'.format(
        scop_weight, cmap_weight),
          file=sys.stderr)

    ## train the model
    print('# training model', file=sys.stderr)

    save_prefix = args.save_prefix
    output = args.output
    if output is None:
        output = sys.stdout
    else:
        output = open(output, 'w')
    digits = int(np.floor(np.log10(num_epochs))) + 1
    tokens = [
        'sim_loss', 'sim_mse', 'sim_acc', 'sim_r', 'sim_rho', 'cmap_loss',
        'cmap_pr', 'cmap_re', 'cmap_f1', 'cmap_aupr'
    ]
    line = '\t'.join(['epoch', 'split'] + tokens)
    print(line, file=output)

    prog_template = '# [{}/{}] training {:.1%} sim_loss={:.5f}, sim_acc={:.5f}, cmap_loss={:.5f}, cmap_f1={:.5f}'

    for epoch in range(num_epochs):
        # train epoch
        model.train()

        scop_n = 0
        scop_loss_accum = 0
        scop_mse_accum = 0
        scop_acc_accum = 0

        cmap_n = 0
        cmap_loss_accum = 0
        cmap_pp = 0
        cmap_pr_accum = 0
        cmap_gp = 0
        cmap_re_accum = 0

        for (cmap_x, cmap_y), (scop_x0, scop_x1,
                               scop_y) in zip(cmap_train_iterator,
                                              scop_train_iterator):

            # calculate gradients and metrics for similarity part
            loss, correct, mse, b = similarity_grad(model,
                                                    scop_x0,
                                                    scop_x1,
                                                    scop_y,
                                                    use_cuda,
                                                    weight=scop_weight)

            scop_n += b
            delta = b * (loss - scop_loss_accum)
            scop_loss_accum += delta / scop_n
            delta = correct - b * scop_acc_accum
            scop_acc_accum += delta / scop_n
            delta = b * (mse - scop_mse_accum)
            scop_mse_accum += delta / scop_n

            report = ((scop_n - b) // 100 < scop_n // 100)

            # calculate the contact map prediction gradients and metrics
            loss, tp, gp_, pp_, b = contacts_grad(model,
                                                  cmap_x,
                                                  cmap_y,
                                                  use_cuda,
                                                  weight=cmap_weight)

            cmap_gp += gp_
            delta = tp - gp_ * cmap_re_accum
            cmap_re_accum += delta / cmap_gp

            cmap_pp += pp_
            delta = tp - pp_ * cmap_pr_accum
            cmap_pr_accum += delta / cmap_pp

            cmap_n += b
            delta = b * (loss - cmap_loss_accum)
            cmap_loss_accum += delta / cmap_n

            ## update the parameters
            optim.step()
            optim.zero_grad()
            model.clip()

            if report:
                f1 = 2 * cmap_pr_accum * cmap_re_accum / (cmap_pr_accum +
                                                          cmap_re_accum)
                line = prog_template.format(epoch + 1, num_epochs, scop_n / N,
                                            scop_loss_accum, scop_acc_accum,
                                            cmap_loss_accum, f1)
                print(line, end='\r', file=sys.stderr)
        print(' ' * 80, end='\r', file=sys.stderr)
        f1 = 2 * cmap_pr_accum * cmap_re_accum / (cmap_pr_accum +
                                                  cmap_re_accum)
        tokens = [
            scop_loss_accum, scop_mse_accum, scop_acc_accum, '-', '-',
            cmap_loss_accum, cmap_pr_accum, cmap_re_accum, f1, '-'
        ]
        tokens = [x if type(x) is str else '{:.5f}'.format(x) for x in tokens]

        line = '\t'.join([str(epoch + 1).zfill(digits), 'train'] + tokens)
        print(line, file=output)
        output.flush()

        # eval and save model
        if (epoch + 1) % report_steps == 0:
            model.eval()
            with torch.no_grad():
                scop_loss, scop_acc, scop_mse, scop_r, scop_rho = \
                        eval_similarity(model, scop_test_iterator, use_cuda)
                cmap_loss, cmap_pr, cmap_re, cmap_f1, cmap_aupr = \
                        eval_contacts(model, cmap_test_iterator, use_cuda)

            tokens = [
                scop_loss, scop_mse, scop_acc, scop_r, scop_rho, cmap_loss,
                cmap_pr, cmap_re, cmap_f1, cmap_aupr
            ]
            tokens = ['{:.5f}'.format(x) for x in tokens]

            line = '\t'.join([str(epoch + 1).zfill(digits), 'test'] + tokens)
            print(line, file=output)
            output.flush()

            # save the model
            if save_prefix is not None:
                save_path = save_prefix + '_epoch' + str(epoch + 1).zfill(
                    digits) + '.sav'
                model.cpu()
                torch.save(model, save_path)
                if use_cuda:
                    model.cuda()
def main():
    args = parser.parse_args()

    alph = Uniprot21()
    ntokens = len(alph)

    ## load the training sequences
    train_group, X_train = load_pfam(pfam_train, alph)
    print('# loaded',
          len(X_train),
          'sequences from',
          pfam_train,
          file=sys.stderr)

    ## load the testing sequences
    test_group, X_test = load_pfam(pfam_test, alph)
    print('# loaded',
          len(X_test),
          'sequences from',
          pfam_test,
          file=sys.stderr)

    ## initialize the model
    nin = ntokens + 1
    nout = ntokens
    embedding_dim = 21
    hidden_dim = args.hidden_dim
    num_layers = args.num_layers
    mask_idx = ntokens
    dropout = args.dropout

    tied = not args.untied

    model = src.models.sequence.BiLM(nin,
                                     nout,
                                     embedding_dim,
                                     hidden_dim,
                                     num_layers,
                                     mask_idx=mask_idx,
                                     dropout=dropout,
                                     tied=tied)
    print('# initialized model', file=sys.stderr)

    device = args.device
    use_cuda = torch.cuda.is_available() and (device == -2 or device >= 0)
    if device >= 0:
        torch.cuda.set_device(device)
    if use_cuda:
        model = model.cuda()

    ## form the data iterators and optimizer
    lr = args.lr
    l2 = args.l2
    solver = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=l2)

    def collate(xs):
        B = len(xs)
        N = max(len(x) for x in xs)
        lengths = np.array([len(x) for x in xs], dtype=int)

        order = np.argsort(lengths)[::-1]
        lengths = lengths[order]

        X = torch.LongTensor(B, N).zero_() + mask_idx
        for i in range(B):
            x = xs[order[i]]
            n = len(x)
            X[i, :n] = torch.from_numpy(x)
        return X, lengths

    mb = args.minibatch_size

    train_iterator = torch.utils.data.DataLoader(X_train,
                                                 batch_size=mb,
                                                 shuffle=True,
                                                 collate_fn=collate)
    test_iterator = torch.utils.data.DataLoader(X_test,
                                                batch_size=mb,
                                                collate_fn=collate)

    ## fit the model!

    print('# training model', file=sys.stderr)

    output = sys.stdout
    if args.output is not None:
        output = open(args.output, 'w')

    num_epochs = args.num_epochs
    clip = args.clip

    save_prefix = args.save_prefix
    digits = int(np.floor(np.log10(num_epochs))) + 1

    print('epoch\tsplit\tlog_p\tperplexity\taccuracy', file=output)
    output.flush()

    for epoch in range(num_epochs):
        # train epoch
        model.train()
        it = 0
        n = 0
        accuracy = 0
        loss_accum = 0
        for X, lengths in train_iterator:
            if use_cuda:
                X = X.cuda()
            X = Variable(X)
            logp = model(X)

            mask = (X != mask_idx)

            index = X * mask.long()
            loss = -logp.gather(2, index.unsqueeze(2)).squeeze(2)
            loss = torch.mean(loss.masked_select(mask))

            loss.backward()

            # clip the gradient
            torch.nn.utils.clip_grad_norm_(model.parameters(), clip)

            solver.step()
            solver.zero_grad()

            _, y_hat = torch.max(logp, 2)
            correct = torch.sum((y_hat == X).masked_select(mask))
            #correct = torch.sum((y_hat == X)[mask.nonzero()].float())

            b = mask.long().sum().item()
            n += b
            delta = b * (loss.item() - loss_accum)
            loss_accum += delta / n
            delta = correct.item() - b * accuracy
            accuracy += delta / n

            b = X.size(0)
            it += b
            if (it - b) // 100 < it // 100:
                print(
                    '# [{}/{}] training {:.1%} loss={:.5f}, acc={:.5f}'.format(
                        epoch + 1, num_epochs, it / len(X_train), loss_accum,
                        accuracy),
                    end='\r',
                    file=sys.stderr)
        print(' ' * 80, end='\r', file=sys.stderr)

        perplex = np.exp(loss_accum)
        string = str(epoch+1).zfill(digits) + '\t' + 'train' + '\t' + str(loss_accum) \
                 + '\t' + str(perplex) + '\t' + str(accuracy)
        print(string, file=output)
        output.flush()

        # test epoch
        model.eval()
        it = 0
        n = 0
        accuracy = 0
        loss_accum = 0
        with torch.no_grad():
            for X, lengths in test_iterator:
                if use_cuda:
                    X = X.cuda()
                X = Variable(X)
                logp = model(X)

                mask = (X != mask_idx)

                index = X * mask.long()
                loss = -logp.gather(2, index.unsqueeze(2)).squeeze(2)
                loss = torch.mean(loss.masked_select(mask))

                _, y_hat = torch.max(logp, 2)
                correct = torch.sum((y_hat == X).masked_select(mask))

                b = mask.long().sum().item()
                n += b
                delta = b * (loss.item() - loss_accum)
                loss_accum += delta / n
                delta = correct.item() - b * accuracy
                accuracy += delta / n

                b = X.size(0)
                it += b
                if (it - b) // 100 < it // 100:
                    print(
                        '# [{}/{}] test {:.1%} loss={:.5f}, acc={:.5f}'.format(
                            epoch + 1, num_epochs, it / len(X_test),
                            loss_accum, accuracy),
                        end='\r',
                        file=sys.stderr)
        print(' ' * 80, end='\r', file=sys.stderr)

        perplex = np.exp(loss_accum)
        string = str(epoch+1).zfill(digits) + '\t' + 'test' + '\t' + str(loss_accum) \
                 + '\t' + str(perplex) + '\t' + str(accuracy)
        print(string, file=output)
        output.flush()

        ## save the model
        if save_prefix is not None:
            save_path = save_prefix + '_epoch' + str(epoch +
                                                     1).zfill(digits) + '.sav'
            model = model.cpu()
            torch.save(model, save_path)
            if use_cuda:
                model = model.cuda()
Example #8
0
def main():
    import argparse
    parser = argparse.ArgumentParser(
        'Script for evaluating similarity model on SCOP test set.')

    parser.add_argument(
        'features',
        help=
        'path to saved embedding model file or "1-", "3-", or "5-mer" for k-mer features'
    )

    parser.add_argument('--num-epochs',
                        type=int,
                        default=10,
                        help='number of epochs to train for (default: 10)')
    parser.add_argument('--all-hidden',
                        action='store_true',
                        help='use all hidden layers as features')

    parser.add_argument('-v',
                        '--print-examples',
                        default=0,
                        type=int,
                        help='number of examples to print (default: 0)')

    parser.add_argument('-o',
                        '--output',
                        help='output file path (default: stdout)')
    parser.add_argument('--save-prefix', help='path prefix for saving models')
    parser.add_argument('-d',
                        '--device',
                        type=int,
                        default=-2,
                        help='compute device to use')

    args = parser.parse_args()
    num_epochs = args.num_epochs

    ## load the data
    alphabet = Uniprot21()
    secstr = SecStr8

    names_train, x_train, y_train = load_secstr(secstr_train_path, alphabet,
                                                secstr)
    names_test, x_test, y_test = load_secstr(secstr_test_path, alphabet,
                                             secstr)

    sequences_test = [
        ''.join(alphabet[c] for c in x_test[i]) for i in range(len(x_test))
    ]

    y_train = np.concatenate(y_train, 0)

    ## set the device
    d = args.device
    use_cuda = (d != -1) and torch.cuda.is_available()
    if d >= 0:
        torch.cuda.set_device(d)

    if args.features == '1-mer':
        n = len(alphabet)
        x_test = [x.astype(int) for x in x_test]
    elif args.features == '3-mer':
        x_train, n = kmer_features(x_train, len(alphabet), 3)
        x_test, _ = kmer_features(x_test, len(alphabet), 3)
    elif args.features == '5-mer':
        x_train, n = kmer_features(x_train, len(alphabet), 5)
        x_test, _ = kmer_features(x_test, len(alphabet), 5)
    else:
        features = torch.load(args.features)
        features.eval()

        if use_cuda:
            features.cuda()

        features = TorchModel(features,
                              use_cuda,
                              full_features=args.all_hidden)
        batch_size = 32  # batch size for featurizing sequences

        with torch.no_grad():
            z_train = []
            for i in range(0, len(x_train), batch_size):
                for z in features(x_train[i:i + batch_size]):
                    z_train.append(z.cpu().numpy())
            x_train = z_train

            z_test = []
            for i in range(0, len(x_test), batch_size):
                for z in features(x_test[i:i + batch_size]):
                    z_test.append(z.cpu().numpy())
            x_test = z_test

        n = x_train[0].shape[1]
        del features
        del z_train
        del z_test

    print('split', 'epoch', 'loss', 'perplexity', 'accuracy')

    if args.features.endswith('-mer'):
        x_train = np.concatenate(x_train, 0)
        model = fit_kmer_potentials(x_train, y_train, n, len(secstr))
    else:
        x_train = torch.cat([torch.from_numpy(x) for x in x_train], 0)
        if use_cuda and not args.all_hidden:
            x_train = x_train.cuda()

        num_hidden = 1024
        model = nn.Sequential(nn.Linear(n, num_hidden), nn.ReLU(),
                              nn.Linear(num_hidden, num_hidden), nn.ReLU(),
                              nn.Linear(num_hidden, len(secstr)))

        y_train = torch.from_numpy(y_train).long()
        if use_cuda:
            y_train = y_train.cuda()
            model.cuda()

        fit_nn_potentials(model,
                          x_train,
                          y_train,
                          num_epochs=num_epochs,
                          use_cuda=use_cuda)

    if use_cuda:
        model.cuda()
    model.eval()

    num_examples = args.print_examples
    if num_examples > 0:
        names_examples = names_test[:num_examples]
        x_examples = x_test[:num_examples]
        y_examples = y_test[:num_examples]

    A = np.zeros((8, 3), dtype=np.float32)
    I = np.zeros(8, dtype=int)
    # helix
    A[0, 0] = 1.0
    A[3, 0] = 1.0
    A[4, 0] = 1.0
    I[0] = 0
    I[3] = 0
    I[4] = 0
    # sheet
    A[1, 1] = 1.0
    A[2, 1] = 1.0
    I[1] = 1
    I[2] = 1
    # coil
    A[5, 2] = 1.0
    A[6, 2] = 1.0
    A[7, 2] = 1.0
    I[5] = 2
    I[6] = 2
    I[7] = 2

    A = torch.from_numpy(A)
    I = torch.from_numpy(I)
    if use_cuda:
        A = A.cuda()
        I = I.cuda()

    n = 0
    acc_8 = 0
    acc_3 = 0
    loss_8 = 0
    loss_3 = 0

    x_test = torch.cat([torch.from_numpy(x) for x in x_test], 0)
    y_test = torch.cat([torch.from_numpy(y).long() for y in y_test], 0)

    if use_cuda and not args.all_hidden:
        x_test = x_test.cuda()
        y_test = y_test.cuda()

    mb = 256
    with torch.no_grad():
        for i in range(0, len(x_test), mb):
            x = x_test[i:i + mb]
            y = y_test[i:i + mb]

            if use_cuda:
                x = x.cuda()
                y = y.cuda()

            potentials = model(x).view(x.size(0), -1)

            ## 8-class SS
            l = F.cross_entropy(potentials, y).item()
            _, y_hat = potentials.max(1)
            correct = torch.sum((y == y_hat).float()).item()

            n += x.size(0)
            delta = x.size(0) * (l - loss_8)
            loss_8 += delta / n
            delta = correct - x.size(0) * acc_8
            acc_8 += delta / n

            ## 3-class SS
            y = I[y]
            p = F.softmax(potentials, 1)
            p = torch.mm(p, A)  # ss3 probabilities
            log_p = torch.log(p)
            l = F.nll_loss(log_p, y).item()
            _, y_hat = log_p.max(1)
            correct = torch.sum((y == y_hat).float()).item()

            delta = x.size(0) * (l - loss_3)
            loss_3 += delta / n
            delta = correct - x.size(0) * acc_3
            acc_3 += delta / n

    print('-', '-', '8-class', '-', '3-class', '-')
    print('split', 'perplexity', 'accuracy', 'perplexity', 'accuracy')
    print('test', np.exp(loss_8), acc_8, np.exp(loss_3), acc_3)

    if num_examples > 0:
        for i in range(num_examples):
            name = names_examples[i].decode('utf-8')
            x = x_examples[i]
            y = y_examples[i]

            seq = sequences_test[i]

            print('>' + name + ' sequence')
            print(seq)
            print('')

            ss = ''.join(secstr[c] for c in y)
            ss = ss.replace(' ', 'C')
            print('>' + name + ' secstr')
            print(ss)
            print('')

            x = torch.from_numpy(x)
            if use_cuda:
                x = x.cuda()
            potentials = model(x)
            _, y_hat = torch.max(potentials, 1)
            y_hat = y_hat.cpu().numpy()

            ss_hat = ''.join(secstr[c] for c in y_hat)
            ss_hat = ss_hat.replace(' ', 'C')
            print('>' + name + ' predicted')
            print(ss_hat)
            print('')