Ejemplo n.º 1
0
def test(arg_list=None):
    args, device, checkpoint = init_pipeline(arg_list)
    criterion = get_loss_initializer(args.loss)()
    test_loader = load_test_data(args, device)
    init_params = checkpoint.get("model_init", [])
    model = get_model_initializer(args.model)(*init_params).to(device)
    util.load_state_dict(checkpoint, model)
    torchsummary.summary(model, model.input_shape)

    test_model(test_loader, model, criterion)
Ejemplo n.º 2
0
def viz():
    args, device, checkpoint = init_pipeline()
    train_loader, _, init_params = load_train_data(args, device)
    init_params = checkpoint.get('model_init', init_params)
    model = get_model_initializer(args.model)(*init_params).to(device)
    util.load_state_dict(checkpoint, model)

    sample_loader = iter(train_loader)
    visualize(model, sample_loader)
    visualize_trained(model, sample_loader)
Ejemplo n.º 3
0
def train(arg_list=None):
    args, device, checkpoint = init_pipeline(arg_list)
    train_loader, val_loader, init_params = load_train_data(args, device)
    sample_loader = util.get_sample_loader(train_loader)
    model, criterion, optimizer, scheduler = load_model(
        args, device, init_params, sample_loader)
    util.load_state_dict(checkpoint, model, optimizer, scheduler)
    metrics = MetricTracker(args, checkpoint)
    if not args.no_visualize:
        metrics.add_network(model, sample_loader)
        visualize(model, sample_loader, metrics.run_name)

    util.set_rng_state(checkpoint)
    for _ in range(args.epochs):
        metrics.next_epoch()
        train_and_validate(args, model, train_loader, optimizer, criterion,
                           metrics, Mode.TRAIN)
        train_and_validate(args, model, val_loader, None, criterion, metrics,
                           Mode.VAL)
        if args.scheduler:
            scheduler.step()

        if not args.no_save:
            util.save_checkpoint(
                {
                    "model_init":
                    init_params,
                    "model_state_dict":
                    model.state_dict(),
                    "optimizer_state_dict":
                    optimizer.state_dict(),
                    "scheduler_state_dict":
                    scheduler.state_dict() if args.scheduler else None,
                    "rng_state":
                    random.getstate(),
                    "np_rng_state":
                    np.random.get_state(),
                    "torch_rng_state":
                    torch.get_rng_state(),
                    "run_name":
                    metrics.run_name,
                    "metric_obj":
                    metrics.json_repr(),
                },
                metrics.is_best,
            )

    torch.set_grad_enabled(True)
    if not args.no_visualize:
        visualize_trained(model, sample_loader, metrics.run_name)

    return metrics
Ejemplo n.º 4
0
def load_model(args, device, checkpoint, init_params, train_loader):
    criterion = LOSS_DICT[args.loss]()
    model = get_model_initializer(args.model)(*init_params).to(device)
    if args.model == 'UNet':
        for ind, param in enumerate(model.parameters()):
            if ind < 20:
                param.requires_grad = False
    optimizer = optim.AdamW(filter(lambda p: p.requires_grad,
                                   model.parameters()),
                            lr=args.lr)
    verify_model(model, train_loader, optimizer, criterion, device)
    util.load_state_dict(checkpoint, model, optimizer)
    return model, criterion, optimizer
Ejemplo n.º 5
0
def main():
    args, device, checkpoint = init_pipeline()
    train_loader, _, _, _, init_params = load_train_data(args, device)
    model = Model(*init_params).to(device)
    util.load_state_dict(checkpoint, model)
    visualize(model, train_loader)