def training(local_rank, config): rank = idist.get_rank() manual_seed(config["seed"] + rank) device = idist.device() logger = setup_logger(name="CIFAR10-Training", distributed_rank=local_rank) log_basic_info(logger, config) output_path = config["output_path"] if rank == 0: now = datetime.now().strftime("%Y%m%d-%H%M%S") folder_name = f"{config['model']}_backend-{idist.backend()}-{idist.get_world_size()}_{now}" output_path = Path(output_path) / folder_name if not output_path.exists(): output_path.mkdir(parents=True) config["output_path"] = output_path.as_posix() logger.info(f"Output path: {config['output_path']}") if "cuda" in device.type: config["cuda device name"] = torch.cuda.get_device_name(local_rank) if config["with_clearml"]: try: from clearml import Task except ImportError: # Backwards-compatibility for legacy Trains SDK from trains import Task task = Task.init("CIFAR10-Training", task_name=output_path.stem) task.connect_configuration(config) task.connect(config) # Setup dataflow, model, optimizer, criterion train_loader, test_loader = get_dataflow(config) config["num_iters_per_epoch"] = len(train_loader) model, optimizer, criterion, lr_scheduler = initialize(config) logger.info( f"# model parameters (M): {sum([m.numel() for m in model.parameters()]) * 1e-6}" ) # Create trainer for current task trainer = create_trainer(model, optimizer, criterion, lr_scheduler, train_loader.sampler, config, logger) # Let's now setup evaluator engine to perform model's validation and # compute metrics metrics = { "Accuracy": Accuracy(), "Loss": Loss(criterion), } # We define two evaluators as they wont have exactly similar roles: # - `evaluator` will save the best model based on validation score evaluator = create_evaluator(model, metrics=metrics, config=config) train_evaluator = create_evaluator(model, metrics=metrics, config=config) if config["smoke_test"]: logger.info( "Reduce the size of training and test dataloader as smoke_test=True" ) def get_batches(loader): loader_iter = iter(loader) return [next(loader_iter) for _ in range(5)] train_loader = get_batches(train_loader) test_loader = get_batches(test_loader) if config["with_pbar"] and rank == 0: ProgressBar(desc="Evaluation (train)", persist=False).attach(train_evaluator) ProgressBar(desc="Evaluation (val)", persist=False).attach(evaluator) def run_validation(engine): epoch = trainer.state.epoch state = train_evaluator.run(train_loader) log_metrics(logger, epoch, state.times["COMPLETED"], "Train", state.metrics) state = evaluator.run(test_loader) log_metrics(logger, epoch, state.times["COMPLETED"], "Test", state.metrics) trainer.add_event_handler( Events.EPOCH_COMPLETED(every=config["validate_every"]) | Events.COMPLETED, run_validation, ) if rank == 0: # Setup TensorBoard logging on trainer and evaluators. Logged values are: # - Training metrics, e.g. running average loss values # - Learning rate # - Evaluation train/test metrics evaluators = {"training": train_evaluator, "test": evaluator} tb_logger = common.setup_tb_logging(output_path, trainer, optimizer, evaluators=evaluators) # Store 1 best models by validation accuracy starting from num_epochs / 2: best_model_handler = Checkpoint( {"model": model}, get_save_handler(config), filename_prefix="best", n_saved=1, global_step_transform=global_step_from_engine(trainer), score_name="test_accuracy", score_function=Checkpoint.get_default_score_fn("Accuracy"), ) evaluator.add_event_handler( Events.COMPLETED( lambda *_: trainer.state.epoch > config["num_epochs"] // 2), best_model_handler, ) try: trainer.run(train_loader, max_epochs=config["num_epochs"]) except Exception as e: logger.exception("") raise e if rank == 0: tb_logger.close()
def training(local_rank, config): rank = idist.get_rank() manual_seed(config["seed"] + rank) device = idist.device() logger = setup_logger(name="CIFAR10-Training", distributed_rank=local_rank) log_basic_info(logger, config) output_path = config["output_path"] if rank == 0: if config["stop_iteration"] is None: now = datetime.now().strftime("%Y%m%d-%H%M%S") else: now = f"stop-on-{config['stop_iteration']}" folder_name = f"{config['model']}_backend-{idist.backend()}-{idist.get_world_size()}_{now}" output_path = Path(output_path) / folder_name if not output_path.exists(): output_path.mkdir(parents=True) config["output_path"] = output_path.as_posix() logger.info(f"Output path: {config['output_path']}") if "cuda" in device.type: config["cuda device name"] = torch.cuda.get_device_name(local_rank) if config["with_clearml"]: try: from clearml import Task except ImportError: # Backwards-compatibility for legacy Trains SDK from trains import Task task = Task.init("CIFAR10-Training", task_name=output_path.stem) task.connect_configuration(config) # Log hyper parameters hyper_params = [ "model", "batch_size", "momentum", "weight_decay", "num_epochs", "learning_rate", "num_warmup_epochs", ] task.connect({k: config[k] for k in hyper_params}) # Setup dataflow, model, optimizer, criterion train_loader, test_loader = get_dataflow(config) config["num_iters_per_epoch"] = len(train_loader) model, optimizer, criterion, lr_scheduler = initialize(config) # Create trainer for current task trainer = create_trainer(model, optimizer, criterion, lr_scheduler, train_loader.sampler, config, logger) # Let's now setup evaluator engine to perform model's validation and compute metrics metrics = { "Accuracy": Accuracy(), "Loss": Loss(criterion), } # We define two evaluators as they wont have exactly similar roles: # - `evaluator` will save the best model based on validation score evaluator = create_evaluator(model, metrics=metrics, config=config) train_evaluator = create_evaluator(model, metrics=metrics, config=config) def run_validation(engine): epoch = trainer.state.epoch state = train_evaluator.run(train_loader) log_metrics(logger, epoch, state.times["COMPLETED"], "Train", state.metrics) state = evaluator.run(test_loader) log_metrics(logger, epoch, state.times["COMPLETED"], "Test", state.metrics) trainer.add_event_handler(Events.EPOCH_COMPLETED(every=config["validate_every"]) | Events.COMPLETED, run_validation) if rank == 0: # Setup TensorBoard logging on trainer and evaluators. Logged values are: # - Training metrics, e.g. running average loss values # - Learning rate # - Evaluation train/test metrics evaluators = {"training": train_evaluator, "test": evaluator} tb_logger = common.setup_tb_logging(output_path, trainer, optimizer, evaluators=evaluators) # Store 2 best models by validation accuracy starting from num_epochs / 2: best_model_handler = Checkpoint( {"model": model}, get_save_handler(config), filename_prefix="best", n_saved=2, global_step_transform=global_step_from_engine(trainer), score_name="test_accuracy", score_function=Checkpoint.get_default_score_fn("Accuracy"), ) evaluator.add_event_handler( Events.COMPLETED(lambda *_: trainer.state.epoch > config["num_epochs"] // 2), best_model_handler ) # In order to check training resuming we can stop training on a given iteration if config["stop_iteration"] is not None: @trainer.on(Events.ITERATION_STARTED(once=config["stop_iteration"])) def _(): logger.info(f"Stop training on {trainer.state.iteration} iteration") trainer.terminate() try: trainer.run(train_loader, max_epochs=config["num_epochs"]) except Exception as e: logger.exception("") raise e if rank == 0: tb_logger.close()
def create_trainer(model, tasks, optims, loaders, args): zt = [] zt_task = {'left': [], 'right': []} if args.dataset.name == 'dummy': lim = 2.5 lims = [[-lim, lim], [-lim, lim]] grid = setup_grid(lims, 1000) def trainer_step(engine, batch): model.train() for optim in optims: optim.zero_grad() # Batch data x, y = batch x = convert_tensor(x.float(), args.device) y = [convert_tensor(y_, args.device) for y_ in y] training_loss = 0. losses = [] # Intermediate representation with cached(): preds = model(x) if args.dataset.name == 'dummy': zt.append(model.rep.detach().clone()) for pred_i, task_i in zip(preds, tasks): loss_i = task_i.loss(pred_i, y[task_i.index]) if args.dataset.name == 'dummy': loss_i = loss_i.mean(dim=0) zt_task[task_i.name].append(pred_i.detach().clone()) # Track losses losses.append(loss_i) training_loss += loss_i.item() * task_i.weight if args.dataset.name == 'dummy' and ( engine.state.epoch == engine.state.max_epochs or engine.state.epoch % args.training.plot_every == 0): fig = plot_toy(grid, model, tasks, [zt, zt_task['left'], zt_task['right']], trainer.state.iteration - 1, levels=20, lims=lims) fig.savefig(f'plots/step_{engine.state.iteration - 1}.png') plt.close(fig) model.backward(losses) for optim in optims: # Run the optimizers optim.step() return training_loss, losses trainer = Engine(trainer_step) trainer.add_event_handler(Events.ITERATION_COMPLETED, TerminateOnNan()) RunningAverage(output_transform=lambda x: x[0]).attach(trainer, 'loss') for i, task_i in enumerate(tasks): output_transform = partial(lambda idx, x: x[1][idx], i) RunningAverage(output_transform=output_transform).attach( trainer, f'train_{task_i.name}') pbar = ProgressBar() pbar.attach(trainer, metric_names=['loss'] + [f'train_{t.name}' for t in tasks]) # Validation validator = create_evaluator(model, tasks, args) @trainer.on(Events.EPOCH_COMPLETED) def run_validator(trainer): validator.run(loaders['val']) metrics = validator.state.metrics loss = 0. for task_i in tasks: loss += metrics[f'loss_{task_i.name}'] * task_i.weight trainer.state.metrics['val_loss'] = loss # Checkpoints model_checkpoint = {'model': model} handler = ModelCheckpoint('checkpoints', 'latest', require_empty=False) trainer.add_event_handler( Events.EPOCH_COMPLETED(every=args.training.save_every), handler, model_checkpoint) @trainer.on(Events.EPOCH_COMPLETED(every=args.training.save_every)) def save_state(engine): with open('checkpoints/state.pkl', 'wb') as f: pickle.dump(engine.state, f) @trainer.on(Events.COMPLETED(every=args.training.save_every)) def save_state(engine): with open('checkpoints/state.pkl', 'wb') as f: pickle.dump(engine.state, f) handler = ModelCheckpoint( 'checkpoints', 'best', require_empty=False, score_function=(lambda e: -e.state.metrics['val_loss'])) trainer.add_event_handler( Events.EPOCH_COMPLETED(every=args.training.save_every), handler, model_checkpoint) trainer.add_event_handler(Events.COMPLETED, handler, model_checkpoint) return trainer