def training(local_rank, config, logger, with_clearml): rank = idist.get_rank() manual_seed(config.seed + local_rank) train_loader = config.train_loader val_loader = config.val_loader train_eval_loader = config.train_eval_loader model, optimizer, criterion = utils.initialize(config) # Setup trainer for this specific task trainer = create_trainer(model, optimizer, criterion, train_loader.sampler, config, logger, with_clearml) # Setup evaluators num_classes = config.num_classes cm_metric = ConfusionMatrix(num_classes=num_classes) val_metrics = { "IoU": IoU(cm_metric), "mIoU_bg": mIoU(cm_metric), } if ("val_metrics" in config) and isinstance(config.val_metrics, dict): val_metrics.update(config.val_metrics) evaluator = create_evaluator(model, val_metrics, config, with_clearml, tag="val") train_evaluator = create_evaluator(model, val_metrics, config, with_clearml, tag="train") val_interval = config.get("val_interval", 1) # Run validation on every val_interval epoch, in the end of the training # and in the begining if config.start_by_validation is True event = Events.EPOCH_COMPLETED(every=val_interval) if config.num_epochs % val_interval != 0: event |= Events.COMPLETED if config.get("start_by_validation", False): event |= Events.STARTED @trainer.on(event) def run_validation(): epoch = trainer.state.epoch state = train_evaluator.run(train_eval_loader) utils.log_metrics(logger, epoch, state.times["COMPLETED"], "Train", state.metrics) state = evaluator.run(val_loader) utils.log_metrics(logger, epoch, state.times["COMPLETED"], "Test", state.metrics) score_metric_name = "mIoU_bg" if "es_patience" in config: common.add_early_stopping_by_val_score(config.es_patience, evaluator, trainer, metric_name=score_metric_name) # Store 2 best models by validation accuracy: common.gen_save_best_models_by_val_score( save_handler=utils.get_save_handler(config.output_path.as_posix(), with_clearml), evaluator=evaluator, models=model, metric_name=score_metric_name, n_saved=2, trainer=trainer, tag="val", ) # Setup Tensorboard logger if rank == 0: tb_logger = common.setup_tb_logging( config.output_path.as_posix(), trainer, optimizer, evaluators={"training": train_evaluator, "validation": evaluator}, ) # Log validation predictions as images # We define a custom event filter to log less frequently the images (to reduce storage size) # - we plot images with masks of the middle validation batch # - once every 3 validations and # - at the end of the training def custom_event_filter(_, val_iteration): c1 = val_iteration == len(val_loader) // 2 c2 = trainer.state.epoch % (config.get("val_interval", 1) * 3) == 0 c2 |= trainer.state.epoch == config.num_epochs return c1 and c2 # Image denormalization function to plot predictions with images mean = config.get("mean", (0.485, 0.456, 0.406)) std = config.get("std", (0.229, 0.224, 0.225)) img_denormalize = partial(data.denormalize, mean=mean, std=std) tb_logger.attach( evaluator, log_handler=vis.predictions_gt_images_handler( img_denormalize_fn=img_denormalize, n_images=15, another_engine=trainer, prefix_tag="validation", ), event_name=Events.ITERATION_COMPLETED(event_filter=custom_event_filter), ) # Log confusion matrix to ClearML: if with_clearml: trainer.add_event_handler(Events.COMPLETED, compute_and_log_cm, cm_metric, trainer.state.iteration) trainer.run(train_loader, max_epochs=config.num_epochs) if idist.get_rank() == 0: tb_logger.close()
def create_trainer(model, optimizer, criterion, lr_scheduler, train_sampler, config, logger): device = idist.device() # Setup Ignite trainer: # - let's define training step # - add other common handlers: # - TerminateOnNan, # - handler to setup learning rate scheduling, # - ModelCheckpoint # - RunningAverage` on `train_step` output # - Two progress bars on epochs and optionally on iterations with_amp = config["with_amp"] scaler = GradScaler(enabled=with_amp) def train_step(engine, batch): input_ids = batch["input_ids"] attention_mask = batch["attention_mask"] token_type_ids = batch["token_type_ids"] labels = batch["label"].view(-1, 1) if input_ids.device != device: input_ids = input_ids.to(device, non_blocking=True, dtype=torch.long) attention_mask = attention_mask.to(device, non_blocking=True, dtype=torch.long) token_type_ids = token_type_ids.to(device, non_blocking=True, dtype=torch.long) labels = labels.to(device, non_blocking=True, dtype=torch.float) model.train() with autocast(enabled=with_amp): y_pred = model(input_ids, attention_mask, token_type_ids) loss = criterion(y_pred, labels) optimizer.zero_grad() scaler.scale(loss).backward() scaler.step(optimizer) scaler.update() return { "batch loss": loss.item(), } trainer = Engine(train_step) trainer.logger = logger to_save = { "trainer": trainer, "model": model, "optimizer": optimizer, "lr_scheduler": lr_scheduler } metric_names = [ "batch loss", ] if config["log_every_iters"] == 0: # Disable logging training metrics: metric_names = None config["log_every_iters"] = 15 common.setup_common_training_handlers( trainer=trainer, train_sampler=train_sampler, to_save=to_save, save_every_iters=config["checkpoint_every"], save_handler=utils.get_save_handler(config), lr_scheduler=lr_scheduler, output_names=metric_names, log_every_iters=config["log_every_iters"], with_pbars=not config["with_clearml"], clear_cuda_cache=False, ) resume_from = config["resume_from"] if resume_from is not None: checkpoint_fp = Path(resume_from) assert checkpoint_fp.exists( ), f"Checkpoint '{checkpoint_fp.as_posix()}' is not found" logger.info(f"Resume from a checkpoint: {checkpoint_fp.as_posix()}") checkpoint = torch.load(checkpoint_fp.as_posix(), map_location="cpu") Checkpoint.load_objects(to_load=to_save, checkpoint=checkpoint) return trainer
def create_trainer(model, optimizer, criterion, train_sampler, config, logger, with_clearml): device = config.device prepare_batch = data.prepare_image_mask # Setup trainer accumulation_steps = config.get("accumulation_steps", 1) model_output_transform = config.get("model_output_transform", lambda x: x) with_amp = config.get("with_amp", True) scaler = GradScaler(enabled=with_amp) def forward_pass(batch): model.train() x, y = prepare_batch(batch, device=device, non_blocking=True) with autocast(enabled=with_amp): y_pred = model(x) y_pred = model_output_transform(y_pred) loss = criterion(y_pred, y) / accumulation_steps return loss def amp_backward_pass(engine, loss): scaler.scale(loss).backward() if engine.state.iteration % accumulation_steps == 0: scaler.step(optimizer) scaler.update() optimizer.zero_grad() def hvd_amp_backward_pass(engine, loss): scaler.scale(loss).backward() optimizer.synchronize() with optimizer.skip_synchronize(): scaler.step(optimizer) scaler.update() optimizer.zero_grad() if idist.backend() == "horovod" and with_amp: backward_pass = hvd_amp_backward_pass else: backward_pass = amp_backward_pass def training_step(engine, batch): loss = forward_pass(batch) output = {"supervised batch loss": loss.item()} backward_pass(engine, loss) return output trainer = Engine(training_step) trainer.logger = logger output_names = [ "supervised batch loss", ] lr_scheduler = config.lr_scheduler to_save = { "model": model, "optimizer": optimizer, "lr_scheduler": lr_scheduler, "trainer": trainer, "amp": scaler, } save_every_iters = config.get("save_every_iters", 1000) common.setup_common_training_handlers( trainer, train_sampler, to_save=to_save, save_every_iters=save_every_iters, save_handler=utils.get_save_handler(config.output_path.as_posix(), with_clearml), lr_scheduler=lr_scheduler, output_names=output_names, with_pbars=not with_clearml, log_every_iters=1, ) resume_from = config.get("resume_from", None) if resume_from is not None: checkpoint_fp = Path(resume_from) assert checkpoint_fp.exists(), f"Checkpoint '{checkpoint_fp.as_posix()}' is not found" logger.info(f"Resume from a checkpoint: {checkpoint_fp.as_posix()}") checkpoint = torch.load(checkpoint_fp.as_posix(), map_location="cpu") Checkpoint.load_objects(to_load=to_save, checkpoint=checkpoint) return trainer
def training(local_rank, config): rank = idist.get_rank() manual_seed(config["seed"] + rank) device = idist.device() logger = setup_logger(name="IMDB-Training", distributed_rank=local_rank) log_basic_info(logger, config) output_path = config["output_dir"] if rank == 0: now = datetime.now().strftime("%Y%m%d-%H%M%S") folder_name = f"{config['model']}_backend-{idist.backend()}-{idist.get_world_size()}_{now}" output_path = Path(output_path) / folder_name if not output_path.exists(): output_path.mkdir(parents=True) config["output_dir"] = output_path.as_posix() logger.info(f"Output path: {config['output_dir']}") if "cuda" in device.type: config["cuda device name"] = torch.cuda.get_device_name(local_rank) if config["with_clearml"]: try: from clearml import Task except ImportError: # Backwards-compatibility for legacy Trains SDK from trains import Task task = Task.init("IMDB-Training", task_name=output_path.stem) task.connect_configuration(config) # Log hyper parameters hyper_params = [ "model", "dropout", "n_fc", "batch_size", "max_length", "weight_decay", "num_epochs", "learning_rate", "num_warmup_epochs", ] task.connect({k: config[k] for k in hyper_params}) # Setup dataflow, model, optimizer, criterion train_loader, test_loader = get_dataflow(config) config["num_iters_per_epoch"] = len(train_loader) model, optimizer, criterion, lr_scheduler = initialize(config) # Create trainer for current task trainer = create_trainer(model, optimizer, criterion, lr_scheduler, train_loader.sampler, config, logger) # Let's now setup evaluator engine to perform model's validation and compute metrics metrics = { "Accuracy": Accuracy(output_transform=utils.thresholded_output_transform), "Loss": Loss(criterion), } # We define two evaluators as they wont have exactly similar roles: # - `evaluator` will save the best model based on validation score evaluator = create_evaluator(model, metrics, config, tag="val") train_evaluator = create_evaluator(model, metrics, config, tag="train") def run_validation(engine): epoch = trainer.state.epoch state = train_evaluator.run(train_loader) log_metrics(logger, epoch, state.times["COMPLETED"], "Train", state.metrics) state = evaluator.run(test_loader) log_metrics(logger, epoch, state.times["COMPLETED"], "Test", state.metrics) trainer.add_event_handler( Events.EPOCH_COMPLETED(every=config["validate_every"]) | Events.COMPLETED | Events.STARTED, run_validation) if rank == 0: # Setup TensorBoard logging on trainer and evaluators. Logged values are: # - Training metrics, e.g. running average loss values # - Learning rate # - Evaluation train/test metrics evaluators = {"training": train_evaluator, "test": evaluator} tb_logger = common.setup_tb_logging( output_path, trainer, optimizer, evaluators=evaluators, log_every_iters=config["log_every_iters"]) # Store 2 best models by validation accuracy starting from num_epochs / 2: best_model_handler = Checkpoint( {"model": model}, utils.get_save_handler(config), filename_prefix="best", n_saved=2, global_step_transform=global_step_from_engine(trainer), score_name="test_accuracy", score_function=Checkpoint.get_default_score_fn("Accuracy"), ) evaluator.add_event_handler( Events.COMPLETED( lambda *_: trainer.state.epoch > config["num_epochs"] // 2), best_model_handler) try: trainer.run(train_loader, max_epochs=config["num_epochs"]) except Exception as e: logger.exception("") raise e if rank == 0: tb_logger.close()