示例#1
0
def check_accuracy(model, descriptor, loader):
    num_correct = 0
    num_samples = 0
    model.eval() # Put the model in test mode (the opposite of model.train(), essentially)
    
    for dialogues, all_cats, all_spatial, correct_objs in loader:
        dialogues_var, all_cats_var, all_spatial_var, correct_objs_var, dialogue_lens = \
            make_vars(dialogues, all_cats, all_spatial, correct_objs, volatile=True)

        scores = model(dialogues_var, dialogue_lens, all_cats_var, all_spatial_var)
        _, preds = scores.data.cpu().max(1)
        num_correct += (preds == torch.LongTensor(correct_objs)).sum()
        num_samples += preds.size(0)
    acc = float(num_correct) / num_samples
    log_print(descriptor, 'Got %d / %d correct (%.2f)' % (num_correct, num_samples, 100 * acc))
    return acc
示例#2
0
def train(model, descriptor, loader_train, num_epochs, print_every=100):
    for epoch in range(num_epochs):
        tqdm.write('Starting epoch {} / {}'.format(epoch + 1, num_epochs))
        model.train()
        
        for t, (features, in_seqs, out_seqs, seq_masks) in enumerate(loader_train):
            features_var = Variable(features.cuda(), requires_grad=False)
            in_seqs_var = Variable(in_seqs.cuda(), requires_grad=False)
            out_seqs_var = Variable(out_seqs.cuda(), requires_grad=False)
            seq_masks_var = Variable(seq_masks.cuda(), requires_grad=False)
            
            loss = model.train_step(
                features_var[:2], in_seqs_var[:2], out_seqs_var[:2], seq_masks_var[:2]
            )
            
            if t % print_every == 0:
                log_print(descriptor, 't = {}, loss = {:.4}'.format(t + 1, loss.data[0]))
        torch.save(model.state_dict(), data.get_saved_model(descriptor))
示例#3
0
def check_accuracy(model, descriptor, loader):
    num_correct = 0
    num_samples = 0
    model.eval(
    )  # Put the model in test mode (the opposite of model.train(), essentially)

    for tokens, q_lens, features, cats, answers in loader:
        tokens_var = Variable(tokens.cuda(), volatile=True)
        q_lens_var = Variable(q_lens.cuda(), volatile=True)
        features_var = Variable(features.cuda(), volatile=True)
        cats_var = Variable(cats.cuda(), volatile=True)

        scores = model(tokens_var, q_lens_var, features_var, cats_var)
        _, preds = scores.data.cpu().max(1)
        num_correct += (preds == answers).sum()
        num_samples += preds.size(0)
    acc = float(num_correct) / num_samples
    log_print(
        descriptor,
        'Got %d / %d correct (%.2f)' % (num_correct, num_samples, 100 * acc))
    return acc
示例#4
0
def train(model,
          descriptor,
          loader_valid_local,
          loader_train_local,
          loader_test_local,
          num_epochs,
          print_every=1000):
    start_log(descriptor)
    log_print(descriptor, 'Getting accuracy on validation set')
    check_accuracy(model, descriptor, loader_valid_local)
    current_max_val_acc = 0
    for epoch in range(num_epochs):
        log_print(descriptor,
                  'Starting epoch {} / {}'.format(epoch + 1, num_epochs))
        model.train()

        for t, (tokens, q_lens, features, cats,
                answers) in tqdm(enumerate(loader_train_local)):
            tokens_var = Variable(tokens.cuda(), requires_grad=False)
            q_lens_var = Variable(q_lens.cuda(), requires_grad=False)
            features_var = Variable(features.cuda(), requires_grad=False)
            cats_var = Variable(cats.cuda(), requires_grad=False)
            answers_var = Variable(answers.cuda(), requires_grad=False)

            loss = model.train_step(tokens_var, q_lens_var, features_var,
                                    cats_var, answers_var)

            if t % print_every == 0:
                log_print(descriptor,
                          't = {}, loss = {:.4}'.format(t + 1, loss.data[0]))

        log_print(descriptor, 'Getting accuracy on validation set')
        accuracy = check_accuracy(model, descriptor, loader_valid_local)
        if accuracy > current_max_val_acc:
            current_max_val_acc = accuracy
            torch.save(model.state_dict(), data.get_saved_model(descriptor))

    best_model = OracleLiteNet().cuda()
    best_model.load_state_dict(torch.load(data.get_saved_model(descriptor)))

    log_print(descriptor, 'Getting accuracy on training set')
    check_accuracy(best_model, descriptor, loader_train_local)
    log_print(descriptor, 'Getting accuracy on validation set')
    accuracy = check_accuracy(best_model, descriptor, loader_valid_local)
    log_print(descriptor, 'Getting accuracy on test set')
    check_accuracy(best_model, descriptor, loader_test_local)
示例#5
0
    return pred_idx == correct_obj


descriptor = 'eval_questioner_reinforce_lstm1_fc2_sample'

vocab_tagger = VocabTagger()
agents = GuessWhatAgents(questioner='questioner_reinforce_lstm1_fc2')
small = True

for split in ('train', 'valid'):
    for seen_obj in (True, False):
        if split == 'valid' and seen_obj == True:
            continue

        log_print(
            descriptor,
            'Playing game with images in: {} | seen objects: {}'.format(
                split, seen_obj))

        with open(data.get_processed_file('game', split, small), 'rb') as f:
            data_img_names, data_raw_objs, data_all_cats, data_all_spatial, data_correct_obj = pickle.load(
                f)

        num_total = len(data_img_names)
        num_correct = 0
        for i in range(num_total):
            if i % 100 == 0:
                print(i)

            num_correct += play_game(i, seen_obj)

        log_print(
示例#6
0
        
        if question_ids[0] == vocab_tagger.vocab_map.stop:
            dialogue_probs.append(probs[0])
            dialogue_outputs.append(outputs[0])
            break
            
        dialogue_probs.extend(probs)
        dialogue_outputs.extend(outputs)

        answer_id = game.answer(question_ids)
        answer = vocab_tagger.get_answer(answer_id)

    pred_idx = game.guess()
    reward = float(pred_idx == correct_obj)
    
    log_print(descriptor, 'i = {} | b = {} | r = {}'.format(i, baseline, reward))
    
    dialogue_log_probs = torch.log(torch.cat(dialogue_probs))
    adjusted_reward = reward - baseline
    
    J = torch.sum(torch.mul(dialogue_log_probs, adjusted_reward))
    
    for output in dialogue_outputs:
        output.reinforce(adjusted_reward)
    
    optimizer.zero_grad()
    J.backward()
    optimizer.step()
    
    baseline = BASELINE_ALPHA * reward + (1 - BASELINE_ALPHA) * baseline
示例#7
0
def train(model, descriptor, loader_valid_local, loader_train_local, loader_test_local, num_epochs, print_every=500):
    log_print(descriptor, 'Getting accuracy on validation set')
    check_accuracy(model, descriptor, loader_valid_local)
    current_max_val_acc = 0;
    for epoch in range(num_epochs):
        log_print(descriptor, 'Starting epoch {} / {}'.format(epoch + 1, num_epochs))
        model.train()

        for t, (dialogues, all_cats, all_spatial, correct_objs) in tqdm(enumerate(loader_train_local)):
            dialogues_var, all_cats_var, all_spatial_var, correct_objs_var, dialogue_lens = \
            make_vars(dialogues, all_cats, all_spatial, correct_objs, requires_grad=False)

            loss = model.train_step(
                dialogues_var, dialogue_lens, all_cats_var, all_spatial_var, correct_objs_var
            )

            if t % print_every == 0:
                log_print(descriptor, 't = {}, loss = {:.4}'.format(t + 1, loss.data[0]))

        log_print(descriptor, 'Getting accuracy on validation set')
        accuracy = check_accuracy(model, descriptor, loader_valid_local)
        if accuracy > current_max_val_acc:
            current_max_val_acc = accuracy
            torch.save(model.state_dict(), data.get_saved_model(descriptor))
        print(current_max_val_acc)
        
    best_model = GuesserNet().cuda()
    best_model.load_state_dict(torch.load(data.get_saved_model(descriptor)))
    
    print(current_max_val_acc)
    log_print(descriptor, 'Getting accuracy on training set')
    check_accuracy(best_model, descriptor, loader_train_local)
    log_print(descriptor, 'Getting accuracy on validation set')
    check_accuracy(best_model, descriptor, loader_valid_local)
    log_print(descriptor, 'Getting accuracy on test set')
    check_accuracy(best_model, descriptor, loader_test_local)