def add_progress_bar_eval(evaluator, validation_loader): """ "I can't believe it's not Keras" Running average accuracy and loss metrics + TQDM progressbar """ validation_history = {'accuracy': [], 'loss': []} last_epoch = [] RunningAverage(output_transform=lambda x: x[0]).attach(evaluator, 'loss') RunningAverage(Accuracy(output_transform=lambda x: (x[0], x[1]))).attach( evaluator, 'accuracy') prog_bar = ProgressBar() prog_bar.attach(evaluator, ['accuracy']) # prog_bar.pbar_cls=tqdm.tqdm from ignite.handlers import Timer timer = Timer(average=True) timer.attach(evaluator, start=Events.EPOCH_STARTED, resume=Events.EPOCH_STARTED, pause=Events.EPOCH_COMPLETED, step=Events.EPOCH_COMPLETED) @evaluator.on(Events.EPOCH_COMPLETED) def log_validation_results(evaluator): metrics = evaluator.state.metrics accuracy = metrics['accuracy'] * 100 loss = metrics['nll'] validation_history['accuracy'].append(accuracy) validation_history['loss'].append(loss) val_msg = "Valid Epoch {}: acc: {:.2f}% loss: {:.2f}, eval time: {:.2f}s".format( evaluator.state.epoch, accuracy, loss, timer.value()) prog_bar.log_message(val_msg)
def __init__(self, num_iters=100, prepare_batch=None, device="cuda"): from ignite.handlers import Timer def upload_to_gpu(engine, batch): if prepare_batch is not None: x, y = prepare_batch(batch, device=device, non_blocking=False) self.num_iters = num_iters self.benchmark_dataflow = Engine(upload_to_gpu) @self.benchmark_dataflow.on(Events.ITERATION_COMPLETED(once=num_iters)) def stop_benchmark_dataflow(engine): engine.terminate() if dist.is_available() and dist.get_rank() == 0: @self.benchmark_dataflow.on( Events.ITERATION_COMPLETED(every=num_iters // 100)) def show_progress_benchmark_dataflow(engine): print(".", end=" ") self.timer = Timer(average=False) self.timer.attach(self.benchmark_dataflow, start=Events.EPOCH_STARTED, resume=Events.ITERATION_STARTED, pause=Events.ITERATION_COMPLETED, step=Events.ITERATION_COMPLETED)
def __init__(self, model, config, evaluator, data_loader, tb_writer, run_info, logger, checkpoint_dir): """ Creates a new trainer object for training a model. :param model: model to train. Needs to inherit from the BaseModel class. :param config: dictionary containing the whole configuration of the experiment :param evaluator: Instance of the evaluator class, used to run evaluation on a specified schedule :param data_loader: pytorch data loader providing the training data :param tb_writer: tensorboardX summary writer :param run_info: sacred run info for loging training progress :param logger: python logger object :param checkpoint_dir: directory path for storing checkpoints """ self.run_info = run_info self.logger = logger self.data_loader = data_loader self.evaluator = evaluator self.engine = Engine(self._step) self.model = model self.config = config self.train_cfg = config['train'] self.tb_writer = tb_writer self.pbar = ProgressBar(ascii=True, desc='* Epoch') self.timer = Timer(average=True) self.save_last_checkpoint_handler = ModelCheckpoint( checkpoint_dir, 'last', save_interval=self.train_cfg['save_interval'], n_saved=self.train_cfg['save_n_last'], require_empty=False) self.add_handler()
def setup_timer(engine): timer = Timer(average=True) timer.attach(engine, start=Events.EPOCH_STARTED, resume=Events.ITERATION_STARTED, pause=Events.ITERATION_COMPLETED) return timer
def attach_events(self, description, environment=None, save_file = None): tim = Timer() tim.attach( self.engine, start=Events.STARTED, step=Events.ITERATION_COMPLETED, ) log_interval = 100 plot_interval = 10 @self.engine.on(Events.ITERATION_COMPLETED) def print_training_loss(engine): iter = (engine.state.iteration -1) if iter % log_interval == 0: print("Epoch[{}] Iteration: {} Time: {} Loss: {:.2f}".format( engine.state.epoch, iter, str(datetime.timedelta(seconds=int(tim.value()))), engine.state.output['loss'] )) if environment: vis = visdom.Visdom(env=environment) def create_plot_window(vis, xlabel, ylabel, title): return vis.line(X=np.array([1]), Y=np.array([np.nan]), opts=dict(xlabel=xlabel, ylabel=ylabel, title=title)) train_loss_window = create_plot_window(vis, '#Iterations', 'Loss', 'Training Loss {0}'.format(description)) @self.engine.on(Events.ITERATION_COMPLETED) def plot_training_loss(engine): iter = (engine.state.iteration -1) if iter % plot_interval == 0: vis.line(X=np.array([engine.state.iteration]), Y=np.array([engine.state.output['loss']]), update='append', win=train_loss_window)
def do_validate(cfg, model, val_loader): device = cfg.MODEL.DEVICE if device == "cuda": torch.cuda.set_device(cfg.MODEL.CUDA) evaluator = create_evaluator(model, device=device) RunningAverage(output_transform=lambda x: x).attach( evaluator, 'eva_avg_acc') timer = Timer(average=True) timer.attach(evaluator, start=Events.EPOCH_STARTED, resume=Events.ITERATION_STARTED, pause=Events.ITERATION_COMPLETED, step=Events.ITERATION_COMPLETED) acc_list = list() @evaluator.on(Events.ITERATION_COMPLETED) def log_accuracy(engine): iter = (engine.state.iteration - 1) % len(val_loader) + 1 print("Iteration[{}/{}]".format(iter, len(val_loader))) acc_list.append(engine.state.metrics['eva_avg_acc']) evaluator.run(val_loader) print("Validation Accuracy: {:1%}".format(np.array(acc_list).mean()))
class DataflowBenchmark: def __init__(self, num_iters=100, prepare_batch=None): from ignite.handlers import Timer device = idist.device() def upload_to_gpu(engine, batch): if prepare_batch is not None: x, y = prepare_batch(batch, device=device, non_blocking=False) self.num_iters = num_iters self.benchmark_dataflow = Engine(upload_to_gpu) @self.benchmark_dataflow.on(Events.ITERATION_COMPLETED(once=num_iters)) def stop_benchmark_dataflow(engine): engine.terminate() if idist.get_rank() == 0: @self.benchmark_dataflow.on( Events.ITERATION_COMPLETED(every=num_iters // 100)) def show_progress_benchmark_dataflow(engine): print(".", end=" ") self.timer = Timer(average=False) self.timer.attach( self.benchmark_dataflow, start=Events.EPOCH_STARTED, resume=Events.ITERATION_STARTED, pause=Events.ITERATION_COMPLETED, step=Events.ITERATION_COMPLETED, ) def attach(self, trainer, train_loader): from torch.utils.data import DataLoader @trainer.on(Events.STARTED) def run_benchmark(_): if idist.get_rank() == 0: print("-" * 50) print(" - Dataflow benchmark") self.benchmark_dataflow.run(train_loader) t = self.timer.value() if idist.get_rank() == 0: print(" ") print(" Total time ({} iterations) : {:.5f} seconds".format( self.num_iters, t)) print(" time per iteration : {} seconds".format( t / self.num_iters)) if isinstance(train_loader, DataLoader): num_images = train_loader.batch_size * self.num_iters print(" number of images / s : {}".format( num_images / t)) print("-" * 50)
def __init__(self, model, criterion, optimizer, lr_scheduler=None, metrics=None, test_metrics=None, save_path=".", name="Net"): self.model = model self.criterion = criterion self.optimizer = optimizer self.lr_scheduler = lr_scheduler self.metrics = metrics or {} self.test_metrics = test_metrics if test_metrics is None: self.test_metrics = metrics.copy() if 'loss' in metrics and isinstance(metrics['loss'], TrainLoss): self.test_metrics['loss'] = Loss(criterion=criterion) self.save_path = os.path.join(save_path, 'trainer') self.name = name current_time = datetime.now().strftime('%b%d_%H-%M-%S') log_dir = os.path.join(save_path, 'runs', self.name, current_time) self.writer = SummaryWriter(log_dir) self.metric_history = defaultdict(list) self.device = 'cuda' if CUDA else 'cpu' self._timer = Timer() self._epochs = 0 self.model.to(self.device)
def decorate_trainer(trainer, evaluator, scheduler, trainloader, valloader, sampler, workspace_dir, logfilename, logterm, max_epochs, checkpointer, model): trainer.setup_logger('trainer', workspace_dir, get_rank(), logfilename) trainer.setup_metric_logger() evaluator.setup_logger('evaluator', workspace_dir, get_rank(), logfilename) timer = Timer(average=True) timer.attach(trainer, step=E.ITERATION_COMPLETED) @trainer.on(E.EPOCH_STARTED) def start_epoch(engine): epoch = engine.state.epoch engine.logger.info(f'start training epoch {epoch}') if sampler is not None: sampler.set_epoch(epoch) trainer.add_event_handler(E.ITERATION_COMPLETED, scheduler) @trainer.on(E.ITERATION_COMPLETED) def log_results(engine): engine.metric_logger.update(**engine.state.output) global_iter = engine.state.iteration if global_iter % logterm == 0: epoch = engine.state.epoch iter_per_epoch = len(trainloader) local_iter = global_iter - (engine.state.epoch-1) * iter_per_epoch iter_remain = max_epochs * iter_per_epoch - global_iter elapse = timer._elapsed() elapse = str(datetime.timedelta(seconds=int(elapse))) eta = iter_remain * timer.value() eta = str(datetime.timedelta(seconds=int(eta))) logstr = f'elapse: {elapse} eta: {eta} ' logstr += f'Epoch [{epoch}/{max_epochs}] ' logstr += f'[{local_iter}/{iter_per_epoch}] ({global_iter}) ' logstr += f'lr: {engine.optimizer.param_groups[0]["lr"]:.3e} ' logstr += str(engine.metric_logger) engine.logger.info(logstr) if evaluator is not None and valloader is not None: @trainer.on(E.EPOCH_COMPLETED) def validate(engine): epoch = engine.state.epoch evaluator.logger.info(f'start evaluation epoch {epoch}') evaluator.run() result = evaluator.state.metrics fmtstr = f'Epoch {epoch} validation result: ' for k, v in result.items(): fmtstr += f'{k}: {v:.4f} ' evaluator.logger.info(fmtstr) if get_rank() == 0: trainer.add_event_handler( E.EPOCH_COMPLETED, checkpointer, {'epoch', model} )
def main(dataset_path, batch_size=256, max_epochs=10): assert torch.cuda.is_available() assert torch.backends.cudnn.enabled, "NVIDIA/Apex:Amp requires cudnn backend to be enabled." torch.backends.cudnn.benchmark = True device = "cuda" train_loader, test_loader, eval_train_loader = get_train_eval_loaders( dataset_path, batch_size=batch_size) model = wide_resnet50_2(num_classes=100).to(device) optimizer = SGD(model.parameters(), lr=0.01) criterion = CrossEntropyLoss().to(device) def train_step(engine, batch): x = convert_tensor(batch[0], device, non_blocking=True) y = convert_tensor(batch[1], device, non_blocking=True) optimizer.zero_grad() y_pred = model(x) loss = criterion(y_pred, y) loss.backward() optimizer.step() return loss.item() trainer = Engine(train_step) timer = Timer(average=True) timer.attach(trainer, step=Events.EPOCH_COMPLETED) ProgressBar(persist=True).attach( trainer, output_transform=lambda out: {"batch loss": out}) metrics = {"Accuracy": Accuracy(), "Loss": Loss(criterion)} evaluator = create_supervised_evaluator(model, metrics=metrics, device=device, non_blocking=True) def log_metrics(engine, title): for name in metrics: print("\t{} {}: {:.2f}".format(title, name, engine.state.metrics[name])) @trainer.on(Events.COMPLETED) def run_validation(_): print("- Mean elapsed time for 1 epoch: {}".format(timer.value())) print("- Metrics:") with evaluator.add_event_handler(Events.COMPLETED, log_metrics, "Train"): evaluator.run(eval_train_loader) with evaluator.add_event_handler(Events.COMPLETED, log_metrics, "Test"): evaluator.run(test_loader) trainer.run(train_loader, max_epochs=max_epochs)
def run(cfg, train_loader, tr_comp, saver, trainer, valid_dict): # TODO resume # trainer = Engine(...) # trainer.load_state_dict(state_dict) # trainer.run(data) # checkpoint handler = ModelCheckpoint(saver.model_dir, 'train', n_saved=3, create_dir=True) checkpoint_params = tr_comp.state_dict() trainer.add_event_handler(Events.EPOCH_COMPLETED, handler, checkpoint_params) timer = Timer(average=True) timer.attach(trainer, start=Events.EPOCH_STARTED, resume=Events.ITERATION_STARTED, pause=Events.ITERATION_COMPLETED, step=Events.ITERATION_COMPLETED) # average metric to attach on trainer names = ["Acc", "Loss"] names.extend(tr_comp.loss_function_map.keys()) for n in names: RunningAverage(output_transform=Run(n)).attach(trainer, n) @trainer.on(Events.EPOCH_COMPLETED) def adjust_learning_rate(engine): tr_comp.scheduler.step() @trainer.on(Events.ITERATION_COMPLETED(every=cfg.TRAIN.LOG_ITER_PERIOD)) def log_training_loss(engine): message = f"Epoch[{engine.state.epoch}], " + \ f"Iteration[{engine.state.iteration}/{len(train_loader)}], " + \ f"Base Lr: {tr_comp.scheduler.get_last_lr()[0]:.2e}, " for loss_name in engine.state.metrics.keys(): message += f"{loss_name}: {engine.state.metrics[loss_name]:.4f}, " if tr_comp.xent and tr_comp.xent.learning_weight: message += f"xentWeight: {tr_comp.xent.uncertainty.mean().item():.4f}, " logger.info(message) # adding handlers using `trainer.on` decorator API @trainer.on(Events.EPOCH_COMPLETED) def print_times(engine): logger.info('Epoch {} done. Time per batch: {:.3f}[s] Speed: {:.1f}[samples/s]' .format(engine.state.epoch, timer.value() * timer.step_count, train_loader.batch_size / timer.value())) logger.info('-' * 80) timer.reset() @trainer.on(Events.EPOCH_COMPLETED(every=cfg.EVAL.EPOCH_PERIOD)) def log_validation_results(engine): logger.info(f"Valid - Epoch: {engine.state.epoch}") eval_multi_dataset(cfg, valid_dict, tr_comp) trainer.run(train_loader, max_epochs=cfg.TRAIN.MAX_EPOCHS)
def _start_timer(self): timer = Timer(average=True) timer.attach( self._trainer, start=Events.EPOCH_STARTED, resume=Events.ITERATION_STARTED, pause=Events.ITERATION_COMPLETED, ) return timer
def engine_eval_geomreg(cfg, mode): prepare_config_eval(cfg) ckpt_path = cfg.eval.general.ckpt_path gpu = cfg.general.gpu root_path = cfg.log.root_path seed = cfg.general.seed eu.redirect_stdout(root_path, 'eval_geomreg-{}'.format(mode)) eu.print_config(cfg) eu.seed_random(seed) device = eu.get_device(gpu) dataloader = get_dataloader_eval_geomreg(cfg, mode) num_batches = len(dataloader) render_model, desc_model = get_models(cfg) render_model.to(device) render_model.eval_mode() render_model.print_params('render_model') desc_model.to(device) desc_model.eval_mode() desc_model.print_params('desc_model') assert eu.is_not_empty(ckpt_path) render_model.load(ckpt_path) desc_model.load(ckpt_path) engine = Engine( functools.partial(step_eval_geomreg, render_model=render_model, desc_model=desc_model, device=device, cfg=cfg)) timer = Timer(average=True) timer.attach(engine, start=Events.EPOCH_STARTED, pause=Events.EPOCH_COMPLETED, resume=Events.ITERATION_STARTED, step=Events.ITERATION_COMPLETED) engine.add_event_handler(Events.ITERATION_COMPLETED, eu.print_eval_log, timer=timer, num_batches=num_batches) engine.add_event_handler(Events.EXCEPTION_RAISED, eu.handle_exception) engine.run(dataloader, 1) return root_path
def timer_metric(engine, name='timer'): timer = Timer(average=True) timer.attach(engine, start=Events.EPOCH_STARTED, resume=Events.ITERATION_STARTED, pause=Events.ITERATION_COMPLETED, step=Events.ITERATION_COMPLETED) def handler(_engine): _engine.state.metrics[name] = timer.value() engine.add_event_handler(event_name=Events.ITERATION_COMPLETED, handler=handler)
def create_trainer(self): self.trainer = Engine(self.train_step) self.trainer.add_event_handler(Events.EPOCH_COMPLETED, self.K_step) self.trainer.add_event_handler(Events.EPOCH_COMPLETED, self.log_metrics, 'train') self.trainer.add_event_handler(Events.ITERATION_COMPLETED, self.write_metrics, 'train') self.pbar = ProgressBar() self.timer = Timer(average=True) self.timer.attach(self.trainer, start=Events.EPOCH_STARTED, resume=Events.ITERATION_STARTED, pause=Events.ITERATION_COMPLETED, step=Events.ITERATION_COMPLETED) self.trainer.add_event_handler(Events.EPOCH_COMPLETED, self.print_times)
def _create_timer(self): """ Create and attach a new timer to the trainer, registering callbacks. :return: the newly created timer :type: ignite.handlers.Timer """ timer = Timer(average=True) timer.attach(self.trainer_engine, start=Events.EPOCH_STARTED, resume=Events.ITERATION_STARTED, pause=Events.ITERATION_COMPLETED, step=Events.ITERATION_COMPLETED) return timer
def attach_evaluator_events(evaluator, experiment_dir, data_set: str): # Timers initializations timer_iter = Timer() timer_iter.attach(evaluator, start=Events.ITERATION_STARTED, resume=Events.ITERATION_STARTED, pause=Events.ITERATION_COMPLETED, step=Events.ITERATION_COMPLETED) timer_epoch = Timer() timer_epoch.attach(evaluator, start=Events.EPOCH_STARTED, resume=Events.ITERATION_STARTED, pause=Events.ITERATION_COMPLETED, step=Events.ITERATION_COMPLETED) # Evaluator iteration events evaluator.add_event_handler(Events.ITERATION_COMPLETED, log_eval_iter_screen, timer=timer_iter) # Ending of evaluation events eval_dir = osp.join(experiment_dir, 'inference_results') evaluator.add_event_handler(Events.COMPLETED, evaluate_model_without_training, eval_dir=eval_dir, timer=timer_epoch, data_set=data_set)
def visdom_loss_handler(modules_dict, model_name): """ Attaches plots and metrics to trainer. This handler creates or connects to an environment on a running Visdom dashboard and creates a line plot that tracks the loss function of a training loop as a function of the number of iterations. This can be attached to an Ignite Engine, and the training closure must have 'loss' as one of the keys in its return dict for this plot to be made. See documentation for Ignite (https://github.com/pytorch/ignite) and Visdom (https://github.com/facebookresearch/visdom) for more information. """ tim = Timer() tim.attach( trainer, start=Events.STARTED, step=Events.ITERATION_COMPLETED, ) vis = visdom.Visdom(env=environment) def create_plot_window(vis, xlabel, ylabel, title): return vis.line(X=np.array([1]), Y=np.array([np.nan]), opts=dict(xlabel=xlabel, ylabel=ylabel, title=title)) train_loss_window = create_plot_window(vis, '#Iterations', 'Loss', description) log_interval = 10 @trainer.on(Events.ITERATION_COMPLETED) def log_training_loss(engine): iter = (engine.state.iteration - 1) if iter % log_interval == 0: print("Epoch[{}] Iteration: {} Time: {} Loss: {:.2f}".format( engine.state.epoch, iter, str(datetime.timedelta(seconds=int(tim.value()))), engine.state.output)) vis.line(X=np.array([engine.state.iteration]), Y=np.array([engine.state.output]), update='append', win=train_loss_window) save_interval = 50 handler = ModelCheckpoint('/tmp/models', model_name, save_interval=save_interval, n_saved=5, create_dir=True, require_empty=False) trainer.add_event_handler(Events.ITERATION_COMPLETED, handler, modules_dict)
def __init__(self, G, D, criterionG, criterionD, optimizerG, optimizerD, lr_schedulerG=None, lr_schedulerD=None, make_latent=None, metrics=None, save_path=".", name="GAN", gan_type='gan'): self.G = G self.D = D self.criterionG = criterionG self.criterionD = criterionD self.optimizerG = optimizerG self.optimizerD = optimizerD self.lr_schedulerG = lr_schedulerG self.lr_schedulerD = lr_schedulerD self.make_latent = make_latent self.metrics = metrics or {} self.name = name root = Path(save_path).expanduser().absolute() self.save_path = root / 'gan_trainer' / self.name self.metric_history = defaultdict(list) self.device = 'cuda' if CUDA else 'cpu' self._timer = Timer() self._iterations = 0 self.G.to(self.device) self.D.to(self.device) assert gan_type in ['gan', 'acgan', 'cgan', 'infogan'] if gan_type == 'gan': self.create_fn = create_gan_trainer elif gan_type == 'acgan': self.create_fn = create_acgan_trainer elif gan_type == 'cgan': self.create_fn = create_cgan_trainer elif gan_type == 'infogan': self.create_fn = create_infogan_trainer
def __init__(self, model, criterion, optimizer, lr_scheduler=None, metrics=None, test_metrics=None, save_path=".", name="Net", fp16=False, lr_step_on_iter=None): self.fp16 = fp16 self.device = 'cuda' if CUDA else 'cpu' model.to(self.device) if self.fp16: from apex import amp model, optimizer = amp.initialize(model, optimizer, opt_level="O1", verbosity=0) self.model = model self.criterion = criterion self.optimizer = optimizer self.lr_scheduler = lr_scheduler self.metrics = metrics or {} self.test_metrics = test_metrics if test_metrics is None: self.test_metrics = metrics.copy() if 'loss' in metrics and isinstance(metrics['loss'], TrainLoss): self.test_metrics['loss'] = Loss(criterion=criterion) self.save_path = os.path.join(save_path, 'trainer') self.name = name self.lr_step_on_iter = lr_step_on_iter current_time = datetime.now().strftime('%b%d_%H-%M-%S') log_dir = os.path.join(save_path, 'runs', self.name, current_time) self.writer = SummaryWriter(log_dir) self.metric_history = defaultdict(list) self._timer = Timer() self._epochs = 0 self._verbose = True
def warp_common_handler(engine, option, networks_to_save, monitoring_metrics, add_message, use_folder_pathes): # attach progress bar pbar = ProgressBar() pbar.attach(engine, metric_names=monitoring_metrics) timer = Timer(average=True) timer.attach(engine, start=Events.EPOCH_STARTED, resume=Events.ITERATION_STARTED, pause=Events.ITERATION_COMPLETED, step=Events.ITERATION_COMPLETED) create_plots = make_handle_create_plots(option.output_dir, LOGS_FNAME, PLOT_FNAME) checkpoint_handler = ModelCheckpoint(option.output_dir, CKPT_PREFIX, save_interval=option.save_interval, n_saved=option.n_saved, require_empty=False, create_dir=True, save_as_state_dict=True) engine.add_event_handler(Events.ITERATION_COMPLETED, checkpoint_handler, to_save=networks_to_save) engine.add_event_handler(Events.ITERATION_COMPLETED, create_plots) engine.add_event_handler( Events.EXCEPTION_RAISED, make_handle_handle_exception(checkpoint_handler, networks_to_save, create_plots)) engine.add_event_handler( Events.STARTED, make_handle_make_dirs(option.output_dir, use_folder_pathes)) engine.add_event_handler(Events.STARTED, make_move_html(option.output_dir)) engine.add_event_handler(Events.STARTED, make_create_option_data(option)) engine.add_event_handler(Events.EPOCH_COMPLETED, make_handle_print_times(timer, pbar)) engine.add_event_handler( Events.ITERATION_COMPLETED, make_handle_print_logs(option.output_dir, option.epochs, option.print_freq, pbar, add_message)) return engine
def __init__(self): self._dataflow_timer = Timer() self._processing_timer = Timer() self._event_handlers_timer = Timer() self.dataflow_times = None self.processing_times = None self.event_handlers_times = None self._events = [ Events.EPOCH_STARTED, Events.EPOCH_COMPLETED, Events.ITERATION_STARTED, Events.ITERATION_COMPLETED, Events.GET_BATCH_STARTED, Events.GET_BATCH_COMPLETED, Events.COMPLETED ] self._fmethods = [ self._as_first_epoch_started, self._as_first_epoch_completed, self._as_first_iter_started, self._as_first_iter_completed, self._as_first_get_batch_started, self._as_first_get_batch_completed, self._as_first_completed ] self._lmethods = [ self._as_last_epoch_started, self._as_last_epoch_completed, self._as_last_iter_started, self._as_last_iter_completed, self._as_last_get_batch_started, self._as_last_get_batch_completed, self._as_last_completed ]
def __init__(self): self._dataflow_timer = Timer() self._processing_timer = Timer() self._event_handlers_timer = Timer() self.dataflow_times = None self.processing_times = None self.event_handlers_times = None
def __init__(self) -> None: self._dataflow_timer = Timer() self._processing_timer = Timer() self._event_handlers_timer = Timer() self.dataflow_times = [] # type: List[float] self.processing_times = [] # type: List[float] self.event_handlers_times = {} # type: Dict[EventEnum, Dict[str, List[float]]]
def __init__(self, G, D, criterionG, criterionD, optimizerG, optimizerD, lr_schedulerG=None, lr_schedulerD=None, metrics={}, device=None, save_path=".", name="Net"): self.G = G self.D = D self.criterionG = criterionG self.criterionD = criterionD self.optimizerG = optimizerG self.optimizerD = optimizerD self.lr_schedulerG = lr_schedulerG self.lr_schedulerD = lr_schedulerD self.metrics = metrics self.device = device or ('cuda' if torch.cuda.is_available() else 'cpu') self.save_path = save_path self.name = name self.metric_history = defaultdict(list) self._print_callbacks = set([lambda msg: print(msg, end='')]) self._weixin_logined = False self._timer = Timer() self._epochs = 0 self.G.to(self.device) self.D.to(self.device)
def __init__(self) -> None: self._dataflow_timer = Timer() self._processing_timer = Timer() self._event_handlers_timer = Timer() self.dataflow_times = torch.zeros(1) self.processing_times = torch.zeros(1) self.event_handlers_times = {} # type: Dict[EventEnum, torch.Tensor] self._events = [ Events.EPOCH_STARTED, Events.EPOCH_COMPLETED, Events.ITERATION_STARTED, Events.ITERATION_COMPLETED, Events.GET_BATCH_STARTED, Events.GET_BATCH_COMPLETED, Events.COMPLETED, ] self._fmethods = [ self._as_first_epoch_started, self._as_first_epoch_completed, self._as_first_iter_started, self._as_first_iter_completed, self._as_first_get_batch_started, self._as_first_get_batch_completed, self._as_first_completed, ] self._lmethods = [ self._as_last_epoch_started, self._as_last_epoch_completed, self._as_last_iter_started, self._as_last_iter_completed, self._as_last_get_batch_started, self._as_last_get_batch_completed, self._as_last_completed, ]
# loss function loss_fn = nn.CrossEntropyLoss() # optimizer optimizer = optim.SGD(model.parameters(), lr=init_lr, momentum=0.9, weight_decay=5e-4) # scheduler scheduler = CosineAnnealingScheduler(optimizer, 'lr', init_lr, end_lr, 4*len(trainloader), cycle_mult=1.5, start_value_mult=0.1) scheduler = create_lr_scheduler_with_warmup(scheduler, warmup_start_value=0., warmup_end_value=init_lr, warmup_duration=len(trainloader)) # create trainer trainer = create_trainer(model, optimizer, loss_fn, device=device) trainer.add_event_handler(Events.ITERATION_STARTED, scheduler) # add timer for each iteration timer = Timer(average=False) # logging training loss def log_loss(engine): i = engine.state.iteration e = engine.state.epoch if i % 100 == 0: print('[Iters {:0>7d}/{:0>2d}, {:.2f}s/100 iters, lr={:.4E}] loss={:.4f}'.format(i, e, timer.value(), optimizer.param_groups[0]['lr'], engine.state.output)) timer.reset() trainer.add_event_handler(Events.ITERATION_COMPLETED, log_loss) # Evaluation metrics = { 'loss': Loss(loss_fn), 'acc': Accuracy()
def main(dataset, dataroot, download, augment, batch_size, eval_batch_size, epochs, saved_model, seed, hidden_channels, K, L, actnorm_scale, flow_permutation, flow_coupling, LU_decomposed, learn_top, y_condition, y_weight, max_grad_clip, max_grad_norm, lr, n_workers, cuda, n_init_batches, warmup_steps, output_dir, saved_optimizer, warmup, fresh, logittransform, gan, disc_lr, sn, flowgan, eval_every, ld_on_samples, weight_gan, weight_prior, weight_logdet, jac_reg_lambda, affine_eps, no_warm_up, optim_name, clamp, svd_every, eval_only, no_actnorm, affine_scale_eps, actnorm_max_scale, no_conv_actnorm, affine_max_scale, actnorm_eps, init_sample, no_split, disc_arch, weight_entropy_reg, db): check_manual_seed(seed) ds = check_dataset(dataset, dataroot, augment, download) image_shape, num_classes, train_dataset, test_dataset = ds # Note: unsupported for now multi_class = False train_loader = data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=n_workers, drop_last=True) test_loader = data.DataLoader(test_dataset, batch_size=eval_batch_size, shuffle=False, num_workers=n_workers, drop_last=False) model = Glow(image_shape, hidden_channels, K, L, actnorm_scale, flow_permutation, flow_coupling, LU_decomposed, num_classes, learn_top, y_condition, logittransform, sn, affine_eps, no_actnorm, affine_scale_eps, actnorm_max_scale, no_conv_actnorm, affine_max_scale, actnorm_eps, no_split) model = model.to(device) if disc_arch == 'mine': discriminator = mine.Discriminator(image_shape[-1]) elif disc_arch == 'biggan': discriminator = cgan_models.Discriminator( image_channels=image_shape[-1], conditional_D=False) elif disc_arch == 'dcgan': discriminator = DCGANDiscriminator(image_shape[0], 64, image_shape[-1]) elif disc_arch == 'inv': discriminator = InvDiscriminator( image_shape, hidden_channels, K, L, actnorm_scale, flow_permutation, flow_coupling, LU_decomposed, num_classes, learn_top, y_condition, logittransform, sn, affine_eps, no_actnorm, affine_scale_eps, actnorm_max_scale, no_conv_actnorm, affine_max_scale, actnorm_eps, no_split) discriminator = discriminator.to(device) D_optimizer = optim.Adam(filter(lambda p: p.requires_grad, discriminator.parameters()), lr=disc_lr, betas=(.5, .99), weight_decay=0) if optim_name == 'adam': optimizer = optim.Adam(model.parameters(), lr=lr, betas=(.5, .99), weight_decay=0) elif optim_name == 'adamax': optimizer = optim.Adamax(model.parameters(), lr=lr, weight_decay=5e-5) if not no_warm_up: lr_lambda = lambda epoch: min(1.0, (epoch + 1) / warmup) scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lr_lambda) iteration_fieldnames = [ 'global_iteration', 'fid', 'sample_pad', 'train_bpd', 'eval_bpd', 'pad', 'batch_real_acc', 'batch_fake_acc', 'batch_acc' ] iteration_logger = CSVLogger(fieldnames=iteration_fieldnames, filename=os.path.join(output_dir, 'iteration_log.csv')) iteration_fieldnames = [ 'global_iteration', 'condition_num', 'max_sv', 'min_sv', 'inverse_condition_num', 'inverse_max_sv', 'inverse_min_sv' ] svd_logger = CSVLogger(fieldnames=iteration_fieldnames, filename=os.path.join(output_dir, 'svd_log.csv')) # test_iter = test_loader.__iter__() N_inception = 1000 x_real_inception = torch.cat([ test_iter.__next__()[0].to(device) for _ in range(N_inception // args.batch_size + 1) ], 0)[:N_inception] x_real_inception = x_real_inception + .5 x_for_recon = test_iter.__next__()[0].to(device) def gan_step(engine, batch): assert not y_condition if 'iter_ind' in dir(engine): engine.iter_ind += 1 else: engine.iter_ind = -1 losses = {} model.train() discriminator.train() x, y = batch x = x.to(device) def run_noised_disc(discriminator, x): x = uniform_binning_correction(x)[0] return discriminator(x) real_acc = fake_acc = acc = 0 if weight_gan > 0: fake = generate_from_noise(model, x.size(0), clamp=clamp) D_real_scores = run_noised_disc(discriminator, x.detach()) D_fake_scores = run_noised_disc(discriminator, fake.detach()) ones_target = torch.ones((x.size(0), 1), device=x.device) zeros_target = torch.zeros((x.size(0), 1), device=x.device) D_real_accuracy = torch.sum( torch.round(F.sigmoid(D_real_scores)) == ones_target).float() / ones_target.size(0) D_fake_accuracy = torch.sum( torch.round(F.sigmoid(D_fake_scores)) == zeros_target).float() / zeros_target.size(0) D_real_loss = F.binary_cross_entropy_with_logits( D_real_scores, ones_target) D_fake_loss = F.binary_cross_entropy_with_logits( D_fake_scores, zeros_target) D_loss = (D_real_loss + D_fake_loss) / 2 gp = gradient_penalty( x.detach(), fake.detach(), lambda _x: run_noised_disc(discriminator, _x)) D_loss_plus_gp = D_loss + 10 * gp D_optimizer.zero_grad() D_loss_plus_gp.backward() D_optimizer.step() # Train generator fake = generate_from_noise(model, x.size(0), clamp=clamp, guard_nans=False) G_loss = F.binary_cross_entropy_with_logits( run_noised_disc(discriminator, fake), torch.ones((x.size(0), 1), device=x.device)) # Trace real_acc = D_real_accuracy.item() fake_acc = D_fake_accuracy.item() acc = .5 * (D_fake_accuracy.item() + D_real_accuracy.item()) z, nll, y_logits, (prior, logdet) = model.forward(x, None, return_details=True) train_bpd = nll.mean().item() loss = 0 if weight_gan > 0: loss = loss + weight_gan * G_loss if weight_prior > 0: loss = loss + weight_prior * -prior.mean() if weight_logdet > 0: loss = loss + weight_logdet * -logdet.mean() if weight_entropy_reg > 0: _, _, _, (sample_prior, sample_logdet) = model.forward(fake, None, return_details=True) # notice this is actually "decreasing" sample likelihood. loss = loss + weight_entropy_reg * (sample_prior.mean() + sample_logdet.mean()) # Jac Reg if jac_reg_lambda > 0: # Sample x_samples = generate_from_noise(model, args.batch_size, clamp=clamp).detach() x_samples.requires_grad_() z = model.forward(x_samples, None, return_details=True)[0] other_zs = torch.cat([ split._last_z2.view(x.size(0), -1) for split in model.flow.splits ], -1) all_z = torch.cat([other_zs, z.view(x.size(0), -1)], -1) sample_foward_jac = compute_jacobian_regularizer(x_samples, all_z, n_proj=1) _, c2, h, w = model.prior_h.shape c = c2 // 2 zshape = (batch_size, c, h, w) randz = torch.randn(zshape).to(device) randz = torch.autograd.Variable(randz, requires_grad=True) images = model(z=randz, y_onehot=None, temperature=1, reverse=True, batch_size=0) other_zs = [split._last_z2 for split in model.flow.splits] all_z = [randz] + other_zs sample_inverse_jac = compute_jacobian_regularizer_manyinputs( all_z, images, n_proj=1) # Data x.requires_grad_() z = model.forward(x, None, return_details=True)[0] other_zs = torch.cat([ split._last_z2.view(x.size(0), -1) for split in model.flow.splits ], -1) all_z = torch.cat([other_zs, z.view(x.size(0), -1)], -1) data_foward_jac = compute_jacobian_regularizer(x, all_z, n_proj=1) _, c2, h, w = model.prior_h.shape c = c2 // 2 zshape = (batch_size, c, h, w) z.requires_grad_() images = model(z=z, y_onehot=None, temperature=1, reverse=True, batch_size=0) other_zs = [split._last_z2 for split in model.flow.splits] all_z = [z] + other_zs data_inverse_jac = compute_jacobian_regularizer_manyinputs( all_z, images, n_proj=1) # loss = loss + jac_reg_lambda * (sample_foward_jac + sample_inverse_jac ) loss = loss + jac_reg_lambda * (sample_foward_jac + sample_inverse_jac + data_foward_jac + data_inverse_jac) if not eval_only: optimizer.zero_grad() loss.backward() if not db: assert max_grad_clip == max_grad_norm == 0 if max_grad_clip > 0: torch.nn.utils.clip_grad_value_(model.parameters(), max_grad_clip) if max_grad_norm > 0: torch.nn.utils.clip_grad_norm_(model.parameters(), max_grad_norm) # Replace NaN gradient with 0 for p in model.parameters(): if p.requires_grad and p.grad is not None: g = p.grad.data g[g != g] = 0 optimizer.step() if engine.iter_ind % 100 == 0: with torch.no_grad(): fake = generate_from_noise(model, x.size(0), clamp=clamp) z = model.forward(fake, None, return_details=True)[0] print("Z max min") print(z.max().item(), z.min().item()) if (fake != fake).float().sum() > 0: title = 'NaNs' else: title = "Good" grid = make_grid((postprocess(fake.detach().cpu(), dataset)[:30]), nrow=6).permute(1, 2, 0) plt.figure(figsize=(10, 10)) plt.imshow(grid) plt.axis('off') plt.title(title) plt.savefig( os.path.join(output_dir, f'sample_{engine.iter_ind}.png')) if engine.iter_ind % eval_every == 0: def check_all_zero_except_leading(x): return x % 10**np.floor(np.log10(x)) == 0 if engine.iter_ind == 0 or check_all_zero_except_leading( engine.iter_ind): torch.save( model.state_dict(), os.path.join(output_dir, f'ckpt_sd_{engine.iter_ind}.pt')) model.eval() with torch.no_grad(): # Plot recon fpath = os.path.join(output_dir, '_recon', f'recon_{engine.iter_ind}.png') sample_pad = run_recon_evolution( model, generate_from_noise(model, args.batch_size, clamp=clamp).detach(), fpath) print( f"Iter: {engine.iter_ind}, Recon Sample PAD: {sample_pad}") pad = run_recon_evolution(model, x_for_recon, fpath) print(f"Iter: {engine.iter_ind}, Recon PAD: {pad}") pad = pad.item() sample_pad = sample_pad.item() # Inception score sample = torch.cat([ generate_from_noise(model, args.batch_size, clamp=clamp) for _ in range(N_inception // args.batch_size + 1) ], 0)[:N_inception] sample = sample + .5 if (sample != sample).float().sum() > 0: print("Sample NaNs") raise else: fid = run_fid(x_real_inception.clamp_(0, 1), sample.clamp_(0, 1)) print(f'fid: {fid}, global_iter: {engine.iter_ind}') # Eval BPD eval_bpd = np.mean([ model.forward(x.to(device), None, return_details=True)[1].mean().item() for x, _ in test_loader ]) stats_dict = { 'global_iteration': engine.iter_ind, 'fid': fid, 'train_bpd': train_bpd, 'pad': pad, 'eval_bpd': eval_bpd, 'sample_pad': sample_pad, 'batch_real_acc': real_acc, 'batch_fake_acc': fake_acc, 'batch_acc': acc } iteration_logger.writerow(stats_dict) plot_csv(iteration_logger.filename) model.train() if engine.iter_ind + 2 % svd_every == 0: model.eval() svd_dict = {} ret = utils.computeSVDjacobian(x_for_recon, model) D_for, D_inv = ret['D_for'], ret['D_inv'] cn = float(D_for.max() / D_for.min()) cn_inv = float(D_inv.max() / D_inv.min()) svd_dict['global_iteration'] = engine.iter_ind svd_dict['condition_num'] = cn svd_dict['max_sv'] = float(D_for.max()) svd_dict['min_sv'] = float(D_for.min()) svd_dict['inverse_condition_num'] = cn_inv svd_dict['inverse_max_sv'] = float(D_inv.max()) svd_dict['inverse_min_sv'] = float(D_inv.min()) svd_logger.writerow(svd_dict) # plot_utils.plot_stability_stats(output_dir) # plot_utils.plot_individual_figures(output_dir, 'svd_log.csv') model.train() if eval_only: sys.exit() # Dummy losses['total_loss'] = torch.mean(nll).item() return losses def eval_step(engine, batch): model.eval() x, y = batch x = x.to(device) with torch.no_grad(): if y_condition: y = y.to(device) z, nll, y_logits = model(x, y) losses = compute_loss_y(nll, y_logits, y_weight, y, multi_class, reduction='none') else: z, nll, y_logits = model(x, None) losses = compute_loss(nll, reduction='none') return losses trainer = Engine(gan_step) # else: # trainer = Engine(step) checkpoint_handler = ModelCheckpoint(output_dir, 'glow', save_interval=5, n_saved=1, require_empty=False) trainer.add_event_handler(Events.EPOCH_COMPLETED, checkpoint_handler, { 'model': model, 'optimizer': optimizer }) monitoring_metrics = ['total_loss'] RunningAverage(output_transform=lambda x: x['total_loss']).attach( trainer, 'total_loss') evaluator = Engine(eval_step) # Note: replace by https://github.com/pytorch/ignite/pull/524 when released Loss(lambda x, y: torch.mean(x), output_transform=lambda x: (x['total_loss'], torch.empty(x['total_loss'].shape[0]))).attach( evaluator, 'total_loss') if y_condition: monitoring_metrics.extend(['nll']) RunningAverage(output_transform=lambda x: x['nll']).attach( trainer, 'nll') # Note: replace by https://github.com/pytorch/ignite/pull/524 when released Loss(lambda x, y: torch.mean(x), output_transform=lambda x: (x['nll'], torch.empty(x['nll'].shape[0]))).attach( evaluator, 'nll') pbar = ProgressBar() pbar.attach(trainer, metric_names=monitoring_metrics) # load pre-trained model if given if saved_model: print("Loading...") print(saved_model) loaded = torch.load(saved_model) # if 'Glow' in str(type(loaded)): # model = loaded # else: # raise # # if 'Glow' in str(type(loaded)): # # loaded = loaded.state_dict() model.load_state_dict(loaded) model.set_actnorm_init() if saved_optimizer: optimizer.load_state_dict(torch.load(saved_optimizer)) file_name, ext = os.path.splitext(saved_model) resume_epoch = int(file_name.split('_')[-1]) @trainer.on(Events.STARTED) def resume_training(engine): engine.state.epoch = resume_epoch engine.state.iteration = resume_epoch * len( engine.state.dataloader) @trainer.on(Events.STARTED) def init(engine): if saved_model: return model.train() print("Initializing Actnorm...") init_batches = [] init_targets = [] if n_init_batches == 0: model.set_actnorm_init() return with torch.no_grad(): if init_sample: generate_from_noise(model, args.batch_size * args.n_init_batches) else: for batch, target in islice(train_loader, None, n_init_batches): init_batches.append(batch) init_targets.append(target) init_batches = torch.cat(init_batches).to(device) assert init_batches.shape[0] == n_init_batches * batch_size if y_condition: init_targets = torch.cat(init_targets).to(device) else: init_targets = None model(init_batches, init_targets) @trainer.on(Events.EPOCH_COMPLETED) def evaluate(engine): evaluator.run(test_loader) if not no_warm_up: scheduler.step() metrics = evaluator.state.metrics losses = ', '.join( [f"{key}: {value:.2f}" for key, value in metrics.items()]) print(f'Validation Results - Epoch: {engine.state.epoch} {losses}') timer = Timer(average=True) timer.attach(trainer, start=Events.EPOCH_STARTED, resume=Events.ITERATION_STARTED, pause=Events.ITERATION_COMPLETED, step=Events.ITERATION_COMPLETED) @trainer.on(Events.EPOCH_COMPLETED) def print_times(engine): pbar.log_message( f'Epoch {engine.state.epoch} done. Time per batch: {timer.value():.3f}[s]' ) timer.reset() trainer.run(train_loader, epochs)
def do_train(cfg, model, train_loader, val_loader, optimizer, scheduler, loss_fn, num_query, start_epoch): log_period = cfg.SOLVER.LOG_PERIOD checkpoint_period = cfg.SOLVER.CHECKPOINT_PERIOD eval_period = cfg.SOLVER.EVAL_PERIOD output_dir = cfg.OUTPUT_DIR device = cfg.MODEL.DEVICE epochs = cfg.SOLVER.MAX_EPOCHS logger = logging.getLogger("reid_baseline.train") logger.info("Start training") trainer = create_supervised_trainer(model, optimizer, loss_fn, cfg=cfg, device=device) evaluator = create_supervised_evaluator( model, metrics={ 'r1_mAP': R1_mAP(num_query, max_rank=50, feat_norm=cfg.TEST.FEAT_NORM) }, device=device) checkpointer = ModelCheckpoint(output_dir, cfg.MODEL.NAME, checkpoint_period, n_saved=epochs, require_empty=False, start_iter=start_epoch) timer = Timer(average=True) trainer.add_event_handler(Events.EPOCH_COMPLETED, checkpointer, { 'model': model, 'optimizer': optimizer }) timer.attach(trainer, start=Events.EPOCH_STARTED, resume=Events.ITERATION_STARTED, pause=Events.ITERATION_COMPLETED, step=Events.ITERATION_COMPLETED) # average metric to attach on trainer RunningAverage(output_transform=lambda x: x[0]).attach(trainer, 'avg_loss') RunningAverage(output_transform=lambda x: x[1]).attach(trainer, 'avg_acc') @trainer.on(Events.STARTED) def start_training(engine): engine.state.epoch = start_epoch engine.state.total_iteration = 0 @trainer.on(Events.EPOCH_STARTED) def adjust_learning_rate(engine): scheduler.step() engine.state.iteration = 0 @trainer.on(Events.ITERATION_COMPLETED) def log_training_loss(engine): iter = (engine.state.iteration - 1) % len(train_loader) + 1 if iter % log_period == 0: logger.info( "Epoch[{}] Iteration[{}/{}] Loss: {:.3f}, Acc: {:.3f}, Base Lr: {:.2e}" .format(engine.state.epoch, iter, len(train_loader), engine.state.metrics['avg_loss'], engine.state.metrics['avg_acc'], scheduler.get_lr()[0])) # adding handlers using `trainer.on` decorator API @trainer.on(Events.EPOCH_COMPLETED) def print_times(engine): logger.info( 'Epoch {} done. Time per batch: {:.3f}[s] Speed: {:.1f}[samples/s]' .format(engine.state.epoch, timer.value() * timer.step_count, train_loader.batch_size / timer.value())) logger.info('-' * 10) timer.reset() @evaluator.on(Events.ITERATION_COMPLETED) def log_evaluate_extract_features(engine): iter = (engine.state.iteration - 1) % len(val_loader) + 1 if iter % log_period == 0: logger.info("Extract Features Iteration[{}/{}]".format( iter, len(val_loader))) @trainer.on(Events.EPOCH_COMPLETED) def log_validation_results(engine): if engine.state.epoch % eval_period == 0 or engine.state.epoch > 120: evaluator.run(val_loader) cmc, mAP = evaluator.state.metrics['r1_mAP'] logger.info("Validation Results - Epoch: {}".format( engine.state.epoch)) logger.info("mAP: {:.1%}".format(mAP)) for r in [1, 5, 10]: logger.info("CMC curve, Rank-{:<3}:{:.1%}".format( r, cmc[r - 1])) trainer.run(train_loader, max_epochs=epochs)
def run(args): train_loader, val_loader = get_data_loaders(args.dir, args.batch_size, args.num_workers) if args.seed is not None: torch.manual_seed(args.seed) device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') num_classes = CityscapesDataset.num_instance_classes() + 1 model = models.box2pix(num_classes=num_classes) model.init_from_googlenet() writer = create_summary_writer(model, train_loader, args.log_dir) if torch.cuda.device_count() > 1: print("Using %d GPU(s)" % torch.cuda.device_count()) model = nn.DataParallel(model) model = model.to(device) semantics_criterion = nn.CrossEntropyLoss(ignore_index=255) offsets_criterion = nn.MSELoss() box_criterion = BoxLoss(num_classes, gamma=2) multitask_criterion = MultiTaskLoss().to(device) box_coder = BoxCoder() optimizer = optim.Adam([{ 'params': model.parameters(), 'weight_decay': 5e-4 }, { 'params': multitask_criterion.parameters() }], lr=args.lr) if args.resume: if os.path.isfile(args.resume): print("Loading checkpoint '{}'".format(args.resume)) checkpoint = torch.load(args.resume) args.start_epoch = checkpoint['epoch'] model.load_state_dict(checkpoint['model']) optimizer.load_state_dict(checkpoint['optimizer']) multitask_criterion.load_state_dict(checkpoint['multitask']) print("Loaded checkpoint '{}' (Epoch {})".format( args.resume, checkpoint['epoch'])) else: print("No checkpoint found at '{}'".format(args.resume)) def _prepare_batch(batch, non_blocking=True): x, instance, boxes, labels = batch return (convert_tensor(x, device=device, non_blocking=non_blocking), convert_tensor(instance, device=device, non_blocking=non_blocking), convert_tensor(boxes, device=device, non_blocking=non_blocking), convert_tensor(labels, device=device, non_blocking=non_blocking)) def _update(engine, batch): model.train() optimizer.zero_grad() x, instance, boxes, labels = _prepare_batch(batch) boxes, labels = box_coder.encode(boxes, labels) loc_preds, conf_preds, semantics_pred, offsets_pred = model(x) semantics_loss = semantics_criterion(semantics_pred, instance) offsets_loss = offsets_criterion(offsets_pred, instance) box_loss, conf_loss = box_criterion(loc_preds, boxes, conf_preds, labels) loss = multitask_criterion(semantics_loss, offsets_loss, box_loss, conf_loss) loss.backward() optimizer.step() return { 'loss': loss.item(), 'loss_semantics': semantics_loss.item(), 'loss_offsets': offsets_loss.item(), 'loss_ssdbox': box_loss.item(), 'loss_ssdclass': conf_loss.item() } trainer = Engine(_update) checkpoint_handler = ModelCheckpoint(args.output_dir, 'checkpoint', save_interval=1, n_saved=10, require_empty=False, create_dir=True, save_as_state_dict=False) timer = Timer(average=True) # attach running average metrics train_metrics = [ 'loss', 'loss_semantics', 'loss_offsets', 'loss_ssdbox', 'loss_ssdclass' ] for m in train_metrics: transform = partial(lambda x, metric: x[metric], metric=m) RunningAverage(output_transform=transform).attach(trainer, m) # attach progress bar pbar = ProgressBar(persist=True) pbar.attach(trainer, metric_names=train_metrics) checkpoint = { 'model': model.state_dict(), 'epoch': trainer.state.epoch, 'optimizer': optimizer.state_dict(), 'multitask': multitask_criterion.state_dict() } trainer.add_event_handler(event_name=Events.EPOCH_COMPLETED, handler=checkpoint_handler, to_save={'checkpoint': checkpoint}) timer.attach(trainer, start=Events.EPOCH_STARTED, resume=Events.ITERATION_STARTED, pause=Events.ITERATION_COMPLETED, step=Events.ITERATION_COMPLETED) def _inference(engine, batch): model.eval() with torch.no_grad(): x, instance, boxes, labels = _prepare_batch(batch) loc_preds, conf_preds, semantics, offsets_pred = model(x) boxes_preds, labels_preds, scores_preds = box_coder.decode( loc_preds, F.softmax(conf_preds, dim=1), score_thresh=0.01) semantics_loss = semantics_criterion(semantics, instance) offsets_loss = offsets_criterion(offsets_pred, instance) box_loss, conf_loss = box_criterion(loc_preds, boxes, conf_preds, labels) semantics_pred = semantics.argmax(dim=1) instances = helper.assign_pix2box(semantics_pred, offsets_pred, boxes_preds, labels_preds) return { 'loss': (semantics_loss, offsets_loss, { 'box_loss': box_loss, 'conf_loss': conf_loss }), 'objects': (boxes_preds, labels_preds, scores_preds, boxes, labels), 'semantics': semantics_pred, 'instances': instances } train_evaluator = Engine(_inference) Loss(multitask_criterion, output_transform=lambda x: x['loss']).attach(train_evaluator, 'loss') MeanAveragePrecision(num_classes, output_transform=lambda x: x['objects']).attach( train_evaluator, 'objects') IntersectionOverUnion(num_classes, output_transform=lambda x: x['semantics']).attach( train_evaluator, 'semantics') evaluator = Engine(_inference) Loss(multitask_criterion, output_transform=lambda x: x['loss']).attach(evaluator, 'loss') MeanAveragePrecision(num_classes, output_transform=lambda x: x['objects']).attach( evaluator, 'objects') IntersectionOverUnion(num_classes, output_transform=lambda x: x['semantics']).attach( evaluator, 'semantics') @trainer.on(Events.STARTED) def initialize(engine): if args.resume: engine.state.epoch = args.start_epoch @trainer.on(Events.EPOCH_COMPLETED) def print_times(engine): pbar.log_message( "Epoch [{}/{}] done. Time per batch: {:.3f}[s]".format( engine.state.epoch, engine.state.max_epochs, timer.value())) timer.reset() @trainer.on(Events.ITERATION_COMPLETED) def log_training_loss(engine): iteration = (engine.state.iteration - 1) % len(train_loader) + 1 if iteration % args.log_interval == 0: writer.add_scalar("training/loss", engine.state.output['loss'], engine.state.iteration) @trainer.on(Events.EPOCH_COMPLETED) def log_training_results(engine): train_evaluator.run(train_loader) metrics = train_evaluator.state.metrics loss = metrics['loss'] mean_ap = metrics['objects'] iou = metrics['semantics'] pbar.log_message( 'Training results - Epoch: [{}/{}]: Loss: {:.4f}, mAP(50%): {:.1f}, IoU: {:.1f}' .format(loss, evaluator.state.epochs, evaluator.state.max_epochs, mean_ap, iou * 100.0)) writer.add_scalar("train-val/loss", loss, engine.state.epoch) writer.add_scalar("train-val/mAP", mean_ap, engine.state.epoch) writer.add_scalar("train-val/IoU", iou, engine.state.epoch) @trainer.on(Events.EPOCH_COMPLETED) def log_validation_results(engine): evaluator.run(val_loader) metrics = evaluator.state.metrics loss = metrics['loss'] mean_ap = metrics['objects'] iou = metrics['semantics'] pbar.log_message( 'Validation results - Epoch: [{}/{}]: Loss: {:.4f}, mAP(50%): {:.1f}, IoU: {:.1f}' .format(loss, evaluator.state.epochs, evaluator.state.max_epochs, mean_ap, iou * 100.0)) writer.add_scalar("validation/loss", loss, engine.state.epoch) writer.add_scalar("validation/mAP", mean_ap, engine.state.epoch) writer.add_scalar("validation/IoU", iou, engine.state.epoch) @trainer.on(Events.EXCEPTION_RAISED) def handle_exception(engine, e): if isinstance(e, KeyboardInterrupt) and (engine.state.iteration > 1): engine.terminate() warnings.warn("KeyboardInterrupt caught. Exiting gracefully.") checkpoint_handler(engine, {'model_exception': model}) else: raise e @trainer.on(Events.COMPLETED) def save_final_model(engine): checkpoint_handler(engine, {'final': model}) trainer.run(train_loader, max_epochs=args.epochs) writer.close()