def setup_trainer(model, optimizer, device, data_parallel: bool) -> Engine: def update(trainer, batch: Tuple[torch.Tensor]): model.train() optimizer.zero_grad() if isinstance(batch, tuple) or isinstance(batch, list): assert len(batch) == 1 batch = batch[0] else: assert isinstance(batch, torch.Tensor) batch = batch.to(device) if not data_parallel: masked_batch = mask_for_forward( batch) # replace -1 with some other token lm_logits = model(masked_batch)[0] loss = calculate_lm_loss(lm_logits, batch) loss.backward() else: # handling of -1 as padding is not implemented losses = model(batch, lm_labels=batch) losses.backward(torch.ones_like(losses)) loss = losses.mean() optimizer.step() return loss.item() trainer = Engine(update) trainer.add_event_handler(Events.ITERATION_COMPLETED, TerminateOnNan()) return trainer
def create_engine(): engine = Engine(update) pbar = ProgressBar() engine.add_event_handler(Events.ITERATION_COMPLETED, TerminateOnNan()) pbar.attach(engine, event_name=Events.EPOCH_COMPLETED, closing_event_name=Events.COMPLETED) return engine
def _setup_trainer_handlers(self, trainer): # Setup timer to measure training time timer = setup_timer(trainer) self._setup_log_training_loss(trainer) @trainer.on(Events.EPOCH_COMPLETED) def log_training_time(engine): self.logger.info("One epoch training time (seconds): {}".format( timer.value())) last_model_saver = ModelCheckpoint( self.log_dir.as_posix(), filename_prefix="checkpoint", save_interval=self.trainer_checkpoint_interval, n_saved=1, atomic=True, create_dir=True, save_as_state_dict=True) model_name = get_object_name(self.model) to_save = { model_name: self.model, "optimizer": self.optimizer, } if self.lr_scheduler is not None: to_save['lr_scheduler'] = self.lr_scheduler trainer.add_event_handler(Events.ITERATION_COMPLETED, last_model_saver, to_save) trainer.add_event_handler(Events.ITERATION_COMPLETED, TerminateOnNan())
def test_terminate_on_nan_and_inf(state_output, should_terminate): torch.manual_seed(12) def update_fn(engine, batch): pass trainer = Engine(update_fn) trainer.state = State() h = TerminateOnNan() trainer.state.output = state_output if isinstance(state_output, np.ndarray): h._output_transform = lambda x: x.tolist() h(trainer) assert trainer.should_terminate == should_terminate
def test_with_terminate_on_inf(): torch.manual_seed(12) data = [ 1.0, 0.8, torch.rand(4, 4), (1.0 / torch.randint(0, 2, size=(4, )).type(torch.float), torch.tensor(1.234)), torch.rand(5), torch.asin(torch.randn(4, 4)), 0.0, 1.0, ] def update_fn(engine, batch): return batch trainer = Engine(update_fn) h = TerminateOnNan() trainer.add_event_handler(Events.ITERATION_COMPLETED, h) trainer.run(data, max_epochs=2) assert trainer.state.iteration == 4
def _add_event_handlers(self): """ Adds a progressbar and a summary writer to output the current training status. Adds event handlers to output common messages and update the progressbar. """ progressbar_description = 'TRAINING => loss: {:.6f}' progressbar = tqdm(initial=0, leave=False, total=len(self.train_loader), desc=progressbar_description.format(0)) writer = SummaryWriter(self.log_directory) @self.trainer_engine.on(Events.ITERATION_COMPLETED) def log_training_loss(trainer): writer.add_scalar('loss', trainer.state.output) progressbar.desc = progressbar_description.format( trainer.state.output) progressbar.update(1) @self.trainer_engine.on(Events.EPOCH_COMPLETED) def log_training_results(trainer): progressbar.n = progressbar.last_print_n = 0 self.evaluator.run(self.train_loader) metrics = self.evaluator.state.metrics for key, value in metrics.items(): writer.add_scalar(key, value) tqdm.write( '\nTraining Results - Epoch: {} Avg accuracy: {:.2f} Avg loss: {:.2f}\n' .format(trainer.state.epoch, metrics['accuracy'], metrics['loss'])) @self.trainer_engine.on(Events.EPOCH_COMPLETED) def log_validation_results(trainer): progressbar.n = progressbar.last_print_n = 0 self.evaluator.run(self.val_loader) metrics = self.evaluator.state.metrics for key, value in metrics.items(): writer.add_scalar(key, value) tqdm.write( '\nValidation Results - Epoch: {} Avg accuracy: {:.2f} Avg loss: {:.2f}\n' .format(trainer.state.epoch, metrics['accuracy'], metrics['loss'])) checkpoint_saver = ModelCheckpoint( # create a Checkpoint handler that can be used to periodically self.checkpoint_directory, filename_prefix='net', # save model objects to disc. save_interval=1, n_saved=5, atomic=True, create_dir=True, save_as_state_dict=False, require_empty=False) self.trainer_engine.add_event_handler(Events.EPOCH_COMPLETED, checkpoint_saver, {'train': self.model}) self.trainer_engine.add_event_handler(Events.COMPLETED, checkpoint_saver, {'complete': self.model}) self.trainer_engine.add_event_handler(Events.ITERATION_COMPLETED, TerminateOnNan())
def test_without_terminate_on_nan_inf(): data = [1.0, 0.8, torch.rand(4, 4), (torch.rand(5), torch.rand(5, 4)), 0.0, 1.0] def update_fn(engine, batch): return batch trainer = Engine(update_fn) h = TerminateOnNan() trainer.add_event_handler(Events.ITERATION_COMPLETED, h) trainer.run(data, max_epochs=2) assert trainer.state.iteration == len(data) * 2
def _setup_common_training_handlers(trainer, to_save=None, save_every_iters=1000, output_path=None, lr_scheduler=None, with_gpu_stats=True, output_names=None, with_pbars=True, with_pbar_on_iters=True, log_every_iters=100, device='cuda'): trainer.add_event_handler(Events.ITERATION_COMPLETED, TerminateOnNan()) if lr_scheduler is not None: if isinstance(lr_scheduler, torch.optim.lr_scheduler._LRScheduler): trainer.add_event_handler(Events.ITERATION_COMPLETED, lambda engine: lr_scheduler.step()) else: trainer.add_event_handler(Events.ITERATION_STARTED, lr_scheduler) trainer.add_event_handler(Events.EPOCH_COMPLETED, empty_cuda_cache) if to_save is not None: if output_path is None: raise ValueError("If to_save argument is provided then output_path argument should be also defined") checkpoint_handler = ModelCheckpoint(dirname=output_path, filename_prefix="training") trainer.add_event_handler(Events.ITERATION_COMPLETED(every=save_every_iters), checkpoint_handler, to_save) if with_gpu_stats: GpuInfo().attach(trainer, name='gpu', event_name=Events.ITERATION_COMPLETED(every=log_every_iters)) if output_names is not None: def output_transform(x, index, name): if isinstance(x, Mapping): return x[name] elif isinstance(x, Sequence): return x[index] elif isinstance(x, torch.Tensor): return x else: raise ValueError("Unhandled type of update_function's output. " "It should either mapping or sequence, but given {}".format(type(x))) for i, n in enumerate(output_names): RunningAverage(output_transform=partial(output_transform, index=i, name=n), epoch_bound=False, device=device).attach(trainer, n) if with_pbars: if with_pbar_on_iters: ProgressBar(persist=False).attach(trainer, metric_names='all', event_name=Events.ITERATION_COMPLETED(every=log_every_iters)) ProgressBar(persist=True, bar_format="").attach(trainer, event_name=Events.EPOCH_STARTED, closing_event_name=Events.COMPLETED)
def test_with_terminate_on_nan(): torch.manual_seed(12) data = [1.0, 0.8, (torch.rand(4, 4), torch.rand(4, 4)), torch.rand(5), torch.asin(torch.randn(4, 4)), 0.0, 1.0] def update_fn(engine, batch): return batch trainer = Engine(update_fn) h = TerminateOnNan() trainer.add_event_handler(Events.ITERATION_COMPLETED, h) trainer.run(data, max_epochs=2) assert trainer.state.iteration == 5
def __init__(self, name, model, log_dir, lr, lr_decay_step, adam=False): """ Initialize to train the given model. :param name: The name of the model to be trained. :param model: The model to be trained. :param log_dir: String. The log directory of the tensorboard. :param lr: Float. The learning rate. :param lr_decay_step: Integer. The amount of steps the learning rate decays. :param adam: Bool. Whether to use adam optimizer or not. """ super(Trainer, self).__init__(self.update_model) self.model = model # tqdm ProgressBar(persist=True).attach(self) # Optimizer params = [p for p in model.parameters() if p.requires_grad] if adam: self.optimizer = torch.optim.Adam(params, lr=lr) else: self.optimizer = torch.optim.SGD(params, lr=lr, momentum=0.9) # Scheduler if lr_decay_step > 0: self.scheduler = torch.optim.lr_scheduler.StepLR(self.optimizer, step_size=lr_decay_step, gamma=0.1) self.add_event_handler(Events.EPOCH_COMPLETED, lambda e: e.scheduler.step()) else: self.scheduler = None # Terminate if nan values found self.add_event_handler(Events.ITERATION_COMPLETED, TerminateOnNan()) # Tensorboard logging self.tb_logger = TensorboardLogger(log_dir=os.path.join(log_dir, name)) self.add_event_handler(Events.COMPLETED, lambda x: self.tb_logger.close()) self.tb_logger.attach(self, log_handler=OptimizerParamsHandler(self.optimizer), event_name=Events.EPOCH_COMPLETED) self.tb_logger.attach(self, log_handler=OutputHandler(tag='training', output_transform=lambda x: { 'rpn_box_loss': round(self.state.output['loss_rpn_box_reg'].item(), 4), 'rpn_cls_loss': round(self.state.output['loss_objectness'].item(), 4), 'roi_box_loss': round(self.state.output['loss_box_reg'].item(), 4), 'roi_cls_loss': round(self.state.output['loss_classifier'].item(), 4) }), event_name=Events.EPOCH_COMPLETED) # Run on GPU (cuda) if available if torch.cuda.is_available(): torch.cuda.set_device(int(get_free_gpu())) model.cuda(torch.cuda.current_device())
def _finetune(self, train_dl, val_dl, criterion, iter_num): print("Recovery") self.model.to_rank = False finetune_epochs = config["pruning"]["finetune_epochs"].get() optimizer_constructor = optimizer_constructor_from_config(config) optimizer = optimizer_constructor(self.model.parameters()) finetune_engine = create_supervised_trainer(self.model, optimizer, criterion, self.device) # progress bar pbar = Progbar(train_dl, metrics='none') finetune_engine.add_event_handler(Events.ITERATION_COMPLETED, pbar) # log training loss if self.writer: finetune_engine.add_event_handler(Events.ITERATION_COMPLETED, lambda engine: log_training_loss(engine, self.writer)) # terminate on Nan finetune_engine.add_event_handler(Events.ITERATION_COMPLETED, TerminateOnNan()) # model checkpoints checkpoint = ModelCheckpoint(config["pruning"]["out_path"].get(), require_empty=False, filename_prefix=f"pruning_iteration_{iter_num}", save_interval=1) finetune_engine.add_event_handler(Events.COMPLETED, checkpoint, {"weights": self.model.cpu()}) # add early stopping validation_evaluator = create_supervised_evaluator(self.model, device=self.device, metrics=self._metrics) if config["pruning"]["early_stopping"].get(): def _score_function(evaluator): return -evaluator.state.metrics["loss"] early_stop = EarlyStopping(config["pruning"]["patience"].get(), _score_function, finetune_engine) validation_evaluator.add_event_handler(Events.EPOCH_COMPLETED, early_stop) finetune_engine.add_event_handler(Events.EPOCH_COMPLETED, lambda engine: run_evaluator(engine, validation_evaluator, val_dl)) for handler_dict in self._finetune_handlers: finetune_engine.add_event_handler(handler_dict["event_name"], handler_dict["handler"], *handler_dict["args"], **handler_dict["kwargs"]) # run training engine finetune_engine.run(train_dl, max_epochs=finetune_epochs)
def _train(self, model, optimizer, train_loader, max_epochs, **kwargs): trainer = create_supervised_trainer(model, optimizer, self.criterion) trainer.add_event_handler(Events.ITERATION_COMPLETED, TerminateOnNan()) val_metrics = {"smape": SymmetricMeanAbsolutePercentageError()} evaluator = create_supervised_evaluator(model, metrics=val_metrics) @trainer.on(Events.COMPLETED) def log_training_results(trainer): evaluator.run(train_loader) metrics = evaluator.state.metrics self._print_line( "Interval {} Training Results - Epoch: {} Avg smape: {:.2f}". format(kwargs.get("interval"), trainer.state.epoch, metrics["smape"])) trainer.run(train_loader, max_epochs=max_epochs) return (model, optimizer)
def _add_event_handlers(self): """Add event handlers to output common messages and update the progressbar.""" self.trainer_engine.add_event_handler(Events.ITERATION_COMPLETED, TerminateOnNan()) self.trainer_engine.add_event_handler(Events.ITERATION_COMPLETED, self._event_log_training_output) self.trainer_engine.add_event_handler( Events.ITERATION_COMPLETED, self._event_update_progressbar_step) self.trainer_engine.add_event_handler(Events.ITERATION_COMPLETED, self._event_update_step_counter) self.trainer_engine.add_event_handler(Events.EPOCH_COMPLETED, self._event_log_training_results) self.trainer_engine.add_event_handler( Events.EPOCH_COMPLETED, self._event_log_validation_results) self.trainer_engine.add_event_handler( Events.EPOCH_COMPLETED, self._event_save_trainer_checkpoint) self.trainer_engine.add_event_handler(Events.EPOCH_COMPLETED, self._event_reset_progressbar) self.trainer_engine.add_event_handler(Events.EPOCH_COMPLETED, self._event_update_epoch_counter) self.trainer_engine.add_event_handler(Events.COMPLETED, self._event_cleanup)
def test_terminate_on_nan_and_inf(): torch.manual_seed(12) def update_fn(engine, batch): pass trainer = Engine(update_fn) trainer.state = State() h = TerminateOnNan() trainer.state.output = 1.0 h(trainer) assert not trainer.should_terminate trainer.state.output = torch.tensor(123.45) h(trainer) assert not trainer.should_terminate trainer.state.output = torch.asin(torch.randn(10, )) h(trainer) assert trainer.should_terminate trainer.should_terminate = False trainer.state.output = np.array([1.0, 2.0]) h._output_transform = lambda x: x.tolist() h(trainer) assert not trainer.should_terminate h._output_transform = lambda x: x trainer.state.output = torch.asin(torch.randn(4, 4)) h(trainer) assert trainer.should_terminate trainer.should_terminate = False trainer.state.output = (10.0, 1.0 / torch.randint(0, 2, size=(4, )).type(torch.float), 1.0) h(trainer) assert trainer.should_terminate trainer.should_terminate = False trainer.state.output = (1.0, torch.tensor(1.0), "abc") h(trainer) assert not trainer.should_terminate trainer.state.output = 1.0 / torch.randint(0, 2, size=(4, 4)).type( torch.float) h(trainer) assert trainer.should_terminate trainer.should_terminate = False trainer.state.output = (float("nan"), 10.0) h(trainer) assert trainer.should_terminate trainer.should_terminate = False trainer.state.output = float("inf") h(trainer) assert trainer.should_terminate trainer.should_terminate = False trainer.state.output = [float("nan"), 10.0] h(trainer) assert trainer.should_terminate trainer.should_terminate = False
def run(conf: DictConfig, local_rank=0, distributed=False): epochs = conf.train.epochs epoch_length = conf.train.epoch_length torch.manual_seed(conf.general.seed) if distributed: rank = dist.get_rank() num_replicas = dist.get_world_size() torch.cuda.set_device(local_rank) else: rank = 0 num_replicas = 1 torch.cuda.set_device(conf.general.gpu) device = torch.device('cuda') loader_args = dict() master_node = rank == 0 if master_node: print(conf.pretty()) if num_replicas > 1: epoch_length = epoch_length // num_replicas loader_args = dict(rank=rank, num_replicas=num_replicas) train_dl = create_train_loader(conf.data, **loader_args) if epoch_length < 1: epoch_length = len(train_dl) metric_names = list(conf.logging.stats) metrics = create_metrics(metric_names, device if distributed else None) G = instantiate(conf.model.G).to(device) D = instantiate(conf.model.D).to(device) G_loss = instantiate(conf.loss.G).to(device) D_loss = instantiate(conf.loss.D).to(device) G_opt = instantiate(conf.optim.G, G.parameters()) D_opt = instantiate(conf.optim.D, D.parameters()) G_ema = None if master_node and conf.G_smoothing.enabled: G_ema = instantiate(conf.model.G) if not conf.G_smoothing.use_cpu: G_ema = G_ema.to(device) G_ema.load_state_dict(G.state_dict()) G_ema.requires_grad_(False) to_save = { 'G': G, 'D': D, 'G_loss': G_loss, 'D_loss': D_loss, 'G_opt': G_opt, 'D_opt': D_opt, 'G_ema': G_ema } if master_node and conf.logging.model: logging.info(G) logging.info(D) if distributed: ddp_kwargs = dict(device_ids=[ local_rank, ], output_device=local_rank) G = torch.nn.parallel.DistributedDataParallel(G, **ddp_kwargs) D = torch.nn.parallel.DistributedDataParallel(D, **ddp_kwargs) train_options = { 'train': dict(conf.train), 'snapshot': dict(conf.snapshots), 'smoothing': dict(conf.G_smoothing), 'distributed': distributed } bs_dl = int(conf.data.loader.batch_size) * num_replicas bs_eff = conf.train.batch_size if bs_eff % bs_dl: raise AttributeError( "Effective batch size should be divisible by data-loader batch size " "multiplied by number of devices in use" ) # until there is no special bs for master node... upd_interval = max(bs_eff // bs_dl, 1) train_options['train']['update_interval'] = upd_interval if epoch_length < len(train_dl): # ideally epoch_length should be tied to the effective batch_size only # and the ignite trainer counts data-loader iterations epoch_length *= upd_interval train_loop, sample_images = create_train_closures(G, D, G_loss, D_loss, G_opt, D_opt, G_ema=G_ema, device=device, options=train_options) trainer = create_trainer(train_loop, metrics, device, num_replicas) to_save['trainer'] = trainer every_iteration = Events.ITERATION_COMPLETED trainer.add_event_handler(every_iteration, TerminateOnNan()) cp = conf.checkpoints pbar = None if master_node: log_freq = conf.logging.iter_freq log_event = Events.ITERATION_COMPLETED(every=log_freq) pbar = ProgressBar(persist=False) trainer.add_event_handler(Events.EPOCH_STARTED, on_epoch_start) trainer.add_event_handler(log_event, log_iter, pbar, log_freq) trainer.add_event_handler(Events.EPOCH_COMPLETED, log_epoch) pbar.attach(trainer, metric_names=metric_names) setup_checkpoints(trainer, to_save, epoch_length, conf) setup_snapshots(trainer, sample_images, conf) if 'load' in cp.keys() and cp.load is not None: if master_node: logging.info("Resume from a checkpoint: {}".format(cp.load)) trainer.add_event_handler(Events.STARTED, _upd_pbar_iter_from_cp, pbar) Checkpoint.load_objects(to_load=to_save, checkpoint=torch.load(cp.load, map_location=device)) try: trainer.run(train_dl, max_epochs=epochs, epoch_length=epoch_length) except Exception as e: import traceback logging.error(traceback.format_exc()) if pbar is not None: pbar.close()
def main(parser_args): """Main function to create trainer engine, add handlers to train and validation engines. Then runs train engine to perform training and validation. Args: parser_args (dict): parsed arguments """ dataloader_train, dataloader_validation = get_dataloaders(parser_args) criterion = nn.CrossEntropyLoss() unet = SphericalUNet(parser_args.pooling_class, parser_args.n_pixels, parser_args.depth, parser_args.laplacian_type, parser_args.kernel_size) unet, device = init_device(parser_args.device, unet) lr = parser_args.learning_rate optimizer = optim.Adam(unet.parameters(), lr=lr) def trainer(engine, batch): """Train Function to define train engine. Called for every batch of the train engine, for each epoch. Args: engine (ignite.engine): train engine batch (:obj:`torch.utils.data.dataloader`): batch from train dataloader Returns: :obj:`torch.tensor` : train loss for that batch and epoch """ unet.train() data, labels = batch labels = labels.to(device) data = data.to(device) output = unet(data) B, V, C = output.shape B_labels, V_labels, C_labels = labels.shape output = output.view(B * V, C) labels = labels.view(B_labels * V_labels, C_labels).max(1)[1] loss = criterion(output, labels) optimizer.zero_grad() loss.backward() optimizer.step() return loss.item() writer = SummaryWriter(parser_args.tensorboard_path) engine_train = Engine(trainer) engine_validate = create_supervised_evaluator( model=unet, metrics={"AP": EpochMetric(average_precision_compute_fn)}, device=device, output_transform=validate_output_transform) engine_train.add_event_handler( Events.EPOCH_STARTED, lambda x: print("Starting Epoch: {}".format(x.state.epoch))) engine_train.add_event_handler(Events.ITERATION_COMPLETED, TerminateOnNan()) @engine_train.on(Events.EPOCH_COMPLETED) def epoch_validation(engine): """Handler to run the validation engine at the end of the train engine's epoch. Args: engine (ignite.engine): train engine """ print("beginning validation epoch") engine_validate.run(dataloader_validation) reduce_lr_plateau = ReduceLROnPlateau( optimizer, mode=parser_args.reducelronplateau_mode, factor=parser_args.reducelronplateau_factor, patience=parser_args.reducelronplateau_patience, ) @engine_validate.on(Events.EPOCH_COMPLETED) def update_reduce_on_plateau(engine): """Handler to reduce the learning rate on plateau at the end of the validation engine's epoch Args: engine (ignite.engine): validation engine """ ap = engine.state.metrics["AP"] mean_average_precision = np.mean(ap[1:]) reduce_lr_plateau.step(mean_average_precision) @engine_validate.on(Events.EPOCH_COMPLETED) def save_epoch_results(engine): """Handler to save the metrics at the end of the validation engine's epoch Args: engine (ignite.engine): validation engine """ ap = engine.state.metrics["AP"] mean_average_precision = np.mean(ap[1:]) print("Average precisions:", ap) print("mAP:", mean_average_precision) writer.add_scalars( "metrics", { "mean average precision (AR+TC)": mean_average_precision, "AR average precision": ap[2], "TC average precision": ap[1] }, engine_train.state.epoch, ) writer.close() step_scheduler = StepLR(optimizer, step_size=parser_args.steplr_step_size, gamma=parser_args.steplr_gamma) scheduler = create_lr_scheduler_with_warmup( step_scheduler, warmup_start_value=parser_args.warmuplr_warmup_start_value, warmup_end_value=parser_args.warmuplr_warmup_end_value, warmup_duration=parser_args.warmuplr_warmup_duration, ) engine_validate.add_event_handler(Events.EPOCH_COMPLETED, scheduler) earlystopper = EarlyStopping( patience=parser_args.earlystopping_patience, score_function=lambda x: -x.state.metrics["AP"][1], trainer=engine_train) engine_validate.add_event_handler(Events.EPOCH_COMPLETED, earlystopper) add_tensorboard(engine_train, optimizer, unet, log_dir=parser_args.tensorboard_path) engine_train.run(dataloader_train, max_epochs=parser_args.n_epochs) torch.save(unet.state_dict(), parser_args.model_save_path + "unet_state.pt")
def do_train(cfg, model, train_loader, val_loader, optimizer, scheduler, loss_fn, metrics): device = cfg['device_ids'][0] if torch.cuda.is_available( ) else 'cpu' #默认主卡设置为 max_epochs = cfg['max_epochs'] # create trainer if cfg['multi_gpu']: #多卡时,不需要传入loss_fn trainer = create_supervised_dp_trainer(model.train(), optimizer, device=device) else: trainer = create_supervised_trainer(model.train(), optimizer, loss_fn, device=device) trainer.add_event_handler(Events.ITERATION_COMPLETED, TerminateOnNan()) RunningAverage(output_transform=lambda x: x).attach(trainer, 'avg_loss') # create pbar len_train_loader = len(train_loader) pbar = tqdm(total=len_train_loader) ########################################################################################## ########### Events.ITERATION_COMPLETED ############# ########################################################################################## # 每 log_period 轮迭代结束输出train_loss @trainer.on(Events.ITERATION_COMPLETED) def log_training_loss(engine): log_period = cfg['log_period'] log_per_iter = int(log_period * len_train_loader) if int( log_period * len_train_loader) >= 1 else 1 # 计算打印周期 current_iter = (engine.state.iteration - 1) % len_train_loader + 1 + ( engine.state.epoch - 1) * len_train_loader # 计算当前 iter lr = optimizer.state_dict()['param_groups'][0]['lr'] if current_iter % log_per_iter == 0: pbar.write("Epoch[{}] Iteration[{}] lr {:.7f} Loss {:.7f}".format( engine.state.epoch, current_iter, lr, engine.state.metrics['avg_loss'])) pbar.update(log_per_iter) # lr_scheduler @trainer.on(Events.ITERATION_COMPLETED) def adjust_lr_scheduler(engine): if isinstance(scheduler, lr_scheduler.CyclicLR): scheduler.step() @trainer.on(Events.ITERATION_COMPLETED) def update_swa(engine): if isinstance(scheduler, lr_scheduler.CyclicLR): if cfg['enable_swa']: swa_period = 2 * cfg['lr_scheduler']['step_size_up'] current_iter = ( engine.state.iteration - 1) % len_train_loader + 1 + ( engine.state.epoch - 1) * len_train_loader # 计算当前 iter if current_iter % swa_period == 0: optimizer.update_swa() @trainer.on(Events.ITERATION_COMPLETED) def update_bn(engine): if isinstance(scheduler, lr_scheduler.CyclicLR): save_period = 2 * cfg['lr_scheduler']['step_size_up'] current_iter = ( engine.state.iteration - 1) % len_train_loader + 1 + ( engine.state.epoch - 1) * len_train_loader # 计算当前 iter if current_iter % save_period == 0 and current_iter >= save_period * 2: # 从第 4 个周期开始存 save_dir = cfg['save_dir'] if not os.path.isdir(save_dir): os.makedirs(save_dir) if cfg['enable_swa']: optimizer.swap_swa_sgd() optimizer.bn_update(train_loader, model, device=device) model_name = os.path.join( save_dir, cfg['model']['type'] + '_' + cfg['tag'] + "_" + str(current_iter) + ".pth") if cfg['multi_gpu']: save_pth = { 'model': model.module.model.state_dict(), 'cfg': cfg } torch.save(save_pth, model_name) else: save_pth = {'model': model.state_dict(), 'cfg': cfg} torch.save(save_pth, model_name) ########################################################################################## ################## Events.EPOCH_COMPLETED ############### ########################################################################################## @trainer.on(Events.EPOCH_COMPLETED) def save_temp_epoch(engine): save_dir = cfg['save_dir'] if not os.path.isdir(save_dir): os.makedirs(save_dir) epoch = engine.state.epoch if epoch % 1 == 0: model_name = os.path.join( save_dir, cfg['model']['type'] + '_' + cfg['tag'] + "_temp.pth") if cfg['multi_gpu']: save_pth = { 'model': model.module.model.state_dict(), 'cfg': cfg } torch.save(save_pth, model_name) else: save_pth = {'model': model.state_dict(), 'cfg': cfg} torch.save(save_pth, model_name) @trainer.on(Events.EPOCH_COMPLETED) def reset_pbar(engine): pbar.reset() trainer.run(train_loader, max_epochs=max_epochs) pbar.close()
def run(args): train_iter, valid_iter, test_iter, indexed_vector = load_dataset(args) # iters_per_epoch = len(train_iter) // 100 * 100 # 取整百,比如train dataset是7463,batch16,则每epoch有466.4 iteration iters_per_epoch = len(train_iter) model = LSTMClassifier(indexed_vector, hidden_dim=args.nhid, output_dim=args.nclass, num_layers=args.nlayers, dropout=args.dropout, bidirectional=args.bi) optimizer = optim.Adam( filter(lambda param: param.requires_grad, model.parameters())) criterion = nn.CrossEntropyLoss() trainer = create_supervised_trainer(model=model, optimizer=optimizer, loss_fn=criterion, device=args.device) train_evaluator = create_supervised_evaluator(model=model, metrics={ 'accuracy': Accuracy(), 'precision': Precision(), 'recall': Recall(), 'loss': Loss(criterion) }, device=args.device) valid_evaluator = create_supervised_evaluator(model=model, metrics={ 'accuracy': Accuracy(), 'precision': Precision(), 'recall': Recall(), 'loss': Loss(criterion) }, device=args.device) def loss_score(engine): loss = engine.state.output return -loss # 分数越高越好,所以loss取负 def acc_score(engine): accuracy = engine.state.metrics['accuracy'] return accuracy @trainer.on(Events.EPOCH_COMPLETED) def log_training_loss(engine): train_iter_num = engine.state.iteration logger.info("Epoch {} Iteration {}: Loss {:.4f}" "".format(engine.state.epoch, train_iter_num, engine.state.output)) @trainer.on(Events.ITERATION_COMPLETED) def log_validation_results(engine): train_iter_num = engine.state.iteration if train_iter_num > iters_per_epoch and train_iter_num % args.log_interval == 0: valid_evaluator.run(valid_iter) metrics = valid_evaluator.state.metrics logger.info( "Validation Results - Epoch {}, Iter {}: Avg accuracy {}, Precision {}, Recall {}, valid loss {:.4f}" "".format(engine.state.epoch, train_iter_num, metrics['accuracy'], metrics['precision'].tolist(), metrics['recall'].tolist(), metrics['loss'])) # train的每ITERATION检查loss是否是 "Nan" # 是的话终止训练 terminateonnan = TerminateOnNan(output_transform=lambda output: output) trainer.add_event_handler(Events.ITERATION_COMPLETED, terminateonnan) checkpoint_handler = ModelCheckpoint( dirname='./data/saved_models/', filename_prefix='checkpoint', score_function=acc_score, score_name="acc", save_interval=None, # 按次数周期保存 n_saved=3, require_empty=False, # 强制覆盖 create_dir=True, save_as_state_dict=False) # 因为valid的epoch往往是1,所以Events.EPOCH_COMPLETED和Events.COMPLETED是一样的 valid_evaluator.add_event_handler(Events.COMPLETED, checkpoint_handler, {'model': model}) patience = int(args.early_stop * (iters_per_epoch / args.log_interval)) earlystop_handler = EarlyStopping(patience=patience, score_function=acc_score, trainer=trainer) earlystop_handler._logger = logger valid_evaluator.add_event_handler(Events.COMPLETED, earlystop_handler) trainer_bar = ProgressBar() trainer_bar.attach(trainer, output_transform=lambda x: {'loss': x}) trainer.run(data=train_iter, max_epochs=args.epochs) logger.info("Best model: Epoch {}, Train iters {}, Valid iters {}, acc {}" "".format(earlystop_handler.best_state['epoch'], \ earlystop_handler.best_state['iters'], \ earlystop_handler.best_state['valid_iters'], earlystop_handler.best_score)) logger.info("Valid results in best model: {}".format( earlystop_handler.best_state['metrics'])) logger.info("Best models info: {}".format(str( checkpoint_handler._saved))) # [(0.65,['model_6_acc=0.65.pth']),...] best_models_info = { 'model_args': str(args.__dict__), 'checkpint_saved': checkpoint_handler._saved, 'train_epoch': earlystop_handler.best_state['epoch'], 'train_iters': earlystop_handler.best_state['iters'], 'valid_iters': earlystop_handler.best_state['valid_iters'], 'best_model_path': checkpoint_handler._saved[-1][1] [0], #checkpoint_handler._saved按sore升序排列的 'best_score': checkpoint_handler._saved[-1][0], 'score_function': checkpoint_handler._score_function.__name__, 'valid_results': { 'accuracy': earlystop_handler.best_state['metrics']['accuracy'], 'precision': earlystop_handler.best_state['metrics']['precision'].tolist(), 'recall': earlystop_handler.best_state['metrics']['recall'].tolist(), 'loss': earlystop_handler.best_state['metrics']['loss'] } } print(checkpoint_handler._saved) pprint(str(best_models_info)) # exit() # with open('./data/pkl/best_models_path.pkl', 'wb') as f: # pickle.dump(list(map(lambda model_info: model_info[1][0], checkpoint_handler._saved)), f) with open('./data/saved_models/best_models_info.json', 'w') as f: f.write(repr(best_models_info)) def test(test_iter, args): # with open('./data/pkl/best_models_path.pkl', 'rb') as f: # best_models = pickle.load(f) with open('./data/saved_models/best_models_info.json', 'r') as f: best_models_info = eval(f.read()) print("best models info: {}".format(best_models_info)) logger.info("best model path: {}".format( best_models_info['best_model_path'])) model = torch.load(best_models_info['best_model_path'], map_location=args.device) test_evaluator = create_supervised_evaluator(model=model, metrics={ 'accuracy': Accuracy(), 'precision': Precision(), 'recall': Recall(), 'loss': Loss(criterion) }, device=args.device) @test_evaluator.on(Events.COMPLETED) def log_test_results(engine): metrics = engine.state.metrics logger.info("Test Results: Avg accuracy: {}, Precision: {}, Recall: {}, Loss: {}" "".format( metrics['accuracy'], \ metrics['precision'].tolist(), metrics['recall'].tolist(), metrics['loss'] ) ) test_evaluator.run(test_iter) test(valid_iter, args)
def run(self, train_loader, val_loader, test_loader): """Perform model training and evaluation on holdout dataset.""" ## attach certain metrics to trainer ## CpuInfo().attach(self.trainer, "cpu_util") Loss(self.loss).attach(self.trainer, "loss") ###### configure evaluator settings ###### def get_output_transform(target: str, collapse_y: bool = False): return lambda out: metric_output_transform( out, self.loss, target, collapse_y=collapse_y) graph_num_classes = len(self.graph_classes) node_num_classes = len(self.node_classes) node_num_classes = 2 if node_num_classes == 1 else node_num_classes node_output_transform = get_output_transform("node") node_output_transform_collapsed = get_output_transform("node", collapse_y=True) graph_output_transform = get_output_transform("graph") graph_output_transform_collapsed = get_output_transform( "graph", collapse_y=True) # metrics we are interested in base_metrics: dict = { 'loss': Loss(self.loss), "cpu_util": CpuInfo(), 'node_accuracy_avg': Accuracy(output_transform=node_output_transform, is_multilabel=False), 'node_accuracy': LabelwiseAccuracy(output_transform=node_output_transform, is_multilabel=False), "node_recall": Recall(output_transform=node_output_transform_collapsed, is_multilabel=False, average=False), "node_precision": Precision(output_transform=node_output_transform_collapsed, is_multilabel=False, average=False), "node_f1_score": Fbeta(1, output_transform=node_output_transform_collapsed, average=False), "node_c_matrix": ConfusionMatrix(node_num_classes, output_transform=node_output_transform_collapsed, average=None) } metrics = dict(**base_metrics) # settings for the evaluator evaluator_settings = { "device": self.device, "loss_fn": self.loss, "node_classes": self.node_classes, "graph_classes": self.graph_classes, "non_blocking": True, "metrics": OrderedDict(sorted(metrics.items(), key=lambda m: m[0])), "pred_collector_function": self._pred_collector_function } ## configure evaluators ## val_evaluator = None if len(val_loader): val_evaluator = create_supervised_evaluator( self.model, **evaluator_settings) # configure behavior for early stopping if self.stopper: val_evaluator.add_event_handler(Events.COMPLETED, self.stopper) # configure behavior for checkpoint saving val_evaluator.add_event_handler(Events.COMPLETED, self.best_checkpoint_handler) val_evaluator.add_event_handler(Events.COMPLETED, self.latest_checkpoint_handler) else: self.trainer.add_event_handler(Events.COMPLETED, self.latest_checkpoint_handler) test_evaluator = None if len(test_loader): test_evaluator = create_supervised_evaluator( self.model, **evaluator_settings) ############################# @self.trainer.on(Events.STARTED) def log_training_start(trainer): self.custom_print("Start training...") @self.trainer.on(Events.EPOCH_COMPLETED) def compute_metrics(trainer): """Compute evaluation metric values after each epoch.""" epoch = trainer.state.epoch self.custom_print(f"Finished epoch {epoch:03d}!") if len(val_loader): self.persist_collection = True val_evaluator.run(val_loader) self._save_collected_predictions( prefix=f"validation_epoch{epoch:03}") # write metrics to file self.write_metrics(trainer, val_evaluator, suffix="validation") @self.trainer.on(Events.COMPLETED) def log_training_complete(trainer): """Trigger evaluation on test set if training is completed.""" epoch = trainer.state.epoch suffix = "(Early Stopping)" if epoch < self.epochs else "" self.custom_print("Finished after {:03d} epochs! {}".format( epoch, suffix)) # load best model for evaluation self.custom_print("Load best model for final evaluation...") last_checkpoint: str = self.best_checkpoint_handler.last_checkpoint or self.latest_checkpoint_handler.last_checkpoint best_checkpoint_path = os.path.join(self.save_path, last_checkpoint) checkpoint_path_dict: dict = { "latest_checkpoint_path": best_checkpoint_path # we want to load states from the best checkpoint as "latest" configuration for testing } self.model, self.optimizer, self.trainer, _, _, _ = self._load_checkpoint( self.model, self.optimizer, self.trainer, None, None, None, checkpoint_path_dict=checkpoint_path_dict) if len(test_loader): self.persist_collection = True test_evaluator.run(test_loader) self._save_collected_predictions(prefix="test_final") # write metrics to file self.write_metrics(trainer, test_evaluator, suffix="test") # terminate training if Nan values are produced self.trainer.add_event_handler(Events.ITERATION_COMPLETED, TerminateOnNan()) # start the actual training self.custom_print(f"Train for a maximum of {self.epochs} epochs...") self.trainer.run(train_loader, max_epochs=self.epochs)
def __init__( self, module, device, train_loss, train_loader, opt, lr_scheduler, max_epochs, max_grad_norm, test_metrics, test_loader, epochs_per_test, early_stopping, valid_loss, valid_loader, max_bad_valid_epochs, visualizer, writer, should_checkpoint_latest, should_checkpoint_best_valid ): self._module = module self._module.to(device) self._device = device self._train_loss = train_loss self._train_loader = train_loader self._opt = opt self._lr_scheduler = lr_scheduler self._max_epochs = max_epochs self._max_grad_norm = max_grad_norm self._test_metrics = test_metrics self._test_loader = test_loader self._epochs_per_test = epochs_per_test self._valid_loss = valid_loss self._valid_loader = valid_loader self._max_bad_valid_epochs = max_bad_valid_epochs self._best_valid_loss = float("inf") self._num_bad_valid_epochs = 0 self._visualizer = visualizer self._writer = writer self._should_checkpoint_best_valid = should_checkpoint_best_valid ### Training self._trainer = Engine(self._train_batch) AverageMetric().attach(self._trainer) ProgressBar(persist=True).attach(self._trainer, ["loss"]) self._trainer.add_event_handler(Events.EPOCH_STARTED, lambda _: self._module.train()) self._trainer.add_event_handler(Events.ITERATION_COMPLETED, TerminateOnNan()) self._trainer.add_event_handler(Events.ITERATION_COMPLETED, self._log_training_info) if should_checkpoint_latest: self._trainer.add_event_handler(Events.EPOCH_COMPLETED, lambda _: self._save_checkpoint("latest")) ### Validation if early_stopping: self._validator = Engine(self._validate_batch) AverageMetric().attach(self._validator) ProgressBar(persist=False, desc="Validating").attach(self._validator) self._trainer.add_event_handler(Events.EPOCH_COMPLETED, self._validate) self._validator.add_event_handler(Events.EPOCH_STARTED, lambda _: self._module.eval()) ### Testing self._tester = Engine(self._test_batch) AverageMetric().attach(self._tester) ProgressBar(persist=False, desc="Testing").attach(self._tester) self._trainer.add_event_handler(Events.EPOCH_COMPLETED, self._test) self._tester.add_event_handler(Events.EPOCH_STARTED, lambda _: self._module.eval())
def run(conf: DictConfig, local_rank=0, distributed=False): epochs = conf.train.epochs epoch_length = conf.train.epoch_length torch.manual_seed(conf.seed) if distributed: rank = dist.get_rank() num_replicas = dist.get_world_size() torch.cuda.set_device(local_rank) else: rank = 0 num_replicas = 1 torch.cuda.set_device(conf.gpu) device = torch.device('cuda') loader_args = dict(mean=conf.data.mean, std=conf.data.std) master_node = rank == 0 if master_node: print(conf.pretty()) if num_replicas > 1: epoch_length = epoch_length // num_replicas loader_args["rank"] = rank loader_args["num_replicas"] = num_replicas train_dl = create_train_loader(conf.data.train, **loader_args) valid_dl = create_val_loader(conf.data.val, **loader_args) if epoch_length < 1: epoch_length = len(train_dl) model = instantiate(conf.model).to(device) model_ema, update_ema = setup_ema(conf, model, device=device, master_node=master_node) optim = build_optimizer(conf.optim, model) scheduler_kwargs = dict() if "schedule.OneCyclePolicy" in conf.lr_scheduler["class"]: scheduler_kwargs["cycle_steps"] = epoch_length lr_scheduler: Scheduler = instantiate(conf.lr_scheduler, optim, **scheduler_kwargs) use_amp = False if conf.use_apex: import apex from apex import amp logging.debug("Nvidia's Apex package is available") model, optim = amp.initialize(model, optim, **conf.amp) use_amp = True if master_node: logging.info("Using AMP with opt_level={}".format( conf.amp.opt_level)) else: apex, amp = None, None to_save = dict(model=model, optim=optim) if use_amp: to_save["amp"] = amp if model_ema is not None: to_save["model_ema"] = model_ema if master_node and conf.logging.model: logging.info(model) if distributed: sync_bn = conf.distributed.sync_bn if apex is not None: if sync_bn: model = apex.parallel.convert_syncbn_model(model) model = apex.parallel.distributed.DistributedDataParallel( model, delay_allreduce=True) else: if sync_bn: model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model) model = torch.nn.parallel.DistributedDataParallel( model, device_ids=[ local_rank, ], output_device=local_rank) upd_interval = conf.optim.step_interval ema_interval = conf.smoothing.interval_it * upd_interval clip_grad = conf.optim.clip_grad _handle_batch_train = build_process_batch_func(conf.data, stage="train", device=device) _handle_batch_val = build_process_batch_func(conf.data, stage="val", device=device) def _update(eng: Engine, batch: Batch) -> FloatDict: model.train() batch = _handle_batch_train(batch) losses: Dict = model(*batch) stats = {k: v.item() for k, v in losses.items()} loss = losses["loss"] del losses if use_amp: with amp.scale_loss(loss, optim) as scaled_loss: scaled_loss.backward() else: loss.backward() it = eng.state.iteration if not it % upd_interval: if clip_grad > 0: params = amp.master_params( optim) if use_amp else model.parameters() torch.nn.utils.clip_grad_norm_(params, clip_grad) optim.step() optim.zero_grad() lr_scheduler.step_update(it) if not it % ema_interval: update_ema() eng.state.lr = optim.param_groups[0]["lr"] return stats calc_map = conf.validate.calc_map min_score = conf.validate.get("min_score", -1) model_val = model if conf.train.skip and model_ema is not None: model_val = model_ema.to(device) def _validate(eng: Engine, batch: Batch) -> FloatDict: model_val.eval() images, targets = _handle_batch_val(batch) with torch.no_grad(): out: Dict = model_val(images, targets) pred_boxes = out.pop("detections") stats = {k: v.item() for k, v in out.items()} if calc_map: pred_boxes = pred_boxes.detach().cpu().numpy() true_boxes = targets['bbox'].cpu().numpy() img_scale = targets['img_scale'].cpu().numpy() # yxyx -> xyxy true_boxes = true_boxes[:, :, [1, 0, 3, 2]] # xyxy -> xywh true_boxes[:, :, [2, 3]] -= true_boxes[:, :, [0, 1]] # scale downsized boxes to match predictions on a full-sized image true_boxes *= img_scale[:, None, None] scores = [] for i in range(len(images)): mask = pred_boxes[i, :, 4] >= min_score s = calculate_image_precision(true_boxes[i], pred_boxes[i, mask, :4], thresholds=IOU_THRESHOLDS, form='coco') scores.append(s) stats['map'] = np.mean(scores) return stats train_metric_names = list(conf.logging.out.train) train_metrics = create_metrics(train_metric_names, device if distributed else None) val_metric_names = list(conf.logging.out.val) if calc_map: from utils.metric import calculate_image_precision, IOU_THRESHOLDS val_metric_names.append('map') val_metrics = create_metrics(val_metric_names, device if distributed else None) trainer = build_engine(_update, train_metrics) evaluator = build_engine(_validate, val_metrics) to_save['trainer'] = trainer every_iteration = Events.ITERATION_COMPLETED trainer.add_event_handler(every_iteration, TerminateOnNan()) if distributed: dist_bn = conf.distributed.dist_bn if dist_bn in ["reduce", "broadcast"]: from timm.utils import distribute_bn @trainer.on(Events.EPOCH_COMPLETED) def _distribute_bn_stats(eng: Engine): reduce = dist_bn == "reduce" if master_node: logging.info("Distributing BN stats...") distribute_bn(model, num_replicas, reduce) sampler = train_dl.sampler if isinstance(sampler, (CustomSampler, DistributedSampler)): @trainer.on(Events.EPOCH_STARTED) def _set_epoch(eng: Engine): sampler.set_epoch(eng.state.epoch - 1) @trainer.on(Events.EPOCH_COMPLETED) def _scheduler_step(eng: Engine): # it starts from 1, so we don't need to add 1 here ep = eng.state.epoch lr_scheduler.step(ep) cp = conf.checkpoints pbar, pbar_vis = None, None if master_node: log_interval = conf.logging.interval_it log_event = Events.ITERATION_COMPLETED(every=log_interval) pbar = ProgressBar(persist=False) pbar.attach(trainer, metric_names=train_metric_names) pbar.attach(evaluator, metric_names=val_metric_names) for engine, name in zip([trainer, evaluator], ['train', 'val']): engine.add_event_handler(Events.EPOCH_STARTED, on_epoch_start) engine.add_event_handler(log_event, log_iter, pbar, interval_it=log_interval, name=name) engine.add_event_handler(Events.EPOCH_COMPLETED, log_epoch, name=name) setup_checkpoints(trainer, to_save, epoch_length, conf) if 'load' in cp.keys() and cp.load is not None: if master_node: logging.info("Resume from a checkpoint: {}".format(cp.load)) trainer.add_event_handler(Events.STARTED, _upd_pbar_iter_from_cp, pbar) resume_from_checkpoint(to_save, cp, device=device) state = trainer.state # epoch counter start from 1 lr_scheduler.step(state.epoch - 1) state.max_epochs = epochs @trainer.on(Events.EPOCH_COMPLETED(every=conf.validate.interval_ep)) def _run_validation(eng: Engine): if distributed: torch.cuda.synchronize(device) evaluator.run(valid_dl) skip_train = conf.train.skip if master_node and conf.visualize.enabled: vis_eng = evaluator if skip_train else trainer setup_visualizations(vis_eng, model, valid_dl, device, conf, force_run=skip_train) try: if skip_train: evaluator.run(valid_dl) else: trainer.run(train_dl, max_epochs=epochs, epoch_length=epoch_length) except Exception as e: import traceback logging.error(traceback.format_exc()) for pb in [pbar, pbar_vis]: if pb is not None: pbar.close()
def run(args, random_seed=0): # Set random seed np.random.seed(random_seed) torch.manual_seed(random_seed) torch.backends.cudnn.deterministic = True torch.backends.cudnn.benchmark = False writer = SummaryWriter(os.path.join(args['log_dir'], args['name'])) print('Loading model....') training_history = {'CrossEntropy': [], 'Accuracy': []} testing_history = {'CrossEntropy': [], 'Accuracy': []} model_dict = args['config']['model'] model = FastWeights(**model_dict) device = torch.device(args['config']['device']) model = model.to(device) print('Loading data....') train_loader, test_loader = load_data(args['config']['batch_size'], args['config']['workers']) params = [p for p in model.parameters() if p.requires_grad] # optimizer = torch.optim.SGD( # params, lr=args['config']['lr'], # momentum=args['config']['momentum'], # weight_decay=args['config']['weight_decay'] # ) optimizer = torch.optim.Adam(params, lr=args['config']['lr']) if args['config']['scheduler'] == 'multi': lr_scheduler = torch.optim.lr_scheduler.MultiStepLR( optimizer, milestones=args['config']['lr_steps'], gamma=args['config']['lr_gamma']) elif args['config']['scheduler'] == 'step': lr_scheduler = torch.optim.lr_scheduler.StepLR( optimizer, milestones=args['config']['lr_step_size'], gamma=args['config']['lr_gamma']) elif args['config']['scheduler'] == 'reduce': lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau( optimizer, mode=args['config']['reduce_type'], factor=args['config']['lr_gamma'], ) elif args['config']['scheduler'] == 'cyclic': lr_scheduler = torch.optim.lr_scheduler.CyclicLR( optimizer, base_lr=args['config']['lr'], max_lr=10 * args['config']['lr']) else: lr_scheduler = torch.optim.lr_scheduler.LambdaLR( optimizer, lr_lambda=lambda epoch: 1) criterion = nn.CrossEntropyLoss() def evaluate_function(engine, batch): model.eval() with torch.no_grad(): inputs, targets = batch if device: inputs = inputs.to(device) targets = targets.to(device) inputs = torch.transpose(inputs, 0, 1) preds = model(inputs) return preds, targets def process_function(engine, batch): model.train() optimizer.zero_grad() inputs, targets = batch if device: inputs = inputs.to(device) targets = targets.to(device) inputs = torch.transpose(inputs, 0, 1) preds = model(inputs) loss = criterion(preds, targets) loss.backward() if args['config']['max_norm'] > 0: nn.utils.clip_grad_norm_(model.parameters(), max_norm=args['config']['max_norm']) optimizer.step() if args['config']['scheduler'] == 'cyclic': lr_scheduler.step() else: pass return loss.item() trainer = Engine(process_function) evaluator = Engine(evaluate_function) train_evaluator = Engine(evaluate_function) RunningAverage(output_transform=lambda x: x).attach(trainer, 'loss') Loss(criterion, output_transform=lambda x: [x[0], x[1]]).attach( evaluator, 'CrossEntropy') Accuracy().attach(evaluator, 'Accuracy') Loss(criterion, output_transform=lambda x: [x[0], x[1]]).attach( train_evaluator, 'CrossEntropy') Accuracy().attach(train_evaluator, 'Accuracy') pbar = ProgressBar(persist=True, bar_format="") pbar.attach(trainer, ['loss']) @trainer.on(Events.STARTED) def resume_training(engine): if args['config']['resume'] > 0: checkpoint = torch.load(os.path.join( args['dir'], args['config']['output_dir'], f"{args['name']}_{args['config']['resume']}.pth"), map_location=device) model.load_state_dict(checkpoint['model']) optimizer.load_state_dict(checkpoint['optimizer']) lr_scheduler.load_state_dict(checkpoint['lr_scheduler']) engine.state.epoch = args['config']['resume'] def val_score(engine): evaluator.run(test_loader) metrics = evaluator.state.metrics avg_loss = metrics['CrossEntropy'] return -avg_loss def checkpointer(engine): checkpoint = { 'model': model.state_dict(), 'optimizer': optimizer.state_dict(), 'lr_scheduler': lr_scheduler.state_dict(), 'epoch': engine.state.epoch, 'args': args } save_on_master( checkpoint, os.path.join(args['dir'], args['config']['output_dir'], f"{args['name']}_{engine.state.epoch}.pth")) trainer.add_event_handler(Events.EPOCH_COMPLETED, checkpointer) trainer.add_event_handler(Events.ITERATION_COMPLETED, TerminateOnNan()) def score_function(engine): metrics = evaluator.state.metrics avg_loss = metrics['CrossEntropy'] return -avg_loss def print_trainer_logs(engine): train_evaluator.run(train_loader) metrics = train_evaluator.state.metrics avg_loss = metrics['CrossEntropy'] avg_acc = metrics['Accuracy'] * 100 training_history['CrossEntropy'].append(avg_loss) training_history['Accuracy'].append(avg_acc) writer.add_scalar("training/avg_loss", avg_loss, engine.state.epoch) writer.add_scalar("training/avg_accuracy", avg_acc, engine.state.epoch) print("Training Results - Epoch: {} ".format(engine.state.epoch), "Avg loss: {:.4f} ".format(avg_loss), "Avg Acc: {:.4f} ".format(avg_acc)) trainer.add_event_handler(Events.EPOCH_COMPLETED, print_trainer_logs) def log_validation_results(engine): evaluator.run(test_loader) metrics = evaluator.state.metrics avg_loss = metrics['CrossEntropy'] avg_acc = metrics['Accuracy'] * 100 if args['config']['scheduler'] == 'reduce': lr_scheduler.step(-avg_loss) elif args['config']['scheduler'] == 'cyclic': pass else: lr_scheduler.step() testing_history['CrossEntropy'].append(avg_loss) testing_history['Accuracy'].append(avg_acc) writer.add_scalar("validation/avg_loss", avg_loss, engine.state.epoch) writer.add_scalar("validation/avg_accuracy", avg_acc, engine.state.epoch) print("Validation Results - Epoch: {} ".format(engine.state.epoch), "Avg loss: {:.4f} ".format(avg_loss), "Avg Acc: {:.4f} ".format(avg_acc)) trainer.add_event_handler(Events.EPOCH_COMPLETED, log_validation_results) handler = EarlyStopping(patience=args['config']['patience'], score_function=score_function, trainer=trainer) evaluator.add_event_handler(Events.COMPLETED, handler) print('Training....') trainer.run(train_loader, max_epochs=args['config']['epochs']) writer.close() np.save(os.path.join(args['dir'], f"{args['name']}_traininglog.npy"), [training_history]) np.save(os.path.join(args['dir'], f"{args['name']}_testinglog.npy"), [testing_history])
def train(train_data_path, valid_data_path, config, out_dir="./explainability", batch_size=64, lr=1e-4, epochs=100): train_dataset = dataset_classes[config['type']](train_data_path, config) valid_dataset = dataset_classes[config['type']](valid_data_path, config) train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=8) val_loader = torch.utils.data.DataLoader(valid_dataset, batch_size=batch_size, shuffle=True, num_workers=8) model = Model(valid_dataset, config) path_cp = "./explainability_checkpoints/" + out_dir os.makedirs(path_cp, exist_ok=True) with open(path_cp + "/config.json", 'w') as config_file: json.dump(config, config_file) optimizer = Adam(model.parameters(), lr=lr) trainer = create_supervised_trainer(model, optimizer, weighted_binary_cross_entropy, device=model.device) validation_evaluator = create_evaluator(model) RunningAverage(output_transform=lambda x: x).attach(trainer, "loss") pbar = ProgressBar(persist=True) pbar.attach(trainer, metric_names="all") trainer.add_event_handler(Events.ITERATION_COMPLETED, TerminateOnNan()) best_model_handler = ModelCheckpoint( dirname="./explainability_checkpoints/" + out_dir, filename_prefix="best", n_saved=1, global_step_transform=global_step_from_engine(trainer), score_name="val_ap", score_function=lambda engine: engine.state.metrics['ap'], require_empty=False) validation_evaluator.add_event_handler(Events.COMPLETED, best_model_handler, { 'model': model, }) tb_logger = TensorboardLogger(log_dir='./explainability_tensorboard/' + out_dir) tb_logger.attach( trainer, log_handler=OutputHandler( tag="training", output_transform=lambda loss: {"batchloss": loss}, metric_names="all"), event_name=Events.ITERATION_COMPLETED(every=100), ) tb_logger.attach( validation_evaluator, log_handler=OutputHandler(tag="validation", metric_names=["ap"], another_engine=trainer), event_name=Events.EPOCH_COMPLETED, ) #tb_logger.attach(trainer, log_handler=OptimizerParamsHandler(optimizer), event_name=Events.ITERATION_COMPLETED(every=100)) #tb_logger.attach(trainer, log_handler=WeightsScalarHandler(model), event_name=Events.ITERATION_COMPLETED(every=100)) #tb_logger.attach(trainer, log_handler=WeightsHistHandler(model), event_name=Events.EPOCH_COMPLETED(every=100)) #tb_logger.attach(trainer, log_handler=GradsScalarHandler(model), event_name=Events.ITERATION_COMPLETED(every=100)) #tb_logger.attach(trainer, log_handler=GradsHistHandler(model), event_name=Events.EPOCH_COMPLETED(every=100)) @trainer.on(Events.EPOCH_COMPLETED(every=5)) def log_validation_results(engine): validation_evaluator.run(val_loader) metrics = validation_evaluator.state.metrics pbar.log_message( f"Validation Results - Epoch: {engine.state.epoch} ap: {metrics['ap']}" # f1: {metrics['f1']}, p: {metrics['p']}, r: {metrics['r']} ) pbar.n = pbar.last_print_n = 0 trainer.run(train_loader, max_epochs=epochs)
def do_train(cfg,model,train_loader,val_loader,optimizer,scheduler,loss_fn,metrics,image_3_dataloader=None,image_4_dataloader=None): device = cfg.MODEL.DEVICE if torch.cuda.is_available() else 'cpu' epochs = cfg.SOLVER.MAX_EPOCHS logging.basicConfig(level=logging.INFO) logger = logging.getLogger("Trainer") logger.info("Start training") trainer = create_supervised_trainer(model.train(),optimizer,loss_fn,device=device) trainer.add_event_handler(Events.ITERATION_COMPLETED,TerminateOnNan()) evaluator = create_supervised_evaluator(model.eval(),metrics={"pixel_error":metrics},device=device) # evaluator_trainer = create_supervised_evaluator(model.eval(),metrics={"pixel_error":(metrics)},device=device) timer = Timer(average=True) timer.attach(trainer,start=Events.EPOCH_STARTED,resume=Events.ITERATION_STARTED,pause=Events.ITERATION_COMPLETED,step=Events.ITERATION_COMPLETED) RunningAverage(output_transform=lambda x:x).attach(trainer,'avg_loss') # 每 log_period 轮迭代结束输出train_loss @trainer.on(Events.ITERATION_COMPLETED) def log_training_loss(engine): len_train_loader = len(train_loader) log_period = int(cfg.LOG_PERIOD*len_train_loader) iter = (engine.state.iteration-1)%len_train_loader + 1 + engine.state.epoch*len_train_loader if iter % log_period == 0: iter = (engine.state.iteration-1)%len_train_loader + 1 logger.info("Epoch[{}] Iteration[{}/{}] Loss {:.7f}".format(engine.state.epoch,iter,len_train_loader,engine.state.metrics['avg_loss'])) @trainer.on(Events.EPOCH_COMPLETED) def save(engine): epoch = engine.state.epoch print("epoch: "+str(epoch)) if epoch%1 == 0: model_name=os.path.join(cfg.OUTPUT.DIR_NAME+"model/","epoch_"+str(engine.state.epoch)+"_"+cfg.TAG+"_"+cfg.MODEL.NET_NAME+".pth") torch.save(model.module.state_dict(),model_name) # 每val_period轮迭代结束计算一次val_metric @trainer.on(Events.ITERATION_COMPLETED) def log_val_metric(engine): len_train_loader = len(train_loader) iter = (engine.state.iteration-1)%len_train_loader + 1 + engine.state.epoch*len_train_loader val_period = int(cfg.VAL_PERIOD*len_train_loader) if iter % val_period == 0: pass # 打印输出 # evaluator.run(val_loader) # metrics = evaluator.state.metrics # avg_loss = metrics["pixel_error"] # logger.info("Validation Result - Epoch: {} Avg Pixel Accuracy: {:.7f} ".format(engine.state.epoch,avg_loss)) ###################### # # 分别用ttaforward # cfg.TOOLS.image_n = 3 # image_3_predict = tta_forward(cfg,image_3_dataloader,model.eval()) # pil_image_3 = Image.fromarray(image_3_predict) # image_3_save_path = "iter_" + str(iter) + "_" + "image_3_predict.png" # pil_image_3.save(os.path.join(r"./output",image_3_save_path)) # image_3_label_save_path = "iter_" + str(iter) + "_" + "vis_" + "image_3_predict.jpg" # source_image_3 = cv.imread("./output/source/image_3.png") # mask_3 = label_resize_vis(image_3_predict,source_image_3) # cv.imwrite(os.path.join(r"./output",image_3_label_save_path),mask_3) # cfg.TOOLS.image_n = 4 # image_4_predict = tta_forward(cfg,image_4_dataloader,model.eval()) # pil_image_4 = Image.fromarray(image_4_predict) # image_4_save_path = "iter_" + str(iter) + "_" + "image_4_predict.png" # pil_image_4.save(os.path.join(r"./output",image_4_save_path)) # image_4_label_save_path = "iter_" + str(iter) + "_" + "vis_" + "image_4_predict.jpg" # source_image_4 = cv.imread("./output/source/image_4.png") # mask_4 = label_resize_vis(image_4_predict,source_image_4) # cv.imwrite(os.path.join(r"./output",image_4_label_save_path),mask_4) # 设置Loss检测,当检测到pixel_accuracy停止下降时,调整loss if cfg.SOLVER.LR_SCHEDULER == "StepLR": lr = optimizer.state_dict()['param_groups'][0]['lr'] scheduler.step() new_lr = optimizer.state_dict()['param_groups'][0]['lr'] # elif cfg.SOLVER.LR_SCHEDULER == "ReduceLROnPlateau": # lr = optimizer.state_dict()['param_groups'][0]['lr'] # scheduler.step(-avg_loss) # new_lr = optimizer.state_dict()['param_groups'][0]['lr'] # print(new_lr,lr) # if new_lr != lr: # cfg.SOLVER.LR_SCHEDULER_REPEAT = cfg.SOLVER.LR_SCHEDULER_REPEAT - 1 # if cfg.SOLVER.LR_SCHEDULER_REPEAT <0: trainer.terminate() #设定学习率调整次数,降低太多次学习率太低时,终止训练 elif cfg.SOLVER.LR_SCHEDULER == "CosineAnnealingLR": lr = optimizer.state_dict()['param_groups'][0]['lr'] scheduler.step() new_lr = optimizer.state_dict()['param_groups'][0]['lr'] pass if new_lr!=lr: print(new_lr,lr) # @trainer.on(Events.EPOCH_COMPLETED) # def log_training_result(engine): # if engine.state.epoch % 5 == 0: # evaluator_trainer.run(train_loader) # metrics = evaluator_trainer.state.metrics # avg_loss = metrics["pixel_error"] # logger.info("Training Result - Epoch: {} Avg Pixel Error: {:.7f} ".format(engine.state.epoch,avg_loss)) @trainer.on(Events.EPOCH_COMPLETED) def print_times(engine): logger.info("Epoch {} done.Time per batch: {:.3f}[s] Speed: {:.1f}[samples/s]".format(engine.state.epoch,timer.value()*timer.step_count, train_loader.batch_size / timer.value())) timer.reset() # def score_pixel_error(engine): # error = evaluator.state.metrics['pixel_error'] # return error # handler_ModelCheckpoint_pixel_error = ModelCheckpoint(dirname=cfg.OUTPUT.DIR_NAME+"model/",filename_prefix=cfg.TAG+"_"+cfg.MODEL.NET_NAME, # score_function=score_pixel_error,n_saved=cfg.OUTPUT.N_SAVED,create_dir=True,score_name=cfg.SOLVER.CRITERION,require_empty=False) # evaluator.add_event_handler(Events.EPOCH_COMPLETED,handler_ModelCheckpoint_pixel_error,{'model':model.module.state_dict()}) trainer.run(train_loader,max_epochs=epochs)
def do_train(cfg, model, train_loader, val_loader, optimizer, scheduler, metrics, device): def _prepare_batch(batch, device=None, non_blocking=False): """Prepare batch for training: pass to a device with options. """ x, y = batch return (convert_tensor(x, device=device, non_blocking=non_blocking), convert_tensor(y, device=device, non_blocking=non_blocking)) def create_supervised_dp_trainer( model, optimizer, device=None, non_blocking=False, prepare_batch=_prepare_batch, output_transform=lambda x, y, y_pred, loss: loss.item()): """ Factory function for creating a trainer for supervised models. Args: model (`torch.nn.Module`): the model to train. optimizer (`torch.optim.Optimizer`): the optimizer to use. loss_fn (torch.nn loss function): the loss function to use. device (str, optional): device type specification (default: None). Applies to both model and batches. non_blocking (bool, optional): if True and this copy is between CPU and GPU, the copy may occur asynchronously with respect to the host. For other cases, this argument has no effect. prepare_batch (callable, optional): function that receives `batch`, `device`, `non_blocking` and outputs tuple of tensors `(batch_x, batch_y)`. output_transform (callable, optional): function that receives 'x', 'y', 'y_pred', 'loss' and returns value to be assigned to engine's state.output after each iteration. Default is returning `loss.item()`. Note: `engine.state.output` for this engine is defind by `output_transform` parameter and is the loss of the processed batch by default. Returns: Engine: a trainer engine with supervised update function. """ if device: model.to(device) def _update(engine, batch): # model.train() optimizer.zero_grad() x, y = prepare_batch(batch, device=device, non_blocking=non_blocking) with autocast(): total_loss = model(x, y) total_loss = total_loss.mean() # model 里求均值 # Scales loss. 为了梯度放大. scaler.scale(total_loss).backward() scaler.step(optimizer) writer.add_scalar("total loss", total_loss.cpu().data.numpy()) scaler.update() # total_loss.backward() # optimizer.step() return output_transform(x, y, None, total_loss) return Engine(_update) scaler = torch.cuda.amp.GradScaler() master_device = device[0] #默认设置第一块为主卡 trainer = create_supervised_dp_trainer(model, optimizer, device=master_device) trainer.add_event_handler(Events.ITERATION_COMPLETED, TerminateOnNan()) RunningAverage(output_transform=lambda x: x).attach(trainer, 'avg_loss') log_dir = cfg['log_dir'] writer = SummaryWriter(log_dir=log_dir) # create pbar len_train_loader = len(train_loader) pbar = tqdm(total=len_train_loader) froze_num_layers = cfg['warm_up']['froze_num_lyers'] if cfg['multi_gpu']: freeze_layers(model.module, froze_num_layers) else: freeze_layers(model, froze_num_layers) # Finetuning 模式下,patch较大,batch较小冻结全模型bn # Normal 模式下, 冻结对应网络层数 if 'mode' in cfg and cfg['mode'] == "Finetuning": if cfg['multi_gpu']: fix_bn(model.module) else: fix_bn(model) ########################################################################################## ########### Events.ITERATION_COMPLETED ############# ########################################################################################## # 每 log_period 轮迭代结束输出train_loss @trainer.on(Events.ITERATION_COMPLETED) def log_training_loss(engine): log_period = cfg['log_period'] log_per_iter = int(log_period * len_train_loader) if int( log_period * len_train_loader) >= 1 else 1 # 计算打印周期 current_iter = (engine.state.iteration - 1) % len_train_loader + 1 + ( engine.state.epoch - 1) * len_train_loader # 计算当前 iter lr = optimizer.state_dict()['param_groups'][0]['lr'] if current_iter % log_per_iter == 0: pbar.write("Epoch[{}] Iteration[{}] lr {:.7f} Loss {:.7f}".format( engine.state.epoch, current_iter, lr, engine.state.metrics['avg_loss'])) pbar.update(log_per_iter) writer.add_scalar('loss', engine.state.metrics['avg_loss'], current_iter) # lr_scheduler Warm Up @trainer.on(Events.ITERATION_COMPLETED) def lr_scheduler_iteration(engine): scheduler.ITERATION_COMPLETED() current_iter = (engine.state.iteration - 1) % len_train_loader + 1 + ( engine.state.epoch - 1) * len_train_loader # 计算当前 iter length = cfg['warm_up']['length'] min_lr = cfg['warm_up']['min_lr'] max_lr = cfg['warm_up']['max_lr'] froze_num_layers = cfg['warm_up']['froze_num_lyers'] if current_iter < length: """Sets the learning rate to the initial LR decayed by 10 every 30 epochs""" lr = (max_lr - min_lr) / length * current_iter for param_group in optimizer.param_groups: param_group['lr'] = lr # pbar.write("lr: {}".format(lr)) if current_iter == length: if 'mode' in cfg and cfg['mode'] == "Finetuning": pass else: # Normal 模式下,Warm Up结束解冻 pass # if cfg['multi_gpu']: # freeze_layers(model.module,froze_num_layers) # else: # freeze_layers(model,froze_num_layers) # for param_group in optimizer.param_groups: # param_group['lr'] = cfg['optimizer']['lr'] @trainer.on(Events.EPOCH_COMPLETED) def lr_scheduler_epoch(engine): scheduler.EPOCH_COMPLETED() ########################################################################################## ################## Events.EPOCH_COMPLETED ############### ########################################################################################## @trainer.on(Events.EPOCH_COMPLETED) def save_temp_epoch(engine): save_dir = cfg['save_dir'] if not os.path.isdir(save_dir): os.makedirs(save_dir) epoch = engine.state.epoch if epoch % 1 == 0: model_name = os.path.join(save_dir, cfg['tag'] + "_temp.pth") # import pdb; pdb.set_trace() if cfg['multi_gpu']: save_pth = {'model': model.module.state_dict(), 'cfg': cfg} torch.save(save_pth, model_name) else: save_pth = {'model': model.state_dict(), 'cfg': cfg} torch.save(save_pth, model_name) if epoch % 10 == 0: model_name = os.path.join(save_dir, cfg['tag'] + "_" + str(epoch) + ".pth") if cfg['multi_gpu']: save_pth = {'model': model.module.state_dict(), 'cfg': cfg} torch.save(save_pth, model_name) else: save_pth = {'model': model.state_dict(), 'cfg': cfg} torch.save(save_pth, model_name) @trainer.on(Events.EPOCH_COMPLETED) def calu_acc(engine): epoch = engine.state.epoch if epoch % 10 == 0: model.eval() num_correct = 0 num_example = 0 torch.cuda.empty_cache() with torch.no_grad(): for image, target in tqdm(train_loader): image, target = image.to(master_device), target.to( master_device) pred_logit_dict = model(image, target) pred_logit = [ value for value in pred_logit_dict.values() if value is not None ] pred_logit = pred_logit[0] indices = torch.max(pred_logit, dim=1)[1] correct = torch.eq(indices, target).view(-1) num_correct += torch.sum(correct).item() num_example += correct.shape[0] acc = (num_correct / num_example) pbar.write("Acc: {}".format(acc)) writer.add_scalar("Acc", acc, epoch) torch.cuda.empty_cache() model.train() # Finetuning 模式下,patch较大,batch较小冻结全模型bn # Normal 模式下, 冻结对应网络层数 if 'mode' in cfg and cfg['mode'] == "Finetuning": if cfg['multi_gpu']: fix_bn(model.module) else: fix_bn(model) @trainer.on(Events.EPOCH_COMPLETED) def reset_pbar(engine): pbar.reset() @trainer.on(Events.EPOCH_COMPLETED) def reset_dataset(engine): # 仅针对jr写的train_dataset,手动shuffle if hasattr(train_loader.dataset, 'shuffle'): pbar.write("shuffle train_dataloader") train_loader.dataset.shuffle() max_epochs = cfg['max_epochs'] trainer.run(train_loader, max_epochs=max_epochs) pbar.close()
def _run_training(self, train_data, valid_data=[], test_data=[], tb_log_dir=None, split_num=None, verbose=True): # setup epochs = self.training_config.get("epochs") optimizer_config = self.training_config.get("optimizer_config") early_stopping = self.training_config.get("early_stopping") checkpoint_saving = self.training_config.get("checkpoint_saving") graph_dataset_config = self.training_config.get("graph_dataset_config") device = self.training_config.get("device") extraction_target = self.training_config.get("extraction_target") graph_dataset_config["device"] = device graph_dataset_config["extraction_target"] = extraction_target if split_num is not None: tb_log_dir += "-Split{}".format(split_num + 1) # prepare graph-sets graph_dataset = GraphDataset(train_data, valid_data=valid_data, test_data=test_data, graph_dataset_config=graph_dataset_config) train_loader, val_loader, test_loader = graph_dataset.get_loaders() worker_init_fn = graph_dataset.init_fn # create model, optimizer, loss model = self.model_class(self.model_config) model = model.to(device) optimizer = self.optimizer_class(model.parameters(), **optimizer_config) # apparently, we have to do this self.model = model self.optimizer = optimizer # load model from checkpoint if available if self.trained_model_checkpoint is not None: self.custom_print("load transfer-learning checkpoint...") model, optimizer = self._prepare_trained_model(model, optimizer) loss = self.loss_class() loss_name = "mse" evaluator_settings = { "device": device, "extraction_target": extraction_target, "pred_collector_function": lambda x: self._pred_collector_function(x), "metrics": { loss_name: Loss(loss) } } ## configure trainer ## trainer = create_supervised_trainer( model, optimizer, loss, device=device, extraction_target=extraction_target) ############################################### ## configure evaluators for each data source ## train_evaluator = create_supervised_evaluator(model, **evaluator_settings) val_evaluator = create_supervised_evaluator(model, **evaluator_settings) test_evaluator = create_supervised_evaluator(model, **evaluator_settings) # configure behavior for early stopping if early_stopping is not None: stopper = EarlyStopping(patience=early_stopping, score_function=self.score_function, trainer=trainer) val_evaluator.add_event_handler(Events.COMPLETED, stopper) # configure behavior for checkpoint saving if checkpoint_saving is not None: save_handler = None if self.test_mode and self.validation_mode: self.custom_print("Use LocalSaveHandler...") save_handler = LocalSaveHandler(self) else: self.custom_print("Use IgniteSaveHandler...") save_handler = DiskSaver(self.save_path, create_dir=True, require_empty=False) saver = Checkpoint( { "model_state_dict": model, "optimizer_state_dict": optimizer }, save_handler, filename_prefix='{}_best'.format(self.dataset.name), score_name="val_loss", score_function=self.score_function, global_step_transform=global_step_from_engine(trainer), n_saved=1) train_evaluator.add_event_handler(Events.COMPLETED, saver) @trainer.on(Events.STARTED) def log_training_start(trainer): self.custom_print("Split: {}".format(split_num + 1)) @trainer.on(Events.COMPLETED) def log_training_complete(trainer): """Trigger evaluation on test set if training is completed.""" epoch = trainer.state.epoch suffix = "(Early Stopping)" if epoch < epochs else "" self.custom_print("Finished after {:03d} epochs! {}".format( epoch, suffix)) embedding_list = [] def _graph_embedding_function(tensor, idx): while idx >= len(embedding_list): embedding_list.append([]) embedding_list[idx].append(tensor.cpu().detach().numpy()) if self.test_mode and self.validation_mode: checkpoint_dict = self.best_model_checkpoint self.custom_print( "Load best model checkpoint by validation loss... Epoch: {}" .format(checkpoint_dict["epoch"])) model, optimizer = self._load_checkpoint( self.model, self.optimizer, checkpoint_dict["checkpoint"]) self.model.graph_embedding_function = _graph_embedding_function self.persist_pred = True if not self.test_mode: return test_evaluator.run(test_loader) @trainer.on(Events.EPOCH_COMPLETED) def compute_metrics(engine): """Compute evaluation metric values after each epoch.""" train_evaluator.run(train_loader) if hasattr(self.model, "node_counter"): self.custom_print(self.model.node_counter) if self.validation_mode: val_evaluator.run(val_loader) # terminate training if Nan values are produced trainer.add_event_handler(Events.ITERATION_COMPLETED, TerminateOnNan()) # create tensorboard-logger tb_logger = create_tb_logger(model, optimizer, trainer, train_evaluator, val_evaluator, test_evaluator, log_dir=tb_log_dir, verbose=verbose, custom_print=self.custom_print, loss_name=loss_name) with torch.autograd.detect_anomaly(): trainer.run(train_loader, max_epochs=epochs) tb_logger.close() if not self.test_mode: return 0, 0 test_acc = test_evaluator.state.metrics["accuracy"] test_loss = test_evaluator.state.metrics["mse"] return test_acc, test_loss
print(f"TRAINING IS DONE FOR {RUN_NAME} RUN.") pbar = ProgressBar() checkpointer = ModelCheckpoint( CHECKPOINTS_RUN_DIR_PATH, filename_prefix=RUN_NAME.lower(), n_saved=None, score_function=lambda engine: round(engine.state.metrics['WRA'], 3), score_name='WRA', atomic=True, require_empty=True, create_dir=True, archived=False, global_step_transform=global_step_from_engine(trainer)) nan_handler = TerminateOnNan() coslr = CosineAnnealingScheduler(opt, "lr", start_value=LR, end_value=LR / 4, cycle_size=TOTAL_UPDATE_STEPS // 1) evaluator.add_event_handler(Events.EPOCH_COMPLETED, checkpointer, {'_': mude}) trainer.add_event_handler(Events.ITERATION_COMPLETED, nan_handler) trainer.add_event_handler(Events.ITERATION_COMPLETED, coslr) GpuInfo().attach(trainer, name='gpu') pbar.attach(trainer, output_transform=lambda output: {'loss': output['loss']},
def add_events(engines, dataloaders, model, optimizer, device, save_dir, args): trainer, valid_evaluator, test_evaluator = engines train_dl, valid_dl, test_dl = dataloaders if args.valid_on == 'Loss': score_fn = lambda engine: -engine.state.metrics[args.valid_on] elif args.valid_on == 'Product': score_fn = lambda engine: engine.state.metrics[ 'MRR'] * engine.state.metrics['HR@10'] elif args.valid_on == 'RMS': score_fn = lambda engine: engine.state.metrics[ 'MRR']**2 + engine.state.metrics['HR@10']**2 else: score_fn = lambda engine: engine.state.metrics[args.valid_on] # LR Scheduler if args.lr_scheduler == 'restart': scheduler = CosineAnnealingScheduler(optimizer, 'lr', start_value=args.lr, end_value=args.lr * 0.01, cycle_size=len(train_dl), cycle_mult=args.cycle_mult) trainer.add_event_handler(Events.ITERATION_STARTED, scheduler, 'lr_scheduler') elif args.lr_scheduler == 'triangle': scheduler = make_slanted_triangular_lr_scheduler( optimizer, n_events=args.n_epochs * len(train_dl), lr_max=args.lr) trainer.add_event_handler(Events.ITERATION_STARTED, scheduler, 'lr_scheduler') elif args.lr_scheduler == 'none': pass else: raise NotImplementedError # EarlyStopping trainer.add_event_handler(Events.ITERATION_COMPLETED, TerminateOnNan()) valid_evaluator.add_event_handler( Events.COMPLETED, EarlyStopping(args.patience, score_function=score_fn, trainer=trainer)) # Training Loss RunningAverage(output_transform=lambda x: x, alpha=args.avg_alpha).attach(trainer, 'loss') # Checkpoint ckpt_handler = ModelCheckpoint(save_dir, 'best', score_function=score_fn, score_name=args.valid_on, n_saved=1) valid_evaluator.add_event_handler(Events.COMPLETED, ckpt_handler, {'model': model}) # Timer timer = Timer(average=True) timer.attach(trainer, resume=Events.EPOCH_STARTED, step=Events.EPOCH_COMPLETED) # Progress Bar if args.pbar: pbar = ProgressBar() pbar.attach(trainer, ['loss']) log_msg = pbar.log_message else: log_msg = print cpe_valid = CustomPeriodicEvent(n_epochs=args.valid_every) cpe_valid.attach(trainer) valid_metrics_history = [] @trainer.on( getattr(cpe_valid.Events, f'EPOCHS_{args.valid_every}_COMPLETED')) def evaluate_on_valid(engine): state = valid_evaluator.run(valid_dl) metrics = state.metrics valid_metrics_history.append(metrics) msg = f'Epoch: {engine.state.epoch:3d} AvgTime: {timer.value():3.1f}s TrainLoss: {engine.state.metrics["loss"]:.4f} ' msg += ' '.join([ f'{k}: {v:.4f}' for k, v in metrics.items() if k in ['Loss', 'MRR', 'HR@10'] ]) log_msg(msg) @trainer.on(Events.COMPLETED) def evaluate_on_test(engine): pth_file = [ f for f in pathlib.Path(save_dir).iterdir() if f.name.endswith('pth') ][0] log_msg(f'Load Best Model: {str(pth_file)}') model.load_state_dict(torch.load(pth_file, map_location=device)) # Rerun on Valid for log. valid_state = valid_evaluator.run(valid_dl) engine.state.valid_metrics = valid_state.metrics # Test test_state = test_evaluator.run(test_dl) engine.state.test_metrics = test_state.metrics engine.state.valid_metrics_history = valid_metrics_history msg = f'[Test] ' msg += ' '.join([ f'{k}: {v:.4f}' for k, v in test_state.metrics.items() if k in ['Loss', 'MRR', 'HR@10'] ]) log_msg(msg) # Tensorboard if args.tensorboard: tb_logger = TensorboardLogger(log_dir=str(save_dir / 'tb_log')) # Loss tb_logger.attach(trainer, log_handler=OutputHandler( tag='training', output_transform=lambda x: x), event_name=Events.ITERATION_COMPLETED) # Metrics tb_logger.attach(valid_evaluator, log_handler=OutputHandler( tag='validation', metric_names=['Loss', 'MRR', 'HR@10'], another_engine=trainer), event_name=Events.EPOCH_COMPLETED) # Optimizer tb_logger.attach(trainer, log_handler=OptimizerParamsHandler(optimizer), event_name=Events.ITERATION_STARTED) # Parameters # tb_logger.attach(trainer, # log_handler=WeightsScalarHandler(model), # event_name=Events.ITERATION_COMPLETED) # tb_logger.attach(trainer, # log_handler=GradsScalarHandler(model), # event_name=Events.ITERATION_COMPLETED) @trainer.on(Events.COMPLETED) def close_tb(engine): tb_logger.close()
def _setup_common_training_handlers( trainer: Engine, to_save: Optional[Mapping] = None, save_every_iters: int = 1000, output_path: Optional[str] = None, lr_scheduler: Optional[Union[ParamScheduler, _LRScheduler]] = None, with_gpu_stats: bool = False, output_names: Optional[Iterable[str]] = None, with_pbars: bool = True, with_pbar_on_iters: bool = True, log_every_iters: int = 100, stop_on_nan: bool = True, clear_cuda_cache: bool = True, save_handler: Optional[Union[Callable, BaseSaveHandler]] = None, **kwargs: Any, ) -> None: if output_path is not None and save_handler is not None: raise ValueError( "Arguments output_path and save_handler are mutually exclusive. Please, define only one of them" ) if stop_on_nan: trainer.add_event_handler(Events.ITERATION_COMPLETED, TerminateOnNan()) if lr_scheduler is not None: if isinstance(lr_scheduler, torch.optim.lr_scheduler._LRScheduler): trainer.add_event_handler( Events.ITERATION_COMPLETED, lambda engine: cast(_LRScheduler, lr_scheduler).step()) elif isinstance(lr_scheduler, LRScheduler): trainer.add_event_handler(Events.ITERATION_COMPLETED, lr_scheduler) else: trainer.add_event_handler(Events.ITERATION_STARTED, lr_scheduler) if torch.cuda.is_available() and clear_cuda_cache: trainer.add_event_handler(Events.EPOCH_COMPLETED, empty_cuda_cache) if to_save is not None: if output_path is None and save_handler is None: raise ValueError( "If to_save argument is provided then output_path or save_handler arguments should be also defined" ) if output_path is not None: save_handler = DiskSaver(dirname=output_path, require_empty=False) checkpoint_handler = Checkpoint(to_save, cast(Union[Callable, BaseSaveHandler], save_handler), filename_prefix="training", **kwargs) trainer.add_event_handler( Events.ITERATION_COMPLETED(every=save_every_iters), checkpoint_handler) if with_gpu_stats: GpuInfo().attach( trainer, name="gpu", event_name=Events.ITERATION_COMPLETED( every=log_every_iters) # type: ignore[arg-type] ) if output_names is not None: def output_transform(x: Any, index: int, name: str) -> Any: if isinstance(x, Mapping): return x[name] elif isinstance(x, Sequence): return x[index] elif isinstance(x, (torch.Tensor, numbers.Number)): return x else: raise TypeError( "Unhandled type of update_function's output. " f"It should either mapping or sequence, but given {type(x)}" ) for i, n in enumerate(output_names): RunningAverage(output_transform=partial(output_transform, index=i, name=n), epoch_bound=False).attach(trainer, n) if with_pbars: if with_pbar_on_iters: ProgressBar(persist=False).attach( trainer, metric_names="all", event_name=Events.ITERATION_COMPLETED(every=log_every_iters)) ProgressBar(persist=True, bar_format="").attach(trainer, event_name=Events.EPOCH_STARTED, closing_event_name=Events.COMPLETED)
def create_trainer(model, tasks, optims, loaders, args): zt = [] zt_task = {'left': [], 'right': []} if args.dataset.name == 'dummy': lim = 2.5 lims = [[-lim, lim], [-lim, lim]] grid = setup_grid(lims, 1000) def trainer_step(engine, batch): model.train() for optim in optims: optim.zero_grad() # Batch data x, y = batch x = convert_tensor(x.float(), args.device) y = [convert_tensor(y_, args.device) for y_ in y] training_loss = 0. losses = [] # Intermediate representation with cached(): preds = model(x) if args.dataset.name == 'dummy': zt.append(model.rep.detach().clone()) for pred_i, task_i in zip(preds, tasks): loss_i = task_i.loss(pred_i, y[task_i.index]) if args.dataset.name == 'dummy': loss_i = loss_i.mean(dim=0) zt_task[task_i.name].append(pred_i.detach().clone()) # Track losses losses.append(loss_i) training_loss += loss_i.item() * task_i.weight if args.dataset.name == 'dummy' and ( engine.state.epoch == engine.state.max_epochs or engine.state.epoch % args.training.plot_every == 0): fig = plot_toy(grid, model, tasks, [zt, zt_task['left'], zt_task['right']], trainer.state.iteration - 1, levels=20, lims=lims) fig.savefig(f'plots/step_{engine.state.iteration - 1}.png') plt.close(fig) model.backward(losses) for optim in optims: # Run the optimizers optim.step() return training_loss, losses trainer = Engine(trainer_step) trainer.add_event_handler(Events.ITERATION_COMPLETED, TerminateOnNan()) RunningAverage(output_transform=lambda x: x[0]).attach(trainer, 'loss') for i, task_i in enumerate(tasks): output_transform = partial(lambda idx, x: x[1][idx], i) RunningAverage(output_transform=output_transform).attach( trainer, f'train_{task_i.name}') pbar = ProgressBar() pbar.attach(trainer, metric_names=['loss'] + [f'train_{t.name}' for t in tasks]) # Validation validator = create_evaluator(model, tasks, args) @trainer.on(Events.EPOCH_COMPLETED) def run_validator(trainer): validator.run(loaders['val']) metrics = validator.state.metrics loss = 0. for task_i in tasks: loss += metrics[f'loss_{task_i.name}'] * task_i.weight trainer.state.metrics['val_loss'] = loss # Checkpoints model_checkpoint = {'model': model} handler = ModelCheckpoint('checkpoints', 'latest', require_empty=False) trainer.add_event_handler( Events.EPOCH_COMPLETED(every=args.training.save_every), handler, model_checkpoint) @trainer.on(Events.EPOCH_COMPLETED(every=args.training.save_every)) def save_state(engine): with open('checkpoints/state.pkl', 'wb') as f: pickle.dump(engine.state, f) @trainer.on(Events.COMPLETED(every=args.training.save_every)) def save_state(engine): with open('checkpoints/state.pkl', 'wb') as f: pickle.dump(engine.state, f) handler = ModelCheckpoint( 'checkpoints', 'best', require_empty=False, score_function=(lambda e: -e.state.metrics['val_loss'])) trainer.add_event_handler( Events.EPOCH_COMPLETED(every=args.training.save_every), handler, model_checkpoint) trainer.add_event_handler(Events.COMPLETED, handler, model_checkpoint) return trainer