def train(arg_list=None): args, device, checkpoint = init_pipeline(arg_list) train_loader, val_loader, init_params = load_train_data(args, device) model, criterion, optimizer = load_model(args, device, checkpoint, init_params, train_loader) run_name, metrics = init_metrics(args, checkpoint) if args.visualize: metrics.add_network(model, train_loader) # visualize(model, train_loader, run_name) util.set_rng_state(checkpoint) start_epoch = metrics.epoch + 1 for epoch in range(start_epoch, start_epoch + args.epochs): print(f'Epoch [{epoch}/{start_epoch + args.epochs - 1}]') metrics.next_epoch() tr_loss = train_and_validate(model, train_loader, optimizer, criterion, metrics, Mode.TRAIN, args.binary) val_loss = train_and_validate(model, val_loader, optimizer, criterion, metrics, Mode.VAL, args.binary) is_best = metrics.update_best_metric(val_loss) util.save_checkpoint( { 'model_init': init_params, 'state_dict': model.state_dict(), 'optimizer_state_dict': optimizer.state_dict(), 'rng_state': random.getstate(), 'np_rng_state': np.random.get_state(), 'torch_rng_state': torch.get_rng_state(), 'run_name': run_name, 'metric_obj': metrics.json_repr() }, run_name, is_best) return val_loss
def test_init_metrics(): arg_list = ["--no-save", "--no-visualize", "--num-examples=100", "--epochs=1"] args, _, _ = init_pipeline(arg_list) metrics = MetricTracker(args, {}) assert metrics
def test(arg_list=None): args, device, checkpoint = init_pipeline(arg_list) criterion = get_loss_initializer(args.loss)() test_loader = load_test_data(args, device) init_params = checkpoint.get("model_init", []) model = get_model_initializer(args.model)(*init_params).to(device) util.load_state_dict(checkpoint, model) torchsummary.summary(model, model.input_shape) test_model(test_loader, model, criterion)
def viz(): args, device, checkpoint = init_pipeline() train_loader, _, init_params = load_train_data(args, device) init_params = checkpoint.get('model_init', init_params) model = get_model_initializer(args.model)(*init_params).to(device) util.load_state_dict(checkpoint, model) sample_loader = iter(train_loader) visualize(model, sample_loader) visualize_trained(model, sample_loader)
def evaluate(): args, device, checkpoint = init_pipeline() _, val_loader, _, _, _ = load_train_data(args, device) model = Binary() # binary_checkpoint = 'AG' # binary_path = os.path.join('checkpoints', binary_checkpoint, 'model_best.pth.tar') # bin_model_weights = torch.load(binary_path) # model.load_state_dict(bin_model_weights['state_dict']) result = evaluate_model(model, val_loader, device) print(result)
def train(arg_list=None): args, device, checkpoint = init_pipeline(arg_list) train_loader, val_loader, init_params = load_train_data(args, device) sample_loader = util.get_sample_loader(train_loader) model, criterion, optimizer, scheduler = load_model( args, device, init_params, sample_loader) util.load_state_dict(checkpoint, model, optimizer, scheduler) metrics = MetricTracker(args, checkpoint) if not args.no_visualize: metrics.add_network(model, sample_loader) visualize(model, sample_loader, metrics.run_name) util.set_rng_state(checkpoint) for _ in range(args.epochs): metrics.next_epoch() train_and_validate(args, model, train_loader, optimizer, criterion, metrics, Mode.TRAIN) train_and_validate(args, model, val_loader, None, criterion, metrics, Mode.VAL) if args.scheduler: scheduler.step() if not args.no_save: util.save_checkpoint( { "model_init": init_params, "model_state_dict": model.state_dict(), "optimizer_state_dict": optimizer.state_dict(), "scheduler_state_dict": scheduler.state_dict() if args.scheduler else None, "rng_state": random.getstate(), "np_rng_state": np.random.get_state(), "torch_rng_state": torch.get_rng_state(), "run_name": metrics.run_name, "metric_obj": metrics.json_repr(), }, metrics.is_best, ) torch.set_grad_enabled(True) if not args.no_visualize: visualize_trained(model, sample_loader, metrics.run_name) return metrics
def train(arg_list: list[str] | None = None) -> None: args, checkpoint = init_pipeline(arg_list) ( (train_images, train_labels), (_, _), # (test_images, test_labels), class_labels, init_params, ) = get_dataset_initializer(args.dataset).load_train_data() model = load_model(checkpoint, init_params, train_images, train_labels) # add_network(model, train_loader, device) # visualize(model, train_loader, class_labels, device, run_name) # util.set_rng_state(checkpoint) train_and_validate(args, model, train_images, train_labels, class_labels)
def test_one_batch_update(example_batch): data, loss, output, target, batch_size = \ example_batch.data, example_batch.loss, \ example_batch.output, example_batch.target, example_batch.batch_size arg_list = ['--no-save', '--no-visualize', '--epochs=1'] args, _, _ = init_pipeline(arg_list) metrics = MetricTracker(args, {}) tqdm_dict = metrics.batch_update(0, 1, batch_size, data, loss, output, target, Mode.TRAIN) for key in tqdm_dict: tqdm_dict[key] = round(tqdm_dict[key], 2) result = [round(metric.epoch_avg, 2) for metric in metrics.metric_data.values()] assert tqdm_dict == {'Loss': 0.21, 'Accuracy': 0.67} assert result == [0.63, 2.0]
def test_epoch_update(capsys, example_batch): data, loss, output, target, batch_size = \ example_batch.data, example_batch.loss, \ example_batch.output, example_batch.target, example_batch.batch_size arg_list = ['--no-save', '--no-visualize', '--epochs=1', '--log-interval=3'] args, _, _ = init_pipeline(arg_list) metrics = MetricTracker(args, {}) num_batches = 4 for i in range(num_batches): _ = metrics.batch_update(i, num_batches, batch_size, data, loss, output, target, Mode.TRAIN) metrics.epoch_update(Mode.TRAIN) captured = capsys.readouterr().out assert captured == 'Mode.TRAIN Loss: 0.2100 Accuracy: 66.67% \n'
def test_many_batch_update(example_batch): data, loss, output, target, batch_size = \ example_batch.data, example_batch.loss, \ example_batch.output, example_batch.target, example_batch.batch_size arg_list = ['--no-save', '--no-visualize', '--epochs=1', '--log-interval=3'] args, _, _ = init_pipeline(arg_list) metrics = MetricTracker(args, {}) num_batches = 4 for i in range(num_batches): tqdm_dict = metrics.batch_update(i, num_batches, batch_size, data, loss, output, target, Mode.TRAIN) for key in tqdm_dict: tqdm_dict[key] = round(tqdm_dict[key], 2) result = [round(metric.epoch_avg, 2) for metric in metrics.metric_data.values()] assert tqdm_dict == {'Loss': 0.21, 'Accuracy': 0.67} assert all([metric.running_avg == 0. for metric in metrics.metric_data.values()]) assert result == [2.52, 8.]
def main(): args, device, checkpoint = init_pipeline() train_loader, _, _, _, init_params = load_train_data(args, device) model = Model(*init_params).to(device) util.load_state_dict(checkpoint, model) visualize(model, train_loader)