Пример #1
0
def get_val_loss(vocab, args, model):

    total_loss = 0
    total_util_loss = 0
    total_answer_loss = 0

    util_examples = 0
    answer_examples = 0

    model.eval()

    #criterion = torch.nn.CrossEntropyLoss(reduction='sum')
    criterion = InfoNCELoss(size_average=False).cuda()

    for ids, posts, questions, answers, labels in batch_iter(val_ids, \
                       post_content, qa_dict, vocab, args.batch_size, shuffle=False):

       
        util_examples += len(ids)

        question_vectors = vocab.id2vector(questions)
        post_vectors = vocab.id2vector(posts)
        answer_vectors = vocab.id2vector(answers)

        padded_posts, post_pad_idx = pad_sequence(args.device, posts)
        padded_questions, question_pad_idx = pad_sequence(args.device, questions)
        padded_answers, answer_pad_idx = pad_sequence(args.device, answers)

        pqa_probs = model(ids, (padded_posts, post_pad_idx),\
                  (padded_questions, question_pad_idx),\
                  (padded_answers, answer_pad_idx))

        labels = torch.tensor(labels).to(device=args.device)
        util_loss = criterion(pqa_probs, labels)

        #bp()

        total_util_loss += util_loss.item()

    total_loss = (total_util_loss / util_examples)
    model.train()

    return total_loss
def train():

    device = args.device

    log_every = args.log_every
    valid_iter = args.valid_iter
    train_iter = 0
    cum_loss = 0
    avg_loss = 0
    avg_util_loss = 0
    avg_answer_loss = 0
    valid_num = 0
    patience = 0
    num_trial = 0
    hist_valid_scores = []
    begin_time = time.time()

    vocab = get_vocab(args.vocab_file)

    model = EVPI(args, vocab)

    if args.use_embed == 1:
       model.load_vector(args, vocab)

    print("Placing model on ", args.device)
    if args.device == 'cuda':
       model.cuda()

    lr = args.lr
    optim = torch.optim.Adam(list(model.parameters()), lr=lr)

    # The loss functions
    #criterion = torch.nn.CrossEntropyLoss().to(device=device)
    #criterion = FocalLoss()
    criterion = InfoNCELoss().cuda()

    print("Beginning Training")
    model.train()

    cosine_function = torch.nn.functional.cosine_similarity

    model_counter = 0
    train_iter = 0
    for ep in range(args.max_epochs):

        val_iter = 0

        count = 0
        hello = set()
        for ids, posts, questions, answers, labels in batch_iter(train_ids, \
                            post_content, qa_dict, vocab, args.batch_size, shuffle=False):

            train_iter += 1

            optim.zero_grad()

            question_vectors = vocab.id2vector(questions)
            post_vectors = vocab.id2vector(posts)
            answer_vectors = vocab.id2vector(answers)

            padded_posts, post_pad_idx = pad_sequence(args.device, posts)
            padded_questions, question_pad_idx = pad_sequence(args.device, questions)
            padded_answers, answer_pad_idx = pad_sequence(args.device, answers)

            pqa_probs = model(ids, (padded_posts, post_pad_idx),\
                      (padded_questions, question_pad_idx),\
                      (padded_answers, answer_pad_idx))

            labels = torch.tensor(labels).to(device=args.device)
            total_loss = criterion(pqa_probs, labels)

            #bp()

            avg_loss += total_loss.item()
            cum_loss += total_loss.item()

            total_loss.backward()
            torch.nn.utils.clip_grad_norm_(list(model.parameters()), args.clip_grad)
            optim.step()

            if train_iter % log_every == 0:
                print('epoch %d, iter %d, avg.loss %.6f, time elapsed %.2f'\
                     % (ep + 1, train_iter, avg_loss / log_every, time.time() - begin_time), file=sys.stderr)

                begin_time = time.time()
                avg_loss = 0

            if train_iter % valid_iter == 0:

                print('epoch %d, iter %d, cum.loss %.2f, time elapsed %.2f'\
                     % (ep + 1, train_iter, cum_loss / valid_iter, time.time() - begin_time), file=sys.stderr)

                cum_loss = 0
                valid_num += 1

                print("Begin Validation ", file=sys.stderr)

                model.eval()

                val_loss = get_val_loss(vocab, args, model)
                model.train()

                print('validation: iter %d, loss %f' % (train_iter, val_loss), file=sys.stderr)

                is_better = (len(hist_valid_scores) == 0) or (val_loss < min(hist_valid_scores))
                hist_valid_scores.append(val_loss)

                if is_better:
                    patience = 0
                    print("Save the current model and optimiser state")
                    torch.save(model, args.model_save_path)
                    #torch.save(model, args.model_save_path + '.' + str(val_loss) + '-' + str(model_counter))
                    #model_counter += 1
                    torch.save(optim.state_dict(), args.model_save_path + '.optim')

                elif patience < args.patience:

                    patience += 1
                    print('hit patience %d' % patience, file=sys.stderr)

                    if patience == args.patience:
                        num_trial += 1
                        print('hit #%d trial' % num_trial, file=sys.stderr)
                        if num_trial == args.max_num_trials:
                            print('early stop!', file=sys.stderr)
                            return

                        lr = lr * args.lr_decay

                        print('load previously best model and decay learning rate to %f' % lr, file=sys.stderr)
                        model = load(args.model_save_path)
                        model.train()

                        print('restore parameters of the optimizers', file=sys.stderr)

                        optim = torch.optim.Adam(list(model.parameters()), lr=lr)
                        optim.load_state_dict(torch.load(args.model_save_path + '.optim'))
                        for state in optim.state.values():
                            for k, v in state.items():
                                if isinstance(v, torch.Tensor):
                                    state[k] = v.to(args.device)
                        for group in optim.param_groups:
                            group['lr'] = lr

                        patience = 0
    print("Training Finished", file=sys.stderr) 
Пример #3
0
def get_val_loss(vocab, args, model, ep):

    total_loss = 0
    total_util_loss = 0
    total_answer_loss = 0

    util_examples = 0
    answer_examples = 0

    model.eval()

    #criterion = torch.nn.CrossEntropyLoss(reduction='sum')
    criterion = InfoNCELoss().to(device='cuda')
    valid_epoch_iter = 0

    for ids, posts, questions, answers, labels in batch_iter(val_ids, \
                       post_content, qa_dict, vocab, args.batch_size, shuffle=False):

        print("Validation Iteration {}".format(valid_epoch_iter))

        util_examples += len(ids)

        #question_vectors = vocab.id2vector(questions)
        #post_vectors = vocab.id2vector(posts)
        #answer_vectors = vocab.id2vector(answers)

        #padded_posts, post_pad_idx = pad_sequence(args.device, posts)
        #padded_questions, question_pad_idx = pad_sequence(args.device, questions)
        #padded_answers, answer_pad_idx = pad_sequence(args.device, answers)

        #if ep == 1:

        #    with open('valid_p_embeddings.pickle', 'wb') as b:
        #        pickle.dump(valid_p_embeddings, b)
        #    with open('valid_q_embeddings.pickle', 'wb') as b:
        #        pickle.dump(valid_q_embeddings, b)
        #    with open('valid_a_embeddings.pickle', 'wb') as b:
        #        pickle.dump(valid_a_embeddings, b)

        #if ep == 0:

        #    posts_embeddings = np.asarray(sentence_bert_model.encode(posts))
        #    questions_embeddings = np.asarray(sentence_bert_model.encode(questions))
        #    answers_embeddings = np.asarray(sentence_bert_model.encode(answers))

        #    valid_p_embeddings.append(posts_embeddings)
        #    valid_q_embeddings.append(questions_embeddings)
        #    valid_a_embeddings.append(answers_embeddings)

        #    print("Embeddings Cached for Validation Iteration {}".format(valid_epoch_iter))

        #else:

        posts_embeddings = valid_p_embeddings[valid_epoch_iter]
        questions_embeddings = valid_q_embeddings[valid_epoch_iter]
        answers_embeddings = valid_a_embeddings[valid_epoch_iter]

        valid_epoch_iter += 1

        #posts_embeddings = np.asarray(sentence_bert_model.encode(posts))
        #questions_embeddings = np.asarray(sentence_bert_model.encode(questions))
        #answers_embeddings = np.asarray(sentence_bert_model.encode(answers))

        posts_embeddings = torch.from_numpy(posts_embeddings).float().to(args.device)
        questions_embeddings = torch.from_numpy(questions_embeddings).float().to(args.device)
        answers_embeddings = torch.from_numpy(answers_embeddings).float().to(args.device)

        pqa_probs = model(posts_embeddings, questions_embeddings, answers_embeddings)

        #pqa_probs = model(ids, posts, questions, answers)

        labels = torch.tensor(labels).to(device=args.device)
        util_loss = criterion(pqa_probs, labels)

        total_util_loss += util_loss.item()

    total_loss = (total_util_loss / util_examples)
    model.train()

    return total_loss
Пример #4
0
def train():

    device = args.device

    log_every = args.log_every
    valid_iter = args.valid_iter
    train_iter = 0
    cum_loss = 0
    avg_loss = 0
    avg_util_loss = 0
    avg_answer_loss = 0
    valid_num = 0
    patience = 0
    num_trial = 0
    hist_valid_scores = []
    begin_time = time.time()

    vocab = get_vocab(args.vocab_file)

    model = EVPI(args, vocab)

    if args.use_embed == 1:
       model.load_vector(args, vocab)

    print("Placing model on ", args.device)
    if args.device == 'cuda':
       model = model.cuda()

    lr = args.lr
    optim = torch.optim.Adam(list(model.parameters()), lr=lr)

    # The loss functions
    # criterion = torch.nn.CrossEntropyLoss().to(device=device)
    criterion = InfoNCELoss().to(device=device)

    print("Beginning Training")
    model.train()

    cosine_function = torch.nn.functional.cosine_similarity

    model_counter = 0
    train_iter = 0

    #p_embeddings = pickle.load( open( "p_large_embeddings.pickle", "rb" ) )
    #print("P embeddings loaded", p_embeddings[0].shape)

    #q_embeddings = pickle.load( open( "q_large_embeddings.pickle", "rb" ) )
    #print("Q embeddings loaded", q_embeddings[0].shape)

    #a_embeddings = pickle.load( open( "a_large_embeddings.pickle", "rb" ) )
    #print("A embeddings loaded", a_embeddings[0].shape)

    # favorite_color = pickle.load( open( "a_embeddings.pickle", "rb" ) )

    q_embeddings = []
    #pickle.load( open( "q_large_embeddings.pickle", "rb" ) )
    #print("Q embeddings loaded", q_embeddings[0].shape)

    a_embeddings = pickle.load( open( "a_large_embeddings.pickle", "rb" ) )
    print("A embeddings loaded", a_embeddings[0].shape)

    q_embeddings = pickle.load( open( "q_large_embeddings.pickle", "rb" ) )
    print("Q embeddings loaded", q_embeddings[0].shape)

    p_embeddings = pickle.load( open( "p_large_embeddings.pickle", "rb" ) )
    print("P embeddings loaded", p_embeddings[0].shape)

    #q_embeddings = pickle.load( open( "q_large_embeddings.pickle", "rb" ) )
    #print("Q embeddings loaded", q_embeddings[0].shape)

    #a_embeddings = pickle.load( open( "a_large_embeddings.pickle", "rb" ) )
    #print("A embeddings loaded", a_embeddings[0].shape)

    for ep in range(args.max_epochs):

        val_iter = 0

        epoch_iter = 0

        count = 0
        hello = set()

        #p_embeddings = pickle.load( open( "p_large_embeddings.pickle", "rb" ) )
        #print("P embeddings loaded", p_embeddings[0].shape)

        #q_embeddings = pickle.load( open( "q_large_embeddings.pickle", "rb" ) )
        #print("Q embeddings loaded", q_embeddings[0].shape)

        #a_embeddings = pickle.load( open( "a_large_embeddings.pickle", "rb" ) )
        #print("A embeddings loaded", a_embeddings[0].shape)

        for ids, posts, questions, answers, labels in batch_iter(train_ids, \
                            post_content, qa_dict, vocab, args.batch_size, shuffle=False):

            train_iter += 1
            #print(train_iter)

            optim.zero_grad()

            #question_vectors = vocab.id2vector(questions)
            #post_vectors = vocab.id2vector(posts)
            #answer_vectors = vocab.id2vector(answers)

            #padded_posts, post_pad_idx = pad_sequence(args.device, posts)
            #padded_questions, question_pad_idx = pad_sequence(args.device, questions)
            #padded_answers, answer_pad_idx = pad_sequence(args.device, answers)

            #posts = torch.tensor(posts).to(device=args.device)
            #questions = torch.tensor(questions).to(device=args.device)
            #answers = torch.tensor(answers).to(device=args.device)

            #if ep == 1:
            #    with open('p_large_embeddings.pickle', 'wb') as b:
            #        pickle.dump(p_embeddings, b)
            #    with open('q_large_embeddings.pickle', 'wb') as b:
            #        pickle.dump(q_embeddings, b)
            #    with open('a_large_embeddings.pickle', 'wb') as b:
            #        pickle.dump(a_embeddings, b)

            #if ep == 0:

            #    posts_embeddings = np.asarray(sentence_bert_model.encode(posts))
            #    questions_embeddings = np.asarray(sentence_bert_model.encode(questions))
            #    answers_embeddings = np.asarray(sentence_bert_model.encode(answers))

            #    p_embeddings.append(posts_embeddings)
            #    q_embeddings.append(questions_embeddings)
            #    a_embeddings.append(answers_embeddings)

            #    print("Embeddings Cached for Iteration {}".format(epoch_iter))

            #else:
            #    questions_embeddings = q_embeddings[epoch_iter]

            posts_embeddings = p_embeddings[epoch_iter]
            questions_embeddings = q_embeddings[epoch_iter]
            answers_embeddings = a_embeddings[epoch_iter]

            epoch_iter += 1

            posts_embeddings = torch.from_numpy(posts_embeddings).float().to(args.device)
            questions_embeddings = torch.from_numpy(questions_embeddings).float().to(args.device)
            answers_embeddings = torch.from_numpy(answers_embeddings).float().to(args.device)

            pqa_probs = model(posts_embeddings, questions_embeddings, answers_embeddings)
            labels = torch.tensor(labels).to(device=args.device)

            #bp()
            total_loss = criterion(pqa_probs, labels)

            #bp()

            avg_loss += total_loss.item()
            cum_loss += total_loss.item()

            total_loss.backward()
            torch.nn.utils.clip_grad_norm_(list(model.parameters()), args.clip_grad)
            optim.step()

            if train_iter % log_every == 0:
                print('epoch %d, iter %d, avg.loss %.6f, time elapsed %.2f'\
                     % (ep + 1, train_iter, avg_loss / log_every, time.time() - begin_time), file=sys.stderr)

                begin_time = time.time()
                avg_loss = 0

            if train_iter % valid_iter == 0:

                print('epoch %d, iter %d, cum.loss %.2f, time elapsed %.2f'\
                     % (ep + 1, train_iter, cum_loss / valid_iter, time.time() - begin_time), file=sys.stderr)

                cum_loss = 0
                valid_num += 1

                print("Begin Validation ", file=sys.stderr)

                model.eval()

                val_loss = get_val_loss(vocab, args, model, ep)
                model.train()

                print('validation: iter %d, loss %f' % (train_iter, val_loss), file=sys.stderr)

                is_better = (len(hist_valid_scores) == 0) or (val_loss < min(hist_valid_scores))
                hist_valid_scores.append(val_loss)

                if is_better:
                    patience = 0
                    print("Save the current model and optimiser state")
                    torch.save(model, args.model_save_path)
                    #torch.save(model, args.model_save_path + '.' + str(val_loss) + '-' + str(model_counter))
                    #model_counter += 1
                    torch.save(optim.state_dict(), args.model_save_path + '.optim')

                elif patience < args.patience:

                    patience += 1
                    print('hit patience %d' % patience, file=sys.stderr)

                    if patience == args.patience:
                        num_trial += 1
                        print('hit #%d trial' % num_trial, file=sys.stderr)
                        if num_trial == args.max_num_trials:
                            print('early stop!', file=sys.stderr)
                            return

                        lr = lr * args.lr_decay

                        print('load previously best model and decay learning rate to %f' % lr, file=sys.stderr)
                        model = load(args.model_save_path)
                        model.train()

                        print('restore parameters of the optimizers', file=sys.stderr)

                        optim = torch.optim.Adam(list(model.parameters()), lr=lr)
                        optim.load_state_dict(torch.load(args.model_save_path + '.optim'))
                        for state in optim.state.values():
                            for k, v in state.items():
                                if isinstance(v, torch.Tensor):
                                    state[k] = v.to(args.device)
                        for group in optim.param_groups:
                            group['lr'] = lr

                        patience = 0
    print("Training Finished", file=sys.stderr)