예제 #1
0
def evaluate(args):
    with open(args.data, 'rb') as f:
        test_dataset: SNLIDataset = pickle.load(f)
    word_vocab = test_dataset.word_vocab
    label_vocab = test_dataset.label_vocab
    model = SNLIModel(num_classes=len(label_vocab),
                      num_words=len(word_vocab),
                      word_dim=args.word_dim,
                      hidden_dim=args.hidden_dim,
                      clf_hidden_dim=args.clf_hidden_dim,
                      clf_num_layers=args.clf_num_layers,
                      use_leaf_rnn=args.leaf_rnn,
                      intra_attention=args.intra_attention,
                      use_batchnorm=args.batchnorm,
                      dropout_prob=args.dropout)
    num_params = sum(np.prod(p.size()) for p in model.parameters())
    num_embedding_params = np.prod(model.word_embedding.weight.size())
    print(f'# of parameters: {num_params}')
    print(f'# of word embedding parameters: {num_embedding_params}')
    print(f'# of parameters (excluding word embeddings): '
          f'{num_params - num_embedding_params}')
    model.load_state_dict(torch.load(args.model))
    model.eval()
    if args.gpu > -1:
        model.cuda(args.gpu)
    test_data_loader = DataLoader(dataset=test_dataset,
                                  batch_size=args.batch_size,
                                  collate_fn=test_dataset.collate)
    num_correct = 0
    num_data = len(test_dataset)
    for batch in test_data_loader:
        pre = wrap_with_variable(batch['pre'], volatile=True, gpu=args.gpu)
        hyp = wrap_with_variable(batch['hyp'], volatile=True, gpu=args.gpu)
        pre_length = wrap_with_variable(batch['pre_length'],
                                        volatile=True,
                                        gpu=args.gpu)
        hyp_length = wrap_with_variable(batch['hyp_length'],
                                        volatile=True,
                                        gpu=args.gpu)
        label = wrap_with_variable(batch['label'], volatile=True, gpu=args.gpu)
        logits = model(pre=pre,
                       pre_length=pre_length,
                       hyp=hyp,
                       hyp_length=hyp_length)
        label_pred = logits.max(1)[1].squeeze(1)
        num_correct_batch = torch.eq(label, label_pred).long().sum()
        num_correct_batch = unwrap_scalar_variable(num_correct_batch)
        num_correct += num_correct_batch
    print(f'# data: {num_data}')
    print(f'# correct: {num_correct}')
    print(f'Accuracy: {num_correct / num_data:.4f}')
예제 #2
0
def train(args):
    with open(args.train_data, 'rb') as f:
        train_dataset: SNLIDataset = pickle.load(f)
    with open(args.valid_data, 'rb') as f:
        valid_dataset: SNLIDataset = pickle.load(f)

    train_loader = DataLoader(dataset=train_dataset,
                              batch_size=args.batch_size,
                              shuffle=True,
                              num_workers=2,
                              collate_fn=train_dataset.collate,
                              pin_memory=True)
    valid_loader = DataLoader(dataset=valid_dataset,
                              batch_size=args.batch_size,
                              shuffle=False,
                              num_workers=2,
                              collate_fn=valid_dataset.collate,
                              pin_memory=True)
    word_vocab = train_dataset.word_vocab
    label_vocab = train_dataset.label_vocab

    model = SNLIModel(num_classes=len(label_vocab),
                      num_words=len(word_vocab),
                      word_dim=args.word_dim,
                      hidden_dim=args.hidden_dim,
                      clf_hidden_dim=args.clf_hidden_dim,
                      clf_num_layers=args.clf_num_layers,
                      use_leaf_rnn=args.leaf_rnn,
                      use_batchnorm=args.batchnorm,
                      intra_attention=args.intra_attention,
                      dropout_prob=args.dropout)
    if args.glove:
        logging.info('Loading GloVe pretrained vectors...')
        model.word_embedding.weight.data.zero_()
        glove_weight = load_glove(
            path=args.glove,
            vocab=word_vocab,
            init_weight=model.word_embedding.weight.data.numpy())
        glove_weight[word_vocab.pad_id] = 0
        model.word_embedding.weight.data.set_(torch.FloatTensor(glove_weight))
    if args.fix_word_embedding:
        logging.info('Will not update word embeddings')
        model.word_embedding.weight.requires_grad = False
    if args.gpu > -1:
        logging.info(f'Using GPU {args.gpu}')
        model.cuda(args.gpu)
    params = [p for p in model.parameters() if p.requires_grad]
    optimizer = optim.Adam(params=params)
    criterion = nn.CrossEntropyLoss()

    train_summary_writer = tensorboard.FileWriter(logdir=os.path.join(
        args.save_dir, 'log', 'train'),
                                                  flush_secs=10)
    valid_summary_writer = tensorboard.FileWriter(logdir=os.path.join(
        args.save_dir, 'log', 'valid'),
                                                  flush_secs=10)

    def run_iter(batch, is_training):
        model.train(is_training)
        pre = wrap_with_variable(batch['pre'],
                                 volatile=not is_training,
                                 gpu=args.gpu)
        hyp = wrap_with_variable(batch['hyp'],
                                 volatile=not is_training,
                                 gpu=args.gpu)
        pre_length = wrap_with_variable(batch['pre_length'],
                                        volatile=not is_training,
                                        gpu=args.gpu)
        hyp_length = wrap_with_variable(batch['hyp_length'],
                                        volatile=not is_training,
                                        gpu=args.gpu)
        label = wrap_with_variable(batch['label'],
                                   volatile=not is_training,
                                   gpu=args.gpu)
        logits = model(pre=pre,
                       pre_length=pre_length,
                       hyp=hyp,
                       hyp_length=hyp_length)
        label_pred = logits.max(1)[1]
        accuracy = torch.eq(label, label_pred).float().mean()
        loss = criterion(input=logits, target=label)
        if is_training:
            optimizer.zero_grad()
            loss.backward()
            clip_grad_norm(parameters=params, max_norm=5)
            optimizer.step()
        return loss, accuracy

    def add_scalar_summary(summary_writer, name, value, step):
        value = unwrap_scalar_variable(value)
        summ = summary.scalar(name=name, scalar=value)
        summary_writer.add_summary(summary=summ, global_step=step)

    num_train_batches = len(train_loader)
    validate_every = num_train_batches // 10
    best_vaild_accuacy = 0
    iter_count = 0
    for epoch_num in range(1, args.max_epoch + 1):
        logging.info(f'Epoch {epoch_num}: start')
        for batch_iter, train_batch in enumerate(train_loader):
            if args.anneal_temperature and iter_count % 500 == 0:
                gamma = 0.00001
                new_temperature = max([0.5, math.exp(-gamma * iter_count)])
                model.encoder.gumbel_temperature = new_temperature
                logging.info(
                    f'Iter #{iter_count}: '
                    f'Set Gumbel temperature to {new_temperature:.4f}')
            train_loss, train_accuracy = run_iter(batch=train_batch,
                                                  is_training=True)
            iter_count += 1
            add_scalar_summary(summary_writer=train_summary_writer,
                               name='loss',
                               value=train_loss,
                               step=iter_count)
            add_scalar_summary(summary_writer=train_summary_writer,
                               name='accuracy',
                               value=train_accuracy,
                               step=iter_count)

            if (batch_iter + 1) % validate_every == 0:
                valid_loss_sum = valid_accuracy_sum = 0
                num_valid_batches = len(valid_loader)
                for valid_batch in valid_loader:
                    valid_loss, valid_accuracy = run_iter(batch=valid_batch,
                                                          is_training=False)
                    valid_loss_sum += unwrap_scalar_variable(valid_loss)
                    valid_accuracy_sum += unwrap_scalar_variable(
                        valid_accuracy)
                valid_loss = valid_loss_sum / num_valid_batches
                valid_accuracy = valid_accuracy_sum / num_valid_batches
                add_scalar_summary(summary_writer=valid_summary_writer,
                                   name='loss',
                                   value=valid_loss,
                                   step=iter_count)
                add_scalar_summary(summary_writer=valid_summary_writer,
                                   name='accuracy',
                                   value=valid_accuracy,
                                   step=iter_count)
                progress = epoch_num + batch_iter / num_train_batches
                logging.info(f'Epoch {progress:.2f}: '
                             f'valid loss = {valid_loss:.4f}, '
                             f'valid accuracy = {valid_accuracy:.4f}')
                if valid_accuracy > best_vaild_accuacy:
                    best_vaild_accuacy = valid_accuracy
                    model_filename = (f'model-{progress:.2f}'
                                      f'-{valid_loss:.4f}'
                                      f'-{valid_accuracy:.4f}.pkl')
                    model_path = os.path.join(args.save_dir, model_filename)
                    torch.save(model.state_dict(), model_path)
                    print(f'Saved the new best model to {model_path}')