def train():
    print('Preprocessing raw data')
    preprocessor = Preprocessor()
    preprocessor.preprocess()

    dataset = Dataset(preprocessor)

    print('Training MF')
    mf = MF(preprocessor, dataset)
    mf.train_or_load_if_exists()

    print('Building I2I')
    i2i = Item2Item(dataset)

    print('Generating candidates')
    candidate_generator = CandidateGenerator(preprocessor, dataset, mf, i2i)
    X_train, y_train, q_train, q_train_reader = candidate_generator.generate_train()
    X_val, y_val, q_val, q_val_reader = candidate_generator.generate_val()

    import pickle
    try:
        with open('puke.pkl', 'wb') as f:
            pickle.dump((X_train, y_train, q_train, q_train_reader,
                         X_val, y_val, q_val, q_val_reader), f)
    except:
        print("Couldn't save puke")

    print('Training ranker')
    ranker = Ranker()
    ranker.train(X_train, y_train, q_train, X_val, y_val, q_val)
    ranker.save()

    print('Validating ranker')
    rank_scores = ranker.rank(X_val)
    print('ndcg', dataset.validate_ndcg(y_val, q_val, q_val_reader, rank_scores))
Example #2
0
    if args.which_ranker == 'ranker':
        from ranker import Ranker
    elif args.which_ranker == 'masker_ranker':
        from masker_ranker import Ranker
    model = Ranker(vocab_src, vocab_tgt, args.embed_dim, args.ff_embed_dim,
                   args.num_heads, args.dropout, args.num_layers)
    model = model.cuda()
    optimizer = torch.optim.Adam(model.parameters(), args.lr)

    train_data = DataLoader(args.train_data, vocab_src, vocab_tgt,
                            args.train_batch_size, True)
    dev_data = DataLoader(args.dev_data, vocab_src, vocab_tgt,
                          args.dev_batch_size, True)

    model.train()
    loss_accumulated = 0.
    acc_accumulated = 0.
    batches_processed = 0
    best_dev_acc = 0
    for epoch in range(args.epochs):
        for src_input, tgt_input in train_data:
            optimizer.zero_grad()
            loss, acc = model(src_input, tgt_input)

            loss_accumulated += loss.item()
            acc_accumulated += acc
            batches_processed += 1
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 5.0)
            optimizer.step()