class Trainer:
    def __init__(self, cfg):
        seed_torch(1029)
        self.cfg = cfg
        self.start_epoch = -1
        self.n_iters_elapsed = 0
        self.device = self.cfg.GPU_IDS
        self.batch_size = self.cfg.DATASET.TRAIN.BATCH_SIZE * len(self.cfg.GPU_IDS)

        self.n_steps_per_epoch = None
        if cfg.local_rank == 0:
            self.experiment_id = self.experiment_id(self.cfg)
            self.ckpts = Checkpoints(logger,self.cfg.CHECKPOINT_DIR,self.experiment_id)
            self.tb_writer = DummyWriter(log_dir="%s/%s" % (self.cfg.TENSORBOARD_LOG_DIR, self.experiment_id))

    def experiment_id(self, cfg):
        return f"{cfg.EXPERIMENT_NAME}#{cfg.USE_MODEL.split('.')[-1]}#{datetime.now().strftime('%Y_%m_%d_%H_%M_%S')}"

    def _parser_dict(self):
        dictionary = CommonConfiguration.from_yaml(cfg.DATASET.DICTIONARY)
        return dictionary[cfg.DATASET.DICTIONARY_NAME]

    def _parser_datasets(self):
        *dataset_str_parts, dataset_class_str = cfg.DATASET.CLASS.split(".")
        dataset_class = getattr(import_module(".".join(dataset_str_parts)), dataset_class_str)

        datasets = {x: dataset_class(data_cfg=cfg.DATASET[x.upper()], dictionary=self.dictionary, transform=None,
                                     target_transform=None, stage=x) for x in ['train', 'val']}

        data_samplers = defaultdict()
        if self.cfg.distributed:
            data_samplers = {x: DistributedSampler(datasets[x],shuffle=cfg.DATASET[x.upper()].SHUFFLE) for x in ['train', 'val']}
        else:
            data_samplers['train'] = RandomSampler(datasets['train'])
            data_samplers['val'] = SequentialSampler(datasets['val'])

        dataloaders = {
            x: PrefetchDataLoader(datasets[x], batch_size=cfg.DATASET[x.upper()].BATCH_SIZE, sampler=data_samplers[x],
                          num_workers=cfg.DATASET[x.upper()].NUM_WORKER,
                          collate_fn=dataset_class.collate_fn if hasattr(dataset_class,
                                                                         'collate_fn') else default_collate,
                          pin_memory=True, drop_last=True) for x in ['train', 'val']}
        dataset_sizes = {x: len(datasets[x]) for x in ['train', 'val']}
        return datasets, dataloaders, data_samplers, dataset_sizes


    def _parser_model(self):
        *model_mod_str_parts, model_class_str = self.cfg.USE_MODEL.split(".")
        model_class = getattr(import_module(".".join(model_mod_str_parts)), model_class_str)
        model = model_class(dictionary=self.dictionary)

        if self.cfg.distributed:
            model = SyncBatchNorm.convert_sync_batchnorm(model).cuda()
        else:
            model = model.cuda()

        return model

    def clip_grad(self, model):
        if self.cfg.GRAD_CLIP.TYPE == "norm":
            clip_method = clip_grad_norm_
        elif self.cfg.GRAD_CLIP.TYPE == "value":
            clip_method = clip_grad_value_
        else:
            raise ValueError(
                f"Only support 'norm' and 'value' as the grad_clip type, but {self.cfg.GRAD_CLIP.TYPE} is given."
            )

        clip_method(model.parameters(), self.cfg.GRAD_CLIP.VALUE)

    def run_step(self, scaler, model, sample, optimizer, lossLogger, performanceLogger, prefix):
        '''
            Training step including forward
            :param model: model to train
            :param sample: a batch of input data
            :param optimizer:
            :param lossLogger:
            :param performanceLogger:
            :param prefix: train or val or infer
            :return: losses, predicts
        '''
        imgs, targets = sample['image'], sample['target']
        imgs = list(img.cuda() for img in imgs) if isinstance(imgs, list) else imgs.cuda()
        if isinstance(targets, list):
            if isinstance(targets[0], torch.Tensor):
                targets = [t.cuda() for t in targets]
            elif isinstance(targets[0], np.ndarray):
                targets = [torch.from_numpy(t).cuda() for t in targets]
            else:
                targets = [{k: v.cuda() for k, v in t.items()} for t in targets]
        elif isinstance(targets, dict):
            for (k, v) in targets.items():
                if isinstance(v, torch.Tensor):
                    targets[k] = v.cuda()
                elif isinstance(v, list):
                    if isinstance(v[0], torch.Tensor):
                        targets[k] = [t.cuda() for t in v]
                    elif isinstance(v[0], np.ndarray):
                        targets[k] = [torch.from_numpy(t).cuda() for t in v]
        else:
            targets = targets.cuda()

        if prefix=='train':
            # zero the parameter gradients
            optimizer.zero_grad()

            # Autocast
            with amp.autocast(enabled=True):
                out = model(imgs, targets, prefix)
                if not isinstance(out, tuple):
                    losses, predicts = out, None
                else:
                    losses, predicts = out

            # Scales loss.  Calls backward() on scaled loss to create scaled gradients.
            # Backward passes under autocast are not recommended.
            # Backward ops run in the same dtype autocast chose for corresponding forward ops.
            scaler.scale(losses["loss"]).backward()

            if self.cfg.GRAD_CLIP and self.cfg.GRAD_CLIP.VALUE:
                self.clip_grad(model)

            # scaler.step() first unscales the gradients of the optimizer's assigned params.
            # If these gradients do not contain infs or NaNs, optimizer.step() is then called,
            # otherwise, optimizer.step() is skipped.
            scaler.step(optimizer)
            # Updates the scale for next iteration.
            scaler.update()
        else:
            losses, predicts = model(imgs, targets, prefix)

        if lossLogger is not None:
            if self.cfg.distributed:
                # reduce losses over all GPUs for logging purposes
                loss_dict_reduced = reduce_dict(losses)
                lossLogger.update(**loss_dict_reduced)
                del loss_dict_reduced
            else:
                lossLogger.update(**losses)

        if performanceLogger is not None:
            if predicts is not None:
                if self.cfg.distributed:
                    # reduce performances over all GPUs for logging purposes
                    predicts_dict_reduced = reduce_dict(predicts)
                    performanceLogger.update(targets, predicts_dict_reduced)
                    del predicts_dict_reduced
                else:
                    performanceLogger.update(targets, predicts)
                del predicts

        del imgs, targets
        return losses

    def warm_up(self, scaler, model, dataloader, cfg, prefix='train'):
        optimizer = build_optimizer(cfg, model)
        model.train()

        cur_iter = 0
        while cur_iter < cfg.WARMUP.ITERS:
            for i, sample in enumerate(dataloader):
                cur_iter += 1
                if cur_iter >= cfg.WARMUP.ITERS:
                    break
                lr = get_warmup_lr(cur_iter, cfg)
                for param_group in optimizer.param_groups:
                    param_group['lr'] = lr
                losses = self.run_step(scaler, model, sample, optimizer, None, None, prefix)

                if self.cfg.local_rank == 0:
                    template = "[iter {}/{}, lr {}] Total train loss: {:.4f} \n" "{}"
                    logger.info(
                        template.format(
                            cur_iter, cfg.WARMUP.ITERS, round(get_current_lr(optimizer), 6),
                            losses["loss"].item(),
                            "\n".join(
                                ["{}: {:.4f}".format(n, l.item()) for n, l in losses.items() if n != "loss"]),
                        )
                    )
        del optimizer


    def run(self):
        ## init distributed
        self.cfg = init_distributed(self.cfg)
        cfg = self.cfg
        # cfg.print()

        ## parser_dict
        self.dictionary = self._parser_dict()

        ## parser_datasets
        datasets, dataloaders,data_samplers, dataset_sizes = self._parser_datasets()

        ## parser_model
        model_ft = self._parser_model()

        # Scale learning rate based on global batch size
        if cfg.SCALE_LR:
            cfg.INIT_LR = cfg.INIT_LR * float(self.batch_size) / cfg.SCALE_LR

        scaler = amp.GradScaler(enabled=True)
        if cfg.WARMUP.NAME is not None and cfg.WARMUP.ITERS:
            logger.info('Start warm-up ... ')
            self.warm_up(scaler, model_ft, dataloaders['train'], cfg)
            logger.info('finish warm-up!')

        ## parser_optimizer
        optimizer_ft = build_optimizer(cfg, model_ft)

        ## parser_lr_scheduler
        lr_scheduler_ft = build_lr_scheduler(cfg, optimizer_ft)

        if cfg.distributed:
            model_ft = DDP(model_ft, device_ids=[cfg.local_rank], output_device=(cfg.local_rank))

        # Freeze
        freeze_models(model_ft)

        if self.cfg.PRETRAIN_MODEL is not None:
            if self.cfg.RESUME:
                self.start_epoch = self.ckpts.resume_checkpoint(model_ft, optimizer_ft)
            else:
                self.start_epoch = self.ckpts.load_checkpoint(self.cfg.PRETRAIN_MODEL, model_ft, optimizer_ft)

        ## vis network graph
        if self.cfg.TENSORBOARD_MODEL and False:
            self.tb_writer.add_graph(model_ft, (model_ft.dummy_input.cuda(),))

        self.steps_per_epoch = int(dataset_sizes['train']//self.batch_size)

        best_acc = 0.0
        best_perf_rst = None
        for epoch in range(self.start_epoch + 1, self.cfg.N_MAX_EPOCHS):
            if cfg.distributed:
                dataloaders['train'].sampler.set_epoch(epoch)
            self.train_epoch(scaler, epoch, model_ft,datasets['train'], dataloaders['train'], optimizer_ft)
            lr_scheduler_ft.step()

            if self.cfg.DATASET.VAL and (not epoch % cfg.EVALUATOR.EVAL_INTERVALS or epoch==self.cfg.N_MAX_EPOCHS-1):
                acc, perf_rst = self.val_epoch(epoch, model_ft,datasets['val'], dataloaders['val'])

                if cfg.local_rank == 0:
                    # start to save best performance model after learning rate decay to 1e-6
                    if best_acc < acc:
                        self.ckpts.autosave_checkpoint(model_ft, epoch, 'best', optimizer_ft)
                        best_acc = acc
                        best_perf_rst = perf_rst
                        # continue

            if not epoch % cfg.N_EPOCHS_TO_SAVE_MODEL:
                if cfg.local_rank == 0:
                    self.ckpts.autosave_checkpoint(model_ft, epoch,'last', optimizer_ft)

        if best_perf_rst is not None:
            logger.info(best_perf_rst.replace("(val)", "(best)"))

        if cfg.local_rank == 0:
            self.tb_writer.close()

        dist.destroy_process_group() if cfg.local_rank!=0 else None
        torch.cuda.empty_cache()

    def train_epoch(self, scaler, epoch, model, dataset, dataloader, optimizer, prefix="train"):
        model.train()

        _timer = Timer()
        lossLogger = LossLogger()
        performanceLogger = build_evaluator(self.cfg, dataset)

        num_iters = len(dataloader)
        for i, sample in enumerate(dataloader):
            self.n_iters_elapsed += 1
            _timer.tic()
            self.run_step(scaler, model, sample, optimizer, lossLogger, performanceLogger, prefix)
            torch.cuda.synchronize()
            _timer.toc()

            if (i + 1) % self.cfg.N_ITERS_TO_DISPLAY_STATUS == 0:
                if self.cfg.local_rank == 0:
                    template = "[epoch {}/{}, iter {}/{}, lr {}] Total train loss: {:.4f} " "(ips = {:.2f})\n" "{}"
                    logger.info(
                        template.format(
                            epoch, self.cfg.N_MAX_EPOCHS - 1, i, num_iters - 1,
                            round(get_current_lr(optimizer), 6),
                            lossLogger.meters["loss"].value,
                                   self.batch_size * self.cfg.N_ITERS_TO_DISPLAY_STATUS / _timer.diff,
                            "\n".join(
                                ["{}: {:.4f}".format(n, l.value) for n, l in lossLogger.meters.items() if n != "loss"]),
                        )
                    )

        if self.cfg.TENSORBOARD and self.cfg.local_rank == 0:
            # Logging train losses
            [self.tb_writer.add_scalar(f"loss/{prefix}_{n}", l.global_avg, epoch) for n, l in lossLogger.meters.items()]
            performances = performanceLogger.evaluate()
            if performances is not None and len(performances):
                [self.tb_writer.add_scalar(f"performance/{prefix}_{k}", v, epoch) for k, v in performances.items()]

        if self.cfg.TENSORBOARD_WEIGHT and False:
            for name, param in model.named_parameters():
                layer, attr = os.path.splitext(name)
                attr = attr[1:]
                self.tb_writer.add_histogram("{}/{}".format(layer, attr), param, epoch)

    @torch.no_grad()
    def val_epoch(self, epoch, model, dataset, dataloader, prefix="val"):
        model.eval()

        lossLogger = LossLogger()
        performanceLogger = build_evaluator(self.cfg, dataset)

        with torch.no_grad():
            for sample in dataloader:
                self.run_step(None, model, sample, None, lossLogger, performanceLogger, prefix)

        if self.cfg.TENSORBOARD and self.cfg.local_rank == 0:
            # Logging val Loss
            [self.tb_writer.add_scalar(f"loss/{prefix}_{n}", l.global_avg, epoch) for n, l in lossLogger.meters.items()]
            performances = performanceLogger.evaluate()
            if performances is not None and len(performances):
                # Logging val performances
                [self.tb_writer.add_scalar(f"performance/{prefix}_{k}", v, epoch) for k, v in performances.items()]

        if self.cfg.local_rank == 0:
            template = "[epoch {}] Total {} loss : {:.4f} " "\n" "{}"
            logger.info(
                template.format(
                    epoch, prefix, lossLogger.meters["loss"].global_avg,
                    "\n".join(
                        ["{}: {:.4f}".format(n, l.global_avg) for n, l in lossLogger.meters.items() if n != "loss"]),
                )
            )

            perf_log = f"\n------------ Performances ({prefix}) ----------\n"
            for k, v in performances.items():
                perf_log += "{:}: {:.4f}\n".format(k, v)
            perf_log += "------------------------------------\n"
            logger.info(perf_log)

        acc = performances['performance']

        return acc, perf_log
Beispiel #2
0
class Trainer:
    def __init__(self, cfg):
        self.cfg = cfg
        self.start_epoch = -1
        self.n_iters_elapsed = 0
        self.device = self.cfg.GPU_IDS
        self.batch_size = self.cfg.BATCH_SIZE
        self.batch_size_all = self.cfg.BATCH_SIZE * len(self.cfg.GPU_IDS)

        self.n_steps_per_epoch = None
        self.logger = logging.getLogger("pytorch")
        self.experiment_id = self.experiment_id(self.cfg)
        self.checkpoints = Checkpoints(self.logger, self.cfg.CHECKPOINT_DIR,
                                       self.experiment_id)
        self.tb_writer = DummyWriter(
            log_dir="%s/%s" %
            (self.cfg.TENSORBOARD_LOG_DIR, self.experiment_id))

    def experiment_id(self, cfg):
        return f"{cfg.EXPERIMENT_NAME}#{cfg.USE_MODEL.split('.')[-1]}#{cfg.OPTIMIZER.TYPE}#{cfg.LR_SCHEDULER.TYPE}" \
               f"#{datetime.now().strftime('%Y_%m_%d_%H_%M_%S')}"

    def _parser_dict(self):
        dictionary = CommonConfiguration.from_yaml(cfg.DATASET.DICTIONARY)
        return next(dictionary.items())[1]  ## return first

    def _parser_datasets(self):
        transforms = prepare_transforms_seg()
        target_transforms = prepare_transforms_mask()
        *dataset_str_parts, dataset_class_str = cfg.DATASET.CLASS.split(".")
        dataset_class = getattr(
            importlib.import_module(".".join(dataset_str_parts)),
            dataset_class_str)
        datasets = {
            x: dataset_class(data_cfg=cfg.DATASET[x.upper()],
                             transform=transforms[x],
                             target_transform=target_transforms[x],
                             stage=x)
            for x in ['train', 'val']
        }
        dataloaders = {
            x: DataLoader(datasets[x],
                          batch_size=self.batch_size_all,
                          num_workers=cfg.NUM_WORKERS,
                          shuffle=cfg.DATASET[x.upper()].SHUFFLE,
                          collate_fn=detection_collate,
                          pin_memory=True)
            for x in ['train', 'val']
        }
        return datasets, dataloaders

    def _parser_model(self, dictionary):
        *model_mod_str_parts, model_class_str = self.cfg.USE_MODEL.split(".")
        model_class = getattr(
            importlib.import_module(".".join(model_mod_str_parts)),
            model_class_str)
        model = model_class(dictionary=dictionary)
        return model

    def run(self):
        cfg = self.cfg
        # cfg.print()

        ## parser_dict
        dictionary = self._parser_dict()

        ## parser_datasets
        datasets, dataloaders = self._parser_datasets()
        # dataset_sizes = {x: len(datasets[x]) for x in ['train', 'val']}
        # class_names = datasets['train'].classes

        ## parser_model
        model_ft = self._parser_model(dictionary)

        ## parser_optimizer
        optimizer_ft = parser_optimizer(cfg, model_ft)

        ## parser_lr_scheduler
        lr_scheduler_ft = parser_lr_scheduler(cfg, optimizer_ft)

        if self.cfg.PRETRAIN_MODEL is not None:
            if self.cfg.RESUME:
                self.start_epoch = self.checkpoints.load_checkpoint(
                    self.cfg.PRETRAIN_MODEL, model_ft, optimizer_ft,
                    lr_scheduler_ft)
            else:
                self.checkpoints.load_checkpoint(self.cfg.PRETRAIN_MODEL,
                                                 model_ft)

        if torch.cuda.is_available():
            model_ft = model_ft.cuda()
            cudnn.benchmark = True
            for state in optimizer_ft.state.values():
                for k, v in state.items():
                    if torch.is_tensor(v):
                        state[k] = v.cuda()

        ## vis net graph
        if self.cfg.TENSORBOARD_MODEL and False:
            self.tb_writer.add_graph(model_ft, (model_ft.dummy_input.cuda(), ))

        if self.cfg.HALF:
            model_ft.half()

        self.n_steps_per_epoch = int(
            ceil(sum(len(t) for t in datasets['train'])))

        best_acc = 0.0
        for epoch in range(self.start_epoch + 1, self.cfg.N_MAX_EPOCHS):
            self.train_epoch(epoch, model_ft, dataloaders['train'],
                             optimizer_ft, lr_scheduler_ft, None)
            if self.cfg.DATASET.VAL:
                acc = self.val_epoch(epoch,
                                     model_ft,
                                     dataloaders['val'],
                                     optimizer=optimizer_ft,
                                     lr_scheduler=lr_scheduler_ft)
                # start to save best performance model after learning rate decay to 1e-6
                if best_acc < acc:
                    self.checkpoints.autosave_checkpoint(
                        model_ft, epoch, 'best', optimizer_ft, lr_scheduler_ft)
                    best_acc = acc
                    continue

            if not epoch % cfg.N_EPOCHS_TO_SAVE_MODEL:
                self.checkpoints.autosave_checkpoint(model_ft, epoch,
                                                     'autosave', optimizer_ft,
                                                     lr_scheduler_ft)

        self.tb_writer.close()

    def train_epoch(self,
                    epoch,
                    model,
                    dataloader,
                    optimizer,
                    lr_scheduler,
                    grad_normalizer=None,
                    prefix="train"):
        model.train()

        _timer = Timer()
        lossMeter = LossMeter()
        perfMeter = PerfMeter()

        for i, (imgs, labels) in enumerate(dataloader):
            _timer.tic()
            # zero the parameter gradients
            optimizer.zero_grad()

            if self.cfg.HALF:
                imgs = imgs.half()

            if len(self.device) > 1:
                out = data_parallel(model, (imgs, labels, prefix),
                                    device_ids=self.device,
                                    output_device=self.device[0])
            else:
                imgs = imgs.cuda()
                labels = [label.cuda() for label in labels] if isinstance(
                    labels, list) else labels.cuda()
                out = model(imgs, labels, prefix)

            if not isinstance(out, tuple):
                losses, performances = out, None
            else:
                losses, performances = out

            if losses["all_loss"].sum().requires_grad:
                if self.cfg.GRADNORM is not None:
                    grad_normalizer.adjust_losses(losses)
                    grad_normalizer.adjust_grad(model, losses)
                else:
                    losses["all_loss"].sum().backward()

            optimizer.step()

            self.n_iters_elapsed += 1

            _timer.toc()

            lossMeter.__add__(losses)

            if performances is not None and all(performances):
                perfMeter.put(performances)

            if (i + 1) % self.cfg.N_ITERS_TO_DISPLAY_STATUS == 0:
                avg_losses = lossMeter.average()
                template = "[epoch {}/{}, iter {}, lr {}] Total train loss: {:.4f} " "(ips = {:.2f} )\n" "{}"
                self.logger.info(
                    template.format(
                        epoch,
                        self.cfg.N_MAX_EPOCHS,
                        i,
                        round(get_current_lr(optimizer), 6),
                        avg_losses["all_loss"],
                        self.batch_size * self.cfg.N_ITERS_TO_DISPLAY_STATUS /
                        _timer.total_time,
                        "\n".join([
                            "{}: {:.4f}".format(n, l)
                            for n, l in avg_losses.items() if n != "all_loss"
                        ]),
                    ))

                if self.cfg.TENSORBOARD:
                    tb_step = int((epoch * self.n_steps_per_epoch + i) /
                                  self.cfg.N_ITERS_TO_DISPLAY_STATUS)
                    # Logging train losses
                    [
                        self.tb_writer.add_scalar(f"loss/{prefix}_{n}", l,
                                                  tb_step)
                        for n, l in avg_losses.items()
                    ]

                lossMeter.clear()

            del imgs, labels, losses, performances

        lr_scheduler.step()

        if self.cfg.TENSORBOARD and len(perfMeter):
            avg_perf = perfMeter.average()
            [
                self.tb_writer.add_scalar(f"performance/{prefix}_{k}", v,
                                          epoch) for k, v in avg_perf.items()
            ]

        if self.cfg.TENSORBOARD_WEIGHT and False:
            for name, param in model.named_parameters():
                layer, attr = os.path.splitext(name)
                attr = attr[1:]
                self.tb_writer.add_histogram("{}/{}".format(layer, attr),
                                             param, epoch)

    @torch.no_grad()
    def val_epoch(self,
                  epoch,
                  model,
                  dataloader,
                  optimizer=None,
                  lr_scheduler=None,
                  prefix="val"):
        model.eval()

        lossMeter = LossMeter()
        perfMeter = PerfMeter()

        with torch.no_grad():
            for (imgs, labels) in dataloader:

                if self.cfg.HALF:
                    im = im.half()

                if len(self.device) > 1:
                    losses, performances = data_parallel(
                        model, (imgs, labels, prefix),
                        device_ids=self.device,
                        output_device=self.device[-1])
                else:
                    imgs = imgs.cuda()
                    labels = [label.cuda() for label in labels] if isinstance(
                        labels, list) else labels.cuda()
                    losses, performances = model(imgs, labels, prefix)

                lossMeter.__add__(losses)
                perfMeter.__add__(performances)

                del imgs, labels, losses, performances

        avg_losses = lossMeter.average()
        avg_perf = perfMeter.average()

        template = "[epoch {}] Total {} loss : {:.4f} " "\n" "{}"
        self.logger.info(
            template.format(
                epoch,
                prefix,
                avg_losses["all_loss"],
                "\n".join([
                    "{}: {:.4f}".format(n, l) for n, l in avg_losses.items()
                    if n != "all_loss"
                ]),
            ))

        if self.cfg.TENSORBOARD:
            # Logging val Loss
            [
                self.tb_writer.add_scalar(f"loss/{prefix}_{n}", l, epoch)
                for n, l in avg_losses.items()
            ]
            # Logging val performances
            [
                self.tb_writer.add_scalar(f"performance/{prefix}_{k}", v,
                                          epoch) for k, v in avg_perf.items()
            ]

        perf_log_str = f"\n------------ Performances ({prefix}) ----------\n"
        for k, v in avg_perf.items():
            perf_log_str += "{:}: {:.4f}\n".format(k, v)
        perf_log_str += "------------------------------------\n"
        self.logger.info(perf_log_str)

        acc = avg_perf['all_perf']

        del avg_losses, avg_perf
        return acc
Beispiel #3
0
class Trainer:
    def __init__(self, cfg):
        self.cfg = cfg
        self.start_epoch = -1
        self.n_iters_elapsed = 0
        self.device = self.cfg.GPU_IDS
        self.batch_size = self.cfg.BATCH_SIZE * len(self.cfg.GPU_IDS)

        self.n_steps_per_epoch = None
        if cfg.local_rank == 0:
            self.experiment_id = self.experiment_id(self.cfg)
            self.ckpts = Checkpoints(logger, self.cfg.CHECKPOINT_DIR,
                                     self.experiment_id)
            self.tb_writer = DummyWriter(
                log_dir="%s/%s" %
                (self.cfg.TENSORBOARD_LOG_DIR, self.experiment_id))

    def experiment_id(self, cfg):
        return f"{cfg.EXPERIMENT_NAME}#{cfg.USE_MODEL.split('.')[-1]}#{datetime.now().strftime('%Y_%m_%d_%H_%M_%S')}"

    def _parser_dict(self):
        dictionary = CommonConfiguration.from_yaml(cfg.DATASET.DICTIONARY)
        return dictionary[cfg.DATASET.DICTIONARY_NAME]

    def _parser_datasets(self):
        *dataset_str_parts, dataset_class_str = cfg.DATASET.CLASS.split(".")
        dataset_class = getattr(import_module(".".join(dataset_str_parts)),
                                dataset_class_str)

        datasets = {
            x: dataset_class(data_cfg=cfg.DATASET[x.upper()],
                             dictionary=self.dictionary,
                             transform=None,
                             target_transform=None,
                             stage=x)
            for x in ['train', 'val']
        }

        data_samplers = defaultdict()
        if self.cfg.distributed:
            data_samplers = {
                x: DistributedSampler(datasets[x],
                                      shuffle=cfg.DATASET[x.upper()].SHUFFLE)
                for x in ['train', 'val']
            }
        else:
            data_samplers['train'] = RandomSampler(datasets['train'])
            data_samplers['val'] = SequentialSampler(datasets['val'])

        dataloaders = {
            x:
            DataLoader(datasets[x],
                       batch_size=batch_size,
                       sampler=data_samplers[x],
                       num_workers=cfg.NUM_WORKERS,
                       collate_fn=dataset_class.collate_fn if hasattr(
                           dataset_class, 'collate_fn') else default_collate,
                       pin_memory=True,
                       drop_last=True)
            for x, batch_size in zip(['train', 'val'], [self.batch_size, 1])
        }  # collate_fn=detection_collate,

        return datasets, dataloaders, data_samplers

    def _parser_model(self):
        *model_mod_str_parts, model_class_str = self.cfg.USE_MODEL.split(".")
        model_class = getattr(import_module(".".join(model_mod_str_parts)),
                              model_class_str)
        model = model_class(dictionary=self.dictionary)

        if self.cfg.distributed:
            model = SyncBatchNorm.convert_sync_batchnorm(model).cuda()
        else:
            model = model.cuda()

        return model

    def run(self):
        ## init distributed
        self.cfg = init_distributed(self.cfg)

        cfg = self.cfg
        # cfg.print()

        ## parser_dict
        self.dictionary = self._parser_dict()

        ## parser_datasets
        datasets, dataloaders, data_samplers = self._parser_datasets()
        # dataset_sizes = {x: len(datasets[x]) for x in ['train', 'val']}
        # class_names = datasets['train'].classes

        ## parser_model
        model_ft = self._parser_model()

        ## parser_optimizer
        # Scale learning rate based on global batch size
        # cfg.INIT_LR = cfg.INIT_LR * float(self.batch_size_all) / 256
        optimizer_ft = parser_optimizer(cfg, model_ft)

        ## parser_lr_scheduler
        lr_scheduler_ft = parser_lr_scheduler(cfg, optimizer_ft)
        '''
        # Scheduler https://arxiv.org/pdf/1812.01187.pdf
        # https://pytorch.org/docs/stable/_modules/torch/optim/lr_scheduler.html#OneCycleLR
        lf = lambda x: (((1 + math.cos(x * math.pi / self.cfg.N_MAX_EPOCHS)) / 2) ** 1.0) * 0.8 + 0.2  # cosine
        lr_scheduler_ft = lr_scheduler.LambdaLR(optimizer_ft, lr_lambda=lf)
        '''

        if cfg.distributed:
            model_ft = DDP(model_ft,
                           device_ids=[cfg.local_rank],
                           output_device=(cfg.local_rank))

        # Freeze
        freeze = [
            '',
        ]  # parameter names to freeze (full or partial)
        if any(freeze):
            for k, v in model_ft.named_parameters():
                if any(x in k for x in freeze):
                    print('freezing %s' % k)
                    v.requires_grad = False

        if self.cfg.PRETRAIN_MODEL is not None:
            if self.cfg.RESUME:
                self.start_epoch = self.ckpts.load_checkpoint(
                    self.cfg.PRETRAIN_MODEL, model_ft, optimizer_ft,
                    lr_scheduler_ft)
            else:
                self.ckpts.load_checkpoint(self.cfg.PRETRAIN_MODEL, model_ft)

        ## vis net graph
        if self.cfg.TENSORBOARD_MODEL and False:
            self.tb_writer.add_graph(model_ft, (model_ft.dummy_input.cuda(), ))

        self.n_steps_per_epoch = int(
            ceil(sum(len(t) for t in datasets['train'])))

        best_acc = 0.0
        scaler = amp.GradScaler(enabled=True)
        for epoch in range(self.start_epoch + 1, self.cfg.N_MAX_EPOCHS):
            if cfg.distributed:
                dataloaders['train'].sampler.set_epoch(epoch)
            self.train_epoch(scaler, epoch, model_ft, dataloaders['train'],
                             optimizer_ft)
            lr_scheduler_ft.step()

            if self.cfg.DATASET.VAL:
                acc = self.val_epoch(epoch, model_ft, dataloaders['val'])

                if cfg.local_rank == 0:
                    # start to save best performance model after learning rate decay to 1e-6
                    if best_acc < acc:
                        self.ckpts.autosave_checkpoint(model_ft, epoch, 'best',
                                                       optimizer_ft,
                                                       lr_scheduler_ft)
                        best_acc = acc
                        # continue

            if not epoch % cfg.N_EPOCHS_TO_SAVE_MODEL:
                if cfg.local_rank == 0:
                    self.ckpts.autosave_checkpoint(model_ft, epoch, 'autosave',
                                                   optimizer_ft,
                                                   lr_scheduler_ft)

        if cfg.local_rank == 0:
            self.tb_writer.close()

        dist.destroy_process_group() if cfg.local_rank != 0 else None
        torch.cuda.empty_cache()

    def train_epoch(self,
                    scaler,
                    epoch,
                    model,
                    dataloader,
                    optimizer,
                    prefix="train"):
        model.train()

        _timer = Timer()
        lossLogger = LossLogger()
        performanceLogger = MetricLogger(self.dictionary, self.cfg)

        for i, sample in enumerate(dataloader):
            imgs, targets = sample['image'], sample['target']
            _timer.tic()
            # zero the parameter gradients
            optimizer.zero_grad()

            imgs = list(
                img.cuda()
                for img in imgs) if isinstance(imgs, list) else imgs.cuda()
            if isinstance(targets, list):
                if isinstance(targets[0], torch.Tensor):
                    targets = [t.cuda() for t in targets]
                else:
                    targets = [{k: v.cuda()
                                for k, v in t.items()} for t in targets]
            else:
                targets = targets.cuda()

            # Autocast
            with amp.autocast(enabled=True):
                out = model(imgs, targets, prefix)

            if not isinstance(out, tuple):
                losses, predicts = out, None
            else:
                losses, predicts = out

            self.n_iters_elapsed += 1

            # Scales loss.  Calls backward() on scaled loss to create scaled gradients.
            # Backward passes under autocast are not recommended.
            # Backward ops run in the same dtype autocast chose for corresponding forward ops.
            scaler.scale(losses["loss"]).backward()
            # scaler.step() first unscales the gradients of the optimizer's assigned params.
            # If these gradients do not contain infs or NaNs, optimizer.step() is then called,
            # otherwise, optimizer.step() is skipped.
            scaler.step(optimizer)
            # Updates the scale for next iteration.
            scaler.update()

            # torch.cuda.synchronize()
            _timer.toc()

            if (i + 1) % self.cfg.N_ITERS_TO_DISPLAY_STATUS == 0:
                if self.cfg.distributed:
                    # reduce losses over all GPUs for logging purposes
                    loss_dict_reduced = reduce_dict(losses)
                    lossLogger.update(**loss_dict_reduced)
                    del loss_dict_reduced
                else:
                    lossLogger.update(**losses)

                if predicts is not None:
                    if self.cfg.distributed:
                        # reduce performances over all GPUs for logging purposes
                        predicts_dict_reduced = reduce_dict(predicts)
                        performanceLogger.update(targets,
                                                 predicts_dict_reduced)
                        del predicts_dict_reduced
                    else:
                        performanceLogger.update(**predicts)
                    del predicts

                if self.cfg.local_rank == 0:
                    template = "[epoch {}/{}, iter {}, lr {}] Total train loss: {:.4f} " "(ips = {:.2f})\n" "{}"
                    logger.info(
                        template.format(
                            epoch,
                            self.cfg.N_MAX_EPOCHS,
                            i,
                            round(get_current_lr(optimizer), 6),
                            lossLogger.meters["loss"].value,
                            self.batch_size *
                            self.cfg.N_ITERS_TO_DISPLAY_STATUS / _timer.diff,
                            "\n".join([
                                "{}: {:.4f}".format(n, l.value)
                                for n, l in lossLogger.meters.items()
                                if n != "loss"
                            ]),
                        ))

            del imgs, targets, losses

        if self.cfg.TENSORBOARD and self.cfg.local_rank == 0:
            # Logging train losses
            [
                self.tb_writer.add_scalar(f"loss/{prefix}_{n}", l.global_avg,
                                          epoch)
                for n, l in lossLogger.meters.items()
            ]
            performances = performanceLogger.compute()
            if len(performances):
                [
                    self.tb_writer.add_scalar(f"performance/{prefix}_{k}", v,
                                              epoch)
                    for k, v in performances.items()
                ]

        if self.cfg.TENSORBOARD_WEIGHT and False:
            for name, param in model.named_parameters():
                layer, attr = os.path.splitext(name)
                attr = attr[1:]
                self.tb_writer.add_histogram("{}/{}".format(layer, attr),
                                             param, epoch)

    @torch.no_grad()
    def val_epoch(self, epoch, model, dataloader, prefix="val"):
        model.eval()

        lossLogger = LossLogger()
        performanceLogger = MetricLogger(self.dictionary, self.cfg)

        with torch.no_grad():
            for sample in dataloader:
                imgs, targets = sample['image'], sample['target']
                imgs = list(img.cuda() for img in imgs) if isinstance(
                    imgs, list) else imgs.cuda()
                if isinstance(targets, list):
                    if isinstance(targets[0], torch.Tensor):
                        targets = [t.cuda() for t in targets]
                    else:
                        targets = [{k: v.cuda()
                                    for k, v in t.items()} for t in targets]
                else:
                    targets = targets.cuda()

                losses, predicts = model(imgs, targets, prefix)

                if self.cfg.distributed:
                    # reduce losses over all GPUs for logging purposes
                    loss_dict_reduced = reduce_dict(losses)
                    lossLogger.update(**loss_dict_reduced)
                    del loss_dict_reduced
                else:
                    lossLogger.update(**losses)

                if predicts is not None:
                    if self.cfg.distributed:
                        # reduce performances over all GPUs for logging purposes
                        predicts_dict_reduced = reduce_dict(predicts)
                        performanceLogger.update(targets,
                                                 predicts_dict_reduced)
                        del predicts_dict_reduced
                    else:
                        performanceLogger.update(targets, predicts)
                    del predicts

                del imgs, targets, losses

        performances = performanceLogger.compute()
        if self.cfg.TENSORBOARD and self.cfg.local_rank == 0:
            # Logging val Loss
            [
                self.tb_writer.add_scalar(f"loss/{prefix}_{n}", l.global_avg,
                                          epoch)
                for n, l in lossLogger.meters.items()
            ]
            if len(performances):
                # Logging val performances
                [
                    self.tb_writer.add_scalar(f"performance/{prefix}_{k}", v,
                                              epoch)
                    for k, v in performances.items()
                ]

        if self.cfg.local_rank == 0:
            template = "[epoch {}] Total {} loss : {:.4f} " "\n" "{}"
            logger.info(
                template.format(
                    epoch,
                    prefix,
                    lossLogger.meters["loss"].global_avg,
                    "\n".join([
                        "{}: {:.4f}".format(n, l.global_avg)
                        for n, l in lossLogger.meters.items() if n != "loss"
                    ]),
                ))

            perf_log_str = f"\n------------ Performances ({prefix}) ----------\n"
            for k, v in performances.items():
                perf_log_str += "{:}: {:.4f}\n".format(k, v)
            perf_log_str += "------------------------------------\n"
            logger.info(perf_log_str)

        acc = performances['performance']

        return acc
Beispiel #4
0
class Trainer:
    def __init__(self, cfg):
        self.cfg = cfg
        self.start_epoch = -1
        self.n_iters_elapsed = 0
        self.device = self.cfg.GPU_IDS
        self.batch_size = self.cfg.BATCH_SIZE
        self.batch_size_all = self.cfg.BATCH_SIZE * len(self.cfg.GPU_IDS)

        self.n_steps_per_epoch = None
        if cfg.local_rank == 0:
            self.experiment_id = self.experiment_id(self.cfg)
            self.ckpts = Checkpoints(logger, self.cfg.CHECKPOINT_DIR,
                                     self.experiment_id)
            self.tb_writer = DummyWriter(
                log_dir="%s/%s" %
                (self.cfg.TENSORBOARD_LOG_DIR, self.experiment_id))

    def experiment_id(self, cfg):
        return f"{cfg.EXPERIMENT_NAME}#{cfg.USE_MODEL.split('.')[-1]}#{cfg.OPTIMIZER.TYPE}#{cfg.LR_SCHEDULER.TYPE}" \
               f"#{datetime.now().strftime('%Y_%m_%d_%H_%M_%S')}"

    def _parser_dict(self):
        dictionary = CommonConfiguration.from_yaml(cfg.DATASET.DICTIONARY)
        return next(dictionary.items())[1]  ## return first

    def _parser_datasets(self, dictionary):
        # transforms = prepare_transforms_seg()
        # target_transforms = prepare_transforms_mask()
        *dataset_str_parts, dataset_class_str = cfg.DATASET.CLASS.split(".")
        dataset_class = getattr(import_module(".".join(dataset_str_parts)),
                                dataset_class_str)

        datasets = {
            x: dataset_class(data_cfg=cfg.DATASET[x.upper()],
                             dictionary=dictionary,
                             transform=None,
                             target_transform=None,
                             stage=x)
            for x in ['train', 'val']
        }

        data_samplers = defaultdict()
        if self.cfg.distributed:
            data_samplers = {
                x: DistributedSampler(datasets[x],
                                      shuffle=cfg.DATASET[x.upper()].SHUFFLE)
                for x in ['train', 'val']
            }
        else:
            data_samplers['train'] = RandomSampler(datasets['train'])
            data_samplers['val'] = SequentialSampler(datasets['val'])

        dataloaders = {
            x: DataLoader(datasets[x],
                          batch_size=self.batch_size,
                          sampler=data_samplers[x],
                          num_workers=cfg.NUM_WORKERS,
                          collate_fn=dataset_class.collate if hasattr(
                              dataset_class, 'collate') else default_collate,
                          pin_memory=True,
                          drop_last=True)
            for x in ['train', 'val']
        }  # collate_fn=detection_collate,

        return datasets, dataloaders, data_samplers

    def _parser_model(self, dictionary):
        *model_mod_str_parts, model_class_str = self.cfg.USE_MODEL.split(".")
        model_class = getattr(import_module(".".join(model_mod_str_parts)),
                              model_class_str)
        model = model_class(dictionary=dictionary)

        if self.cfg.distributed:
            model = convert_syncbn_model(model).cuda()
        else:
            model = model.cuda()

        return model

    def run(self):
        ## init distributed
        self.cfg = init_distributed(self.cfg)

        cfg = self.cfg
        # cfg.print()

        ## parser_dict
        dictionary = self._parser_dict()

        ## parser_datasets
        datasets, dataloaders, data_samplers = self._parser_datasets(
            dictionary)
        # dataset_sizes = {x: len(datasets[x]) for x in ['train', 'val']}
        # class_names = datasets['train'].classes

        ## parser_model
        model_ft = self._parser_model(dictionary)

        test_only = False
        if test_only:
            '''
            confmat = evaluate(model_ft, data_loader_test, device=device, num_classes=num_classes)
            print(confmat)
            '''
            return

        ## parser_optimizer
        # Scale learning rate based on global batch size
        # cfg.INIT_LR = cfg.INIT_LR * float(self.batch_size_all) / 256
        optimizer_ft = parser_optimizer(cfg, model_ft)

        ## parser_lr_scheduler
        lr_scheduler_ft = parser_lr_scheduler(cfg, optimizer_ft)

        model_ft, optimizer_ft = amp.initialize(model_ft,
                                                optimizer_ft,
                                                opt_level=cfg.APEX_LEVEL,
                                                verbosity=0)
        '''
        # Scheduler https://arxiv.org/pdf/1812.01187.pdf
        # https://pytorch.org/docs/stable/_modules/torch/optim/lr_scheduler.html#OneCycleLR
        lf = lambda x: (((1 + math.cos(x * math.pi / self.cfg.N_MAX_EPOCHS)) / 2) ** 1.0) * 0.8 + 0.2  # cosine
        lr_scheduler_ft = lr_scheduler.LambdaLR(optimizer_ft, lr_lambda=lf)
        '''

        if cfg.distributed:
            model_ft = DistributedDataParallel(model_ft, delay_allreduce=True)
        '''
        # Freeze
        freeze = ['', ]  # parameter names to freeze (full or partial)
        if any(freeze):
            for k, v in model_ft.named_parameters():
                if any(x in k for x in freeze):
                    print('freezing %s' % k)
                    v.requires_grad = False
        '''

        if self.cfg.PRETRAIN_MODEL is not None:
            if self.cfg.RESUME:
                self.start_epoch = self.ckpts.load_checkpoint(
                    self.cfg.PRETRAIN_MODEL, model_ft, optimizer_ft,
                    lr_scheduler_ft, amp)
            else:
                self.ckpts.load_checkpoint(self.cfg.PRETRAIN_MODEL, model_ft)

        ## vis net graph
        if self.cfg.TENSORBOARD_MODEL and False:
            self.tb_writer.add_graph(model_ft, (model_ft.dummy_input.cuda(), ))

        self.n_steps_per_epoch = int(
            ceil(sum(len(t) for t in datasets['train'])))

        best_acc = 0.0
        for epoch in range(self.start_epoch + 1, self.cfg.N_MAX_EPOCHS):
            if cfg.distributed:
                dataloaders['train'].sampler.set_epoch(epoch)
            self.train_epoch(epoch, model_ft, dataloaders['train'],
                             optimizer_ft, lr_scheduler_ft, None)
            lr_scheduler_ft.step()

            if self.cfg.DATASET.VAL:
                acc = self.val_epoch(epoch, model_ft, dataloaders['val'],
                                     optimizer_ft, lr_scheduler_ft)

                if cfg.local_rank == 0:
                    # start to save best performance model after learning rate decay to 1e-6
                    if best_acc < acc:
                        self.ckpts.autosave_checkpoint(model_ft, epoch, 'best',
                                                       optimizer_ft,
                                                       lr_scheduler_ft, amp)
                        best_acc = acc
                        continue

            if not epoch % cfg.N_EPOCHS_TO_SAVE_MODEL:
                if cfg.local_rank == 0:
                    self.ckpts.autosave_checkpoint(model_ft, epoch, 'autosave',
                                                   optimizer_ft,
                                                   lr_scheduler_ft, amp)

        if cfg.local_rank == 0:
            self.tb_writer.close()

        dist.destroy_process_group() if cfg.local_rank != 0 else None
        torch.cuda.empty_cache()

    def train_epoch(self,
                    epoch,
                    model,
                    dataloader,
                    optimizer,
                    lr_scheduler,
                    grad_normalizer=None,
                    prefix="train"):
        model.train()

        _timer = Timer()
        lossLogger = MetricLogger(delimiter="  ")
        performanceLogger = MetricLogger(delimiter="  ")

        for i, (imgs, targets) in enumerate(dataloader):
            _timer.tic()
            # zero the parameter gradients
            optimizer.zero_grad()

            # imgs = imgs.cuda()
            imgs = list(
                img.cuda()
                for img in imgs) if isinstance(imgs, list) else imgs.cuda()
            # labels = [label.cuda() for label in labels] if isinstance(labels,list) else labels.cuda()
            # labels = [{k: v.cuda() for k, v in t.items()} for t in labels] if isinstance(labels,list) else labels.cuda()
            if isinstance(targets, list):
                if isinstance(targets[0], torch.Tensor):
                    targets = [t.cuda() for t in targets]
                else:
                    targets = [{k: v.cuda()
                                for k, v in t.items()} for t in targets]
            else:
                targets = targets.cuda()

            out = model(imgs, targets, prefix)

            if not isinstance(out, tuple):
                losses, performances = out, None
            else:
                losses, performances = out

            self.n_iters_elapsed += 1

            with amp.scale_loss(losses["loss"], optimizer) as scaled_loss:
                scaled_loss.backward()

            optimizer.step()

            torch.cuda.synchronize()
            _timer.toc()

            if (i + 1) % self.cfg.N_ITERS_TO_DISPLAY_STATUS == 0:
                if self.cfg.distributed:
                    # reduce losses over all GPUs for logging purposes
                    loss_dict_reduced = reduce_dict(losses)
                    lossLogger.update(**loss_dict_reduced)
                else:
                    lossLogger.update(**losses)

                if performances is not None and all(performances):
                    if self.cfg.distributed:
                        # reduce performances over all GPUs for logging purposes
                        performance_dict_reduced = reduce_dict(performances)
                        performanceLogger.update(**performance_dict_reduced)
                    else:
                        performanceLogger.update(**performances)

                if self.cfg.local_rank == 0:
                    template = "[epoch {}/{}, iter {}, lr {}] Total train loss: {:.4f} " "(ips = {:.2f})\n" "{}"
                    logger.info(
                        template.format(
                            epoch,
                            self.cfg.N_MAX_EPOCHS,
                            i,
                            round(get_current_lr(optimizer), 6),
                            lossLogger.meters["loss"].value,
                            self.batch_size *
                            self.cfg.N_ITERS_TO_DISPLAY_STATUS /
                            _timer.total_time,
                            "\n".join([
                                "{}: {:.4f}".format(n, l.value)
                                for n, l in lossLogger.meters.items()
                                if n != "loss"
                            ]),
                        ))

            del imgs, targets

        if self.cfg.TENSORBOARD and self.cfg.local_rank == 0:
            # Logging train losses
            [
                self.tb_writer.add_scalar(f"loss/{prefix}_{n}", l.global_avg,
                                          epoch)
                for n, l in lossLogger.meters.items()
            ]
            if len(performanceLogger.meters):
                [
                    self.tb_writer.add_scalar(f"performance/{prefix}_{k}",
                                              v.global_avg, epoch)
                    for k, v in performanceLogger.meters.items()
                ]

        if self.cfg.TENSORBOARD_WEIGHT and False:
            for name, param in model.named_parameters():
                layer, attr = os.path.splitext(name)
                attr = attr[1:]
                self.tb_writer.add_histogram("{}/{}".format(layer, attr),
                                             param, epoch)

    @torch.no_grad()
    def val_epoch(self,
                  epoch,
                  model,
                  dataloader,
                  optimizer=None,
                  lr_scheduler=None,
                  prefix="val"):
        model.eval()

        lossLogger = MetricLogger(delimiter="  ")
        performanceLogger = MetricLogger(delimiter="  ")

        with torch.no_grad():
            for (imgs, targets) in dataloader:

                # imgs = imgs.cuda()
                imgs = list(img.cuda() for img in imgs) if isinstance(
                    imgs, list) else imgs.cuda()
                # labels = [label.cuda() for label in labels] if isinstance(labels,list) else labels.cuda()
                # labels = [{k: v.cuda() for k, v in t.items()} for t in labels] if isinstance(labels,list) else labels.cuda()
                if isinstance(targets, list):
                    if isinstance(targets[0], torch.Tensor):
                        targets = [t.cuda() for t in targets]
                    else:
                        targets = [{k: v.cuda()
                                    for k, v in t.items()} for t in targets]
                else:
                    targets = targets.cuda()

                losses, performances = model(imgs, targets, prefix)

                if self.cfg.distributed:
                    # reduce losses over all GPUs for logging purposes
                    loss_dict_reduced = reduce_dict(losses)
                    lossLogger.update(**loss_dict_reduced)
                else:
                    lossLogger.update(**losses)

                if performances is not None and all(performances):
                    if self.cfg.distributed:
                        # reduce performances over all GPUs for logging purposes
                        performance_dict_reduced = reduce_dict(performances)
                        performanceLogger.update(**performance_dict_reduced)
                    else:
                        performanceLogger.update(**performances)

                del imgs, targets

        if self.cfg.TENSORBOARD and self.cfg.local_rank == 0:
            # Logging val Loss
            [
                self.tb_writer.add_scalar(f"loss/{prefix}_{n}", l.global_avg,
                                          epoch)
                for n, l in lossLogger.meters.items()
            ]
            if len(performanceLogger.meters):
                # Logging val performances
                [
                    self.tb_writer.add_scalar(f"performance/{prefix}_{k}",
                                              v.global_avg, epoch)
                    for k, v in performanceLogger.meters.items()
                ]

        if self.cfg.local_rank == 0:
            template = "[epoch {}] Total {} loss : {:.4f} " "\n" "{}"
            logger.info(
                template.format(
                    epoch,
                    prefix,
                    lossLogger.meters["loss"].global_avg,
                    "\n".join([
                        "{}: {:.4f}".format(n, l.global_avg)
                        for n, l in lossLogger.meters.items() if n != "loss"
                    ]),
                ))

            perf_log_str = f"\n------------ Performances ({prefix}) ----------\n"
            for k, v in performanceLogger.meters.items():
                perf_log_str += "{:}: {:.4f}\n".format(k, v.global_avg)
            perf_log_str += "------------------------------------\n"
            logger.info(perf_log_str)

        acc = performanceLogger.meters['performance'].global_avg

        return acc