def train_network(model: nn.Module, training_loader: DataLoader, validation_loader: DataLoader): """Trains the given neural network model. Parameters ---------- model (nn.Module): The PyTorch model to be trained training_loader (DataLoader): Training data loader validation_loader (DataLoader): Validation data loader """ device = "cuda:0" if cast(Any, torch).cuda.is_available() else "cpu" if device == "cuda:0": model.cuda() optimizer = cast(Any, torch).optim.Adam(model.parameters(), lr=0.001) criterion = nn.MSELoss() trainer = create_supervised_trainer(model, optimizer, criterion, device=device) save_handler = Checkpoint( { "model": model, "optimizer": optimizer, "trainer": trainer }, DiskSaver("dist/models", create_dir=True), n_saved=2, ) trainer.add_event_handler(Events.EPOCH_COMPLETED(every=100), save_handler) # Create a logger tb_logger = TensorboardLogger(log_dir="logs/training" + datetime.now().strftime("-%Y%m%d-%H%M%S"), flush_secs=1) tb_logger.attach_output_handler( trainer, event_name=Events.ITERATION_COMPLETED, tag="training", output_transform=lambda loss: {"loss": loss}, ) # Training evaluator training_evaluator = create_supervised_evaluator(model, metrics={ "r2": R2Score(), "MSELoss": Loss(criterion) }, device=device) tb_logger.attach_output_handler( training_evaluator, event_name=Events.EPOCH_COMPLETED, tag="training", metric_names=["MSELoss", "r2"], global_step_transform=global_step_from_engine(trainer), ) # Validation evaluator evaluator = create_supervised_evaluator(model, metrics={ "r2": R2Score(), "MSELoss": Loss(criterion) }, device=device) tb_logger.attach_output_handler( evaluator, event_name=Events.EPOCH_COMPLETED, tag="validation", metric_names=["MSELoss", "r2"], global_step_transform=global_step_from_engine(trainer), ) @trainer.on(Events.EPOCH_COMPLETED(every=10)) def log_training_results(trainer): training_evaluator.run(training_loader) metrics = training_evaluator.state.metrics print( f"Training Results - Epoch: {trainer.state.epoch}", f" Avg r2: {metrics['r2']:.2f} Avg loss: {metrics['MSELoss']:.2f}", ) @trainer.on(Events.EPOCH_COMPLETED(every=10)) def log_validation_results(trainer): evaluator.run(validation_loader) metrics = evaluator.state.metrics print( f"Validation Results - Epoch: {trainer.state.epoch}", f" Avg r2: {metrics['r2']:.2f} Avg loss: {metrics['MSELoss']:.2f}\n", ) trainer.run(training_loader, max_epochs=int(1e6))
def run(train_batch_size, val_batch_size, epochs, lr, momentum, log_dir): train_loader, val_loader = get_data_loaders(train_batch_size, val_batch_size) model = Net() device = "cpu" if torch.cuda.is_available(): device = "cuda" model.to(device) # Move model before creating optimizer optimizer = SGD(model.parameters(), lr=lr, momentum=momentum) criterion = nn.CrossEntropyLoss() trainer = create_supervised_trainer(model, optimizer, criterion, device=device) trainer.logger = setup_logger("Trainer") if sys.version_info > (3, ): from ignite.contrib.metrics.gpu_info import GpuInfo try: GpuInfo().attach(trainer) except RuntimeError: print( "INFO: By default, in this example it is possible to log GPU information (used memory, utilization). " "As there is no pynvml python package installed, GPU information won't be logged. Otherwise, please " "install it : `pip install pynvml`") metrics = {"accuracy": Accuracy(), "loss": Loss(criterion)} train_evaluator = create_supervised_evaluator(model, metrics=metrics, device=device) train_evaluator.logger = setup_logger("Train Evaluator") validation_evaluator = create_supervised_evaluator(model, metrics=metrics, device=device) validation_evaluator.logger = setup_logger("Val Evaluator") @trainer.on(Events.EPOCH_COMPLETED) def compute_metrics(engine): train_evaluator.run(train_loader) validation_evaluator.run(val_loader) tb_logger = TensorboardLogger(log_dir=log_dir) tb_logger.attach_output_handler( trainer, event_name=Events.ITERATION_COMPLETED(every=100), tag="training", output_transform=lambda loss: {"batchloss": loss}, metric_names="all", ) for tag, evaluator in [("training", train_evaluator), ("validation", validation_evaluator)]: tb_logger.attach_output_handler( evaluator, event_name=Events.EPOCH_COMPLETED, tag=tag, metric_names=["loss", "accuracy"], global_step_transform=global_step_from_engine(trainer), ) tb_logger.attach_opt_params_handler( trainer, event_name=Events.ITERATION_COMPLETED(every=100), optimizer=optimizer) tb_logger.attach(trainer, log_handler=WeightsScalarHandler(model), event_name=Events.ITERATION_COMPLETED(every=100)) tb_logger.attach(trainer, log_handler=WeightsHistHandler(model), event_name=Events.EPOCH_COMPLETED(every=100)) tb_logger.attach(trainer, log_handler=GradsScalarHandler(model), event_name=Events.ITERATION_COMPLETED(every=100)) tb_logger.attach(trainer, log_handler=GradsHistHandler(model), event_name=Events.EPOCH_COMPLETED(every=100)) def score_function(engine): return engine.state.metrics["accuracy"] model_checkpoint = ModelCheckpoint( log_dir, n_saved=2, filename_prefix="best", score_function=score_function, score_name="validation_accuracy", global_step_transform=global_step_from_engine(trainer), ) validation_evaluator.add_event_handler(Events.COMPLETED, model_checkpoint, {"model": model}) # kick everything off trainer.run(train_loader, max_epochs=epochs) tb_logger.close()
def train(): parser = ArgumentParser() parser.add_argument( "--dataset_path", type=str, default="data/korean/", help="Path or url of the dataset. If empty download from S3.") parser.add_argument("--dataset_cache", type=str, default='./dataset_cache', help="Path or url of the dataset cache") parser.add_argument("--model_checkpoint", type=str, default="gpt2", help="Path, url or short name of the model") parser.add_argument("--model_version", type=str, default='v4', help="version of model") parser.add_argument("--num_candidates", type=int, default=2, help="Number of candidates for training") parser.add_argument("--max_history", type=int, default=30, help="Number of previous exchanges to keep in history") parser.add_argument("--train_batch_size", type=int, default=1, help="Batch size for training") parser.add_argument("--valid_batch_size", type=int, default=1, help="Batch size for validation") parser.add_argument("--gradient_accumulation_steps", type=int, default=8, help="Accumulate gradients on several steps") parser.add_argument("--lr", type=float, default=6.25e-5, help="Learning rate") parser.add_argument("--lm_coef", type=float, default=1.0, help="LM loss coefficient") parser.add_argument("--mc_coef", type=float, default=1.0, help="Multiple-choice loss coefficient") parser.add_argument("--max_norm", type=float, default=1.0, help="Clipping gradient norm") parser.add_argument("--n_epochs", type=int, default=5, help="Number of training epochs") parser.add_argument( "--eval_before_start", action='store_true', help="If true start with a first evaluation before training") parser.add_argument("--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu", help="Device (cuda or cpu)") parser.add_argument( "--fp16", type=str, default="", help= "Set to O0, O1, O2 or O3 for fp16 training (see apex documentation)") parser.add_argument( "--local_rank", type=int, default=-1, help="Local rank for distributed training (-1: not distributed)") args = parser.parse_args() # logging is set to INFO (resp. WARN) for main (resp. auxiliary) process. logger.info => log main process only, logger.warning => log all processes logging.basicConfig( level=logging.INFO if args.local_rank in [-1, 0] else logging.WARN) logger.warning( "Running process %d", args.local_rank ) # This is a logger.warning: it will be printed by all distributed processes logger.info("Arguments: %s", pformat(args)) # Initialize distributed training if needed args.distributed = (args.local_rank != -1) if args.distributed: torch.cuda.set_device(args.local_rank) args.device = torch.device("cuda", args.local_rank) torch.distributed.init_process_group(backend='nccl', init_method='env://') logger.info( "Prepare tokenizer, pretrained model and optimizer - add special tokens for fine-tuning" ) torch.manual_seed(42) def get_kogpt2_tokenizer(model_path=None): if not model_path: model_path = 'taeminlee/kogpt2' tokenizer = GPT2Tokenizer.from_pretrained(model_path) return tokenizer tokenizer = get_kogpt2_tokenizer() optimizer_class = AdamW model = get_kogpt2_model() model.to(args.device) optimizer = optimizer_class(model.parameters(), lr=args.lr) # tokenizer = tokenizer_class.from_pretrained(args.model_checkpoint, unk_token='<|unkwn|>') SPECIAL_TOKENS_DICT = {'additional_special_tokens': SPECIAL_TOKENS} # tokenizer.add_special_tokens(SPECIAL_TOKENS_DICT) print("SPECIAL TOKENS") print(SPECIAL_TOKENS) tokenizer.add_special_tokens(SPECIAL_TOKENS_DICT) for value in SPECIAL_TOKENS: logger.info("Assigning %s to the %s key of the tokenizer", value, value) setattr(tokenizer, value, value) model.resize_token_embeddings(len(tokenizer)) s = ' '.join(act_name) + ' '.join(slot_name) print(tokenizer.decode(tokenizer.encode(s))) print(len(act_name) + len(slot_name), len(tokenizer.encode(s))) # Prepare model for FP16 and distributed training if needed (order is important, distributed should be the last) if args.fp16: from apex import amp # Apex is only required if we use fp16 training model, optimizer = amp.initialize(model, optimizer, opt_level=args.fp16) if args.distributed: model = DistributedDataParallel(model, device_ids=[args.local_rank], output_device=args.local_rank) logger.info("Prepare datasets") train_loader, val_loader, train_sampler, valid_sampler = get_data_loaders( args, tokenizer) # Training function and trainer def update(engine, batch): model.train() batch = tuple(input_tensor.to(args.device) for input_tensor in batch) lm_loss, mc_loss, *_ = model(*batch) loss = (lm_loss * args.lm_coef + mc_loss * args.mc_coef) / args.gradient_accumulation_steps if args.fp16: with amp.scale_loss(loss, optimizer) as scaled_loss: scaled_loss.backward() torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), args.max_norm) else: loss.backward() torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_norm) if engine.state.iteration % args.gradient_accumulation_steps == 0: optimizer.step() optimizer.zero_grad() return loss.item() trainer = Engine(update) trainer.logger.setLevel(logging.INFO) # Evaluation function and evaluator (evaluator output is the input of the metrics) def inference(engine, batch): model.eval() with torch.no_grad(): batch = tuple( input_tensor.to(args.device) for input_tensor in batch) input_ids, mc_token_ids, lm_labels, mc_labels, token_type_ids = batch logger.info(tokenizer.decode(input_ids[0, -1, :].tolist())) model_outputs = model(input_ids, mc_token_ids, token_type_ids=token_type_ids) lm_logits, mc_logits = model_outputs[0], model_outputs[ 1] # So we can also use GPT2 outputs lm_logits_flat_shifted = lm_logits[..., :-1, :].contiguous().view( -1, lm_logits.size(-1)) lm_labels_flat_shifted = lm_labels[..., 1:].contiguous().view(-1) return (lm_logits_flat_shifted, mc_logits), (lm_labels_flat_shifted, mc_labels) evaluator = Engine(inference) evaluator.logger.setLevel(logging.INFO) # Attach evaluation to trainer: we evaluate when we start the training and at the end of each epoch trainer.add_event_handler(Events.EPOCH_COMPLETED, lambda _: evaluator.run(val_loader)) if args.n_epochs < 1: trainer.add_event_handler(Events.COMPLETED, lambda _: evaluator.run(val_loader)) if args.eval_before_start: trainer.add_event_handler(Events.STARTED, lambda _: evaluator.run(val_loader)) # Make sure distributed data samplers split the dataset nicely between the distributed processes if args.distributed: trainer.add_event_handler( Events.EPOCH_STARTED, lambda engine: train_sampler.set_epoch(engine.state.epoch)) evaluator.add_event_handler( Events.EPOCH_STARTED, lambda engine: valid_sampler.set_epoch(engine.state.epoch)) # Linearly decrease the learning rate from lr to zero scheduler = PiecewiseLinear(optimizer, "lr", [(0, args.lr), (args.n_epochs * len(train_loader), 0.0)]) trainer.add_event_handler(Events.ITERATION_STARTED, scheduler) # Prepare metrics - note how we compute distributed metrics RunningAverage(output_transform=lambda x: x).attach(trainer, "loss") metrics = { "nll": Loss(torch.nn.CrossEntropyLoss(ignore_index=-1), output_transform=lambda x: (x[0][0], x[1][0])), "accuracy": Accuracy(output_transform=lambda x: (x[0][1], x[1][1])) } metrics.update({ "average_nll": MetricsLambda(average_distributed_scalar, metrics["nll"], args), "average_accuracy": MetricsLambda(average_distributed_scalar, metrics["accuracy"], args) }) metrics["average_ppl"] = MetricsLambda(math.exp, metrics["average_nll"]) for name, metric in metrics.items(): metric.attach(evaluator, name) # On the main process: add progress bar, tensorboard, checkpoints and save model, configuration and tokenizer before we start to train if args.local_rank in [-1, 0]: pbar = ProgressBar(persist=True) pbar.attach(trainer, metric_names=["loss"]) evaluator.add_event_handler( Events.COMPLETED, lambda _: pbar.log_message( "Validation: %s" % pformat(evaluator.state.metrics))) tb_logger = TensorboardLogger(log_dir=None) tb_logger.writer.log_dir = tb_logger.writer.file_writer.get_logdir() tb_logger.attach(trainer, log_handler=OutputHandler(tag="training", metric_names=["loss"]), event_name=Events.ITERATION_COMPLETED) tb_logger.attach(trainer, log_handler=OptimizerParamsHandler(optimizer), event_name=Events.ITERATION_STARTED) """tb_logger.attach(evaluator, log_handler=OutputHandler(tag="validation", metric_names=list(metrics.keys()), global_step_transform=global_step_from_engine(trainer)), event_name=Events.EPOCH_COMPLETED)""" tb_logger.attach_output_handler( evaluator, event_name=Events.EPOCH_COMPLETED, tag="validation", metric_names=list(metrics.keys()), global_step_transform=global_step_from_engine(trainer)) checkpoint_handler = ModelCheckpoint(tb_logger.writer.log_dir, 'checkpoint', save_interval=1, n_saved=3) trainer.add_event_handler( Events.EPOCH_COMPLETED, checkpoint_handler, {'mymodel': getattr(model, 'module', model) }) # "getattr" take care of distributed encapsulation torch.save(args, tb_logger.writer.log_dir + '/model_training_args.bin') getattr(model, 'module', model).config.to_json_file( os.path.join(tb_logger.writer.log_dir, CONFIG_NAME)) tokenizer.save_vocabulary(tb_logger.writer.log_dir) # Run the training trainer.run(train_loader, max_epochs=args.n_epochs) # On the main process: close tensorboard logger and rename the last checkpoint (for easy re-loading with OpenAIGPTModel.from_pretrained method) if args.local_rank in [-1, 0] and args.n_epochs > 0: os.rename( checkpoint_handler._saved[-1][1][-1], os.path.join(tb_logger.writer.log_dir, WEIGHTS_NAME) ) # TODO: PR in ignite to have better access to saved file paths (cleaner) tb_logger.close()
def attach_handlers(run, model, optimizer, learning_rule, trainer, evaluator, train_loader, val_loader, params): # Metrics UnitConvergence(model[0], learning_rule.norm).attach(trainer.engine, 'unit_conv') # Tqdm logger pbar = ProgressBar(persist=True, bar_format=config.IGNITE_BAR_FORMAT) pbar.attach(trainer.engine, metric_names='all') tqdm_logger = TqdmLogger(pbar=pbar) # noinspection PyTypeChecker tqdm_logger.attach_output_handler( evaluator.engine, event_name=Events.COMPLETED, tag="validation", global_step_transform=global_step_from_engine(trainer.engine), ) # Evaluator evaluator.attach(trainer.engine, Events.EPOCH_COMPLETED(every=100), train_loader, val_loader) # Learning rate scheduling lr_scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer=optimizer, lr_lambda=lambda epoch: 1 - epoch / params['epochs']) lr_scheduler = LRScheduler(lr_scheduler) trainer.engine.add_event_handler(Events.EPOCH_COMPLETED, lr_scheduler) # Early stopping mc_handler = ModelCheckpoint(config.MODELS_DIR, run.replace('/', '-'), n_saved=1, create_dir=True, require_empty=False, global_step_transform=global_step_from_engine(trainer.engine)) trainer.engine.add_event_handler(Events.EPOCH_COMPLETED, mc_handler, {'m': model}) # Create a TensorBoard logger tb_logger = TensorboardLogger(log_dir=os.path.join(config.TENSORBOARD_DIR, run)) images, labels = next(iter(train_loader)) tb_logger.writer.add_graph(copy.deepcopy(model).cpu(), images) tb_logger.writer.add_hparams(params, {}) # noinspection PyTypeChecker tb_logger.attach_output_handler( evaluator.engine, event_name=Events.COMPLETED, tag="validation", metric_names="all", global_step_transform=global_step_from_engine(trainer.engine), ) # noinspection PyTypeChecker tb_logger.attach_output_handler( trainer.engine, event_name=Events.EPOCH_COMPLETED, tag="train", metric_names=["unit_conv"] ) input_shape = tuple(next(iter(train_loader))[0].shape[1:]) tb_logger.attach(trainer.engine, log_handler=WeightsImageHandler(model, input_shape), event_name=Events.EPOCH_COMPLETED) tb_logger.attach(trainer.engine, log_handler=OptimizerParamsHandler(optimizer), event_name=Events.EPOCH_STARTED) # tb_logger.attach(trainer.engine, # log_handler=WeightsScalarHandler(model, layer_names=['linear1', 'linear2']), # event_name=Events.EPOCH_COMPLETED) # tb_logger.attach(trainer.engine, # log_handler=WeightsHistHandler(model, layer_names=['linear1', 'linear2']), # event_name=Events.EPOCH_COMPLETED) # tb_logger.attach(trainer.engine, # log_handler=ActivationsHistHandler(model, layer_names=['batch_norm', 'repu']), # event_name=Events.ITERATION_COMPLETED) # tb_logger.attach(trainer.engine, # log_handler=NumActivationsScalarHandler(model, layer_names=['repu']), # event_name=Events.ITERATION_COMPLETED) # tb_logger.attach(trainer.engine, # log_handler=ActivationsScalarHandler(model, reduction=torch.mean, # layer_names=['batch_norm', 'repu']), # event_name=Events.ITERATION_COMPLETED) # tb_logger.attach(trainer.engine, # log_handler=ActivationsScalarHandler(model, reduction=torch.std, # layer_names=['batch_norm', 'repu']), # event_name=Events.ITERATION_COMPLETED) return tb_logger
# Create evaluators evaluator = create_evaluator(model, metrics=metrics) train_evaluator = create_evaluator(model, metrics=metrics, tag='train') # Add validation logging trainer.add_event_handler(Events.EPOCH_COMPLETED(every=1), evaluate_model) # Add step length update at the end of each epoch trainer.add_event_handler(Events.EPOCH_COMPLETED, lambda _: lr_scheduler.step()) # Add TensorBoard logging tb_logger = TensorboardLogger(log_dir=os.path.join(working_dir,'tb_logs')) # Logging iteration loss tb_logger.attach_output_handler( engine=trainer, event_name=Events.ITERATION_COMPLETED, tag='training', output_transform=lambda loss: {"batch loss": loss} ) # Logging epoch training metrics tb_logger.attach_output_handler( engine=train_evaluator, event_name=Events.EPOCH_COMPLETED, tag="training", metric_names=["loss", "accuracy", "precision", "recall", "f1", "topKCatAcc"], global_step_transform=global_step_from_engine(trainer), ) # Logging epoch validation metrics tb_logger.attach_output_handler( engine=evaluator, event_name=Events.EPOCH_COMPLETED, tag="validation",
def run_training( model, optimizer, scheduler, output_path, train_loader, val_loader, epochs, patience, epochs_pretrain, mixed_precision, classes_weights, ): # trainer device = torch.device("cuda" if torch.cuda.is_available() else "cpu") if classes_weights is not None: classes_weights = classes_weights.to(device) crit = nn.CrossEntropyLoss(weight=classes_weights) metrics = {"accuracy": Accuracy(), "loss": Loss(crit)} model.to(device) trainer = create_supervised_trainer_with_pretraining( model, optimizer, crit, device=device, epochs_pretrain=epochs_pretrain, mixed_precision=mixed_precision, ) train_evaluator = create_supervised_evaluator(model, metrics=metrics, device=device) val_evaluator = create_supervised_evaluator(model, metrics=metrics, device=device) # Out paths path_ckpt = os.path.join(output_path, "model_ckpt") log_dir = os.path.join(output_path, "log_dir") os.makedirs(log_dir, exist_ok=True) # tensorboard tb_logger = TensorboardLogger(log_dir=log_dir) tb_logger.attach_output_handler( train_evaluator, event_name=Events.EPOCH_COMPLETED, tag="training", metric_names=["accuracy", "loss"], ) tb_logger.attach_output_handler( val_evaluator, event_name=Events.EPOCH_COMPLETED, tag="validation", metric_names=["accuracy", "loss"], global_step_transform=global_step_from_engine(trainer), ) # training progress pbar = ProgressBar(persist=True, position=0) pbar.attach(trainer, metric_names="all") def log_training_results(engine): train_evaluator.run(train_loader) val_evaluator.run(val_loader) train_loss = train_evaluator.state.metrics["loss"] val_loss = val_evaluator.state.metrics["loss"] train_acc = train_evaluator.state.metrics["accuracy"] val_acc = val_evaluator.state.metrics["accuracy"] pbar.log_message( "Training Results - Epoch: {} Loss: {:.6f} Accuracy: {:.6f}".format( engine.state.epoch, train_loss, train_acc ) ) pbar.log_message( "Validation Results - Epoch: {} Loss: {:.6f} Accuracy: {:.6f}".format( engine.state.epoch, val_loss, val_acc ) ) pbar.n = pbar.last_print_n = 0 trainer.add_event_handler(Events.EPOCH_COMPLETED, log_training_results) # def get_val_loss(engine): # return -engine.state.metrics['loss'] def get_val_acc(engine): return engine.state.metrics["accuracy"] # checkpoint and early stopping checkpointer = ModelCheckpoint( path_ckpt, "model", score_function=get_val_acc, score_name="accuracy", require_empty=False, ) early_stopper = EarlyStopping(patience, get_val_acc, trainer) to_save = {"optimizer": optimizer, "model": model} if scheduler is not None: to_save["scheduler"] = scheduler val_evaluator.add_event_handler(Events.COMPLETED, checkpointer, to_save) val_evaluator.add_event_handler(Events.COMPLETED, early_stopper) if scheduler is not None: trainer.add_event_handler(Events.ITERATION_STARTED, scheduler) # free resources trainer.add_event_handler(Events.ITERATION_COMPLETED, lambda _: _empty_cache()) train_evaluator.add_event_handler( Events.ITERATION_COMPLETED, lambda _: _empty_cache() ) val_evaluator.add_event_handler( Events.ITERATION_COMPLETED, lambda _: _empty_cache() ) trainer.run(train_loader, max_epochs=epochs) tb_logger.close() # Evaluation with best model model.load_state_dict( torch.load(glob.glob(os.path.join(path_ckpt, "*.pt*"))[0])["model"] ) train_evaluator = create_supervised_evaluator(model, metrics=metrics, device=device) val_evaluator = create_supervised_evaluator(model, metrics=metrics, device=device) train_evaluator.run(train_loader) val_evaluator.run(val_loader) _pretty_print("Evaluating best model") pbar.log_message( "Best model on training set - Loss: {:.6f} Accuracy: {:.6f}".format( train_evaluator.state.metrics["loss"], train_evaluator.state.metrics["accuracy"], ) ) pbar.log_message( "Best model on validation set - Loss: {:.6f} Accuracy: {:.6f}".format( val_evaluator.state.metrics["loss"], val_evaluator.state.metrics["accuracy"] ) ) return model, train_evaluator.state.metrics, val_evaluator.state.metrics
def attach_handlers(run, model, optimizer, trainer, train_evaluator, evaluator, train_loader, val_loader, params): # Tqdm logger pbar = ProgressBar(persist=True, bar_format=config.IGNITE_BAR_FORMAT) pbar.attach(trainer.engine, metric_names='all') tqdm_logger = TqdmLogger(pbar=pbar) # noinspection PyTypeChecker tqdm_logger.attach_output_handler( evaluator.engine, event_name=Events.COMPLETED, tag="validation", global_step_transform=global_step_from_engine(trainer.engine), ) # noinspection PyTypeChecker tqdm_logger.attach_output_handler( train_evaluator.engine, event_name=Events.COMPLETED, tag="train", global_step_transform=global_step_from_engine(trainer.engine), ) # Evaluators train_evaluator.attach(trainer.engine, Events.EPOCH_COMPLETED, train_loader) evaluator.attach(trainer.engine, Events.EPOCH_COMPLETED, data=val_loader) # Learning rate scheduling lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'max', verbose=True, patience=5, factor=0.5) evaluator.engine.add_event_handler( Events.COMPLETED, lambda engine: lr_scheduler.step(engine.state.metrics['accuracy'])) # Early stopping es_handler = EarlyStopping( patience=15, score_function=lambda engine: engine.state.metrics['accuracy'], trainer=trainer.engine, cumulative_delta=True, min_delta=0.0001) if 'train_all' in params and params['train_all']: train_evaluator.engine.add_event_handler(Events.COMPLETED, es_handler) else: evaluator.engine.add_event_handler(Events.COMPLETED, es_handler) es_handler.logger.setLevel(logging.DEBUG) # Model checkpoints name = run.replace('/', '-') mc_handler = ModelCheckpoint( config.MODELS_DIR, name, n_saved=1, create_dir=True, require_empty=False, score_name='acc', score_function=lambda engine: engine.state.metrics['accuracy'], global_step_transform=global_step_from_engine(trainer.engine)) evaluator.engine.add_event_handler(Events.EPOCH_COMPLETED, mc_handler, {'m': model}) # TensorBoard logger tb_logger = TensorboardLogger( log_dir=os.path.join(config.TENSORBOARD_DIR, run)) images, labels = next(iter(train_loader)) tb_logger.writer.add_graph(copy.deepcopy(model).cpu(), images) tb_logger.writer.add_hparams(params, {'hparam/dummy': 0}) # noinspection PyTypeChecker tb_logger.attach_output_handler( train_evaluator.engine, event_name=Events.COMPLETED, tag="train", metric_names="all", global_step_transform=global_step_from_engine(trainer.engine), ) # noinspection PyTypeChecker tb_logger.attach_output_handler( evaluator.engine, event_name=Events.COMPLETED, tag="validation", metric_names="all", global_step_transform=global_step_from_engine(trainer.engine), ) input_shape = tuple(next(iter(train_loader))[0].shape[1:]) tb_logger.attach(trainer.engine, log_handler=WeightsImageHandler(model, input_shape), event_name=Events.EPOCH_COMPLETED) tb_logger.attach(trainer.engine, log_handler=OptimizerParamsHandler(optimizer), event_name=Events.EPOCH_STARTED) # tb_logger.attach(trainer.engine, log_handler=WeightsScalarHandler(model), event_name=Events.EPOCH_COMPLETED) # tb_logger.attach(trainer.engine, log_handler=WeightsHistHandler(model), event_name=Events.EPOCH_COMPLETED) # tb_logger.attach(trainer.engine, # log_handler=ActivationsHistHandler(model, layer_names=['linear1', 'batch_norm', 'repu']), # event_name=Events.ITERATION_COMPLETED) # tb_logger.attach(trainer.engine, # log_handler=NumActivationsScalarHandler(model, layer_names=['linear1', 'repu']), # event_name=Events.ITERATION_COMPLETED) # tb_logger.attach(trainer.engine, # log_handler=ActivationsScalarHandler(model, reduction=torch.mean, # layer_names=['linear1', 'batch_norm', 'repu']), # event_name=Events.ITERATION_COMPLETED) # tb_logger.attach(trainer.engine, # log_handler=ActivationsScalarHandler(model, reduction=torch.std, # layer_names=['linear1', 'batch_norm', 'repu']), # event_name=Events.ITERATION_COMPLETED) return es_handler, tb_logger
class Ignite_Trainer(Trainer): def __init__(self, config=None, cmd_args=None, framework='ignite', model=None, device=None, optimizer=None, scheduler=None, criterion=None, train_loader=None, val_loader=None, data_transforms=None): super().__init__(config=config, cmd_args=cmd_args, framework=framework, model=model, device=device, optimizer=optimizer, scheduler=scheduler, criterion=criterion, train_loader=train_loader, val_loader=val_loader, data_transforms=data_transforms) self.train_engine = None self.evaluator = None self.train_evaluator = None self.tb_logger = None def create_trainer(self): # Define any training logic for iteration update def train_step(engine, batch): # Get the images and labels for this batch x, y = batch[0].to(self.device), batch[1].to(self.device) # Set the model into training mode self.model.train() # Zero paramter gradients self.optimizer.zero_grad() # Update the model if self.config.MODEL.WITH_GRAD_SCALE: with autocast(enabled=self.config.MODEL.WITH_AMP): y_pred = self.model(x) loss = self.criterion(y_pred, y) scaler = GradScaler(enabled=self.config.MODEL.WITH_AMP) scaler.scale(loss).backward() scaler.step(self.optimizer) scaler.update() else: with torch.set_grad_enabled(True): y_pred = self.model(x) loss = self.criterion(y_pred, y) loss.backward() # With ReduceLROnPlateau, the step() call needs validation loss at the end epoch, so this is handled through an evaluator event handler rather than here. if not self.config.TRAIN.SCHEDULER.TYPE == 'ReduceLROnPlateau': self.optimizer.step() return loss.item() # Define trainer engine trainer = Engine(train_step) return trainer def create_evaluator(self, metrics, tag='val'): # Evaluation step function @torch.no_grad() def evaluate_step(engine: Engine, batch): self.model.eval() x, y = batch[0].to(self.device), batch[1].to(self.device) if self.config.MODEL.WITH_GRAD_SCALE: with autocast(enabled=self.config.MODEL.WITH_AMP): y_pred = self.model(x) else: y_pred = self.model(x) return y_pred, y # Create the evaluator object evaluator = Engine(evaluate_step) # Attach the metrics for name, metric in metrics.items(): metric.attach(evaluator, name) return evaluator def evaluate_model(self): epoch = self.train_engine.state.epoch # Training Metrics train_state = self.train_evaluator.run(self.train_loader) tr_accuracy = train_state.metrics['accuracy'] tr_precision = train_state.metrics['precision'] tr_recall = train_state.metrics['recall'] tr_f1 = train_state.metrics['f1'] tr_topKCatAcc = train_state.metrics['topKCatAcc'] tr_loss = train_state.metrics['loss'] # Validation Metrics val_state = self.evaluator.run(self.val_loader) val_accuracy = val_state.metrics['accuracy'] val_precision = val_state.metrics['precision'] val_recall = val_state.metrics['recall'] val_f1 = val_state.metrics['f1'] val_topKCatAcc = val_state.metrics['topKCatAcc'] val_loss = val_state.metrics['loss'] print( "Epoch: {:0>4} TrAcc: {:.3f} ValAcc: {:.3f} TrPrec: {:.3f} ValPrec: {:.3f} TrRec: {:.3f} ValRec: {:.3f} TrF1: {:.3f} ValF1: {:.3f} TrTopK: {:.3f} ValTopK: {:.3f} TrLoss: {:.3f} ValLoss: {:.3f}" .format(epoch, tr_accuracy, val_accuracy, tr_precision, val_precision, tr_recall, val_recall, tr_f1, val_f1, tr_topKCatAcc, val_topKCatAcc, tr_loss, val_loss)) def add_logging(self): # Add validation logging self.train_engine.add_event_handler(Events.EPOCH_COMPLETED(every=1), self.evaluate_model) # Add step length update at the end of each epoch self.train_engine.add_event_handler(Events.EPOCH_COMPLETED, lambda _: self.scheduler.step()) def add_tensorboard_logging(self, logging_dir=None): # Add TensorBoard logging if logging_dir is None: os.path.join(self.config.DIRS.WORKING_DIR, 'tb_logs') else: os.path.join(logging_dir, 'tb_logs') print('Tensorboard logging saving to:: {} ...'.format(logging_dir), end='') self.tb_logger = TensorboardLogger(log_dir=logging_dir) # Logging iteration loss self.tb_logger.attach_output_handler( engine=self.train_engine, event_name=Events.ITERATION_COMPLETED, tag='training', output_transform=lambda loss: {"batch loss": loss}) # Logging epoch training metrics self.tb_logger.attach_output_handler( engine=self.train_evaluator, event_name=Events.EPOCH_COMPLETED, tag="training", metric_names=[ "loss", "accuracy", "precision", "recall", "f1", "topKCatAcc" ], global_step_transform=global_step_from_engine(self.train_engine), ) # Logging epoch validation metrics self.tb_logger.attach_output_handler( engine=self.evaluator, event_name=Events.EPOCH_COMPLETED, tag="validation", metric_names=[ "loss", "accuracy", "precision", "recall", "f1", "topKCatAcc" ], global_step_transform=global_step_from_engine(self.train_engine), ) # Attach the logger to the trainer to log model's weights as a histogram after each epoch self.tb_logger.attach(self.train_engine, event_name=Events.EPOCH_COMPLETED, log_handler=WeightsHistHandler(self.model)) # Attach the logger to the trainer to log model's gradients as a histogram after each epoch self.tb_logger.attach(self.train_engine, event_name=Events.EPOCH_COMPLETED, log_handler=GradsHistHandler(self.model)) print('Tensorboard Logging...', end='') print('done') def create_callbacks(self, best_model_only=True): ## SETUP CALLBACKS print('[INFO] Creating callback functions for training loop...', end='') # If using ReduceLROnPlateau then need to add event to handle the step() call with loss: if self.config.TRAIN.SCHEDULER.TYPE == 'ReduceLROnPlateau': self.evaluator.add_event_handler(Events.COMPLETED, self.scheduler) else: print('No checkpointing required for LR Scheduler....', end='') # Early Stopping - stops training if the validation loss does not decrease after 5 epochs handler = EarlyStopping(patience=self.config.EARLY_STOPPING_PATIENCE, score_function=score_function_loss, trainer=self.train_engine) self.evaluator.add_event_handler(Events.COMPLETED, handler) print('Early Stopping ({} epochs)...'.format( self.config.EARLY_STOPPING_PATIENCE), end='') # Model checkpointing self._create_ingite_model_checkpointer(best_model_only=best_model_only) def run(self, logging_dir=None, best_model_only=True): #assert self.model is not None, '[ERROR] No model object loaded. Please load a PyTorch model torch.nn object into the class object.' #assert (self.train_loader is not None) or (self.val_loader is not None), '[ERROR] You must specify data loaders.' for key in self.trainer_status.keys(): assert self.trainer_status[ key], '[ERROR] The {} has not been generated and you cannot proceed.'.format( key) print('[INFO] Trainer pass OK for training.') # TRAIN ENGINE # Create the objects for training self.train_engine = self.create_trainer() # METRICS AND EVALUATION # Metrics - running average RunningAverage(output_transform=lambda x: x).attach( self.train_engine, 'loss') # Metrics - epochs metrics = { 'accuracy': Accuracy(), 'recall': Recall(average=True), 'precision': Precision(average=True), 'f1': Fbeta(beta=1), 'topKCatAcc': TopKCategoricalAccuracy(k=5), 'loss': Loss(self.criterion) } # Create evaluators self.evaluator = self.create_evaluator(metrics=metrics) self.train_evaluator = self.create_evaluator(metrics=metrics, tag='train') # LOGGING # Create logging to terminal self.add_logging() # Create Tensorboard logging self.add_tensorboard_logging(logging_dir=logging_dir) ## CALLBACKS self.create_callbacks(best_model_only=best_model_only) ## TRAIN # Train the model print('[INFO] Executing model training...') self.train_engine.run(self.train_loader, max_epochs=self.config.TRAIN.NUM_EPOCHS) print('[INFO] Model training is complete.') def update_model_from_checkpoint(self, checkpoint_file=None, load_to_device=True): ''' Function to take a saved checkpoint of the models weights, and load it into the model. ''' assert self.trainer_status[ 'model'], '[ERROR] You must create the model to load the weights. Use Trainer.create_model() method to first create your model, then load weights.' assert checkpoint_file is not None, '[ERROR] You must provide the full path and name of the .pt file containing the saved weights of the model you want to update.' try: # Load the weights of the checkpointed model from the PT file self.model.load_state_dict(torch.load(f=checkpoint_file)) except: raise Exception( '[ERROR] Something went wrong with loading the weights into the model.' ) else: print( '[INFO] Successfully loaded weights into the model from weights file:: {}' .format(checkpoint_file)) if load_to_device: self.model.to(self.device) print( '[INFO] Successfully updated model and pushed it to the device {}' .format(self.device)) # Print summary of model summary( self.model, batch_size=self.config.TRAIN.BATCH_SIZE, input_size=( 3, self.config.DATA.TRANSFORMS.PARAMS.DEFAULT.img_crop_size, self.config.DATA.TRANSFORMS.PARAMS.DEFAULT.img_crop_size)) else: print( '[INFO] Successfully updated model but NOT pushed it to the device {}' .format(self.device)) # Print summary of model summary( self.model, device='cpu', batch_size=self.config.TRAIN.BATCH_SIZE, input_size=( 3, self.config.DATA.TRANSFORMS.PARAMS.DEFAULT.img_crop_size, self.config.DATA.TRANSFORMS.PARAMS.DEFAULT.img_crop_size)) def convert_to_torchscript(self, checkpoint_file=None, torchscript_model_path=None, method='trace', return_jit_model=False): assert self.trainer_status[ 'model'], '[ERROR] You must create the model to load the weights. Use Trainer.create_model() method to first create your model, then load weights.' assert checkpoint_file is not None, '[ERROR] You must provide the path and name of a PyTorch Ignite checkpoint file of model weights [checkpoint_file].' # Update the Trainer class attribute model with model weights file self.update_model_from_checkpoint(checkpoint_file=checkpoint_file) if torchscript_model_path is None: torchscript_model_path = os.path.join(os.getcwd(), 'torchscript_model.pt') if method == 'trace': assert self.trainer_status[ 'val_loader'], '[ERROR] You must create the validation loader in order to load images. Use Trainer.create_dataloaders() method to create access to image batches.' # Create an image batch X, _ = next(iter(self.val_loader)) # Push the input images to the device X = X.to(self.device) # Trace the model jit_model = torch.jit.trace(self.model, (X)) # Write the trace module of the model to disk print( '[INFO] Torchscript file being saved to temporary location:: {}' .format(torchscript_model_path)) jit_model.save(torchscript_model_path) elif method == 'script': # Trace the model jit_model = torch.jit.script(self.model) # Write the trace module of the model to disk print( '[INFO] Torchscript file being saved to temporary location:: {}' .format(torchscript_model_path)) jit_model.save(torchscript_model_path) if return_jit_model: return jit_model def _create_ingite_model_checkpointer(self, best_model_only=True): ''' Function to create an ingite model checkpointer based on validation accuracy (best model == True), or at every epoch (best model == False) ''' print('Model Checkpointing...', end='') if best_model_only: print('best model checkpointing...', end='') # best model checkpointer, based on validation accuracy. self.model_checkpointer = ModelCheckpoint( dirname=self.config.DIRS.WORKING_DIR, filename_prefix='caltech_birds_ignite_best', score_function=score_function_acc, score_name='val_acc', n_saved=2, create_dir=True, save_as_state_dict=True, require_empty=False, global_step_transform=global_step_from_engine( self.train_engine)) self.evaluator.add_event_handler( Events.COMPLETED, self.model_checkpointer, {self.config.MODEL.MODEL_NAME: self.model}) else: # Checkpoint the model # iteration checkpointer print('every iteration model checkpointing...', end='') self.model_checkpointer = ModelCheckpoint( dirname=self.config.DIRS.WORKING_DIR, filename_prefix='caltech_birds_ignite', n_saved=2, create_dir=True, save_as_state_dict=True, require_empty=False) self.train_engine.add_event_handler( Events.EPOCH_COMPLETED, self.model_checkpointer, {self.config.MODEL.MODEL_NAME: self.model}) print('Done')
def training(rank, config): rank = idist.get_rank() manual_seed(config["seed"] + rank) device = idist.device() # Define output folder: config.output = "/tmp/output" model = idist.auto_model(config.model) optimizer = idist.auto_optim(config.optimizer) criterion = config.criterion train_set, val_set = config.train_set, config.val_set train_loader = idist.auto_dataloader(train_set, batch_size=config.train_batch_size) val_loader = idist.auto_dataloader(val_set, batch_size=config.val_batch_size) trainer = create_supervised_trainer(model, optimizer, criterion, device=device) trainer.logger = setup_logger("Trainer") metrics = {"accuracy": Accuracy(), "loss": Loss(criterion)} train_evaluator = create_supervised_evaluator(model, metrics=metrics, device=device) train_evaluator.logger = setup_logger("Train Evaluator") validation_evaluator = create_supervised_evaluator(model, metrics=metrics, device=device) validation_evaluator.logger = setup_logger("Val Evaluator") @trainer.on(Events.EPOCH_COMPLETED(every=config.val_interval)) def compute_metrics(engine): train_evaluator.run(train_loader) validation_evaluator.run(val_loader) if rank == 0: tb_logger = TensorboardLogger(log_dir=config.output) tb_logger.attach_output_handler( trainer, event_name=Events.ITERATION_COMPLETED(every=100), tag="training", output_transform=lambda loss: {"batchloss": loss}, metric_names="all", ) for tag, evaluator in [("training", train_evaluator), ("validation", validation_evaluator)]: tb_logger.attach_output_handler( evaluator, event_name=Events.EPOCH_COMPLETED, tag=tag, metric_names=["loss", "accuracy"], global_step_transform=global_step_from_engine(trainer), ) tb_logger.attach_opt_params_handler( trainer, event_name=Events.ITERATION_COMPLETED(every=100), optimizer=optimizer) model_checkpoint = ModelCheckpoint( config.output, n_saved=2, filename_prefix="best", score_name="accuracy", global_step_transform=global_step_from_engine(trainer), ) validation_evaluator.add_event_handler(Events.COMPLETED, model_checkpoint, {"model": model}) trainer.run(train_loader, max_epochs=config.num_epochs) if rank == 0: tb_logger.close()