Example #1
0
def test_row_key():
    data = [[[1, 2, 3], [1, 2, 3]], [[1, 2], [1, 2]], [[3, 4, 5], [3, 4, 5]],
            [[3, 4], [3, 4]], [[6, 7, 8], [6, 7, 8]]]

    batch_size = 2
    bm = BucketManager(data,
                       1,
                       3,
                       batch_size=batch_size,
                       shuffle=False,
                       row_key=lambda x: len(x[0]))
    assert (bm.buckets[0][bm.DATA_KEY] == [])
    assert (bm.buckets[1][bm.END_INDEX_KEY] == 2)
    assert (bm.buckets[2][bm.END_INDEX_KEY] == 3)
    batch = next(bm)
    assert (len(batch) == 2)
    batch = next(bm)
    batch = next(bm)
    with pytest.raises(StopIteration):
        batch = next(bm)
    bm = BucketManager(data,
                       1,
                       3,
                       batch_size=batch_size,
                       shuffle=False,
                       row_key=lambda x: len(x))
    assert (bm.buckets[0][bm.DATA_KEY] == [])
    assert (bm.buckets[1][bm.END_INDEX_KEY] == 5)
    assert (bm.buckets[2][bm.END_INDEX_KEY] == 0)
Example #2
0
def test_reset():
    data = [[1, 2, 3], [4, 5, 6]]

    batch_size = 2
    bm = BucketManager(data, 1, 3, batch_size=batch_size, shuffle=False)
    batch = next(bm)
    assert (batch == [[1, 2, 3], [4, 5, 6]])
    with pytest.raises(StopIteration):
        batch = next(bm)
    bm.reset()
    batch = next(bm)
    assert (batch == [[1, 2, 3], [4, 5, 6]])
    with pytest.raises(StopIteration):
        batch = next(bm)
Example #3
0
def test_bucket_width():
    data = [
        [1, 2, 3],
        [4, 5, 6],
        [1, 2, 3],
        [2, 3],
        [5, 2],
    ]

    bucket_width = 3
    bm = BucketManager(data, bucket_width, 3, batch_size=2)
    assert (bm.buckets[0][bm.END_INDEX_KEY] == 5)
    bucket_width = 2
    bm = BucketManager(data, bucket_width, 3, batch_size=2)
    assert (bm.buckets[0][bm.END_INDEX_KEY] == 2)
    assert (bm.buckets[1][bm.END_INDEX_KEY] == 3)
Example #4
0
def test_right_leak_minus_one():
    data = [[[1, 2, 3, 4], [1, 2, 3, 4]], [[1, 2], [1, 2]]]

    batch_size = 2
    bm = BucketManager(data,
                       1,
                       5,
                       right_leak=1,
                       batch_size=batch_size,
                       shuffle=False,
                       row_key=lambda x: len(x[0]))
    # make it improbable that it will select bucket with index 4 instead of 1
    bm.left_samples[1] = 1000
    batch = next(bm)
    assert (len(batch) == 1)
    assert ([[1, 2, 3, 4], [1, 2, 3, 4]] not in batch)
Example #5
0
def test_basic():
    data = [[1, 2, 3], [4, 5, 6, 7], [8, 9, 3, 2], [4, 5, 6], [3, 4, 5, 6],
            [1, 2, 3]]

    batch_size = 3
    bm = BucketManager(data, 1, 4, batch_size=batch_size)
    for each in bm:
        assert len(each) == batch_size
Example #6
0
def test_nonzero_start():
    data = [
        [1, 2, 3],
        [4, 5, 6],
        [1, 2, 3],
        [2, 3],
        [5, 2],
    ]

    bucket_width = 2
    bm = BucketManager(data, bucket_width, 3, batch_size=2)
    assert (bm.buckets[0][bm.END_INDEX_KEY] == 2)
    assert (bm.buckets[1][bm.END_INDEX_KEY] == 3)

    bucket_width = 2
    bm = BucketManager(data, bucket_width, 4, min_len=2, batch_size=2)
    assert (bm.buckets[0][bm.END_INDEX_KEY] == 5)
Example #7
0
def test_zero_length():
    data = [
        [],
        [],
        [],
        [2, 3],
        [5, 2],
    ]

    bucket_width = 2
    with pytest.raises(IndexError):
        bm = BucketManager(data, bucket_width, 2, batch_size=2)
    bm = BucketManager(data, bucket_width, 2, min_len=0, batch_size=2)
    assert (bm.buckets[0][bm.END_INDEX_KEY] == 3)
    assert (bm.buckets[1][bm.END_INDEX_KEY] == 2)
    bucket_width = 3
    bm = BucketManager(data, bucket_width, 2, min_len=0, batch_size=2)
    assert (bm.buckets[0][bm.END_INDEX_KEY] == 5)
Example #8
0
def test_uneven():
    data = [[1, 2, 3], [4, 5, 6], [1, 2, 3]]

    batch_size = 2
    bm = BucketManager(data, 1, 3, batch_size=batch_size)
    batch = next(bm)
    assert (len(batch) == 2)
    batch = next(bm)
    assert (len(batch) == 1)
    with pytest.raises(StopIteration):
        batch = next(bm)
Example #9
0
def test_uneven_unshuffled():
    data = [[1, 2, 3], [4, 5, 6], [7, 8, 9]]

    batch_size = 2
    bm = BucketManager(data, 1, 3, batch_size=batch_size, shuffle=False)
    batch = next(bm)
    assert (batch == [[1, 2, 3], [4, 5, 6]])
    batch = next(bm)
    assert (batch == [[7, 8, 9]])
    with pytest.raises(StopIteration):
        batch = next(bm)
Example #10
0
def train_loop(train_rows,
               dev_rows,
               conf,
               checkpoint_callback=None,
               gpu_id=-1):

    model = conf.model
    if gpu_id >= 0:
        chainer.backends.cuda.get_device_from_id(gpu_id).use()
        model.to_gpu(gpu_id)

    train_buckets = BucketManager(train_rows,
                                  conf.train_buckets.bucket_width,
                                  conf.dataset.train_max_sent_len,
                                  shuffle=True,
                                  batch_size=conf.batch_size,
                                  right_leak=conf.train_buckets.right_leak,
                                  row_key=lambda x: len(x[0]),
                                  loop_forever=True)
    dev_batches = tuple(to_batches(dev_rows, conf.dev_batch_size, sort=True))

    print('training max seq len ', train_buckets.max_len)

    opt = chainer.optimizers.Adam(alpha=conf.optimizer.learning_rate)
    opt.setup(model)
    opt.add_hook(
        chainer.optimizer.GradientClipping(threshold=conf.optimizer.grad_clip))

    e = 0
    best_valid_las = 0.
    best_valid_acc = 0.
    patience = conf.checkpoint.patience
    # checkpoint.every defines how often to checkpoint in multiples of
    # the batch size.  if conf.every is <= 0 then we checkpoint each epoch
    cp_iters = conf.batch_size * conf.checkpoint.every \
        if conf.checkpoint.every > 0 else len(train_rows)

    iters_per_epoch = len(train_rows)
    current_iters = 0
    current_checkpoint = 0

    pbar = tqdm(desc='Epoch 0 - Patience %d' % patience)
    while e < conf.max_epochs:
        checkpoint_stats = dict()
        # train
        stats = train_epoch(model, opt, train_buckets, cp_iters, conf.mtl_swap)

        checkpoint_stats.update(**stats)

        # score dev set
        stats = eval_epoch(model,
                           dev_batches,
                           data_size=len(dev_rows),
                           label='valid',
                           num_labels=conf.model.num_labels)
        checkpoint_stats.update(**stats)

        if conf.model.alpha > 0.0:
            if checkpoint_stats['valid_las'] > best_valid_las:
                best_valid_las = checkpoint_stats['valid_las']
                best_valid_acc = checkpoint_stats['valid_mean_aux_acc']
                patience = conf.checkpoint.patience
            else:
                patience -= 1
            checkpoint_stats.update(patience=patience)
        else:
            if checkpoint_stats['valid_mean_aux_acc'] > best_valid_acc:
                best_valid_acc = checkpoint_stats['valid_mean_aux_acc']
                patience = conf.checkpoint.patience
            else:
                patience -= 1
            checkpoint_stats.update(patience=patience)

        current_iters += cp_iters
        e = int(current_iters / iters_per_epoch)
        current_checkpoint += 1
        pbar.set_description(
            'Epoch %d - Patience %d - Best LAS: %.2f UAS: %.2f - Aux. acc: %.2f'
            %
            (e, patience, best_valid_las * 100, checkpoint_stats['valid_uas'] *
             100, checkpoint_stats['valid_mean_aux_acc'] * 100))
        pbar.update()

        if checkpoint_callback is not None:
            checkpoint_callback(
                e,
                checkpoint_stats,
                improved=(patience == conf.checkpoint.patience))

        if patience == 0:
            break
    pbar.close()
    return model
Example #11
0
def test_exact_max_len():
    data = [[1, 2, 3], [4, 5, 6], [1, 2, 3]]

    batch_size = 2
    bm = BucketManager(data, 2, 3, batch_size=batch_size)
    assert (bm.buckets[1][bm.END_INDEX_KEY] == 3)
Example #12
0
def test_wrong_max_len():
    data = [[1, 2, 3], [4, 5, 6], [1, 2, 3]]

    batch_size = 2
    with pytest.raises(IndexError):
        bm = BucketManager(data, 1, 2, batch_size=batch_size)