Ejemplo n.º 1
0
def train(args):
    with open(args.train_data, 'rb') as f:
        train_dataset: SNLIDataset = pickle.load(f)
    with open(args.valid_data, 'rb') as f:
        valid_dataset: SNLIDataset = pickle.load(f)

    train_loader = DataLoader(dataset=train_dataset,
                              batch_size=args.batch_size,
                              shuffle=True,
                              num_workers=2,
                              collate_fn=train_dataset.collate,
                              pin_memory=True)
    valid_loader = DataLoader(dataset=valid_dataset,
                              batch_size=args.batch_size,
                              shuffle=False,
                              num_workers=2,
                              collate_fn=valid_dataset.collate,
                              pin_memory=True)
    word_vocab = train_dataset.word_vocab
    label_vocab = train_dataset.label_vocab

    model = SNLIModel(num_classes=len(label_vocab),
                      num_words=len(word_vocab),
                      word_dim=args.word_dim,
                      hidden_dim=args.hidden_dim,
                      clf_hidden_dim=args.clf_hidden_dim,
                      clf_num_layers=args.clf_num_layers,
                      use_leaf_rnn=args.leaf_rnn,
                      use_batchnorm=args.batchnorm,
                      intra_attention=args.intra_attention,
                      dropout_prob=args.dropout)
    if args.glove:
        logging.info('Loading GloVe pretrained vectors...')
        model.word_embedding.weight.data.zero_()
        glove_weight = load_glove(
            path=args.glove,
            vocab=word_vocab,
            init_weight=model.word_embedding.weight.data.numpy())
        glove_weight[word_vocab.pad_id] = 0
        model.word_embedding.weight.data.set_(torch.FloatTensor(glove_weight))
    if args.fix_word_embedding:
        logging.info('Will not update word embeddings')
        model.word_embedding.weight.requires_grad = False
    if args.gpu > -1:
        logging.info(f'Using GPU {args.gpu}')
        model.cuda(args.gpu)
    params = [p for p in model.parameters() if p.requires_grad]
    optimizer = optim.Adam(params=params)
    criterion = nn.CrossEntropyLoss()

    train_summary_writer = tensorboard.FileWriter(logdir=os.path.join(
        args.save_dir, 'log', 'train'),
                                                  flush_secs=10)
    valid_summary_writer = tensorboard.FileWriter(logdir=os.path.join(
        args.save_dir, 'log', 'valid'),
                                                  flush_secs=10)

    def run_iter(batch, is_training):
        model.train(is_training)
        pre = wrap_with_variable(batch['pre'],
                                 volatile=not is_training,
                                 gpu=args.gpu)
        hyp = wrap_with_variable(batch['hyp'],
                                 volatile=not is_training,
                                 gpu=args.gpu)
        pre_length = wrap_with_variable(batch['pre_length'],
                                        volatile=not is_training,
                                        gpu=args.gpu)
        hyp_length = wrap_with_variable(batch['hyp_length'],
                                        volatile=not is_training,
                                        gpu=args.gpu)
        label = wrap_with_variable(batch['label'],
                                   volatile=not is_training,
                                   gpu=args.gpu)
        logits = model(pre=pre,
                       pre_length=pre_length,
                       hyp=hyp,
                       hyp_length=hyp_length)
        label_pred = logits.max(1)[1]
        accuracy = torch.eq(label, label_pred).float().mean()
        loss = criterion(input=logits, target=label)
        if is_training:
            optimizer.zero_grad()
            loss.backward()
            clip_grad_norm(parameters=params, max_norm=5)
            optimizer.step()
        return loss, accuracy

    def add_scalar_summary(summary_writer, name, value, step):
        value = unwrap_scalar_variable(value)
        summ = summary.scalar(name=name, scalar=value)
        summary_writer.add_summary(summary=summ, global_step=step)

    num_train_batches = len(train_loader)
    validate_every = num_train_batches // 10
    best_vaild_accuacy = 0
    iter_count = 0
    for epoch_num in range(1, args.max_epoch + 1):
        logging.info(f'Epoch {epoch_num}: start')
        for batch_iter, train_batch in enumerate(train_loader):
            if args.anneal_temperature and iter_count % 500 == 0:
                gamma = 0.00001
                new_temperature = max([0.5, math.exp(-gamma * iter_count)])
                model.encoder.gumbel_temperature = new_temperature
                logging.info(
                    f'Iter #{iter_count}: '
                    f'Set Gumbel temperature to {new_temperature:.4f}')
            train_loss, train_accuracy = run_iter(batch=train_batch,
                                                  is_training=True)
            iter_count += 1
            add_scalar_summary(summary_writer=train_summary_writer,
                               name='loss',
                               value=train_loss,
                               step=iter_count)
            add_scalar_summary(summary_writer=train_summary_writer,
                               name='accuracy',
                               value=train_accuracy,
                               step=iter_count)

            if (batch_iter + 1) % validate_every == 0:
                valid_loss_sum = valid_accuracy_sum = 0
                num_valid_batches = len(valid_loader)
                for valid_batch in valid_loader:
                    valid_loss, valid_accuracy = run_iter(batch=valid_batch,
                                                          is_training=False)
                    valid_loss_sum += unwrap_scalar_variable(valid_loss)
                    valid_accuracy_sum += unwrap_scalar_variable(
                        valid_accuracy)
                valid_loss = valid_loss_sum / num_valid_batches
                valid_accuracy = valid_accuracy_sum / num_valid_batches
                add_scalar_summary(summary_writer=valid_summary_writer,
                                   name='loss',
                                   value=valid_loss,
                                   step=iter_count)
                add_scalar_summary(summary_writer=valid_summary_writer,
                                   name='accuracy',
                                   value=valid_accuracy,
                                   step=iter_count)
                progress = epoch_num + batch_iter / num_train_batches
                logging.info(f'Epoch {progress:.2f}: '
                             f'valid loss = {valid_loss:.4f}, '
                             f'valid accuracy = {valid_accuracy:.4f}')
                if valid_accuracy > best_vaild_accuacy:
                    best_vaild_accuacy = valid_accuracy
                    model_filename = (f'model-{progress:.2f}'
                                      f'-{valid_loss:.4f}'
                                      f'-{valid_accuracy:.4f}.pkl')
                    model_path = os.path.join(args.save_dir, model_filename)
                    torch.save(model.state_dict(), model_path)
                    print(f'Saved the new best model to {model_path}')
Ejemplo n.º 2
0
def train(args):
    with open(args.train_data, 'rb') as f:
        train_dataset: SNLIDataset = pickle.load(f)
    with open(args.valid_data, 'rb') as f:
        valid_dataset: SNLIDataset = pickle.load(f)

    train_loader = DataLoader(dataset=train_dataset,
                              batch_size=args.batch_size,
                              shuffle=True,
                              num_workers=2,
                              collate_fn=train_dataset.collate,
                              pin_memory=True)
    valid_loader = DataLoader(dataset=valid_dataset,
                              batch_size=args.batch_size,
                              shuffle=False,
                              num_workers=2,
                              collate_fn=valid_dataset.collate,
                              pin_memory=True)
    word_vocab = train_dataset.word_vocab
    label_vocab = train_dataset.label_vocab

    model = SNLIModel(num_classes=len(label_vocab),
                      num_words=len(word_vocab),
                      word_dim=args.word_dim,
                      hidden_dim=args.hidden_dim,
                      clf_hidden_dim=args.clf_hidden_dim,
                      clf_num_layers=args.clf_num_layers,
                      use_leaf_rnn=args.leaf_rnn,
                      use_batchnorm=args.batchnorm,
                      intra_attention=args.intra_attention,
                      dropout_prob=args.dropout,
                      bidirectional=args.bidirectional)
    if args.glove:
        logging.info('Loading GloVe pretrained vectors...')
        glove_weight = load_glove(
            path=args.glove,
            vocab=word_vocab,
            init_weight=model.word_embedding.weight.data.numpy())
        glove_weight[word_vocab.pad_id] = 0
        model.word_embedding.weight.data.set_(torch.FloatTensor(glove_weight))
    if args.fix_word_embedding:
        logging.info('Will not update word embeddings')
        model.word_embedding.weight.requires_grad = False
    model.to(args.device)
    logging.info(f'Using device {args.device}')
    if args.optimizer == 'adam':
        optimizer_class = optim.Adam
    elif args.optimizer == 'adagrad':
        optimizer_class = optim.Adagrad
    elif args.optimizer == 'adadelta':
        optimizer_class = optim.Adadelta
    params = [p for p in model.parameters() if p.requires_grad]
    optimizer = optimizer_class(params=params, weight_decay=args.l2reg)
    scheduler = lr_scheduler.ReduceLROnPlateau(optimizer=optimizer,
                                               mode='max',
                                               factor=0.5,
                                               patience=10,
                                               verbose=True)
    criterion = nn.CrossEntropyLoss()

    train_summary_writer = SummaryWriter(
        log_dir=os.path.join(args.save_dir, 'log', 'train'))
    valid_summary_writer = SummaryWriter(
        log_dir=os.path.join(args.save_dir, 'log', 'valid'))

    def run_iter(batch, is_training):
        model.train(is_training)
        pre = batch['pre'].to(args.device)
        hyp = batch['hyp'].to(args.device)
        pre_length = batch['pre_length'].to(args.device)
        hyp_length = batch['hyp_length'].to(args.device)
        label = batch['label'].to(args.device)
        logits = model(pre=pre,
                       pre_length=pre_length,
                       hyp=hyp,
                       hyp_length=hyp_length)
        label_pred = logits.max(1)[1]
        accuracy = torch.eq(label, label_pred).float().mean()
        loss = criterion(input=logits, target=label)
        if is_training:
            optimizer.zero_grad()
            loss.backward()
            clip_grad_norm_(parameters=params, max_norm=5)
            optimizer.step()
        return loss, accuracy

    def add_scalar_summary(summary_writer, name, value, step):
        if torch.is_tensor(value):
            value = value.item()
        summary_writer.add_scalar(tag=name,
                                  scalar_value=value,
                                  global_step=step)

    num_train_batches = len(train_loader)
    validate_every = num_train_batches // 10
    best_vaild_accuacy = 0
    iter_count = 0
    for epoch_num in range(args.max_epoch):
        logging.info(f'Epoch {epoch_num}: start')
        for batch_iter, train_batch in enumerate(train_loader):
            if iter_count % args.anneal_temperature_every == 0:
                rate = args.anneal_temperature_rate
                new_temperature = max([0.5, math.exp(-rate * iter_count)])
                model.encoder.gumbel_temperature = new_temperature
                logging.info(
                    f'Iter #{iter_count}: '
                    f'Set Gumbel temperature to {new_temperature:.4f}')
            train_loss, train_accuracy = run_iter(batch=train_batch,
                                                  is_training=True)
            iter_count += 1
            add_scalar_summary(summary_writer=train_summary_writer,
                               name='loss',
                               value=train_loss,
                               step=iter_count)
            add_scalar_summary(summary_writer=train_summary_writer,
                               name='accuracy',
                               value=train_accuracy,
                               step=iter_count)

            if (batch_iter + 1) % validate_every == 0:
                torch.set_grad_enabled(False)
                valid_loss_sum = valid_accuracy_sum = 0
                num_valid_batches = len(valid_loader)
                for valid_batch in valid_loader:
                    valid_loss, valid_accuracy = run_iter(batch=valid_batch,
                                                          is_training=False)
                    valid_loss_sum += valid_loss.item()
                    valid_accuracy_sum += valid_accuracy.item()
                torch.set_grad_enabled(True)
                valid_loss = valid_loss_sum / num_valid_batches
                valid_accuracy = valid_accuracy_sum / num_valid_batches
                scheduler.step(valid_accuracy)
                add_scalar_summary(summary_writer=valid_summary_writer,
                                   name='loss',
                                   value=valid_loss,
                                   step=iter_count)
                add_scalar_summary(summary_writer=valid_summary_writer,
                                   name='accuracy',
                                   value=valid_accuracy,
                                   step=iter_count)
                progress = epoch_num + batch_iter / num_train_batches
                logging.info(f'Epoch {progress:.2f}: '
                             f'valid loss = {valid_loss:.4f}, '
                             f'valid accuracy = {valid_accuracy:.4f}')
                if valid_accuracy > best_vaild_accuacy:
                    best_vaild_accuacy = valid_accuracy
                    model_filename = (f'model-{progress:.2f}'
                                      f'-{valid_loss:.4f}'
                                      f'-{valid_accuracy:.4f}.pkl')
                    model_path = os.path.join(args.save_dir, model_filename)
                    torch.save(model.state_dict(), model_path)
                    print(f'Saved the new best model to {model_path}')