def run(conf: DictConfig, local_rank=0, distributed=False): epochs = conf.train.epochs epoch_length = conf.train.epoch_length torch.manual_seed(conf.general.seed) if distributed: rank = dist.get_rank() num_replicas = dist.get_world_size() torch.cuda.set_device(local_rank) else: rank = 0 num_replicas = 1 torch.cuda.set_device(conf.general.gpu) device = torch.device('cuda') loader_args = dict() master_node = rank == 0 if master_node: print(conf.pretty()) if num_replicas > 1: epoch_length = epoch_length // num_replicas loader_args = dict(rank=rank, num_replicas=num_replicas) train_dl = create_train_loader(conf.data, **loader_args) if epoch_length < 1: epoch_length = len(train_dl) metric_names = list(conf.logging.stats) metrics = create_metrics(metric_names, device if distributed else None) G = instantiate(conf.model.G).to(device) D = instantiate(conf.model.D).to(device) G_loss = instantiate(conf.loss.G).to(device) D_loss = instantiate(conf.loss.D).to(device) G_opt = instantiate(conf.optim.G, G.parameters()) D_opt = instantiate(conf.optim.D, D.parameters()) G_ema = None if master_node and conf.G_smoothing.enabled: G_ema = instantiate(conf.model.G) if not conf.G_smoothing.use_cpu: G_ema = G_ema.to(device) G_ema.load_state_dict(G.state_dict()) G_ema.requires_grad_(False) to_save = { 'G': G, 'D': D, 'G_loss': G_loss, 'D_loss': D_loss, 'G_opt': G_opt, 'D_opt': D_opt, 'G_ema': G_ema } if master_node and conf.logging.model: logging.info(G) logging.info(D) if distributed: ddp_kwargs = dict(device_ids=[ local_rank, ], output_device=local_rank) G = torch.nn.parallel.DistributedDataParallel(G, **ddp_kwargs) D = torch.nn.parallel.DistributedDataParallel(D, **ddp_kwargs) train_options = { 'train': dict(conf.train), 'snapshot': dict(conf.snapshots), 'smoothing': dict(conf.G_smoothing), 'distributed': distributed } bs_dl = int(conf.data.loader.batch_size) * num_replicas bs_eff = conf.train.batch_size if bs_eff % bs_dl: raise AttributeError( "Effective batch size should be divisible by data-loader batch size " "multiplied by number of devices in use" ) # until there is no special bs for master node... upd_interval = max(bs_eff // bs_dl, 1) train_options['train']['update_interval'] = upd_interval if epoch_length < len(train_dl): # ideally epoch_length should be tied to the effective batch_size only # and the ignite trainer counts data-loader iterations epoch_length *= upd_interval train_loop, sample_images = create_train_closures(G, D, G_loss, D_loss, G_opt, D_opt, G_ema=G_ema, device=device, options=train_options) trainer = create_trainer(train_loop, metrics, device, num_replicas) to_save['trainer'] = trainer every_iteration = Events.ITERATION_COMPLETED trainer.add_event_handler(every_iteration, TerminateOnNan()) cp = conf.checkpoints pbar = None if master_node: log_freq = conf.logging.iter_freq log_event = Events.ITERATION_COMPLETED(every=log_freq) pbar = ProgressBar(persist=False) trainer.add_event_handler(Events.EPOCH_STARTED, on_epoch_start) trainer.add_event_handler(log_event, log_iter, pbar, log_freq) trainer.add_event_handler(Events.EPOCH_COMPLETED, log_epoch) pbar.attach(trainer, metric_names=metric_names) setup_checkpoints(trainer, to_save, epoch_length, conf) setup_snapshots(trainer, sample_images, conf) if 'load' in cp.keys() and cp.load is not None: if master_node: logging.info("Resume from a checkpoint: {}".format(cp.load)) trainer.add_event_handler(Events.STARTED, _upd_pbar_iter_from_cp, pbar) Checkpoint.load_objects(to_load=to_save, checkpoint=torch.load(cp.load, map_location=device)) try: trainer.run(train_dl, max_epochs=epochs, epoch_length=epoch_length) except Exception as e: import traceback logging.error(traceback.format_exc()) if pbar is not None: pbar.close()
def main(parser_args): """Main function to create trainer engine, add handlers to train and validation engines. Then runs train engine to perform training and validation. Args: parser_args (dict): parsed arguments """ dataloader_train, dataloader_validation = get_dataloaders(parser_args) criterion = nn.CrossEntropyLoss() unet = SphericalUNet(parser_args.pooling_class, parser_args.n_pixels, parser_args.depth, parser_args.laplacian_type, parser_args.kernel_size) # unet = torch.jit.script(unet) unet, device = init_device(parser_args.device, unet) lr = parser_args.learning_rate optimizer = optim.Adam(unet.parameters(), lr=lr) print(sum(p.numel() for p in unet.parameters() if p.requires_grad)) def trainer(engine, batch): """Train Function to define train engine. Called for every batch of the train engine, for each epoch. Args: engine (ignite.engine): train engine batch (:obj:`torch.utils.data.dataloader`): batch from train dataloader Returns: :obj:`torch.tensor` : train loss for that batch and epoch """ unet.train() optimizer.zero_grad() data, labels = batch.x, batch.y labels = labels.to(device) data = data.to(device) output = unet(data) B, V, C = output.shape B_labels, V_labels, C_labels = labels.shape output = output.view(B * V, C) labels = labels.view(B_labels * V_labels, C_labels).max(1)[1] loss = criterion(output, labels) loss.backward() optimizer.step() return {'loss': loss.item()} writer = SummaryWriter(parser_args.tensorboard_path) engine_train = Engine(trainer) RunningAverage(output_transform=lambda x: x['loss']).attach( engine_train, 'loss') def prepare_batch(batch, device, non_blocking): """Prepare batch for training: pass to a device with options. """ return ( convert_tensor(batch.x, device=device, non_blocking=non_blocking), convert_tensor(batch.y, device=device, non_blocking=non_blocking), ) engine_validate = create_supervised_evaluator( model=unet, metrics={"AP": EpochMetric(average_precision_compute_fn)}, device=device, output_transform=validate_output_transform, prepare_batch=prepare_batch) engine_train.add_event_handler( Events.EPOCH_STARTED, lambda x: print("Starting Epoch: {}".format(x.state.epoch))) engine_train.add_event_handler(Events.ITERATION_COMPLETED, TerminateOnNan()) @engine_train.on(Events.EPOCH_COMPLETED) def epoch_validation(engine): """Handler to run the validation engine at the end of the train engine's epoch. Args: engine (ignite.engine): train engine """ print("beginning validation epoch") engine_validate.run(dataloader_validation) reduce_lr_plateau = ReduceLROnPlateau( optimizer, mode=parser_args.reducelronplateau_mode, factor=parser_args.reducelronplateau_factor, patience=parser_args.reducelronplateau_patience, ) @engine_validate.on(Events.EPOCH_COMPLETED) def update_reduce_on_plateau(engine): """Handler to reduce the learning rate on plateau at the end of the validation engine's epoch Args: engine (ignite.engine): validation engine """ ap = engine.state.metrics["AP"] mean_average_precision = np.mean(ap[1:]) reduce_lr_plateau.step(mean_average_precision) @engine_validate.on(Events.EPOCH_COMPLETED) def save_epoch_results(engine): """Handler to save the metrics at the end of the validation engine's epoch Args: engine (ignite.engine): validation engine """ ap = engine.state.metrics["AP"] mean_average_precision = np.mean(ap[1:]) print("Average precisions:", ap) print("mAP:", mean_average_precision) writer.add_scalars( "metrics", { "mean average precision (AR+TC)": mean_average_precision, "AR average precision": ap[2], "TC average precision": ap[1] }, engine_train.state.epoch, ) writer.close() step_scheduler = StepLR(optimizer, step_size=parser_args.steplr_step_size, gamma=parser_args.steplr_gamma) scheduler = create_lr_scheduler_with_warmup( step_scheduler, warmup_start_value=parser_args.warmuplr_warmup_start_value, warmup_end_value=parser_args.warmuplr_warmup_end_value, warmup_duration=parser_args.warmuplr_warmup_duration, ) engine_validate.add_event_handler(Events.EPOCH_COMPLETED, scheduler) earlystopper = EarlyStopping( patience=parser_args.earlystopping_patience, score_function=lambda x: -x.state.metrics["AP"][1], trainer=engine_train) engine_validate.add_event_handler(Events.EPOCH_COMPLETED, earlystopper) add_tensorboard(engine_train, optimizer, unet, log_dir=parser_args.tensorboard_path) pbar = ProgressBar() pbar.attach(engine_train, metric_names=['loss']) engine_train.run(dataloader_train, max_epochs=parser_args.n_epochs) pbar.close() torch.save(unet.state_dict(), parser_args.model_save_path + "unet_state.pt")
def run(conf: DictConfig, local_rank=0, distributed=False): epochs = conf.train.epochs epoch_length = conf.train.epoch_length torch.manual_seed(conf.seed) if distributed: rank = dist.get_rank() num_replicas = dist.get_world_size() torch.cuda.set_device(local_rank) else: rank = 0 num_replicas = 1 torch.cuda.set_device(conf.gpu) device = torch.device('cuda') loader_args = dict(mean=conf.data.mean, std=conf.data.std) master_node = rank == 0 if master_node: print(conf.pretty()) if num_replicas > 1: epoch_length = epoch_length // num_replicas loader_args["rank"] = rank loader_args["num_replicas"] = num_replicas train_dl = create_train_loader(conf.data.train, **loader_args) valid_dl = create_val_loader(conf.data.val, **loader_args) if epoch_length < 1: epoch_length = len(train_dl) model = instantiate(conf.model).to(device) model_ema, update_ema = setup_ema(conf, model, device=device, master_node=master_node) optim = build_optimizer(conf.optim, model) scheduler_kwargs = dict() if "schedule.OneCyclePolicy" in conf.lr_scheduler["class"]: scheduler_kwargs["cycle_steps"] = epoch_length lr_scheduler: Scheduler = instantiate(conf.lr_scheduler, optim, **scheduler_kwargs) use_amp = False if conf.use_apex: import apex from apex import amp logging.debug("Nvidia's Apex package is available") model, optim = amp.initialize(model, optim, **conf.amp) use_amp = True if master_node: logging.info("Using AMP with opt_level={}".format( conf.amp.opt_level)) else: apex, amp = None, None to_save = dict(model=model, optim=optim) if use_amp: to_save["amp"] = amp if model_ema is not None: to_save["model_ema"] = model_ema if master_node and conf.logging.model: logging.info(model) if distributed: sync_bn = conf.distributed.sync_bn if apex is not None: if sync_bn: model = apex.parallel.convert_syncbn_model(model) model = apex.parallel.distributed.DistributedDataParallel( model, delay_allreduce=True) else: if sync_bn: model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model) model = torch.nn.parallel.DistributedDataParallel( model, device_ids=[ local_rank, ], output_device=local_rank) upd_interval = conf.optim.step_interval ema_interval = conf.smoothing.interval_it * upd_interval clip_grad = conf.optim.clip_grad _handle_batch_train = build_process_batch_func(conf.data, stage="train", device=device) _handle_batch_val = build_process_batch_func(conf.data, stage="val", device=device) def _update(eng: Engine, batch: Batch) -> FloatDict: model.train() batch = _handle_batch_train(batch) losses: Dict = model(*batch) stats = {k: v.item() for k, v in losses.items()} loss = losses["loss"] del losses if use_amp: with amp.scale_loss(loss, optim) as scaled_loss: scaled_loss.backward() else: loss.backward() it = eng.state.iteration if not it % upd_interval: if clip_grad > 0: params = amp.master_params( optim) if use_amp else model.parameters() torch.nn.utils.clip_grad_norm_(params, clip_grad) optim.step() optim.zero_grad() lr_scheduler.step_update(it) if not it % ema_interval: update_ema() eng.state.lr = optim.param_groups[0]["lr"] return stats calc_map = conf.validate.calc_map min_score = conf.validate.get("min_score", -1) model_val = model if conf.train.skip and model_ema is not None: model_val = model_ema.to(device) def _validate(eng: Engine, batch: Batch) -> FloatDict: model_val.eval() images, targets = _handle_batch_val(batch) with torch.no_grad(): out: Dict = model_val(images, targets) pred_boxes = out.pop("detections") stats = {k: v.item() for k, v in out.items()} if calc_map: pred_boxes = pred_boxes.detach().cpu().numpy() true_boxes = targets['bbox'].cpu().numpy() img_scale = targets['img_scale'].cpu().numpy() # yxyx -> xyxy true_boxes = true_boxes[:, :, [1, 0, 3, 2]] # xyxy -> xywh true_boxes[:, :, [2, 3]] -= true_boxes[:, :, [0, 1]] # scale downsized boxes to match predictions on a full-sized image true_boxes *= img_scale[:, None, None] scores = [] for i in range(len(images)): mask = pred_boxes[i, :, 4] >= min_score s = calculate_image_precision(true_boxes[i], pred_boxes[i, mask, :4], thresholds=IOU_THRESHOLDS, form='coco') scores.append(s) stats['map'] = np.mean(scores) return stats train_metric_names = list(conf.logging.out.train) train_metrics = create_metrics(train_metric_names, device if distributed else None) val_metric_names = list(conf.logging.out.val) if calc_map: from utils.metric import calculate_image_precision, IOU_THRESHOLDS val_metric_names.append('map') val_metrics = create_metrics(val_metric_names, device if distributed else None) trainer = build_engine(_update, train_metrics) evaluator = build_engine(_validate, val_metrics) to_save['trainer'] = trainer every_iteration = Events.ITERATION_COMPLETED trainer.add_event_handler(every_iteration, TerminateOnNan()) if distributed: dist_bn = conf.distributed.dist_bn if dist_bn in ["reduce", "broadcast"]: from timm.utils import distribute_bn @trainer.on(Events.EPOCH_COMPLETED) def _distribute_bn_stats(eng: Engine): reduce = dist_bn == "reduce" if master_node: logging.info("Distributing BN stats...") distribute_bn(model, num_replicas, reduce) sampler = train_dl.sampler if isinstance(sampler, (CustomSampler, DistributedSampler)): @trainer.on(Events.EPOCH_STARTED) def _set_epoch(eng: Engine): sampler.set_epoch(eng.state.epoch - 1) @trainer.on(Events.EPOCH_COMPLETED) def _scheduler_step(eng: Engine): # it starts from 1, so we don't need to add 1 here ep = eng.state.epoch lr_scheduler.step(ep) cp = conf.checkpoints pbar, pbar_vis = None, None if master_node: log_interval = conf.logging.interval_it log_event = Events.ITERATION_COMPLETED(every=log_interval) pbar = ProgressBar(persist=False) pbar.attach(trainer, metric_names=train_metric_names) pbar.attach(evaluator, metric_names=val_metric_names) for engine, name in zip([trainer, evaluator], ['train', 'val']): engine.add_event_handler(Events.EPOCH_STARTED, on_epoch_start) engine.add_event_handler(log_event, log_iter, pbar, interval_it=log_interval, name=name) engine.add_event_handler(Events.EPOCH_COMPLETED, log_epoch, name=name) setup_checkpoints(trainer, to_save, epoch_length, conf) if 'load' in cp.keys() and cp.load is not None: if master_node: logging.info("Resume from a checkpoint: {}".format(cp.load)) trainer.add_event_handler(Events.STARTED, _upd_pbar_iter_from_cp, pbar) resume_from_checkpoint(to_save, cp, device=device) state = trainer.state # epoch counter start from 1 lr_scheduler.step(state.epoch - 1) state.max_epochs = epochs @trainer.on(Events.EPOCH_COMPLETED(every=conf.validate.interval_ep)) def _run_validation(eng: Engine): if distributed: torch.cuda.synchronize(device) evaluator.run(valid_dl) skip_train = conf.train.skip if master_node and conf.visualize.enabled: vis_eng = evaluator if skip_train else trainer setup_visualizations(vis_eng, model, valid_dl, device, conf, force_run=skip_train) try: if skip_train: evaluator.run(valid_dl) else: trainer.run(train_dl, max_epochs=epochs, epoch_length=epoch_length) except Exception as e: import traceback logging.error(traceback.format_exc()) for pb in [pbar, pbar_vis]: if pb is not None: pbar.close()
def run(conf: DictConfig): epochs = conf.train.epochs epoch_length = conf.train.epoch_length torch.manual_seed(conf.general.seed) dist_conf = conf.distributed local_rank = dist_conf.local_rank backend = dist_conf.backend distributed = backend is not None use_tpu = conf.tpu.enabled if use_tpu: rank = xm.get_ordinal() num_replicas = xm.xrt_world_size() device = xm.xla_device() else: if distributed: rank = dist.get_rank() num_replicas = dist.get_world_size() torch.cuda.set_device(local_rank) else: rank = 0 num_replicas = 1 torch.cuda.set_device(conf.general.gpu) device = torch.device('cuda') if rank == 0: print(conf.pretty()) if num_replicas > 1: epoch_length = epoch_length // num_replicas loader_args = dict(rank=rank, num_replicas=num_replicas) else: loader_args = dict() train_dl = create_train_loader(conf.data.train, epoch_length=epoch_length, **loader_args) valid_dl = create_val_loader(conf.data.val, **loader_args) train_sampler = train_dl.sampler if epoch_length < 1: epoch_length = len(train_dl) if use_tpu: train_dl = pl.ParallelLoader(train_dl, [device]) valid_dl = pl.ParallelLoader(valid_dl, [device]) model = instantiate(conf.model).to(device) if distributed: model = DistributedDataParallel(model, device_ids=[ local_rank, ], output_device=local_rank) model.to_y = model.module.to_y if rank == 0 and conf.logging.model: print(model) loss = instantiate(conf.loss) optim = instantiate(conf.optimizer, filter(lambda x: x.requires_grad, model.parameters())) metrics = create_metrics(loss.keys(), device if distributed else None) build_trainer_fn = create_tpu_trainer if use_tpu else create_trainer trainer = build_trainer_fn(model, loss, optim, device, conf, metrics) evaluator = create_evaluator(model, loss, device, metrics) every_iteration = Events.ITERATION_COMPLETED if 'lr_scheduler' in conf.keys(): # TODO: total_steps is wrong, it works only for one-cycle lr_scheduler = instantiate(conf.lr_scheduler, optim, total_steps=epoch_length) trainer.add_event_handler(every_iteration, lambda _: lr_scheduler.step()) if isinstance(lr_scheduler, torch.optim.lr_scheduler.OneCycleLR): initial_state = lr_scheduler.state_dict() trainer.add_event_handler( Events.ITERATION_COMPLETED(every=epoch_length), lambda _: lr_scheduler.load_state_dict(initial_state)) else: lr_scheduler = None trainer.add_event_handler(every_iteration, TerminateOnNan()) cp = conf.train.checkpoints to_save = { 'trainer': trainer, 'model': model.module if distributed else model, 'optimizer': optim, 'lr_scheduler': lr_scheduler } save_path = cp.get('base_dir', os.getcwd()) if rank == 0: log_freq = conf.logging.iter_freq log_event = Events.ITERATION_COMPLETED(every=log_freq) pbar = ProgressBar(persist=False) for engine, name in zip([trainer, evaluator], ['train', 'val']): engine.add_event_handler(Events.EPOCH_STARTED, on_epoch_start) engine.add_event_handler(log_event, log_iter, trainer, pbar, name, log_freq) engine.add_event_handler(Events.EPOCH_COMPLETED, log_epoch, trainer, name) pbar.attach(engine, metric_names=loss.keys()) if 'load' in cp.keys() and cp.load: logging.info("Resume from a checkpoint: {}".format(cp.load)) trainer.add_event_handler(Events.STARTED, _upd_pbar_iter_from_cp, pbar) logging.info("Saving checkpoints to {}".format(save_path)) if rank == 0 or use_tpu: max_cp = max(int(cp.get('max_checkpoints', 1)), 1) Saver = TpuDiskSaver if use_tpu else DiskSaver save = Saver(save_path, create_dir=True, require_empty=True) make_checkpoint = Checkpoint(to_save, save, n_saved=max_cp) cp_iter = cp.interval_iteration cp_epoch = cp.interval_epoch if cp_iter > 0: save_event = Events.ITERATION_COMPLETED(every=cp_iter) trainer.add_event_handler(save_event, make_checkpoint) if cp_epoch > 0: if cp_iter < 1 or epoch_length % cp_iter: save_event = Events.EPOCH_COMPLETED(every=cp_epoch) trainer.add_event_handler(save_event, make_checkpoint) if 'load' in cp.keys() and cp.load: Checkpoint.load_objects(to_load=to_save, checkpoint=torch.load(cp.load, map_location=device)) assert train_sampler is not None trainer.add_event_handler( Events.EPOCH_STARTED, lambda e: train_sampler.set_epoch(e.state.epoch - 1)) def run_validation(e: Engine): if distributed: torch.cuda.synchronize(device) if use_tpu: xm.rendezvous('validate_{}'.format(e.state.iteration)) valid_it = valid_dl.per_device_loader(device) evaluator.run(valid_it, epoch_length=len(valid_dl)) else: evaluator.run(valid_dl) eval_event = Events.EPOCH_COMPLETED(every=conf.validate.interval) trainer.add_event_handler(eval_event, run_validation) try: if conf.train.skip: evaluator.run(valid_dl) else: loader = train_dl if use_tpu: # need to catch StopIteration before ignite, otherwise it will crash loader = iter(_regenerate(train_dl, device)) trainer.run(loader, max_epochs=epochs, epoch_length=epoch_length) except Exception as e: import traceback print(traceback.format_exc()) if rank == 0: pbar.close()