Example #1
0
def evaluate(args):
    with open(args.data, 'rb') as f:
        test_dataset: SNLIDataset = pickle.load(f)
    word_vocab = test_dataset.word_vocab
    label_vocab = test_dataset.label_vocab
    model = SNLIModel(num_classes=len(label_vocab),
                      num_words=len(word_vocab),
                      word_dim=args.word_dim,
                      hidden_dim=args.hidden_dim,
                      clf_hidden_dim=args.clf_hidden_dim,
                      clf_num_layers=args.clf_num_layers,
                      use_leaf_rnn=args.leaf_rnn,
                      intra_attention=args.intra_attention,
                      use_batchnorm=args.batchnorm,
                      dropout_prob=args.dropout)
    num_params = sum(np.prod(p.size()) for p in model.parameters())
    num_embedding_params = np.prod(model.word_embedding.weight.size())
    print(f'# of parameters: {num_params}')
    print(f'# of word embedding parameters: {num_embedding_params}')
    print(f'# of parameters (excluding word embeddings): '
          f'{num_params - num_embedding_params}')
    model.load_state_dict(torch.load(args.model))
    model.eval()
    if args.gpu > -1:
        model.cuda(args.gpu)
    test_data_loader = DataLoader(dataset=test_dataset,
                                  batch_size=args.batch_size,
                                  collate_fn=test_dataset.collate)
    num_correct = 0
    num_data = len(test_dataset)
    for batch in test_data_loader:
        pre = wrap_with_variable(batch['pre'], volatile=True, gpu=args.gpu)
        hyp = wrap_with_variable(batch['hyp'], volatile=True, gpu=args.gpu)
        pre_length = wrap_with_variable(batch['pre_length'],
                                        volatile=True,
                                        gpu=args.gpu)
        hyp_length = wrap_with_variable(batch['hyp_length'],
                                        volatile=True,
                                        gpu=args.gpu)
        label = wrap_with_variable(batch['label'], volatile=True, gpu=args.gpu)
        logits = model(pre=pre,
                       pre_length=pre_length,
                       hyp=hyp,
                       hyp_length=hyp_length)
        label_pred = logits.max(1)[1].squeeze(1)
        num_correct_batch = torch.eq(label, label_pred).long().sum()
        num_correct_batch = unwrap_scalar_variable(num_correct_batch)
        num_correct += num_correct_batch
    print(f'# data: {num_data}')
    print(f'# correct: {num_correct}')
    print(f'Accuracy: {num_correct / num_data:.4f}')
Example #2
0
def evaluate(args):
    with open(args.data, 'rb') as f:
        test_dataset: SNLIDataset = pickle.load(f)
    word_vocab = test_dataset.word_vocab
    label_vocab = test_dataset.label_vocab
    model = SNLIModel(num_classes=len(label_vocab),
                      num_words=len(word_vocab),
                      word_dim=args.word_dim,
                      hidden_dim=args.hidden_dim,
                      clf_hidden_dim=args.clf_hidden_dim,
                      clf_num_layers=args.clf_num_layers,
                      use_leaf_rnn=args.leaf_rnn,
                      intra_attention=args.intra_attention,
                      use_batchnorm=args.batchnorm,
                      dropout_prob=args.dropout,
                      bidirectional=args.bidirectional)
    num_params = sum(np.prod(p.size()) for p in model.parameters())
    num_embedding_params = np.prod(model.word_embedding.weight.size())
    print(f'# of parameters: {num_params}')
    print(f'# of word embedding parameters: {num_embedding_params}')
    print(f'# of parameters (excluding word embeddings): '
          f'{num_params - num_embedding_params}')
    model.load_state_dict(torch.load(args.model, map_location='cpu'))
    model.eval()
    model.to(args.device)
    torch.set_grad_enabled(False)
    test_data_loader = DataLoader(dataset=test_dataset,
                                  batch_size=args.batch_size,
                                  collate_fn=test_dataset.collate)
    num_correct = 0
    num_data = len(test_dataset)
    for batch in test_data_loader:
        pre = batch['pre'].to(args.device)
        hyp = batch['hyp'].to(args.device)
        pre_length = batch['pre_length'].to(args.device)
        hyp_length = batch['hyp_length'].to(args.device)
        label = batch['label'].to(args.device)
        logits = model(pre=pre,
                       pre_length=pre_length,
                       hyp=hyp,
                       hyp_length=hyp_length)
        label_pred = logits.max(1)[1]
        num_correct_batch = torch.eq(label, label_pred).long().sum()
        num_correct_batch = num_correct_batch.item()
        num_correct += num_correct_batch
    print(f'# data: {num_data}')
    print(f'# correct: {num_correct}')
    print(f'Accuracy: {num_correct / num_data:.4f}')
Example #3
0
def train(args):
    with open(args.train_data, 'rb') as f:
        train_dataset: SNLIDataset = pickle.load(f)
    with open(args.valid_data, 'rb') as f:
        valid_dataset: SNLIDataset = pickle.load(f)

    train_loader = DataLoader(dataset=train_dataset,
                              batch_size=args.batch_size,
                              shuffle=True,
                              num_workers=2,
                              collate_fn=train_dataset.collate,
                              pin_memory=True)
    valid_loader = DataLoader(dataset=valid_dataset,
                              batch_size=args.batch_size,
                              shuffle=False,
                              num_workers=2,
                              collate_fn=valid_dataset.collate,
                              pin_memory=True)
    word_vocab = train_dataset.word_vocab
    label_vocab = train_dataset.label_vocab

    model = SNLIModel(num_classes=len(label_vocab),
                      num_words=len(word_vocab),
                      word_dim=args.word_dim,
                      hidden_dim=args.hidden_dim,
                      clf_hidden_dim=args.clf_hidden_dim,
                      clf_num_layers=args.clf_num_layers,
                      use_leaf_rnn=args.leaf_rnn,
                      use_batchnorm=args.batchnorm,
                      intra_attention=args.intra_attention,
                      dropout_prob=args.dropout)
    if args.glove:
        logging.info('Loading GloVe pretrained vectors...')
        model.word_embedding.weight.data.zero_()
        glove_weight = load_glove(
            path=args.glove,
            vocab=word_vocab,
            init_weight=model.word_embedding.weight.data.numpy())
        glove_weight[word_vocab.pad_id] = 0
        model.word_embedding.weight.data.set_(torch.FloatTensor(glove_weight))
    if args.fix_word_embedding:
        logging.info('Will not update word embeddings')
        model.word_embedding.weight.requires_grad = False
    if args.gpu > -1:
        logging.info(f'Using GPU {args.gpu}')
        model.cuda(args.gpu)
    params = [p for p in model.parameters() if p.requires_grad]
    optimizer = optim.Adam(params=params)
    criterion = nn.CrossEntropyLoss()

    train_summary_writer = tensorboard.FileWriter(logdir=os.path.join(
        args.save_dir, 'log', 'train'),
                                                  flush_secs=10)
    valid_summary_writer = tensorboard.FileWriter(logdir=os.path.join(
        args.save_dir, 'log', 'valid'),
                                                  flush_secs=10)

    def run_iter(batch, is_training):
        model.train(is_training)
        pre = wrap_with_variable(batch['pre'],
                                 volatile=not is_training,
                                 gpu=args.gpu)
        hyp = wrap_with_variable(batch['hyp'],
                                 volatile=not is_training,
                                 gpu=args.gpu)
        pre_length = wrap_with_variable(batch['pre_length'],
                                        volatile=not is_training,
                                        gpu=args.gpu)
        hyp_length = wrap_with_variable(batch['hyp_length'],
                                        volatile=not is_training,
                                        gpu=args.gpu)
        label = wrap_with_variable(batch['label'],
                                   volatile=not is_training,
                                   gpu=args.gpu)
        logits = model(pre=pre,
                       pre_length=pre_length,
                       hyp=hyp,
                       hyp_length=hyp_length)
        label_pred = logits.max(1)[1]
        accuracy = torch.eq(label, label_pred).float().mean()
        loss = criterion(input=logits, target=label)
        if is_training:
            optimizer.zero_grad()
            loss.backward()
            clip_grad_norm(parameters=params, max_norm=5)
            optimizer.step()
        return loss, accuracy

    def add_scalar_summary(summary_writer, name, value, step):
        value = unwrap_scalar_variable(value)
        summ = summary.scalar(name=name, scalar=value)
        summary_writer.add_summary(summary=summ, global_step=step)

    num_train_batches = len(train_loader)
    validate_every = num_train_batches // 10
    best_vaild_accuacy = 0
    iter_count = 0
    for epoch_num in range(1, args.max_epoch + 1):
        logging.info(f'Epoch {epoch_num}: start')
        for batch_iter, train_batch in enumerate(train_loader):
            if args.anneal_temperature and iter_count % 500 == 0:
                gamma = 0.00001
                new_temperature = max([0.5, math.exp(-gamma * iter_count)])
                model.encoder.gumbel_temperature = new_temperature
                logging.info(
                    f'Iter #{iter_count}: '
                    f'Set Gumbel temperature to {new_temperature:.4f}')
            train_loss, train_accuracy = run_iter(batch=train_batch,
                                                  is_training=True)
            iter_count += 1
            add_scalar_summary(summary_writer=train_summary_writer,
                               name='loss',
                               value=train_loss,
                               step=iter_count)
            add_scalar_summary(summary_writer=train_summary_writer,
                               name='accuracy',
                               value=train_accuracy,
                               step=iter_count)

            if (batch_iter + 1) % validate_every == 0:
                valid_loss_sum = valid_accuracy_sum = 0
                num_valid_batches = len(valid_loader)
                for valid_batch in valid_loader:
                    valid_loss, valid_accuracy = run_iter(batch=valid_batch,
                                                          is_training=False)
                    valid_loss_sum += unwrap_scalar_variable(valid_loss)
                    valid_accuracy_sum += unwrap_scalar_variable(
                        valid_accuracy)
                valid_loss = valid_loss_sum / num_valid_batches
                valid_accuracy = valid_accuracy_sum / num_valid_batches
                add_scalar_summary(summary_writer=valid_summary_writer,
                                   name='loss',
                                   value=valid_loss,
                                   step=iter_count)
                add_scalar_summary(summary_writer=valid_summary_writer,
                                   name='accuracy',
                                   value=valid_accuracy,
                                   step=iter_count)
                progress = epoch_num + batch_iter / num_train_batches
                logging.info(f'Epoch {progress:.2f}: '
                             f'valid loss = {valid_loss:.4f}, '
                             f'valid accuracy = {valid_accuracy:.4f}')
                if valid_accuracy > best_vaild_accuacy:
                    best_vaild_accuacy = valid_accuracy
                    model_filename = (f'model-{progress:.2f}'
                                      f'-{valid_loss:.4f}'
                                      f'-{valid_accuracy:.4f}.pkl')
                    model_path = os.path.join(args.save_dir, model_filename)
                    torch.save(model.state_dict(), model_path)
                    print(f'Saved the new best model to {model_path}')
def main():
    args = parse_args()
    params = Params.from_file(args.params)
    save_dir = Path(args.save)
    save_dir.mkdir(parents=True)

    params.to_file(save_dir / 'params.json')

    train_params, model_params = params.pop('train'), params.pop('model')

    random_seed = train_params.pop_int('random_seed', 2019)
    torch.manual_seed(random_seed)
    random.seed(random_seed)

    log_filename = save_dir / 'stdout.log'
    sys.stdout = TeeLogger(filename=log_filename,
                           terminal=sys.stdout,
                           file_friendly_terminal_output=False)
    sys.stderr = TeeLogger(filename=log_filename,
                           terminal=sys.stderr,
                           file_friendly_terminal_output=False)

    tokenizer = WordTokenizer(
        start_tokens=['<s>'],
        end_tokens=['</s>'],
    )
    token_indexer = SingleIdTokenIndexer(lowercase_tokens=True)
    dataset_reader = SnliReader(tokenizer=tokenizer,
                                token_indexers={'tokens': token_indexer})

    train_labeled_dataset_path = train_params.pop('train_labeled_dataset_path')
    train_unlabeled_dataset_path = train_params.pop(
        'train_unlabeled_dataset_path', None)
    train_labeled_dataset = dataset_reader.read(train_labeled_dataset_path)
    train_labeled_dataset = filter_dataset_by_length(
        dataset=train_labeled_dataset, max_length=30)
    if train_unlabeled_dataset_path is not None:
        train_unlabeled_dataset = dataset_reader.read(
            train_unlabeled_dataset_path)
        train_unlabeled_dataset = filter_dataset_by_length(
            dataset=train_unlabeled_dataset, max_length=30)
    else:
        train_unlabeled_dataset = []

    valid_dataset = dataset_reader.read(train_params.pop('valid_dataset_path'))

    vocab = Vocabulary.from_instances(
        instances=train_labeled_dataset + train_unlabeled_dataset,
        max_vocab_size=train_params.pop_int('max_vocab_size', None))
    vocab.save_to_files(save_dir / 'vocab')

    labeled_batch_size = train_params.pop_int('labeled_batch_size')
    unlabeled_batch_size = train_params.pop_int('unlabeled_batch_size')
    labeled_iterator = BasicIterator(batch_size=labeled_batch_size)
    unlabeled_iterator = BasicIterator(batch_size=unlabeled_batch_size)
    labeled_iterator.index_with(vocab)
    unlabeled_iterator.index_with(vocab)

    if not train_unlabeled_dataset:
        unlabeled_iterator = None

    model = SNLIModel(params=model_params, vocab=vocab)
    optimizer = optim.Adam(params=model.parameters(),
                           lr=train_params.pop_float('lr', 1e-3))
    summary_writer = SummaryWriter(log_dir=save_dir / 'log')

    kl_anneal_rate = train_params.pop_float('kl_anneal_rate', None)
    if kl_anneal_rate is None:
        kl_weight_scheduler = None
    else:
        kl_weight_scheduler = (lambda step: min(1.0, kl_anneal_rate * step))
        model.kl_weight = 0.0

    trainer = Trainer(model=model,
                      optimizer=optimizer,
                      labeled_iterator=labeled_iterator,
                      unlabeled_iterator=unlabeled_iterator,
                      train_labeled_dataset=train_labeled_dataset,
                      train_unlabeled_dataset=train_unlabeled_dataset,
                      validation_dataset=valid_dataset,
                      summary_writer=summary_writer,
                      serialization_dir=save_dir,
                      num_epochs=train_params.pop('num_epochs', 50),
                      iters_per_epoch=len(train_labeled_dataset) //
                      labeled_batch_size,
                      write_summary_every=100,
                      validate_every=2000,
                      patience=2,
                      clip_grad_max_norm=5,
                      kl_weight_scheduler=kl_weight_scheduler,
                      cuda_device=train_params.pop_int('cuda_device', 0),
                      early_stop=train_params.pop_bool('early_stop', True))
    trainer.train()
Example #5
0
def train(args):
    with open(args.train_data, 'rb') as f:
        train_dataset: SNLIDataset = pickle.load(f)
    with open(args.valid_data, 'rb') as f:
        valid_dataset: SNLIDataset = pickle.load(f)

    train_loader = DataLoader(dataset=train_dataset,
                              batch_size=args.batch_size,
                              shuffle=True,
                              num_workers=2,
                              collate_fn=train_dataset.collate,
                              pin_memory=True)
    valid_loader = DataLoader(dataset=valid_dataset,
                              batch_size=args.batch_size,
                              shuffle=False,
                              num_workers=2,
                              collate_fn=valid_dataset.collate,
                              pin_memory=True)
    word_vocab = train_dataset.word_vocab
    label_vocab = train_dataset.label_vocab

    model = SNLIModel(num_classes=len(label_vocab),
                      num_words=len(word_vocab),
                      word_dim=args.word_dim,
                      hidden_dim=args.hidden_dim,
                      clf_hidden_dim=args.clf_hidden_dim,
                      clf_num_layers=args.clf_num_layers,
                      use_leaf_rnn=args.leaf_rnn,
                      use_batchnorm=args.batchnorm,
                      intra_attention=args.intra_attention,
                      dropout_prob=args.dropout,
                      bidirectional=args.bidirectional)
    if args.glove:
        logging.info('Loading GloVe pretrained vectors...')
        glove_weight = load_glove(
            path=args.glove,
            vocab=word_vocab,
            init_weight=model.word_embedding.weight.data.numpy())
        glove_weight[word_vocab.pad_id] = 0
        model.word_embedding.weight.data.set_(torch.FloatTensor(glove_weight))
    if args.fix_word_embedding:
        logging.info('Will not update word embeddings')
        model.word_embedding.weight.requires_grad = False
    model.to(args.device)
    logging.info(f'Using device {args.device}')
    if args.optimizer == 'adam':
        optimizer_class = optim.Adam
    elif args.optimizer == 'adagrad':
        optimizer_class = optim.Adagrad
    elif args.optimizer == 'adadelta':
        optimizer_class = optim.Adadelta
    params = [p for p in model.parameters() if p.requires_grad]
    optimizer = optimizer_class(params=params, weight_decay=args.l2reg)
    scheduler = lr_scheduler.ReduceLROnPlateau(optimizer=optimizer,
                                               mode='max',
                                               factor=0.5,
                                               patience=10,
                                               verbose=True)
    criterion = nn.CrossEntropyLoss()

    train_summary_writer = SummaryWriter(
        log_dir=os.path.join(args.save_dir, 'log', 'train'))
    valid_summary_writer = SummaryWriter(
        log_dir=os.path.join(args.save_dir, 'log', 'valid'))

    def run_iter(batch, is_training):
        model.train(is_training)
        pre = batch['pre'].to(args.device)
        hyp = batch['hyp'].to(args.device)
        pre_length = batch['pre_length'].to(args.device)
        hyp_length = batch['hyp_length'].to(args.device)
        label = batch['label'].to(args.device)
        logits = model(pre=pre,
                       pre_length=pre_length,
                       hyp=hyp,
                       hyp_length=hyp_length)
        label_pred = logits.max(1)[1]
        accuracy = torch.eq(label, label_pred).float().mean()
        loss = criterion(input=logits, target=label)
        if is_training:
            optimizer.zero_grad()
            loss.backward()
            clip_grad_norm_(parameters=params, max_norm=5)
            optimizer.step()
        return loss, accuracy

    def add_scalar_summary(summary_writer, name, value, step):
        if torch.is_tensor(value):
            value = value.item()
        summary_writer.add_scalar(tag=name,
                                  scalar_value=value,
                                  global_step=step)

    num_train_batches = len(train_loader)
    validate_every = num_train_batches // 10
    best_vaild_accuacy = 0
    iter_count = 0
    for epoch_num in range(args.max_epoch):
        logging.info(f'Epoch {epoch_num}: start')
        for batch_iter, train_batch in enumerate(train_loader):
            if iter_count % args.anneal_temperature_every == 0:
                rate = args.anneal_temperature_rate
                new_temperature = max([0.5, math.exp(-rate * iter_count)])
                model.encoder.gumbel_temperature = new_temperature
                logging.info(
                    f'Iter #{iter_count}: '
                    f'Set Gumbel temperature to {new_temperature:.4f}')
            train_loss, train_accuracy = run_iter(batch=train_batch,
                                                  is_training=True)
            iter_count += 1
            add_scalar_summary(summary_writer=train_summary_writer,
                               name='loss',
                               value=train_loss,
                               step=iter_count)
            add_scalar_summary(summary_writer=train_summary_writer,
                               name='accuracy',
                               value=train_accuracy,
                               step=iter_count)

            if (batch_iter + 1) % validate_every == 0:
                torch.set_grad_enabled(False)
                valid_loss_sum = valid_accuracy_sum = 0
                num_valid_batches = len(valid_loader)
                for valid_batch in valid_loader:
                    valid_loss, valid_accuracy = run_iter(batch=valid_batch,
                                                          is_training=False)
                    valid_loss_sum += valid_loss.item()
                    valid_accuracy_sum += valid_accuracy.item()
                torch.set_grad_enabled(True)
                valid_loss = valid_loss_sum / num_valid_batches
                valid_accuracy = valid_accuracy_sum / num_valid_batches
                scheduler.step(valid_accuracy)
                add_scalar_summary(summary_writer=valid_summary_writer,
                                   name='loss',
                                   value=valid_loss,
                                   step=iter_count)
                add_scalar_summary(summary_writer=valid_summary_writer,
                                   name='accuracy',
                                   value=valid_accuracy,
                                   step=iter_count)
                progress = epoch_num + batch_iter / num_train_batches
                logging.info(f'Epoch {progress:.2f}: '
                             f'valid loss = {valid_loss:.4f}, '
                             f'valid accuracy = {valid_accuracy:.4f}')
                if valid_accuracy > best_vaild_accuacy:
                    best_vaild_accuacy = valid_accuracy
                    model_filename = (f'model-{progress:.2f}'
                                      f'-{valid_loss:.4f}'
                                      f'-{valid_accuracy:.4f}.pkl')
                    model_path = os.path.join(args.save_dir, model_filename)
                    torch.save(model.state_dict(), model_path)
                    print(f'Saved the new best model to {model_path}')