示例#1
0
def train(arg_list=None):
    args, device, checkpoint = init_pipeline(arg_list)
    train_loader, val_loader, init_params = load_train_data(args, device)
    model, criterion, optimizer = load_model(args, device, checkpoint,
                                             init_params, train_loader)
    run_name, metrics = init_metrics(args, checkpoint)
    if args.visualize:
        metrics.add_network(model, train_loader)
        # visualize(model, train_loader, run_name)

    util.set_rng_state(checkpoint)
    start_epoch = metrics.epoch + 1
    for epoch in range(start_epoch, start_epoch + args.epochs):
        print(f'Epoch [{epoch}/{start_epoch + args.epochs - 1}]')
        metrics.next_epoch()
        tr_loss = train_and_validate(model, train_loader, optimizer, criterion,
                                     metrics, Mode.TRAIN, args.binary)
        val_loss = train_and_validate(model, val_loader, optimizer, criterion,
                                      metrics, Mode.VAL, args.binary)
        is_best = metrics.update_best_metric(val_loss)
        util.save_checkpoint(
            {
                'model_init': init_params,
                'state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'rng_state': random.getstate(),
                'np_rng_state': np.random.get_state(),
                'torch_rng_state': torch.get_rng_state(),
                'run_name': run_name,
                'metric_obj': metrics.json_repr()
            }, run_name, is_best)

    return val_loss
    def test_init_metrics():
        arg_list = ["--no-save", "--no-visualize", "--num-examples=100", "--epochs=1"]
        args, _, _ = init_pipeline(arg_list)

        metrics = MetricTracker(args, {})

        assert metrics
示例#3
0
def test(arg_list=None):
    args, device, checkpoint = init_pipeline(arg_list)
    criterion = get_loss_initializer(args.loss)()
    test_loader = load_test_data(args, device)
    init_params = checkpoint.get("model_init", [])
    model = get_model_initializer(args.model)(*init_params).to(device)
    util.load_state_dict(checkpoint, model)
    torchsummary.summary(model, model.input_shape)

    test_model(test_loader, model, criterion)
示例#4
0
文件: viz.py 项目: TylerYep/maestro
def viz():
    args, device, checkpoint = init_pipeline()
    train_loader, _, init_params = load_train_data(args, device)
    init_params = checkpoint.get('model_init', init_params)
    model = get_model_initializer(args.model)(*init_params).to(device)
    util.load_state_dict(checkpoint, model)

    sample_loader = iter(train_loader)
    visualize(model, sample_loader)
    visualize_trained(model, sample_loader)
示例#5
0
def evaluate():
    args, device, checkpoint = init_pipeline()
    _, val_loader, _, _, _ = load_train_data(args, device)

    model = Binary()
    # binary_checkpoint = 'AG'
    # binary_path = os.path.join('checkpoints', binary_checkpoint, 'model_best.pth.tar')
    # bin_model_weights = torch.load(binary_path)
    # model.load_state_dict(bin_model_weights['state_dict'])
    result = evaluate_model(model, val_loader, device)
    print(result)
示例#6
0
def train(arg_list=None):
    args, device, checkpoint = init_pipeline(arg_list)
    train_loader, val_loader, init_params = load_train_data(args, device)
    sample_loader = util.get_sample_loader(train_loader)
    model, criterion, optimizer, scheduler = load_model(
        args, device, init_params, sample_loader)
    util.load_state_dict(checkpoint, model, optimizer, scheduler)
    metrics = MetricTracker(args, checkpoint)
    if not args.no_visualize:
        metrics.add_network(model, sample_loader)
        visualize(model, sample_loader, metrics.run_name)

    util.set_rng_state(checkpoint)
    for _ in range(args.epochs):
        metrics.next_epoch()
        train_and_validate(args, model, train_loader, optimizer, criterion,
                           metrics, Mode.TRAIN)
        train_and_validate(args, model, val_loader, None, criterion, metrics,
                           Mode.VAL)
        if args.scheduler:
            scheduler.step()

        if not args.no_save:
            util.save_checkpoint(
                {
                    "model_init":
                    init_params,
                    "model_state_dict":
                    model.state_dict(),
                    "optimizer_state_dict":
                    optimizer.state_dict(),
                    "scheduler_state_dict":
                    scheduler.state_dict() if args.scheduler else None,
                    "rng_state":
                    random.getstate(),
                    "np_rng_state":
                    np.random.get_state(),
                    "torch_rng_state":
                    torch.get_rng_state(),
                    "run_name":
                    metrics.run_name,
                    "metric_obj":
                    metrics.json_repr(),
                },
                metrics.is_best,
            )

    torch.set_grad_enabled(True)
    if not args.no_visualize:
        visualize_trained(model, sample_loader, metrics.run_name)

    return metrics
示例#7
0
def train(arg_list: list[str] | None = None) -> None:
    args, checkpoint = init_pipeline(arg_list)
    (
        (train_images, train_labels),
        (_, _),
        # (test_images, test_labels),
        class_labels,
        init_params,
    ) = get_dataset_initializer(args.dataset).load_train_data()
    model = load_model(checkpoint, init_params, train_images, train_labels)
    # add_network(model, train_loader, device)
    # visualize(model, train_loader, class_labels, device, run_name)
    # util.set_rng_state(checkpoint)
    train_and_validate(args, model, train_images, train_labels, class_labels)
示例#8
0
    def test_one_batch_update(example_batch):
        data, loss, output, target, batch_size = \
            example_batch.data, example_batch.loss, \
            example_batch.output, example_batch.target, example_batch.batch_size
        arg_list = ['--no-save', '--no-visualize', '--epochs=1']
        args, _, _ = init_pipeline(arg_list)
        metrics = MetricTracker(args, {})

        tqdm_dict = metrics.batch_update(0, 1, batch_size, data, loss, output, target, Mode.TRAIN)

        for key in tqdm_dict:
            tqdm_dict[key] = round(tqdm_dict[key], 2)
        result = [round(metric.epoch_avg, 2) for metric in metrics.metric_data.values()]
        assert tqdm_dict == {'Loss': 0.21, 'Accuracy': 0.67}
        assert result == [0.63, 2.0]
示例#9
0
    def test_epoch_update(capsys, example_batch):
        data, loss, output, target, batch_size = \
            example_batch.data, example_batch.loss, \
            example_batch.output, example_batch.target, example_batch.batch_size
        arg_list = ['--no-save', '--no-visualize', '--epochs=1', '--log-interval=3']
        args, _, _ = init_pipeline(arg_list)
        metrics = MetricTracker(args, {})
        num_batches = 4
        for i in range(num_batches):
            _ = metrics.batch_update(i, num_batches, batch_size,
                                     data, loss, output, target, Mode.TRAIN)

        metrics.epoch_update(Mode.TRAIN)

        captured = capsys.readouterr().out
        assert captured == 'Mode.TRAIN Loss: 0.2100 Accuracy: 66.67% \n'
示例#10
0
    def test_many_batch_update(example_batch):
        data, loss, output, target, batch_size = \
            example_batch.data, example_batch.loss, \
            example_batch.output, example_batch.target, example_batch.batch_size
        arg_list = ['--no-save', '--no-visualize', '--epochs=1', '--log-interval=3']
        args, _, _ = init_pipeline(arg_list)
        metrics = MetricTracker(args, {})
        num_batches = 4

        for i in range(num_batches):
            tqdm_dict = metrics.batch_update(i, num_batches, batch_size,
                                             data, loss, output, target, Mode.TRAIN)

        for key in tqdm_dict:
            tqdm_dict[key] = round(tqdm_dict[key], 2)
        result = [round(metric.epoch_avg, 2) for metric in metrics.metric_data.values()]
        assert tqdm_dict == {'Loss': 0.21, 'Accuracy': 0.67}
        assert all([metric.running_avg == 0. for metric in metrics.metric_data.values()])
        assert result == [2.52, 8.]
示例#11
0
文件: viz.py 项目: TylerYep/lung-xray
def main():
    args, device, checkpoint = init_pipeline()
    train_loader, _, _, _, init_params = load_train_data(args, device)
    model = Model(*init_params).to(device)
    util.load_state_dict(checkpoint, model)
    visualize(model, train_loader)