Ejemplo n.º 1
0
    def test_one_batch_update(example_batch):
        data, loss, output, target, batch_size = \
            example_batch.data, example_batch.loss, \
            example_batch.output, example_batch.target, example_batch.batch_size
        arg_list = ['--no-save', '--no-visualize', '--epochs=1']
        args, _, _ = init_pipeline(arg_list)
        metrics = MetricTracker(args, {})

        tqdm_dict = metrics.batch_update(0, 1, batch_size, data, loss, output, target, Mode.TRAIN)

        for key in tqdm_dict:
            tqdm_dict[key] = round(tqdm_dict[key], 2)
        result = [round(metric.epoch_avg, 2) for metric in metrics.metric_data.values()]
        assert tqdm_dict == {'Loss': 0.21, 'Accuracy': 0.67}
        assert result == [0.63, 2.0]
Ejemplo n.º 2
0
    def test_init_metrics():
        arg_list = ["--no-save", "--no-visualize", "--num-examples=100", "--epochs=1"]
        args, _, _ = init_pipeline(arg_list)

        metrics = MetricTracker(args, {})

        assert metrics
Ejemplo n.º 3
0
    def test_epoch_update(capsys, example_batch):
        data, loss, output, target, batch_size = \
            example_batch.data, example_batch.loss, \
            example_batch.output, example_batch.target, example_batch.batch_size
        arg_list = ['--no-save', '--no-visualize', '--epochs=1', '--log-interval=3']
        args, _, _ = init_pipeline(arg_list)
        metrics = MetricTracker(args, {})
        num_batches = 4
        for i in range(num_batches):
            _ = metrics.batch_update(i, num_batches, batch_size,
                                     data, loss, output, target, Mode.TRAIN)

        metrics.epoch_update(Mode.TRAIN)

        captured = capsys.readouterr().out
        assert captured == 'Mode.TRAIN Loss: 0.2100 Accuracy: 66.67% \n'
Ejemplo n.º 4
0
def init_metrics(args, checkpoint):
    run_name = checkpoint.get('run_name', util.get_run_name(args))
    metric_checkpoint = checkpoint.get('metric_obj', {})
    metrics = MetricTracker(args.metric_names, run_name, args.log_interval,
                            **metric_checkpoint)
    with open(os.path.join(run_name, 'args.json'),
              'w') as f:  # Save used args to checkpoint folder
        json.dump(args.__dict__, f, indent=4)
    return run_name, metrics
Ejemplo n.º 5
0
    def test_many_batch_update(example_batch):
        data, loss, output, target, batch_size = \
            example_batch.data, example_batch.loss, \
            example_batch.output, example_batch.target, example_batch.batch_size
        arg_list = ['--no-save', '--no-visualize', '--epochs=1', '--log-interval=3']
        args, _, _ = init_pipeline(arg_list)
        metrics = MetricTracker(args, {})
        num_batches = 4

        for i in range(num_batches):
            tqdm_dict = metrics.batch_update(i, num_batches, batch_size,
                                             data, loss, output, target, Mode.TRAIN)

        for key in tqdm_dict:
            tqdm_dict[key] = round(tqdm_dict[key], 2)
        result = [round(metric.epoch_avg, 2) for metric in metrics.metric_data.values()]
        assert tqdm_dict == {'Loss': 0.21, 'Accuracy': 0.67}
        assert all([metric.running_avg == 0. for metric in metrics.metric_data.values()])
        assert result == [2.52, 8.]
Ejemplo n.º 6
0
def train(arg_list=None):
    args, device, checkpoint = init_pipeline(arg_list)
    train_loader, val_loader, init_params = load_train_data(args, device)
    sample_loader = util.get_sample_loader(train_loader)
    model, criterion, optimizer, scheduler = load_model(
        args, device, init_params, sample_loader)
    util.load_state_dict(checkpoint, model, optimizer, scheduler)
    metrics = MetricTracker(args, checkpoint)
    if not args.no_visualize:
        metrics.add_network(model, sample_loader)
        visualize(model, sample_loader, metrics.run_name)

    util.set_rng_state(checkpoint)
    for _ in range(args.epochs):
        metrics.next_epoch()
        train_and_validate(args, model, train_loader, optimizer, criterion,
                           metrics, Mode.TRAIN)
        train_and_validate(args, model, val_loader, None, criterion, metrics,
                           Mode.VAL)
        if args.scheduler:
            scheduler.step()

        if not args.no_save:
            util.save_checkpoint(
                {
                    "model_init":
                    init_params,
                    "model_state_dict":
                    model.state_dict(),
                    "optimizer_state_dict":
                    optimizer.state_dict(),
                    "scheduler_state_dict":
                    scheduler.state_dict() if args.scheduler else None,
                    "rng_state":
                    random.getstate(),
                    "np_rng_state":
                    np.random.get_state(),
                    "torch_rng_state":
                    torch.get_rng_state(),
                    "run_name":
                    metrics.run_name,
                    "metric_obj":
                    metrics.json_repr(),
                },
                metrics.is_best,
            )

    torch.set_grad_enabled(True)
    if not args.no_visualize:
        visualize_trained(model, sample_loader, metrics.run_name)

    return metrics