def test_init_metrics() -> None: args, _, _ = init_pipeline("--no-save", "--no-visualize", "--num-examples=100", "--epochs=1") metrics = MetricTracker(args, {}) assert metrics
def viz() -> None: args, device, checkpoint = init_pipeline() dataset_loader = get_dataset_initializer(args.dataset) train_loader, _, init_params = dataset_loader.load_train_data(args, device) init_params = checkpoint.get("model_init", init_params) model = get_model_initializer(args.model)(*init_params).to(device) util.load_state_dict(checkpoint, model) sample_loader = util.get_sample_loader(train_loader) visualize(args, model, sample_loader) visualize_trained(args, model, sample_loader)
def test(*arg_list: str) -> None: args, device, checkpoint = init_pipeline(*arg_list) criterion = get_loss_initializer(args.loss)() test_loader = get_dataset_initializer(args.dataset).load_test_data( args, device) init_params = checkpoint.get("model_init", []) model = get_model_initializer(args.model)(*init_params).to(device) util.load_state_dict(checkpoint, model) sample_loader = util.get_sample_loader(test_loader) model_summary(args, model, sample_loader) test_model(args, model, test_loader, criterion)
def test_epoch_update(capsys: pytest.CaptureFixture[str], example_batch: SimpleNamespace) -> None: args, _, _ = init_pipeline("--no-save", "--no-visualize", "--epochs=1", "--log-interval=3") metrics = MetricTracker(args, {}) num_batches = 4 for i in range(num_batches): _ = metrics.batch_update(example_batch, i, num_batches, Mode.TRAIN) metrics.epoch_update(Mode.TRAIN) captured = capsys.readouterr().out assert captured == "Mode.TRAIN Loss: 0.2100 Accuracy: 66.67% \n"
def test_one_batch_update(example_batch: SimpleNamespace) -> None: args, _, _ = init_pipeline("--no-save", "--no-visualize", "--epochs=1") metrics = MetricTracker(args, {}) tqdm_dict = metrics.batch_update(example_batch, 0, 1, Mode.TRAIN) for key, value in tqdm_dict.items(): tqdm_dict[key] = round(value, 2) result = [ round(metric.epoch_avg, 2) for metric in metrics.metric_data.values() ] assert tqdm_dict == {"Loss": 0.21, "Accuracy": 0.67} assert result == [0.63, 2]
def train(*arg_list: str) -> MetricTracker: args, device, checkpoint = init_pipeline(*arg_list) dataset_loader = get_dataset_initializer(args.dataset) train_loader, val_loader, init_params = dataset_loader.load_train_data( args, device) sample_loader = util.get_sample_loader(train_loader) model, criterion, optimizer, scheduler = load_model( args, device, init_params, sample_loader) util.load_state_dict(checkpoint, model, optimizer, scheduler) metrics = MetricTracker(args, checkpoint, dataset_loader.CLASS_LABELS) visualize(args, model, sample_loader, metrics) util.set_rng_state(checkpoint) for _ in range(args.epochs): metrics.next_epoch() train_and_validate(args, model, train_loader, optimizer, criterion, metrics, Mode.TRAIN) train_and_validate(args, model, val_loader, None, criterion, metrics, Mode.VAL) if scheduler is not None: scheduler.step() if not args.no_save: checkpoint_dict = { "model_init": init_params, "model_state_dict": model.state_dict(), "optimizer_state_dict": optimizer.state_dict(), "scheduler_state_dict": (None if scheduler is None else scheduler.state_dict()), "rng_state": random.getstate(), "np_rng_state": np.random.get_state(), "torch_rng_state": torch.get_rng_state(), "run_name": metrics.run_name, "metric_obj": metrics.json_repr(), } util.save_checkpoint(checkpoint_dict, metrics.is_best) torch.set_grad_enabled(True) visualize_trained(args, model, sample_loader, metrics) return metrics
def test_many_batch_update(example_batch: SimpleNamespace) -> None: args, _, _ = init_pipeline("--no-save", "--no-visualize", "--epochs=1", "--log-interval=3") metrics = MetricTracker(args, {}) num_batches = 4 tqdm_dict = {} for i in range(num_batches): tqdm_dict = metrics.batch_update(example_batch, i, num_batches, Mode.TRAIN) for key, value in tqdm_dict.items(): tqdm_dict[key] = round(value, 2) result = [ round(metric.epoch_avg, 2) for metric in metrics.metric_data.values() ] assert tqdm_dict == {"Loss": 0.21, "Accuracy": 0.67} assert all(metric.running_avg == 0 for metric in metrics.metric_data.values()) assert result == [2.52, 8]