Example #1
0
def train(arg_list=None):
    args, device, checkpoint = init_pipeline(arg_list)
    train_loader, val_loader, init_params = load_train_data(args, device)
    model, criterion, optimizer = load_model(args, device, checkpoint,
                                             init_params, train_loader)
    run_name, metrics = init_metrics(args, checkpoint)
    if args.visualize:
        metrics.add_network(model, train_loader)
        # visualize(model, train_loader, run_name)

    util.set_rng_state(checkpoint)
    start_epoch = metrics.epoch + 1
    for epoch in range(start_epoch, start_epoch + args.epochs):
        print(f'Epoch [{epoch}/{start_epoch + args.epochs - 1}]')
        metrics.next_epoch()
        tr_loss = train_and_validate(model, train_loader, optimizer, criterion,
                                     metrics, Mode.TRAIN, args.binary)
        val_loss = train_and_validate(model, val_loader, optimizer, criterion,
                                      metrics, Mode.VAL, args.binary)
        is_best = metrics.update_best_metric(val_loss)
        util.save_checkpoint(
            {
                'model_init': init_params,
                'state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'rng_state': random.getstate(),
                'np_rng_state': np.random.get_state(),
                'torch_rng_state': torch.get_rng_state(),
                'run_name': run_name,
                'metric_obj': metrics.json_repr()
            }, run_name, is_best)

    return val_loss
Example #2
0
def viz():
    args, device, checkpoint = init_pipeline()
    train_loader, _, init_params = load_train_data(args, device)
    init_params = checkpoint.get('model_init', init_params)
    model = get_model_initializer(args.model)(*init_params).to(device)
    util.load_state_dict(checkpoint, model)

    sample_loader = iter(train_loader)
    visualize(model, sample_loader)
    visualize_trained(model, sample_loader)
Example #3
0
def evaluate():
    args, device, checkpoint = init_pipeline()
    _, val_loader, _, _, _ = load_train_data(args, device)

    model = Binary()
    # binary_checkpoint = 'AG'
    # binary_path = os.path.join('checkpoints', binary_checkpoint, 'model_best.pth.tar')
    # bin_model_weights = torch.load(binary_path)
    # model.load_state_dict(bin_model_weights['state_dict'])
    result = evaluate_model(model, val_loader, device)
    print(result)
Example #4
0
def train(arg_list=None):
    args, device, checkpoint = init_pipeline(arg_list)
    train_loader, val_loader, init_params = load_train_data(args, device)
    sample_loader = util.get_sample_loader(train_loader)
    model, criterion, optimizer, scheduler = load_model(
        args, device, init_params, sample_loader)
    util.load_state_dict(checkpoint, model, optimizer, scheduler)
    metrics = MetricTracker(args, checkpoint)
    if not args.no_visualize:
        metrics.add_network(model, sample_loader)
        visualize(model, sample_loader, metrics.run_name)

    util.set_rng_state(checkpoint)
    for _ in range(args.epochs):
        metrics.next_epoch()
        train_and_validate(args, model, train_loader, optimizer, criterion,
                           metrics, Mode.TRAIN)
        train_and_validate(args, model, val_loader, None, criterion, metrics,
                           Mode.VAL)
        if args.scheduler:
            scheduler.step()

        if not args.no_save:
            util.save_checkpoint(
                {
                    "model_init":
                    init_params,
                    "model_state_dict":
                    model.state_dict(),
                    "optimizer_state_dict":
                    optimizer.state_dict(),
                    "scheduler_state_dict":
                    scheduler.state_dict() if args.scheduler else None,
                    "rng_state":
                    random.getstate(),
                    "np_rng_state":
                    np.random.get_state(),
                    "torch_rng_state":
                    torch.get_rng_state(),
                    "run_name":
                    metrics.run_name,
                    "metric_obj":
                    metrics.json_repr(),
                },
                metrics.is_best,
            )

    torch.set_grad_enabled(True)
    if not args.no_visualize:
        visualize_trained(model, sample_loader, metrics.run_name)

    return metrics
Example #5
0
def main():
    args, device, checkpoint = init_pipeline()
    train_loader, _, _, _, init_params = load_train_data(args, device)
    model = Model(*init_params).to(device)
    util.load_state_dict(checkpoint, model)
    visualize(model, train_loader)