def __init__( self, job_dir, num_examples, learning_rate, batch_size, epochs, num_workers, seed, ): super(PyTorchModel, self).__init__(job_dir=job_dir, seed=seed) self.num_examples = num_examples self.learning_rate = learning_rate self.batch_size = batch_size self.epochs = epochs self.summary_writer = tensorboard.SummaryWriter(log_dir=self.job_dir) self.logger = utils.setup_logger(name=__name__ + "." + self.__class__.__name__, distributed_rank=0) self.trainer = engine.Engine(self.train_step) self.evaluator = engine.Engine(self.tune_step) self._network = None self._optimizer = None self._metrics = None self.num_workers = num_workers self.device = distributed.device() self.best_state = None self.counter = 0
def create_vae_trainer(model, optimizer, crt, metrics=None, device=th.device('cpu'), non_blocking=True) -> ie.Engine: if device: model.to(device) def _update(_engine, batch): model.train() optimizer.zero_grad() _in, _cls_gt, _recon_gt = prepare_batch(batch, device=device, non_blocking=non_blocking) _recon_pred, _cls_pred, _temporal_latents, _class_latent, _mean, _var, _ = model( _in, num_samples=1) ce, l1, kld = crt(_recon_pred, _cls_pred, _recon_gt, _cls_gt, _mean, _var) (ce + l1 + crt.kld_factor * kld).backward() optimizer.step() return (_recon_pred.detach(), _cls_pred.detach(), _temporal_latents.detach(), _class_latent.detach(), _mean.detach(), _var.detach(), _in.detach(), _cls_gt.detach(), ce.item(), l1.item(), kld.item(), crt.kld_factor) _engine = ie.Engine(_update) if metrics is not None: for name, metric in metrics.items(): metric.attach(_engine, name) return _engine
def create_ae_trainer(model, optimizer, crt, metrics=None, device=th.device('cpu'), non_blocking=True) -> ie.Engine: if device: model.to(device) def _update(_engine, batch): model.train() optimizer.zero_grad() _in, _cls_gt, _recon_gt = prepare_batch(batch, device=device, non_blocking=non_blocking) _recon_pred, _cls_pred, _temporal_embeds, _class_embed = model(_in) ce, l1 = crt(_recon_pred, _cls_pred, _recon_gt, _cls_gt) (ce + l1).backward() optimizer.step() return (_recon_pred.detach(), _cls_pred.detach(), _temporal_embeds.detach(), _class_embed.detach(), _in.detach(), _cls_gt.detach(), ce.item(), l1.item()) _engine = ie.Engine(_update) if metrics is not None: for name, metric in metrics.items(): metric.attach(_engine, name) return _engine
def create_trainer(self, optimizer: optim.Optimizer, device: torch.device) -> engine.Engine: """Create :class:`ignite.engine.Engine` trainer. Args: optimizer (optim.Optimizer): torch optimizer. device (torch.device): selected device. Returns: engine.Engine: training Engine. """ def _update(engine: engine.Engine, batch: dict): batch["src"] = batch["src"].to(device) batch["trg"] = batch["trg"].to(device) self.train() optimizer.zero_grad() gen_probs = self.forward(batch["src"], batch["trg"]) loss = self.criterion(gen_probs.view(-1, self.vocab_size), batch["trg"][:, 1:].contiguous().view(-1)) loss.backward() optimizer.step() return loss.item() return engine.Engine(_update)
def create_cls_trainer(model, optimizer, crt, metrics=None, device=th.device('cpu'), non_blocking=True) -> ie.Engine: if device: model.to(device) def _update(_engine, batch): model.train() optimizer.zero_grad() _in, _cls_gt, _ = prepare_batch(batch, device=device, non_blocking=non_blocking) y_pred, temporal_embeds, class_embed = model(_in) loss = crt(y_pred, _cls_gt) loss.backward() optimizer.step() return loss.item(), y_pred.detach(), _cls_gt.detach() _engine = ie.Engine(_update) if metrics is not None: for name, metric in metrics.items(): metric.attach(_engine, name) return _engine
def __init__(self, patience, score_name, evaluator_name, mode='max'): if mode not in ['min', 'max']: raise ValueError( f'mode must be min or max. mode value found is {mode}') super(EarlyStopping, self).__init__( patience, score_function=lambda e: e.state.metrics[score_name] if mode == 'max' else -e.state.metrics[score_name], trainer=engine.Engine(lambda engine, batch: None)) self.evaluator_name = evaluator_name
def predict(self, loader: _data.DataLoader) -> Tensor: def estimation_update(engine: _engine.Engine, batch) -> dict: return {"y_pred": self.model(batch)} estimator = _engine.Engine(estimation_update) result = [] @estimator.on(_engine.Events.ITERATION_COMPLETED) def save_results(engine: _engine.Engine) -> None: output = engine.state.output['y_pred'].detach() result.append(output) torch.cuda.empty_cache() self.model.eval() batches = VocalExtractor.get_number_of_batches(loader) estimator.run(loader, epoch_length=batches, max_epochs=1) result = torch.cat(result, dim=0) return result.transpose(0, 1)
def create_evaluator(self, device: torch.device) -> engine.Engine: """Create :class:`ignite.engine.Engine` evaluator Args: device (torch.device): selected device. Returns: engine.Engine: evaluator engine. """ def _evaluate(engine: engine.Engine, batch: dict): batch["src"] = batch["src"].to(device) batch["trg"] = batch["trg"].to(device) self.eval() generated, __ = self.inference(batch["src"], batch["trg"].shape[1] - 1) return generated, batch["trg"][:, 1:] return engine.Engine(_evaluate)
def create_cls_evaluator(model, metrics=None, device=th.device('cpu'), non_blocking=True) -> ie.Engine: if device: model.to(device) def _inference(_engine, batch): model.eval() with th.no_grad(): _in, _cls_gt, _ = prepare_batch(batch, device=device, non_blocking=non_blocking) y_pred, temporal_embeds, class_embed = model(_in) return y_pred, _cls_gt, temporal_embeds, class_embed _engine = ie.Engine(_inference) if metrics is not None: for name, metric in metrics.items(): metric.attach(_engine, name) return _engine
def create_mask_rcnn_trainer(model: nn.Module, optimizer: optim.Optimizer, device=None, non_blocking: bool = False): if device: model.to(device) fn_prepare_batch = lambda batch: engine._prepare_batch(batch, device=device, non_blocking=non_blocking) def _update(engine, batch): model.train() optimizer.zero_grad() image, targets = fn_prepare_batch(batch) losses = model(image, targets) loss = sum(loss for loss in losses.values()) loss.backward() optimizer.step() losses = {k: v.item() for k, v in losses.items()} losses['loss'] = loss.item() return losses return engine.Engine(_update)
def create_vae_evaluator(model, metrics=None, device=None, num_samples: int = None, non_blocking=True) -> ie.Engine: if device: model.to(device) def _inference(_engine, batch): model.eval() with th.no_grad(): _in, _cls_gt, _recon_gt = prepare_batch(batch, device=device, non_blocking=non_blocking) _recon_pred, _cls_pred, _temp_lat, _cls_lat, _mean, _var, _vote = model( _in, num_samples=num_samples) return _recon_pred, _cls_pred, _temp_lat, _cls_lat, _mean, _var, _in, _cls_gt, _vote _engine = ie.Engine(_inference) if metrics is not None: for name, metric in metrics.items(): metric.attach(_engine, name) return _engine
def create_mask_rcnn_evaluator(model: nn.Module, metrics, device=None, non_blocking: bool = False): if device: model.to(device) fn_prepare_batch = lambda batch: engine._prepare_batch(batch, device=device, non_blocking=non_blocking) def _update(engine, batch): # warning(will.brennan) - not putting model in eval mode because we want the losses! with torch.no_grad(): image, targets = fn_prepare_batch(batch) losses = model(image, targets) losses = {k: v.item() for k, v in losses.items()} losses['loss'] = sum(losses.values()) # note(will.brennan) - an ugly hack for metrics... return (losses, len(image)) evaluator = engine.Engine(_update) for name, metric in metrics.items(): metric.attach(evaluator, name) return evaluator
def create_trainer(self, profile: Profile, shared: Storage, logger: Logger, model: nn.Module, loss_function: nn.Module, optimizer: optim.Optimizer, lr_scheduler: Any, output_transform=lambda x, y, y_pred, loss: loss.item(), **kwargs) -> engine.Engine: """ Build the trainer engine. Re-implement this function when you want to customize the updating actions of training. Args: profile: Runtime profile defined in TOML file. shared: Shared storage in the whole lifecycle. logger: The logger named with this Task. model: The model to train. loss_function: The loss function to train. optimizer: The optimizer to train. lr_scheduler: The scheduler to control the learning rate. output_transform: The action to transform the output of the model. Returns: The trainer engine. """ if 'device' in profile: device_type = profile.device else: device_type = 'cpu' if 'non_blocking' in profile: non_blocking = profile.non_blocking else: non_blocking = False if 'deterministic' in profile: deterministic = profile.deterministic else: deterministic = False def _update(_engine: engine.Engine, _batch: Tuple[torch.Tensor]): model.train() optimizer.zero_grad() x, y = self.prepare_train_batch(profile, shared, logger, _batch, device=device_type, non_blocking=non_blocking) y_pred = model(x) loss = loss_function(y_pred, y) loss.backward() optimizer.step() if lr_scheduler is not None: lr_scheduler.step(loss) return output_transform(x, y, y_pred, loss) trainer = engine.Engine( _update) if not deterministic else engine.DeterministicEngine( _update) return trainer
def __init__(self, model, params, config, eval_data_iter=None, optimizer="adam", grad_clip_norm=5.0, grad_noise_weight=0.01): self._logger = logging.getLogger(__name__) self.model = model self.params = params self.config = config self.device = config.device self.model_dir = config.model_dir self.metrics = {} self._optimizer_str = optimizer self.grad_clip_norm = grad_clip_norm self.grad_noise_weight = grad_noise_weight self.train_engine = engine.Engine(self._update_fn) self.train_engine.add_event_handler(Events.EPOCH_COMPLETED, self._print_eta_handler) if config.save_summary_steps > 0: # Summary buffer and tensorboardX writer. self._summary_writer = SummaryWriter(config.model_dir) self.summary = SummaryBuffer(self._summary_writer) # Try to attach summary writer to the model. try: model.attach_summary_writer(self.summary) get_worth_manager().attach_summary_writer(self.summary) self.train_engine.add_event_handler( Events.ITERATION_COMPLETED, self.summary.writing_handler, config.save_summary_steps) except Exception as e: print(e) self._logger.warning( "Can't attach summary writer to this model. Subclass EstimatableModel to access the " "summary writer.") else: self._summary_writer = None # Set to True after writing graph information summary. self._graph_written = False if config.evaluate_steps != 0 and eval_data_iter is None: raise ValueError( "eval_data_iter should be provided for config.evaluate_steps != 0" ) self.eval_data_iter = eval_data_iter if config.evaluate_steps == EstimatorConfig.AFTER_EACH_EPOCH: self.train_engine.add_event_handler(Events.EPOCH_COMPLETED, self._eval_handler) self._eval_summary_writer = SummaryWriter( os.path.join(config.model_dir, "eval")) self.eval_summary = SummaryBuffer(self._eval_summary_writer) elif config.evaluate_steps > 0: self.train_engine.add_event_handler(Events.ITERATION_COMPLETED, self._eval_handler, config.evaluate_steps) self._eval_summary_writer = SummaryWriter( os.path.join(config.model_dir, "eval")) self.eval_summary = SummaryBuffer(self._eval_summary_writer) self.add_metric("loss", Loss(model.compute_loss)) self.add_metric("xentropy_loss", Loss(model.loss_fn)) self._reload_eval_engine() self._built = False
def train(args, trial, is_train=True, study=None): hparams = HPARAMS[args.hparams] print(hparams) if hparams.model_type in {"bjorn"}: Dataset = src.dataset.xTSDatasetSpeakerIdEmbedding def prepare_batch(batch, device, non_blocking): for i in range(len(batch)): batch[i] = batch[i].to(device) batch_x, batch_y, _, emb = batch return (batch_x, batch_y, emb), batch_y else: Dataset = src.dataset.xTSDatasetSpeakerId prepare_batch = prepare_batch_3 train_path_loader = PATH_LOADERS[args.dataset](ROOT, args.filelist + "-train") valid_path_loader = PATH_LOADERS[args.dataset](ROOT, args.filelist + "-valid") train_dataset = Dataset(hparams, train_path_loader, transforms=TRAIN_TRANSFORMS) valid_dataset = Dataset(hparams, valid_path_loader, transforms=VALID_TRANSFORMS) kwargs = dict(batch_size=args.batch_size, collate_fn=collate_fn) train_loader = torch.utils.data.DataLoader(train_dataset, shuffle=True, **kwargs) valid_loader = torch.utils.data.DataLoader(valid_dataset, shuffle=False, **kwargs) num_speakers = len(train_path_loader.speaker_to_id) dataset_parameters = DATASET_PARAMETERS[args.dataset] dataset_parameters["num_speakers"] = num_speakers hparams = update_namespace(hparams, trial.parameters) model = MODELS[hparams.model_type](dataset_parameters, hparams) model_speaker = MODELS_SPEAKER[hparams.model_speaker_type]( hparams.encoder_embedding_dim, num_speakers) if model.speaker_info is SpeakerInfo.EMBEDDING and hparams.embedding_normalize: model.embedding_stats = train_dataset.embedding_stats if hparams.drop_frame_rate: path_mel_mean = os.path.join("output", "mel-mean", f"{args.dataset}-{args.filelist}.npz") mel_mean = cache(compute_mel_mean, path_mel_mean)(train_dataset)["mel_mean"] mel_mean = torch.tensor(mel_mean).float().to(DEVICE) mel_mean = mel_mean.unsqueeze(0).unsqueeze(0) model.decoder.mel_mean = mel_mean model_name = f"{args.dataset}_{args.filelist}_{args.hparams}_dispel" model_path = f"output/models/{model_name}.pth" # Initialize model from existing one. if args.model_path is not None: model.load_state_dict(torch.load(args.model_path, map_location=DEVICE)) if hasattr(hparams, "model_speaker_path"): model_speaker.load_state_dict( torch.load(hparams.model_speaker_path, map_location=DEVICE)) optimizer = torch.optim.Adam(model.parameters(), lr=trial.parameters["lr"]) # 0.001 optimizer_speaker = torch.optim.Adam(model_speaker.parameters(), lr=0.001) mse_loss = nn.MSELoss() def loss_reconstruction(pred, true): pred1, pred2 = pred return mse_loss(pred1, true) + mse_loss(pred2, true) if hasattr(hparams, "loss_speaker_weight"): λ = hparams.loss_speaker_weight else: λ = 0.0002 model.to(DEVICE) model_speaker.to(DEVICE) def step(engine, batch): model.train() model_speaker.train() x, y = prepare_batch(batch, device=DEVICE, non_blocking=True) i = batch[2].to(DEVICE) # Generator: generates audio and dispels speaker identity y_pred, z = model.forward_emb(x) i_pred = model_speaker.forward(z) entropy_s = (-i_pred.exp() * i_pred).sum(dim=1).mean() # entropy on speakers loss_r = loss_reconstruction(y_pred, y) # reconstruction loss_g = loss_r - λ * entropy_s # generator optimizer.zero_grad() loss_g.backward(retain_graph=True) optimizer.step() # Discriminator: predicts speaker identity optimizer_speaker.zero_grad() loss_s = F.nll_loss(i_pred, i) loss_s.backward() optimizer_speaker.step() return { "loss-generator": loss_g.item(), "loss-reconstruction": loss_r.item(), "loss-speaker": loss_s.item(), "entropy-speaker": entropy_s, } trainer = engine.Engine(step) # trainer = engine.create_supervised_trainer( # model, optimizer, loss, device=device, prepare_batch=prepare_batch # ) evaluator = engine.create_supervised_evaluator( model, metrics={"loss": ignite.metrics.Loss(loss_reconstruction)}, device=DEVICE, prepare_batch=prepare_batch, ) @trainer.on(engine.Events.ITERATION_COMPLETED) def log_training_loss(trainer): print( "Epoch {:3d} | Loss gen.: {:+8.6f} = {:8.6f} - λ * {:8.6f} | Loss disc.: {:8.6f}" .format( trainer.state.epoch, trainer.state.output["loss-generator"], trainer.state.output["loss-reconstruction"], trainer.state.output["entropy-speaker"], trainer.state.output["loss-speaker"], )) @trainer.on(engine.Events.ITERATION_COMPLETED(every=EVERY_K_ITERS)) def log_validation_loss(trainer): evaluator.run(valid_loader) metrics = evaluator.state.metrics print("Epoch {:3d} Valid loss: {:8.6f} ←".format( trainer.state.epoch, metrics["loss"])) lr_reduce = lr_scheduler.ReduceLROnPlateau(optimizer, verbose=args.verbose, **LR_REDUCE_PARAMS) @evaluator.on(engine.Events.COMPLETED) def update_lr_reduce(engine): loss = engine.state.metrics["loss"] lr_reduce.step(loss) @evaluator.on(engine.Events.COMPLETED) def terminate_study(engine): """Stops underperforming trials.""" if study and study.should_trial_stop(trial=trial): trainer.terminate() def score_function(engine): return -engine.state.metrics["loss"] early_stopping_handler = ignite.handlers.EarlyStopping( patience=PATIENCE, score_function=score_function, trainer=trainer) evaluator.add_event_handler(engine.Events.COMPLETED, early_stopping_handler) if is_train: def global_step_transform(*args): return trainer.state.iteration // EVERY_K_ITERS checkpoint_handler = ignite.handlers.ModelCheckpoint( "output/models/checkpoints", model_name, score_name="objective", score_function=score_function, n_saved=5, require_empty=False, create_dir=True, global_step_transform=global_step_transform, ) evaluator.add_event_handler(engine.Events.COMPLETED, checkpoint_handler, {"model": model}) trainer.run(train_loader, max_epochs=args.max_epochs) if is_train: torch.save(model.state_dict(), model_path) print("Last model @", model_path) model_best_path = link_best_model(model_name) print("Best model @", model_best_path) return evaluator.state.metrics["loss"]
def __init__(self, dataset, device, max_epochs=1): super(Engine, self).__init__() self.dataset = dataset self.device = device self.max_epochs = max_epochs self.engine = e.Engine(self._update)
return_dict = { 'input_filename': batch['input_filename'], 'mask': masks } if 'TAPNet' in args.model: # for TAPNet, update attention maps after each iteration eval_loader.dataset.update_attmaps(output_logsoftmax_np, batch['idx'].numpy()) # for TAPNet, return extra internal values return_dict['attmap'] = add_params['attmap'] return return_dict evaluator = engine.Engine(eval_step) eval_pbar = c_handlers.ProgressBar(persist=True, dynamic_ncols=True) #valid_pbar = c_handlers.ProgressBar(persist=True, dynamic_ncols=True) eval_pbar.attach(evaluator) # evaluate after iter finish @evaluator.on(engine.Events.ITERATION_COMPLETED) def evaluator_epoch_comp_callback(engine): # save masks for each batch batch_output = engine.state.output input_filenames = batch_output['input_filename'] masks = batch_output['mask'] for i, input_filename in enumerate(input_filenames):
def evaluator_epoch_comp_callback(engine): # save masks for each batch batch_output = engine.state.output input_filenames = batch_output['input_filename'] masks = batch_output['mask'] for i, input_filename in enumerate(input_filenames): mask = cv2.resize(masks[i], dsize=(utils.cropped_width, utils.cropped_height), interpolation=cv2.INTER_AREA) # if pad: # h_start, w_start = utils.h_start, utils.w_start # h, w = mask.shape # # recover to original shape # full_mask = np.zeros((original_height, original_width)) # full_mask[h_start:h_start + h, w_start:w_start + w] = t_mask # mask = full_mask #print("Input Filename-->", input_filename) #instrument_folder_name = input_filename.parent.parent.name instrument_folder_name = os.path.basename( os.path.dirname(os.path.dirname(input_filename))) #print("instrument_folder_name-->", instrument_folder_name) # mask_folder/instrument_dataset_x/problem_type_masks/framexxx.png mask_folder = mask_save_dir / instrument_folder_name / utils.mask_folder[ args.problem_type] mask_folder.mkdir(exist_ok=True, parents=True) mask_filename = mask_folder / os.path.basename(input_filename) #print("mask_filename-->", mask_filename) cv2.imwrite(str(mask_filename), mask) if 'TAPNet' in args.model: attmap = batch_output['attmap'][i] attmap_folder = mask_save_dir / instrument_folder_name / '_'.join( args.problem_type, 'attmaps') attmap_folder.mkdir(exist_ok=True, parents=True) attmap_filename = attmap_folder / input_filename.name cv2.imwrite(str(attmap_filename), attmap) evaluator.run(eval_loader) # validator engine validator = engine.Engine(valid_step) # monitor loss valid_ra_loss = imetrics.RunningAverage( output_transform=lambda x: x['loss'], alpha=0.98) valid_ra_loss.attach(validator, 'valid_ra_loss') # monitor validation loss over epoch valid_loss = imetrics.Loss(loss_func, output_transform=lambda x: (x['output'], x['target'])) valid_loss.attach(validator, 'valid_loss') # monitor <data> mean metrics valid_data_miou = imetrics.RunningAverage( output_transform=lambda x: x['iou'].data_mean()['mean'], alpha=0.98) valid_data_miou.attach(validator, 'mIoU') valid_data_mdice = imetrics.RunningAverage( output_transform=lambda x: x['dice'].data_mean()['mean'], alpha=0.98) valid_data_mdice.attach(validator, 'mDice') # show metrics on progress bar (after every iteration) valid_pbar = c_handlers.ProgressBar(persist=True, dynamic_ncols=True) valid_metric_names = ['valid_ra_loss', 'mIoU', 'mDice'] valid_pbar.attach(validator, metric_names=valid_metric_names) # ## monitor ignite IoU (the same as iou we are using) ### # cm = imetrics.ConfusionMatrix(num_classes, # output_transform=lambda x: (x['output'], x['target'])) # imetrics.IoU(cm, # ignore_index=0 # ).attach(validator, 'iou') # # monitor ignite mean iou (over all classes even not exist in gt) # mean_iou = imetrics.mIoU(cm, # ignore_index=0 # ).attach(validator, 'mean_iou') @validator.on(engine.Events.STARTED) def validator_start_callback(engine): pass @validator.on(engine.Events.EPOCH_STARTED) def validator_epoch_start_callback(engine): engine.state.epoch_metrics = { # directly use definition to calculate 'iou': MetricRecord(), 'dice': MetricRecord(), 'confusion_matrix': np.zeros((num_classes, num_classes), dtype=np.uint32), } # evaluate after iter finish @validator.on(engine.Events.ITERATION_COMPLETED) def validator_iter_comp_callback(engine): pass # evaluate after epoch finish @validator.on(engine.Events.EPOCH_COMPLETED) def validator_epoch_comp_callback(engine): # log ignite metrics # logging_logger.info(engine.state.metrics) # ious = engine.state.metrics['iou'] # msg = 'IoU: ' # for ins_id, iou in enumerate(ious): # msg += '{:d}: {:.3f}, '.format(ins_id + 1, iou) # logging_logger.info(msg) # logging_logger.info('nonzero mean IoU for all data: {:.3f}'.format(ious[ious > 0].mean())) # log monitored epoch metrics epoch_metrics = engine.state.epoch_metrics ######### NOTICE: Two metrics are available but different ########## ### 1. mean metrics for all data calculated by confusion matrix #### ''' compared with using confusion_matrix[1:, 1:] in original code, we use the full confusion matrix and only present non-background result ''' confusion_matrix = epoch_metrics['confusion_matrix'] # [1:, 1:] ious = calculate_iou(confusion_matrix) dices = calculate_dice(confusion_matrix) mean_ious = np.mean(list(ious.values())) mean_dices = np.mean(list(dices.values())) std_ious = np.std(list(ious.values())) std_dices = np.std(list(dices.values())) logging_logger.info('mean IoU: %.3f, std: %.3f, for each class: %s' % (mean_ious, std_ious, ious)) logging_logger.info('mean Dice: %.3f, std: %.3f, for each class: %s' % (mean_dices, std_dices, dices)) ### 2. mean metrics for all data calculated by definition ### iou_data_mean = epoch_metrics['iou'].data_mean() dice_data_mean = epoch_metrics['dice'].data_mean() logging_logger.info('data (%d) mean IoU: %.3f, std: %.3f' % (len(iou_data_mean['items']), iou_data_mean['mean'], iou_data_mean['std'])) logging_logger.info('data (%d) mean Dice: %.3f, std: %.3f' % (len(dice_data_mean['items']), dice_data_mean['mean'], dice_data_mean['std'])) # record metrics in trainer every epoch # trainer.state.metrics_records[trainer.state.epoch] = \ # {'miou': mean_ious, 'std_miou': std_ious, # 'mdice': mean_dices, 'std_mdice': std_dices} trainer.state.metrics_records[trainer.state.epoch] = \ {'miou': iou_data_mean['mean'], 'std_miou': iou_data_mean['std'], 'mdice': dice_data_mean['mean'], 'std_mdice': dice_data_mean['std']} # log interal variables(attention maps, outputs, etc.) on validation def tb_log_valid_iter_vars(engine, logger, event_name): log_tag = 'valid_iter' output = engine.state.output batch_size = output['output'].shape[0] res_grid = tvutils.make_grid( torch.cat([ output['output_argmax'].unsqueeze(1), output['target'].unsqueeze(1), ]), padding=2, normalize=False, # show origin image nrow=batch_size).cpu() logger.writer.add_image(tag='%s (outputs, targets)' % (log_tag), img_tensor=res_grid) if 'TAPNet' in args.model: # log attention maps and other internal values inter_vals_grid = tvutils.make_grid(torch.cat([ output['attmap'], ]), padding=2, normalize=True, nrow=batch_size).cpu() logger.writer.add_image(tag='%s internal vals' % (log_tag), img_tensor=inter_vals_grid) def tb_log_valid_epoch_vars(engine, logger, event_name): log_tag = 'valid_iter' # log monitored epoch metrics epoch_metrics = engine.state.epoch_metrics confusion_matrix = epoch_metrics['confusion_matrix'] # [1:, 1:] ious = calculate_iou(confusion_matrix) dices = calculate_dice(confusion_matrix) mean_ious = np.mean(list(ious.values())) mean_dices = np.mean(list(dices.values())) logger.writer.add_scalar('mIoU', mean_ious, engine.state.epoch) logger.writer.add_scalar('mIoU', mean_dices, engine.state.epoch) if args.tb_log: # log internal values tb_logger.attach(validator, log_handler=tb_log_valid_iter_vars, event_name=engine.Events.ITERATION_COMPLETED) tb_logger.attach(validator, log_handler=tb_log_valid_epoch_vars, event_name=engine.Events.EPOCH_COMPLETED) # tb_logger.attach(validator, log_handler=OutputHandler('valid_iter', valid_metric_names), # event_name=engine.Events.ITERATION_COMPLETED) tb_logger.attach(validator, log_handler=OutputHandler('valid_epoch', ['valid_loss']), event_name=engine.Events.EPOCH_COMPLETED) # score function for model saving ckpt_score_function = lambda engine: \ np.mean(list(calculate_iou(engine.state.epoch_metrics['confusion_matrix']).values())) # ckpt_score_function = lambda engine: engine.state.epoch_metrics['iou'].data_mean()['mean'] ckpt_filename_prefix = 'fold_%d' % fold # model saving handler model_ckpt_handler = handlers.ModelCheckpoint( dirname=args.model_save_dir, filename_prefix=ckpt_filename_prefix, score_function=ckpt_score_function, create_dir=True, require_empty=False, save_as_state_dict=True, atomic=True) validator.add_event_handler(event_name=engine.Events.EPOCH_COMPLETED, handler=model_ckpt_handler, to_save={ 'model': model, }) # early stop # trainer=trainer, but should be handled by validator early_stopping = handlers.EarlyStopping(patience=args.es_patience, score_function=ckpt_score_function, trainer=trainer) validator.add_event_handler(event_name=engine.Events.EPOCH_COMPLETED, handler=early_stopping) # evaluate after epoch finish @trainer.on(engine.Events.EPOCH_COMPLETED) def trainer_epoch_comp_callback(engine): validator.run(valid_loader) trainer.run(train_loader, max_epochs=args.max_epochs) if args.tb_log: # close tb_logger tb_logger.close() return trainer.state.metrics_records
def process_fold(fold, args): num_classes = utils.problem_class[args.problem_type] factor = utils.problem_factor[args.problem_type] # inputs are RGB images (3 * h * w) # outputs are 2d multilabel segmentation maps (h * w) model = eval(args.model)(in_channels=3, num_classes=num_classes) # data parallel for multi-GPU model = nn.DataParallel(model, device_ids=args.device_ids).cuda() ckpt_dir = Path(args.ckpt_dir) #p = pathlib.Path(ckpt_dir) # ckpt for this fold fold_<fold>_model_<epoch>.pth print("ckpt_dir--> ", ckpt_dir) filenames = glob.glob(args.ckpt_dir + 'fold_%d_model_[0-99]*.pth' % fold) #filenames = glob.glob(args.ckpt_dir+'fold_%d_model_[0-99]*.pth') #filenames = ckpt_dir.glob(args.ckpt_dir+'fold_%d_model_[0-9]*.pth'%fold) print("Filename--> ", filenames) # if len(filenames) != 1: # raise ValueError('invalid model ckpt name. correct ckpt name should be \ # fold_<fold>_model_<epoch>.pth') ckpt_filename = filenames[0] # load state dict model.load_state_dict(torch.load(str(ckpt_filename))) logging.info('Restored model [{}] fold {}.'.format(args.model, fold)) # segmentation mask save directory mask_save_dir = Path(args.mask_save_dir) / ckpt_dir.name mask_save_dir.mkdir(exist_ok=True, parents=True) #print("mask_save_dir", mask_save_dir) eval_transform = Compose( [ Normalize(p=1), PadIfNeeded( min_height=args.input_height, min_width=args.input_width, p=1), # optional Resize(height=args.input_height, width=args.input_width, p=1), # CenterCrop(height=args.input_height, width=args.input_width, p=1) ], p=1) # train/valid filenames, # we evaluate and generate masks on validation set train_filenames, valid_filenames = utils.trainval_split( args.train_dir, fold) eval_num_workers = args.num_workers eval_batch_size = args.batch_size # additional ds args if 'TAPNet' in args.model: # in eval, num_workers should be set to 0 for sequences eval_num_workers = 0 # in eval, batch_size should be set to 1 for sequences eval_batch_size = 1 # additional eval dataset kws eval_ds_kwargs = { 'filenames': train_filenames, 'problem_type': args.problem_type, 'transform': eval_transform, 'model': args.model, 'mode': 'eval', } # valid dataloader eval_loader = DataLoader( dataset=RobotSegDataset(**eval_ds_kwargs), shuffle=False, # in eval, no need to shuffle num_workers=eval_num_workers, batch_size= eval_batch_size, # in valid time. have to use one image by one pin_memory=True) # process function for ignite engine def eval_step(engine, batch): with torch.no_grad(): model.eval() #print("batch Keys-->", batch.keys()) inputs = batch['input'].cuda(non_blocking=True) #targets = batch['target'].cuda(non_blocking=True) # additional arguments add_params = {} # for TAPNet, add attention maps if 'TAPNet' in args.model: add_params['attmap'] = batch['attmap'].cuda(non_blocking=True) outputs = model(inputs, **add_params) output_logsoftmax_np = torch.softmax(outputs, dim=1).cpu().numpy() # output_classes and target_classes: <b, h, w> output_classes = output_logsoftmax_np.argmax(axis=1) masks = (output_classes * factor).astype(np.uint8) #print(size(masks)) return_dict = { 'input_filename': batch['input_filename'], 'mask': masks } if 'TAPNet' in args.model: # for TAPNet, update attention maps after each iteration eval_loader.dataset.update_attmaps(output_logsoftmax_np, batch['idx'].numpy()) # for TAPNet, return extra internal values return_dict['attmap'] = add_params['attmap'] return return_dict # eval engine evaluator = engine.Engine(eval_step) eval_pbar = c_handlers.ProgressBar(persist=True, dynamic_ncols=True) #valid_pbar = c_handlers.ProgressBar(persist=True, dynamic_ncols=True) eval_pbar.attach(evaluator) # evaluate after iter finish @evaluator.on(engine.Events.ITERATION_COMPLETED) def evaluator_epoch_comp_callback(engine): global Average_batch_IoU # save masks for each batch batch_output = engine.state.output input_filenames = batch_output['input_filename'] #print("Input_filenames--> ", input_filenames) masks = batch_output['mask'] iou = [] #Average_batch_IoU = [] for i, input_filename in enumerate(input_filenames): mask = cv2.resize(masks[i], dsize=(utils.cropped_width, utils.cropped_height), interpolation=cv2.INTER_AREA) # if pad: # h_start, w_start = utils.h_start, utils.w_start # h, w = mask.shape # # recover to original shape # full_mask = np.zeros((original_height, original_width)) # full_mask[h_start:h_start + h, w_start:w_start + w] = t_mask # mask = full_mask #print("Input Filename-->", input_filename) #img = cv2.imread(input_filename) #instrument_folder_name = input_filename.parent.parent.name instrument_folder_name = os.path.basename( os.path.dirname(os.path.dirname(input_filename))) #print("instrument_folder_name-->", instrument_folder_name) binary_mask = Path(args.type_mask) gt_folder = os.path.dirname( os.path.dirname(input_filename)) / binary_mask #print("gt_folder-->", gt_folder) gt_filename = gt_folder / os.path.basename(input_filename) #print("gt_filename-->", gt_filename) # mask_folder/instrument_dataset_x/problem_type_masks/framexxx.png mask_folder = mask_save_dir / instrument_folder_name / utils.mask_folder[ args.problem_type] mask_folder.mkdir(exist_ok=True, parents=True) mask_filename = mask_folder / os.path.basename(input_filename) gt_mask = cv2.imread(str(gt_filename), cv2.CV_8UC1) #print("mask_filename-->", mask_filename) cv2.imwrite(str(mask_filename), mask) assert (mask.shape == gt_mask.shape) image_iou = get_iou(mask, gt_mask) if math.isnan(image_iou) == False: iou.append(image_iou) #print("IoU for image {} = {}".format(input_filename, iou[-1])) if 'TAPNet' in args.model: attmap = batch_output['attmap'][i] attmap_folder = mask_save_dir / instrument_folder_name / '_'.join( args.problem_type, 'attmaps') attmap_folder.mkdir(exist_ok=True, parents=True) attmap_filename = attmap_folder / os.path.basename( input_filename) cv2.imwrite(str(attmap_filename), attmap) #Average_batch_IoU.append(np.mean(iou)) #Average_batch_IoU = list(np.mean(iou)) Average_batch_IoU.append(np.nanmean(iou)) # evaluator.run(eval_loader) print("Average_batch_IoU-->", np.nanmean(Average_batch_IoU)) f.write(str(np.nanmean(Average_batch_IoU))) f.write('\n')
def train_fold(fold, args): # loggers logging_logger = args.logging_logger if args.tb_log: tb_logger = args.tb_logger num_classes = utils.problem_class[args.problem_type] # init model model = eval(args.model)(in_channels=3, num_classes=num_classes, bn=False) model = nn.DataParallel(model, device_ids=args.device_ids).cuda() # transform for train/valid data train_transform, valid_transform = get_transform(args.model) # loss function loss_func = LossMulti(num_classes, args.jaccard_weight) if args.semi: loss_func_semi = LossMultiSemi(num_classes, args.jaccard_weight, args.semi_loss_alpha, args.semi_method) # train/valid filenames train_filenames, valid_filenames = utils.trainval_split(args.train_dir, fold) # DataLoader and Dataset args train_shuffle = True train_ds_kwargs = { 'filenames': train_filenames, 'problem_type': args.problem_type, 'transform': train_transform, 'model': args.model, 'mode': 'train', 'semi': args.semi, } valid_num_workers = args.num_workers valid_batch_size = args.batch_size if 'TAPNet' in args.model: # for TAPNet, cancel default shuffle, use self-defined shuffle in torch.Dataset instead train_shuffle = False train_ds_kwargs['batch_size'] = args.batch_size train_ds_kwargs['mf'] = args.mf if args.semi == True: train_ds_kwargs['semi_method'] = args.semi_method train_ds_kwargs['semi_percentage'] = args.semi_percentage # additional valid dataset kws valid_ds_kwargs = { 'filenames': valid_filenames, 'problem_type': args.problem_type, 'transform': valid_transform, 'model': args.model, 'mode': 'valid', } if 'TAPNet' in args.model: # in validation, num_workers should be set to 0 for sequences valid_num_workers = 0 # in validation, batch_size should be set to 1 for sequences valid_batch_size = 1 valid_ds_kwargs['mf'] = args.mf # train dataloader train_loader = DataLoader( dataset=RobotSegDataset(**train_ds_kwargs), shuffle=train_shuffle, # set to False to disable pytorch dataset shuffle num_workers=args.num_workers, batch_size=args.batch_size, pin_memory=True ) # valid dataloader valid_loader = DataLoader( dataset=RobotSegDataset(**valid_ds_kwargs), shuffle=False, # in validation, no need to shuffle num_workers=valid_num_workers, batch_size=valid_batch_size, # in valid time. have to use one image by one pin_memory=True ) # optimizer optimizer = optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay) # optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=0.9, # weight_decay=args.weight_decay, nesterov=True) # ignite trainer process function def train_step(engine, batch): # set model to train model.train() # clear gradients optimizer.zero_grad() # additional params to feed into model add_params = {} inputs = batch['input'].cuda(non_blocking=True) with torch.no_grad(): targets = batch['target'].cuda(non_blocking=True) # for TAPNet, add attention maps if 'TAPNet' in args.model: add_params['attmap'] = batch['attmap'].cuda(non_blocking=True) outputs = model(inputs, **add_params) loss_kwargs = {} if args.semi: loss_kwargs['labeled'] = batch['labeled'] if args.semi_method == 'rev_flow': loss_kwargs['optflow'] = batch['optflow'] loss = loss_func_semi(outputs, targets, **loss_kwargs) else: loss = loss_func(outputs, targets, **loss_kwargs) loss.backward() optimizer.step() return_dict = { 'output': outputs, 'target': targets, 'loss_kwargs': loss_kwargs, 'loss': loss.item(), } # for TAPNet, update attention maps after each iteration if 'TAPNet' in args.model: # output_classes and target_classes: <b, h, w> output_softmax_np = torch.softmax(outputs, dim=1).detach().cpu().numpy() # update attention maps train_loader.dataset.update_attmaps(output_softmax_np, batch['abs_idx'].numpy()) return_dict['attmap'] = add_params['attmap'] return return_dict # init trainer trainer = engine.Engine(train_step) # lr scheduler and handler # cyc_scheduler = optim.lr_scheduler.CyclicLR(optimizer, args.lr / 100, args.lr) # lr_scheduler = c_handlers.param_scheduler.LRScheduler(cyc_scheduler) # trainer.add_event_handler(engine.Events.ITERATION_COMPLETED, lr_scheduler) step_scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=args.lr_decay_epochs, gamma=args.lr_decay) lr_scheduler = c_handlers.param_scheduler.LRScheduler(step_scheduler) trainer.add_event_handler(engine.Events.EPOCH_STARTED, lr_scheduler) @trainer.on(engine.Events.STARTED) def trainer_start_callback(engine): logging_logger.info('training fold {}, {} train / {} valid files'. \ format(fold, len(train_filenames), len(valid_filenames))) # resume training if args.resume: # ckpt for current fold fold_<fold>_model_<epoch>.pth ckpt_dir = Path(args.ckpt_dir) ckpt_filename = ckpt_dir.glob('fold_%d_model_[0-9]*.pth' % fold)[0] res = re.match(r'fold_%d_model_(\d+).pth' % fold, ckpt_filename) # restore epoch engine.state.epoch = int(res.groups()[0]) # load model state dict model.load_state_dict(torch.load(str(ckpt_filename))) logging_logger.info('restore model [{}] from epoch {}.'.format(args.model, engine.state.epoch)) else: logging_logger.info('train model [{}] from scratch'.format(args.model)) # record metrics history every epoch engine.state.metrics_records = {} @trainer.on(engine.Events.EPOCH_STARTED) def trainer_epoch_start_callback(engine): # log learning rate on pbar train_pbar.log_message('model: %s, problem type: %s, fold: %d, lr: %.5f, batch size: %d' % \ (args.model, args.problem_type, fold, lr_scheduler.get_param(), args.batch_size)) # for TAPNet, change dataset schedule to random after the first epoch if 'TAPNet' in args.model and engine.state.epoch > 1: train_loader.dataset.set_dataset_schedule("shuffle") @trainer.on(engine.Events.ITERATION_COMPLETED) def trainer_iter_comp_callback(engine): # logging_logger.info(engine.state.metrics) pass # monitor loss # running average loss train_ra_loss = imetrics.RunningAverage(output_transform= lambda x: x['loss'], alpha=0.98) train_ra_loss.attach(trainer, 'train_ra_loss') # monitor train loss over epoch if args.semi: train_loss = imetrics.Loss(loss_func_semi, output_transform=lambda x: (x['output'], x['target'], x['loss_kwargs'])) else: train_loss = imetrics.Loss(loss_func, output_transform=lambda x: (x['output'], x['target'])) train_loss.attach(trainer, 'train_loss') # progress bar train_pbar = c_handlers.ProgressBar(persist=True, dynamic_ncols=True) train_metric_names = ['train_ra_loss'] train_pbar.attach(trainer, metric_names=train_metric_names) # tensorboardX: log train info if args.tb_log: tb_logger.attach(trainer, log_handler=OptimizerParamsHandler(optimizer, 'lr'), event_name=engine.Events.EPOCH_STARTED) tb_logger.attach(trainer, log_handler=OutputHandler('train_iter', train_metric_names), event_name=engine.Events.ITERATION_COMPLETED) tb_logger.attach(trainer, log_handler=OutputHandler('train_epoch', ['train_loss']), event_name=engine.Events.EPOCH_COMPLETED) tb_logger.attach(trainer, log_handler=WeightsScalarHandler(model, reduction=torch.norm), event_name=engine.Events.ITERATION_COMPLETED) # tb_logger.attach(trainer, log_handler=tb_log_train_vars, # event_name=engine.Events.ITERATION_COMPLETED) # ignite validator process function def valid_step(engine, batch): with torch.no_grad(): model.eval() inputs = batch['input'].cuda(non_blocking=True) targets = batch['target'].cuda(non_blocking=True) # additional arguments add_params = {} # for TAPNet, add attention maps if 'TAPNet' in args.model: add_params['attmap'] = batch['attmap'].cuda(non_blocking=True) # output logits outputs = model(inputs, **add_params) # loss loss = loss_func(outputs, targets) output_softmaxs = torch.softmax(outputs, dim=1) output_argmaxs = output_softmaxs.argmax(dim=1) # output_classes and target_classes: <b, h, w> output_classes = output_argmaxs.cpu().numpy() target_classes = targets.cpu().numpy() # record current batch metrics iou_mRecords = MetricRecord() dice_mRecords = MetricRecord() cm_b = np.zeros((num_classes, num_classes), dtype=np.uint32) for output_class, target_class in zip(output_classes, target_classes): # calculate metrics for each frame # calculate using confusion matrix or dirctly using definition cm = calculate_confusion_matrix_from_arrays(output_class, target_class, num_classes) iou_mRecords.update_record(calculate_iou(cm)) dice_mRecords.update_record(calculate_dice(cm)) cm_b += cm ######## calculate directly using definition ########## # iou_mRecords.update_record(iou_multi_np(target_class, output_class)) # dice_mRecords.update_record(dice_multi_np(target_class, output_class)) # accumulate batch metrics to engine state engine.state.epoch_metrics['confusion_matrix'] += cm_b engine.state.epoch_metrics['iou'].merge(iou_mRecords) engine.state.epoch_metrics['dice'].merge(dice_mRecords) return_dict = { 'loss': loss.item(), 'output': outputs, 'output_argmax': output_argmaxs, 'target': targets, # for monitoring 'iou': iou_mRecords, 'dice': dice_mRecords, } if 'TAPNet' in args.model: # for TAPNet, update attention maps after each iteration valid_loader.dataset.update_attmaps(output_softmaxs.cpu().numpy(), batch['abs_idx'].numpy()) # for TAPNet, return extra internal values return_dict['attmap'] = add_params['attmap'] # TODO: for TAPNet, return internal self-learned attention maps return return_dict # validator engine validator = engine.Engine(valid_step) # monitor loss valid_ra_loss = imetrics.RunningAverage(output_transform= lambda x: x['loss'], alpha=0.98) valid_ra_loss.attach(validator, 'valid_ra_loss') # monitor validation loss over epoch valid_loss = imetrics.Loss(loss_func, output_transform=lambda x: (x['output'], x['target'])) valid_loss.attach(validator, 'valid_loss') # monitor <data> mean metrics valid_data_miou = imetrics.RunningAverage(output_transform= lambda x: x['iou'].data_mean()['mean'], alpha=0.98) valid_data_miou.attach(validator, 'mIoU') valid_data_mdice = imetrics.RunningAverage(output_transform= lambda x: x['dice'].data_mean()['mean'], alpha=0.98) valid_data_mdice.attach(validator, 'mDice') # show metrics on progress bar (after every iteration) valid_pbar = c_handlers.ProgressBar(persist=True, dynamic_ncols=True) valid_metric_names = ['valid_ra_loss', 'mIoU', 'mDice'] valid_pbar.attach(validator, metric_names=valid_metric_names) # ## monitor ignite IoU (the same as iou we are using) ### # cm = imetrics.ConfusionMatrix(num_classes, # output_transform=lambda x: (x['output'], x['target'])) # imetrics.IoU(cm, # ignore_index=0 # ).attach(validator, 'iou') # # monitor ignite mean iou (over all classes even not exist in gt) # mean_iou = imetrics.mIoU(cm, # ignore_index=0 # ).attach(validator, 'mean_iou') @validator.on(engine.Events.STARTED) def validator_start_callback(engine): pass @validator.on(engine.Events.EPOCH_STARTED) def validator_epoch_start_callback(engine): engine.state.epoch_metrics = { # directly use definition to calculate 'iou': MetricRecord(), 'dice': MetricRecord(), 'confusion_matrix': np.zeros((num_classes, num_classes), dtype=np.uint32), } # evaluate after iter finish @validator.on(engine.Events.ITERATION_COMPLETED) def validator_iter_comp_callback(engine): pass # evaluate after epoch finish @validator.on(engine.Events.EPOCH_COMPLETED) def validator_epoch_comp_callback(engine): # log ignite metrics # logging_logger.info(engine.state.metrics) # ious = engine.state.metrics['iou'] # msg = 'IoU: ' # for ins_id, iou in enumerate(ious): # msg += '{:d}: {:.3f}, '.format(ins_id + 1, iou) # logging_logger.info(msg) # logging_logger.info('nonzero mean IoU for all data: {:.3f}'.format(ious[ious > 0].mean())) # log monitored epoch metrics epoch_metrics = engine.state.epoch_metrics ######### NOTICE: Two metrics are available but different ########## ### 1. mean metrics for all data calculated by confusion matrix #### ''' compared with using confusion_matrix[1:, 1:] in original code, we use the full confusion matrix and only present non-background result ''' confusion_matrix = epoch_metrics['confusion_matrix']# [1:, 1:] ious = calculate_iou(confusion_matrix) dices = calculate_dice(confusion_matrix) mean_ious = np.mean(list(ious.values())) mean_dices = np.mean(list(dices.values())) std_ious = np.std(list(ious.values())) std_dices = np.std(list(dices.values())) logging_logger.info('mean IoU: %.3f, std: %.3f, for each class: %s' % (mean_ious, std_ious, ious)) logging_logger.info('mean Dice: %.3f, std: %.3f, for each class: %s' % (mean_dices, std_dices, dices)) ### 2. mean metrics for all data calculated by definition ### iou_data_mean = epoch_metrics['iou'].data_mean() dice_data_mean = epoch_metrics['dice'].data_mean() logging_logger.info('data (%d) mean IoU: %.3f, std: %.3f' % (len(iou_data_mean['items']), iou_data_mean['mean'], iou_data_mean['std'])) logging_logger.info('data (%d) mean Dice: %.3f, std: %.3f' % (len(dice_data_mean['items']), dice_data_mean['mean'], dice_data_mean['std'])) # record metrics in trainer every epoch # trainer.state.metrics_records[trainer.state.epoch] = \ # {'miou': mean_ious, 'std_miou': std_ious, # 'mdice': mean_dices, 'std_mdice': std_dices} trainer.state.metrics_records[trainer.state.epoch] = \ {'miou': iou_data_mean['mean'], 'std_miou': iou_data_mean['std'], 'mdice': dice_data_mean['mean'], 'std_mdice': dice_data_mean['std']} # log interal variables(attention maps, outputs, etc.) on validation def tb_log_valid_iter_vars(engine, logger, event_name): log_tag = 'valid_iter' output = engine.state.output batch_size = output['output'].shape[0] res_grid = tvutils.make_grid(torch.cat([ output['output_argmax'].unsqueeze(1), output['target'].unsqueeze(1), ]), padding=2, normalize=False, # show origin image nrow=batch_size).cpu() logger.writer.add_image(tag='%s (outputs, targets)' % (log_tag), img_tensor=res_grid) if 'TAPNet' in args.model: # log attention maps and other internal values inter_vals_grid = tvutils.make_grid(torch.cat([ output['attmap'], ]), padding=2, normalize=True, nrow=batch_size).cpu() logger.writer.add_image(tag='%s internal vals' % (log_tag), img_tensor=inter_vals_grid) def tb_log_valid_epoch_vars(engine, logger, event_name): log_tag = 'valid_iter' # log monitored epoch metrics epoch_metrics = engine.state.epoch_metrics confusion_matrix = epoch_metrics['confusion_matrix']# [1:, 1:] ious = calculate_iou(confusion_matrix) dices = calculate_dice(confusion_matrix) mean_ious = np.mean(list(ious.values())) mean_dices = np.mean(list(dices.values())) logger.writer.add_scalar('mIoU', mean_ious, engine.state.epoch) logger.writer.add_scalar('mIoU', mean_dices, engine.state.epoch) if args.tb_log: # log internal values tb_logger.attach(validator, log_handler=tb_log_valid_iter_vars, event_name=engine.Events.ITERATION_COMPLETED) tb_logger.attach(validator, log_handler=tb_log_valid_epoch_vars, event_name=engine.Events.EPOCH_COMPLETED) # tb_logger.attach(validator, log_handler=OutputHandler('valid_iter', valid_metric_names), # event_name=engine.Events.ITERATION_COMPLETED) tb_logger.attach(validator, log_handler=OutputHandler('valid_epoch', ['valid_loss']), event_name=engine.Events.EPOCH_COMPLETED) # score function for model saving ckpt_score_function = lambda engine: \ np.mean(list(calculate_iou(engine.state.epoch_metrics['confusion_matrix']).values())) # ckpt_score_function = lambda engine: engine.state.epoch_metrics['iou'].data_mean()['mean'] ckpt_filename_prefix = 'fold_%d' % fold # model saving handler model_ckpt_handler = handlers.ModelCheckpoint( dirname=args.model_save_dir, filename_prefix=ckpt_filename_prefix, score_function=ckpt_score_function, create_dir=True, require_empty=False, save_as_state_dict=True, atomic=True) validator.add_event_handler(event_name=engine.Events.EPOCH_COMPLETED, handler=model_ckpt_handler, to_save={ 'model': model, }) # early stop # trainer=trainer, but should be handled by validator early_stopping = handlers.EarlyStopping(patience=args.es_patience, score_function=ckpt_score_function, trainer=trainer ) validator.add_event_handler(event_name=engine.Events.EPOCH_COMPLETED, handler=early_stopping) # evaluate after epoch finish @trainer.on(engine.Events.EPOCH_COMPLETED) def trainer_epoch_comp_callback(engine): validator.run(valid_loader) trainer.run(train_loader, max_epochs=args.max_epochs) if args.tb_log: # close tb_logger tb_logger.close() return trainer.state.metrics_records
def create_supervised_trainer(model, optimizer, loss_fn, device=None, non_blocking=False, prepare_batch=engine._prepare_batch, check_nan=False, grad_clip=None, output_predictions=False): """As ignite.engine.create_supervised_trainer, but may also optionall perform: - NaN checking on predictions (in a more debuggable way than ignite.handlers.TerminateOnNaN) - Gradient clipping - Record the predictions made by a model Arguments: (as ignite.engine.create_supervised_trainer, plus) check_nan: Optional boolean specifying whether the engine should check predictions for NaN values. Defaults to False. If True, and a NaN value is encountered, then a RuntimeError will be raised with attributes 'x', 'y', 'y_pred', 'model', details the feature, label, prediction and model, respetively, on which this occurred. grad_clip: Optional number, boolean or None, specifying the value to clip the infinity-norm of the gradient to. Defaults to None. If False or None then no gradient clipping will be applied. If True then the gradient is clipped to 1.0. output_predictions: Optional boolean specifying whether the engine should record the predictions the model made on a batch. Defaults to False. If True then state.output will be a tuple of (loss, predictions). If False then state.output will just be the loss. (Not wrapped in a tuple.) """ if device: model.to(device) if grad_clip is False: grad_clip = None elif grad_clip is True: grad_clip = 1.0 def _update(engine, batch): model.train() optimizer.zero_grad() x, y = prepare_batch(batch, device=device, non_blocking=non_blocking) y_pred = model(x) if check_nan and torch.isnan(y_pred).any(): e = RuntimeError('Model generated NaN value.') e.y = y e.y_pred = y_pred e.x = x e.model = model raise e loss = loss_fn(y_pred, y) loss.backward() if grad_clip is not None: nnutils.clip_grad_norm_(model.parameters(), grad_clip, norm_type='inf') optimizer.step() if output_predictions: return loss.item(), y_pred else: return loss.item() return engine.Engine(_update)
def _reload_eval_engine(self): self.eval_engine = engine.Engine(self._eval_fn) if len(self.metrics) > 0: for name, metric in self.metrics.items(): metric.attach(self.eval_engine, name)
def run(experiment_name: str, visdom_host: str, visdom_port: int, visdom_env_path: str, model_class: str, model_args: Dict[str, Any], optimizer_class: str, optimizer_args: Dict[str, Any], dataset_class: str, dataset_args: Dict[str, Any], batch_train: int, batch_test: int, workers_train: int, workers_test: int, transforms: List[Dict[str, Union[str, Dict[str, Any]]]], epochs: int, log_interval: int, saved_models_path: str, performance_metrics: Optional = None, scheduler_class: Optional[str] = None, scheduler_args: Optional[Dict[str, Any]] = None, model_suffix: Optional[str] = None, setup_suffix: Optional[str] = None, orig_stdout: Optional[io.TextIOBase] = None): with _utils.tqdm_stdout(orig_stdout) as orig_stdout: device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') transforms_train = list() transforms_test = list() for idx, transform in enumerate(transforms): use_train = transform.get('train', True) use_test = transform.get('test', True) transform = _utils.load_class( transform['class'])(**transform['args']) if use_train: transforms_train.append(transform) if use_test: transforms_test.append(transform) transforms[idx]['train'] = use_train transforms[idx]['test'] = use_test transforms_train = tv.transforms.Compose(transforms_train) transforms_test = tv.transforms.Compose(transforms_test) Dataset: Type = _utils.load_class(dataset_class) train_loader, eval_loader = _utils.get_data_loaders( Dataset, dataset_args, batch_train, batch_test, workers_train, workers_test, transforms_train, transforms_test) Network: Type = _utils.load_class(model_class) model: _interfaces.AbstractNet = Network(**model_args) model = model.to(device) Optimizer: Type = _utils.load_class(optimizer_class) optimizer: torch.optim.Optimizer = Optimizer(model.parameters(), **optimizer_args) if scheduler_class is not None: Scheduler: Type = _utils.load_class(scheduler_class) if scheduler_args is None: scheduler_args = dict() scheduler: Optional[ torch.optim.lr_scheduler._LRScheduler] = Scheduler( optimizer, **scheduler_args) else: scheduler = None model_short_name = ''.join( [c for c in Network.__name__ if c == c.upper()]) model_name = '{}{}'.format( model_short_name, '-{}'.format(model_suffix) if model_suffix is not None else '') visdom_env_name = '{}_{}_{}{}'.format( Dataset.__name__, experiment_name, model_name, '-{}'.format(setup_suffix) if setup_suffix is not None else '') vis, vis_pid = _visdom.get_visdom_instance(visdom_host, visdom_port, visdom_env_name, visdom_env_path) prog_bar_epochs = tqdm.tqdm(total=epochs, desc='Epochs', file=orig_stdout, dynamic_ncols=True, unit='epoch') prog_bar_iters = tqdm.tqdm(desc='Batches', file=orig_stdout, dynamic_ncols=True) tqdm.tqdm.write(f'\n{repr(model)}\n') tqdm.tqdm.write('Total number of parameters: {:.2f}M'.format( sum(p.numel() for p in model.parameters()) / 1e6)) def training_step(_: ieng.Engine, batch: _interfaces.TensorPair) -> torch.Tensor: model.train() optimizer.zero_grad() x, y = batch x = x.to(device) y = y.to(device) _, loss = model(x, y) loss.backward(retain_graph=False) optimizer.step(None) return loss.item() def eval_step(_: ieng.Engine, batch: _interfaces.TensorPair) -> _interfaces.TensorPair: model.eval() with torch.no_grad(): x, y = batch x = x.to(device) y = y.to(device) y_pred = model(x) return y_pred, y trainer = ieng.Engine(training_step) validator_train = ieng.Engine(eval_step) validator_eval = ieng.Engine(eval_step) # placeholder for summary window vis.text(text='', win=experiment_name, env=visdom_env_name, opts={ 'title': 'Summary', 'width': 940, 'height': 416 }, append=vis.win_exists(experiment_name, visdom_env_name)) default_metrics = { "Loss": { "window_name": None, "x_label": "#Epochs", "y_label": model.loss_fn_name, "width": 940, "height": 416, "lines": [{ "line_label": "SMA", "object": imet.RunningAverage(output_transform=lambda x: x), "test": False, "update_rate": "iteration" }, { "line_label": "Val.", "object": imet.Loss(model.loss_fn) }] } } performance_metrics = {**default_metrics, **performance_metrics} checkpoint_metrics = list() for scope_name, scope in performance_metrics.items(): scope['window_name'] = scope.get('window_name', scope_name) or scope_name for line in scope['lines']: if 'object' not in line: line['object']: imet.Metric = _utils.load_class( line['class'])(**line['args']) line['metric_label'] = '{}: {}'.format(scope['window_name'], line['line_label']) line['update_rate'] = line.get('update_rate', 'epoch') line_suffixes = list() if line['update_rate'] == 'iteration': line['object'].attach(trainer, line['metric_label']) line['train'] = False line['test'] = False line_suffixes.append(' Train.') if line.get('train', True): line['object'].attach(validator_train, line['metric_label']) line_suffixes.append(' Train.') if line.get('test', True): line['object'].attach(validator_eval, line['metric_label']) line_suffixes.append(' Eval.') if line.get('is_checkpoint', False): checkpoint_metrics.append(line['metric_label']) for line_suffix in line_suffixes: _visdom.plot_line( vis=vis, window_name=scope['window_name'], env=visdom_env_name, line_label=line['line_label'] + line_suffix, x_label=scope['x_label'], y_label=scope['y_label'], width=scope['width'], height=scope['height'], draw_marker=(line['update_rate'] == 'epoch')) if checkpoint_metrics: score_name = 'performance' def get_score(engine: ieng.Engine) -> float: current_mode = getattr( engine.state.dataloader.iterable.dataset, dataset_args['training']['key']) val_mode = dataset_args['training']['no'] score = 0.0 if current_mode == val_mode: for metric_name in checkpoint_metrics: try: score += engine.state.metrics[metric_name] except KeyError: pass return score model_saver = ihan.ModelCheckpoint(os.path.join( saved_models_path, visdom_env_name), filename_prefix=visdom_env_name, score_name=score_name, score_function=get_score, n_saved=3, save_as_state_dict=True, require_empty=False, create_dir=True) validator_eval.add_event_handler(ieng.Events.EPOCH_COMPLETED, model_saver, {model_name: model}) @trainer.on(ieng.Events.EPOCH_STARTED) def reset_progress_iterations(engine: ieng.Engine): prog_bar_iters.clear() prog_bar_iters.n = 0 prog_bar_iters.last_print_n = 0 prog_bar_iters.start_t = time.time() prog_bar_iters.last_print_t = time.time() prog_bar_iters.total = len(engine.state.dataloader) @trainer.on(ieng.Events.ITERATION_COMPLETED) def log_training(engine: ieng.Engine): prog_bar_iters.update(1) num_iter = (engine.state.iteration - 1) % len(train_loader) + 1 early_stop = np.isnan(engine.state.output) or np.isinf( engine.state.output) if num_iter % log_interval == 0 or num_iter == len( train_loader) or early_stop: tqdm.tqdm.write( 'Epoch[{}] Iteration[{}/{}] Loss: {:.4f}'.format( engine.state.epoch, num_iter, len(train_loader), engine.state.output)) x_pos = engine.state.epoch + num_iter / len(train_loader) - 1 for scope_name, scope in performance_metrics.items(): for line in scope['lines']: if line['update_rate'] == 'iteration': line_label = '{} Train.'.format(line['line_label']) line_value = engine.state.metrics[ line['metric_label']] if engine.state.epoch > 1: _visdom.plot_line( vis=vis, window_name=scope['window_name'], env=visdom_env_name, line_label=line_label, x_label=scope['x_label'], y_label=scope['y_label'], x=np.full(1, x_pos), y=np.full(1, line_value)) if early_stop: tqdm.tqdm.write( colored('Early stopping due to invalid loss value.', 'red')) trainer.terminate() def log_validation(engine: ieng.Engine, train: bool = True): if train: run_type = 'Train.' data_loader = train_loader validator = validator_train else: run_type = 'Eval.' data_loader = eval_loader validator = validator_eval prog_bar_validation = tqdm.tqdm(data_loader, desc=f'Validation {run_type}', file=orig_stdout, dynamic_ncols=True, leave=False) validator.run(prog_bar_validation) prog_bar_validation.clear() prog_bar_validation.close() tqdm_info = ['Epoch: {}'.format(engine.state.epoch)] for scope_name, scope in performance_metrics.items(): for line in scope['lines']: if line['update_rate'] == 'epoch': try: line_label = '{} {}'.format( line['line_label'], run_type) line_value = validator.state.metrics[ line['metric_label']] _visdom.plot_line(vis=vis, window_name=scope['window_name'], env=visdom_env_name, line_label=line_label, x_label=scope['x_label'], y_label=scope['y_label'], x=np.full(1, engine.state.epoch), y=np.full(1, line_value), draw_marker=True) tqdm_info.append('{}: {:.4f}'.format( line_label, line_value)) except KeyError: pass tqdm.tqdm.write('{} results - {}'.format(run_type, '; '.join(tqdm_info))) @trainer.on(ieng.Events.EPOCH_COMPLETED) def log_validation_train(engine: ieng.Engine): log_validation(engine, True) @trainer.on(ieng.Events.EPOCH_COMPLETED) def log_validation_eval(engine: ieng.Engine): log_validation(engine, False) if engine.state.epoch == 1: summary = _utils.build_summary_str( experiment_name=experiment_name, model_short_name=model_name, model_class=model_class, model_args=model_args, optimizer_class=optimizer_class, optimizer_args=optimizer_args, dataset_class=dataset_class, dataset_args=dataset_args, transforms=transforms, epochs=epochs, batch_train=batch_train, log_interval=log_interval, saved_models_path=saved_models_path, scheduler_class=scheduler_class, scheduler_args=scheduler_args) _visdom.create_summary_window(vis=vis, visdom_env_name=visdom_env_name, experiment_name=experiment_name, summary=summary) vis.save([visdom_env_name]) prog_bar_epochs.update(1) if scheduler is not None: scheduler.step(engine.state.epoch) trainer.run(train_loader, max_epochs=epochs) if vis_pid is not None: tqdm.tqdm.write('Stopping visdom') os.kill(vis_pid, signal.SIGTERM) del vis del train_loader del eval_loader prog_bar_iters.clear() prog_bar_iters.close() prog_bar_epochs.clear() prog_bar_epochs.close() tqdm.tqdm.write('\n')
def create_evaluator(self, profile: Profile, shared: Storage, logger: Logger, model: nn.Module, loss_function: nn.Module, optimizer: optim.Optimizer, lr_scheduler: Any, output_transform=lambda x, y, y_pred: (y_pred, y), **kwargs) -> engine.Engine: """ Args: profile: Runtime profile defined in TOML file. shared: Shared storage in the whole lifecycle. logger: The logger named with this Task. model: The model to train. loss_function: The loss function to train. optimizer: The optimizer to train. lr_scheduler: The scheduler to control the learning rate. output_transform: The action to transform the output of the model. Returns: The evaluator engine. """ if 'device' in profile: device_type = profile.device else: device_type = 'cpu' if 'non_blocking' in profile: non_blocking = profile.non_blocking else: non_blocking = False if 'deterministic' in profile: deterministic = profile.deterministic else: deterministic = False _metrics = {} self.register_metrics(profile, shared, logger, _metrics) def _inference(_engine: engine.Engine, _batch: Tuple[torch.Tensor]): model.eval() with torch.no_grad(): x, y = self.prepare_validate_batch(profile, shared, logger, _batch, device=device_type, non_blocking=non_blocking) y_pred = model(x) return output_transform(x, y, y_pred) evaluator = engine.DeterministicEngine( _inference) if deterministic else engine.Engine(_inference) for name, metric in _metrics.items(): metric.attach(evaluator, name) return evaluator
x, y = batch['payload'], batch['target'] ypred = CLF (x) loss = LFN (ypred, y.squeeze(1)) loss.backward() OPM.step() return loss.item() def eval_step(engine, batch): CLF.eval() with t.no_grad(): x, y = batch['payload'], batch['target'] y = y.squeeze (1) ypred = CLF (x) return ypred, y TRAINER = ie.Engine (train_step) EVALUATOR = ie.Engine (eval_step) for name, metric in VAL_METRICS.items(): metric.attach (EVALUATOR, name) ######################### TO_CHECKP = { "trainer":TRAINER, "evaluator":EVALUATOR, "model":CLF, "optimizer":OPM, } tckp = ih.Checkpoint ( to_save = TO_CHECKP, save_handler = ih.DiskSaver (MDIR, require_empty=False), n_saved=10, )