def test(arg_list=None): args, device, checkpoint = init_pipeline(arg_list) criterion = get_loss_initializer(args.loss)() test_loader = load_test_data(args, device) init_params = checkpoint.get("model_init", []) model = get_model_initializer(args.model)(*init_params).to(device) util.load_state_dict(checkpoint, model) torchsummary.summary(model, model.input_shape) test_model(test_loader, model, criterion)
def viz(): args, device, checkpoint = init_pipeline() train_loader, _, init_params = load_train_data(args, device) init_params = checkpoint.get('model_init', init_params) model = get_model_initializer(args.model)(*init_params).to(device) util.load_state_dict(checkpoint, model) sample_loader = iter(train_loader) visualize(model, sample_loader) visualize_trained(model, sample_loader)
def train(arg_list=None): args, device, checkpoint = init_pipeline(arg_list) train_loader, val_loader, init_params = load_train_data(args, device) sample_loader = util.get_sample_loader(train_loader) model, criterion, optimizer, scheduler = load_model( args, device, init_params, sample_loader) util.load_state_dict(checkpoint, model, optimizer, scheduler) metrics = MetricTracker(args, checkpoint) if not args.no_visualize: metrics.add_network(model, sample_loader) visualize(model, sample_loader, metrics.run_name) util.set_rng_state(checkpoint) for _ in range(args.epochs): metrics.next_epoch() train_and_validate(args, model, train_loader, optimizer, criterion, metrics, Mode.TRAIN) train_and_validate(args, model, val_loader, None, criterion, metrics, Mode.VAL) if args.scheduler: scheduler.step() if not args.no_save: util.save_checkpoint( { "model_init": init_params, "model_state_dict": model.state_dict(), "optimizer_state_dict": optimizer.state_dict(), "scheduler_state_dict": scheduler.state_dict() if args.scheduler else None, "rng_state": random.getstate(), "np_rng_state": np.random.get_state(), "torch_rng_state": torch.get_rng_state(), "run_name": metrics.run_name, "metric_obj": metrics.json_repr(), }, metrics.is_best, ) torch.set_grad_enabled(True) if not args.no_visualize: visualize_trained(model, sample_loader, metrics.run_name) return metrics
def load_model(args, device, checkpoint, init_params, train_loader): criterion = LOSS_DICT[args.loss]() model = get_model_initializer(args.model)(*init_params).to(device) if args.model == 'UNet': for ind, param in enumerate(model.parameters()): if ind < 20: param.requires_grad = False optimizer = optim.AdamW(filter(lambda p: p.requires_grad, model.parameters()), lr=args.lr) verify_model(model, train_loader, optimizer, criterion, device) util.load_state_dict(checkpoint, model, optimizer) return model, criterion, optimizer
def main(): args, device, checkpoint = init_pipeline() train_loader, _, _, _, init_params = load_train_data(args, device) model = Model(*init_params).to(device) util.load_state_dict(checkpoint, model) visualize(model, train_loader)