Ejemplo n.º 1
0
    def __init__(self,
                 misc_args,
                 log_period=20,
                 number_val_iter=1,
                 tensorboard_logger=None):
        # Output logging period in SGD iterations
        self.misc_args = misc_args
        self.LOG_PERIOD = log_period
        self.tblogger = tensorboard_logger
        self.tb_ignored_keys = ['iter', 'eta']
        self.iter_timer = Timer()
        # Window size for smoothing tracked values (with median filtering)
        self.WIN_SZ = 20

        def create_smoothed_value():
            return SmoothedValue(self.WIN_SZ)

        self.smoothed_losses = defaultdict(create_smoothed_value)
        self.smoothed_metrics = defaultdict(create_smoothed_value)
        self.smoothed_total_loss = SmoothedValue(self.WIN_SZ)
        # For the support of args.iter_size
        self.inner_total_loss = []
        self.inner_losses = defaultdict(list)
        if cfg.FPN.FPN_ON:
            self.inner_loss_rpn_cls = []
            self.inner_loss_rpn_bbox = []
        self.inner_metrics = defaultdict(list)
        self.number_val_iter = number_val_iter
Ejemplo n.º 2
0
    def __init__(self, misc_args, log_period=20, tensorboard_logger=None):
        # Output logging period in SGD iterations
        self.misc_args = misc_args
        self.LOG_PERIOD = log_period
        self.tblogger = tensorboard_logger
        self.tb_ignored_keys = ['iter', 'eta']
        self.iter_timer = Timer()
        # Window size for smoothing tracked values (with median filtering)
        self.WIN_SZ = 20

        def create_smoothed_value():
            return SmoothedValue(self.WIN_SZ)

        self.smoothed_losses = defaultdict(create_smoothed_value)
        self.smoothed_metrics = defaultdict(create_smoothed_value)
        self.smoothed_total_loss = SmoothedValue(self.WIN_SZ)
 def __init__(self, metrics, losses, solver_max_iters):
     self.solver_max_iters = solver_max_iters
     # Window size for smoothing tracked values (with median filtering)
     self.win_sz = 20
     # Output logging period in SGD iterations
     self.log_period = 20
     self.smoothed_losses_and_metrics = {
         key: SmoothedValue(self.win_sz)
         for key in losses + metrics
     }
     self.losses_and_metrics = {key: 0 for key in losses + metrics}
     self.smoothed_total_loss = SmoothedValue(self.win_sz)
     self.smoothed_mb_qsize = SmoothedValue(self.win_sz)
     self.iter_total_loss = np.nan
     self.iter_timer = Timer()
     self.metrics = metrics
     self.losses = losses
Ejemplo n.º 4
0
 def __init__(self, model):
     # Window size for smoothing tracked values (with median filtering)
     self.WIN_SZ = 20
     # Output logging period in SGD iterations
     self.LOG_PERIOD = 20
     self.smoothed_losses_and_metrics = {
         key: SmoothedValue(self.WIN_SZ)
         for key in model.losses + model.metrics
     }
     self.losses_and_metrics = {
         key: 0
         for key in model.losses + model.metrics
     }
     self.smoothed_total_loss = SmoothedValue(self.WIN_SZ)
     self.smoothed_mb_qsize = SmoothedValue(self.WIN_SZ)
     self.iter_total_loss = np.nan
     self.iter_timer = Timer()
     self.model = model
Ejemplo n.º 5
0
    def __init__(self, misc_args, tblogger):
        # Output logging period in SGD iterations
        self.misc_args = misc_args
        # pause()
        self.LOG_PERIOD = misc_args.disp_interval
        self.tblogger = tblogger
        self.tb_ignored_keys = ['iter', 'eta']
        self.iter_timer = Timer()
        # Window size for smoothing tracked values (with median filtering)
        self.WIN_SZ = 20

        def create_smoothed_value():
            return SmoothedValue(self.WIN_SZ)

        self.smoothed_losses = defaultdict(create_smoothed_value)
        self.smoothed_total_loss = SmoothedValue(self.WIN_SZ)
        # For the support of args.iter_size
        self.inner_total_loss = []
        self.inner_losses = defaultdict(list)
Ejemplo n.º 6
0
class TrainingStats(object):
    """Track vital training statistics."""

    def __init__(self, model):
        # Window size for smoothing tracked values (with median filtering)
        self.WIN_SZ = 20
        # Output logging period in SGD iterations
        self.LOG_PERIOD = 20
        self.smoothed_losses_and_metrics = {
            key: SmoothedValue(self.WIN_SZ)
            for key in model.losses + model.metrics
        }
        self.losses_and_metrics = {
            key: 0
            for key in model.losses + model.metrics
        }
        self.smoothed_total_loss = SmoothedValue(self.WIN_SZ)
        self.smoothed_mb_qsize = SmoothedValue(self.WIN_SZ)
        self.iter_total_loss = np.nan
        self.iter_timer = Timer()
        self.model = model

    def IterTic(self):
        self.iter_timer.tic()

    def IterToc(self):
        return self.iter_timer.toc(average=False)

    def ResetIterTimer(self):
        self.iter_timer.reset()

    def UpdateIterStats(self):
        """Update tracked iteration statistics."""
        for k in list(self.losses_and_metrics.keys()):
            if k in self.model.losses:
                self.losses_and_metrics[k] = nu.sum_multi_gpu_blob(k)
            else:
                self.losses_and_metrics[k] = nu.average_multi_gpu_blob(k)
        for k, v in list(self.smoothed_losses_and_metrics.items()):
            v.AddValue(self.losses_and_metrics[k])
        self.iter_total_loss = np.sum(
            np.array([self.losses_and_metrics[k] for k in self.model.losses])
        )
        self.smoothed_total_loss.AddValue(self.iter_total_loss)
        self.smoothed_mb_qsize.AddValue(
            self.model.roi_data_loader._minibatch_queue.qsize()
        )

    def LogIterStats(self, cur_iter, lr):
        """Log the tracked statistics."""
        if (cur_iter % self.LOG_PERIOD == 0 or
                cur_iter == cfg.SOLVER.MAX_ITER - 1):
            stats = self.GetStats(cur_iter, lr)
            log_json_stats(stats)

    def GetStats(self, cur_iter, lr):
        eta_seconds = self.iter_timer.average_time * (
            cfg.SOLVER.MAX_ITER - cur_iter
        )
        eta = str(datetime.timedelta(seconds=int(eta_seconds)))
        mem_stats = c2_py_utils.GetGPUMemoryUsageStats()
        mem_usage = np.max(mem_stats['max_by_gpu'][:cfg.NUM_GPUS])
        stats = dict(
            iter=cur_iter,
            lr=float(lr),
            time=self.iter_timer.average_time,
            loss=self.smoothed_total_loss.GetMedianValue(),
            eta=eta,
            mb_qsize=int(
                np.round(self.smoothed_mb_qsize.GetMedianValue())
            ),
            mem=int(np.ceil(mem_usage / 1024 / 1024))
        )
        for k, v in list(self.smoothed_losses_and_metrics.items()):
            stats[k] = v.GetMedianValue()
        return stats
 def create_smoothed_value():
     return SmoothedValue(self.WIN_SZ)
class TrainingStats(object):
    """Track vital training statistics."""

    def __init__(self, misc_args, log_period=20, tensorboard_logger=None):
        # Output logging period in SGD iterations
        self.misc_args = misc_args
        self.LOG_PERIOD = log_period
        self.tblogger = tensorboard_logger
        self.tb_ignored_keys = ['iter', 'eta']
        self.iter_timer = Timer()
        # Window size for smoothing tracked values (with median filtering)
        self.WIN_SZ = 20
        def create_smoothed_value():
            return SmoothedValue(self.WIN_SZ)
        self.smoothed_losses = defaultdict(create_smoothed_value)
        self.smoothed_metrics = defaultdict(create_smoothed_value)
        self.smoothed_total_loss = SmoothedValue(self.WIN_SZ)
        # For the support of args.iter_size
        self.inner_total_loss = []
        self.inner_losses = defaultdict(list)
        if cfg.FPN.FPN_ON:
            self.inner_loss_rpn_cls = []
            self.inner_loss_rpn_bbox = []
        self.inner_metrics = defaultdict(list)

    def IterTic(self):
        self.iter_timer.tic()

    def IterToc(self):
        return self.iter_timer.toc(average=False)

    def ResetIterTimer(self):
        self.iter_timer.reset()

    def UpdateIterStats(self, model_out, inner_iter=None):
        """Update tracked iteration statistics."""
        if inner_iter is not None and self.misc_args.iter_size > 1:
            # For the case of using args.iter_size > 1
            return self._UpdateIterStats_inner(model_out, inner_iter)

        # Following code is saved for compatability of train_net.py and iter_size==1
        total_loss = 0
        if cfg.FPN.FPN_ON:
            loss_rpn_cls_data = 0
            loss_rpn_bbox_data = 0

        for k, loss in model_out['losses'].items():
            assert loss.shape[0] == cfg.NUM_GPUS
            loss = loss.mean(dim=0)
            total_loss += loss
            loss_data = loss.data[0]
            model_out['losses'][k] = loss
            if cfg.FPN.FPN_ON:
                if k.startswith('loss_rpn_cls_'):
                    loss_rpn_cls_data += loss_data
                elif k.startswith('loss_rpn_bbox_'):
                    loss_rpn_bbox_data += loss_data
            self.smoothed_losses[k].AddValue(loss_data)

        model_out['total_loss'] = total_loss  # Add the total loss for back propagation
        self.smoothed_total_loss.AddValue(total_loss.data[0])
        if cfg.FPN.FPN_ON:
            self.smoothed_losses['loss_rpn_cls'].AddValue(loss_rpn_cls_data)
            self.smoothed_losses['loss_rpn_bbox'].AddValue(loss_rpn_bbox_data)

        for k, metric in model_out['metrics'].items():
            metric = metric.mean(dim=0)
            self.smoothed_metrics[k].AddValue(metric.data[0])

    def _UpdateIterStats_inner(self, model_out, inner_iter):
        """Update tracked iteration statistics for the case of iter_size > 1"""
        assert inner_iter < self.misc_args.iter_size

        total_loss = 0
        if cfg.FPN.FPN_ON:
            loss_rpn_cls_data = 0
            loss_rpn_bbox_data = 0

        if inner_iter == 0:
            self.inner_total_loss = []
            for k in model_out['losses']:
                self.inner_losses[k] = []
            if cfg.FPN.FPN_ON:
                self.inner_loss_rpn_cls = []
                self.inner_loss_rpn_bbox = []
            for k in model_out['metrics']:
                self.inner_metrics[k] = []

        for k, loss in model_out['losses'].items():
            assert loss.shape[0] == cfg.NUM_GPUS
            loss = loss.mean(dim=0)
            total_loss += loss
            loss_data = loss.data[0]

            model_out['losses'][k] = loss
            if cfg.FPN.FPN_ON:
                if k.startswith('loss_rpn_cls_'):
                    loss_rpn_cls_data += loss_data
                elif k.startswith('loss_rpn_bbox_'):
                    loss_rpn_bbox_data += loss_data

            self.inner_losses[k].append(loss_data)
            if inner_iter == (self.misc_args.iter_size - 1):
                loss_data = self._mean_and_reset_inner_list('inner_losses', k)
                self.smoothed_losses[k].AddValue(loss_data)

        model_out['total_loss'] = total_loss  # Add the total loss for back propagation
        total_loss_data = total_loss.data[0]
        self.inner_total_loss.append(total_loss_data)
        if cfg.FPN.FPN_ON:
            self.inner_loss_rpn_cls.append(loss_rpn_cls_data)
            self.inner_loss_rpn_bbox.append(loss_rpn_bbox_data)
        if inner_iter == (self.misc_args.iter_size - 1):
            total_loss_data = self._mean_and_reset_inner_list('inner_total_loss')
            self.smoothed_total_loss.AddValue(total_loss_data)
            if cfg.FPN.FPN_ON:
                loss_rpn_cls_data = self._mean_and_reset_inner_list('inner_loss_rpn_cls')
                loss_rpn_bbox_data = self._mean_and_reset_inner_list('inner_loss_rpn_bbox')
                self.smoothed_losses['loss_rpn_cls'].AddValue(loss_rpn_cls_data)
                self.smoothed_losses['loss_rpn_bbox'].AddValue(loss_rpn_bbox_data)

        for k, metric in model_out['metrics'].items():
            metric = metric.mean(dim=0)
            metric_data = metric.data[0]
            self.inner_metrics[k].append(metric_data)
            if inner_iter == (self.misc_args.iter_size - 1):
                metric_data = self._mean_and_reset_inner_list('inner_metrics', k)
                self.smoothed_metrics[k].AddValue(metric_data)

    def _mean_and_reset_inner_list(self, attr_name, key=None):
        """Take the mean and reset list empty"""
        if key:
            mean_val = sum(getattr(self, attr_name)[key]) / self.misc_args.iter_size
            getattr(self, attr_name)[key] = []
        else:
            mean_val = sum(getattr(self, attr_name)) / self.misc_args.iter_size
            setattr(self, attr_name, [])
        return mean_val

    def LogIterStats(self, cur_iter, lr):
        """Log the tracked statistics."""
        if (cur_iter % self.LOG_PERIOD == 0 or
                cur_iter == cfg.SOLVER.MAX_ITER - 1):
            stats = self.GetStats(cur_iter, lr)
            log_stats(stats, self.misc_args)
            if self.tblogger:
                self.tb_log_stats(stats, cur_iter)

    def tb_log_stats(self, stats, cur_iter):
        """Log the tracked statistics to tensorboard"""
        for k in stats:
            if k not in self.tb_ignored_keys:
                v = stats[k]
                if isinstance(v, dict):
                    self.tb_log_stats(v, cur_iter)
                else:
                    self.tblogger.add_scalar(k, v, cur_iter)

    def GetStats(self, cur_iter, lr):
        eta_seconds = self.iter_timer.average_time * (
            cfg.SOLVER.MAX_ITER - cur_iter
        )
        eta = str(datetime.timedelta(seconds=int(eta_seconds)))
        stats = OrderedDict(
            iter=cur_iter + 1,  # 1-indexed
            time=self.iter_timer.average_time,
            eta=eta,
            loss=self.smoothed_total_loss.GetMedianValue(),
            lr=lr,
        )
        stats['metrics'] = OrderedDict()
        for k in sorted(self.smoothed_metrics):
            stats['metrics'][k] = self.smoothed_metrics[k].GetMedianValue()

        head_losses = []
        rpn_losses = []
        rpn_fpn_cls_losses = []
        rpn_fpn_bbox_losses = []
        for k, v in self.smoothed_losses.items():
            toks = k.split('_')
            if len(toks) == 2:
                head_losses.append((k, v.GetMedianValue()))
            elif len(toks) == 3:
                rpn_losses.append((k, v.GetMedianValue()))
            elif len(toks) == 4 and toks[2] == 'cls':
                rpn_fpn_cls_losses.append((k, v.GetMedianValue()))
            elif len(toks) == 4 and toks[2] == 'bbox':
                rpn_fpn_bbox_losses.append((k, v.GetMedianValue()))
            else:
                raise ValueError("Unexpected loss key: %s" % k)
        stats['head_losses'] = OrderedDict(head_losses)
        stats['rpn_losses'] = OrderedDict(rpn_losses)
        stats['rpn_fpn_cls_losses'] = OrderedDict(rpn_fpn_cls_losses)
        stats['rpn_fpn_bbox_losses'] = OrderedDict(rpn_fpn_bbox_losses)

        return stats
Ejemplo n.º 9
0
class TrainingStats(object):
    """Track vital training statistics."""
    def __init__(self, misc_args, log_period=20, tensorboard_logger=None):
        # Output logging period in SGD iterations
        self.misc_args = misc_args
        self.LOG_PERIOD = log_period
        self.tblogger = tensorboard_logger
        self.tb_ignored_keys = ['iter', 'eta']
        self.iter_timer = Timer()
        # Window size for smoothing tracked values (with median filtering)
        self.WIN_SZ = 20

        def create_smoothed_value():
            return SmoothedValue(self.WIN_SZ)

        self.smoothed_losses = defaultdict(create_smoothed_value)
        self.smoothed_metrics = defaultdict(create_smoothed_value)
        self.smoothed_total_loss = SmoothedValue(self.WIN_SZ)

    def IterTic(self):
        self.iter_timer.tic()

    def IterToc(self):
        return self.iter_timer.toc(average=False)

    def ResetIterTimer(self):
        self.iter_timer.reset()

    def UpdateIterStats(self, model_out):
        """Update tracked iteration statistics."""
        total_loss = 0
        if cfg.FPN.FPN_ON:
            loss_rpn_cls_value = 0
            loss_rpn_bbox_value = 0

        for k, loss in model_out['losses'].items():
            assert loss.shape[0] == cfg.NUM_GPUS
            loss = loss.mean(dim=0)
            total_loss += loss
            loss_data = loss.data[0]
            self.smoothed_losses[k].AddValue(loss_data)
            model_out['losses'][k] = loss
            if k.startswith('loss_rpn_cls'):
                loss_rpn_cls_value += loss_data
            elif k.startswith('loss_rpn_bbox'):
                loss_rpn_bbox_value += loss_data

        self.smoothed_total_loss.AddValue(total_loss.data[0])
        model_out['total_loss'] = total_loss
        if cfg.FPN.FPN_ON:
            self.smoothed_losses['loss_rpn_cls'].AddValue(loss_rpn_cls_value)
            self.smoothed_losses['loss_rpn_bbox'].AddValue(loss_rpn_bbox_value)

        for k, metric in model_out['metrics'].items():
            metric = metric.mean(dim=0)
            self.smoothed_metrics[k].AddValue(metric.data[0])
            model_out['metrics'][k] = metric

    def LogIterStats(self, cur_iter, lr):
        """Log the tracked statistics."""
        if (cur_iter % self.LOG_PERIOD == 0
                or cur_iter == cfg.SOLVER.MAX_ITER - 1):
            stats = self.GetStats(cur_iter, lr)
            log_stats(stats, self.misc_args)
            if self.tblogger:
                self.tb_log_stats(stats, cur_iter)

    def tb_log_stats(self, stats, cur_iter):
        """Log the tracked statistics to tensorboard"""
        for k in stats:
            if k not in self.tb_ignored_keys:
                v = stats[k]
                if isinstance(v, dict):
                    self.tb_log_stats(v, cur_iter)
                else:
                    self.tblogger.add_scalar(k, v, cur_iter)

    def GetStats(self, cur_iter, lr):
        eta_seconds = self.iter_timer.average_time * (cfg.SOLVER.MAX_ITER -
                                                      cur_iter)
        eta = str(datetime.timedelta(seconds=int(eta_seconds)))
        stats = OrderedDict(
            iter=cur_iter + 1,  # 1-indexed
            time=self.iter_timer.average_time,
            eta=eta,
            loss=self.smoothed_total_loss.GetMedianValue(),
            lr=lr,
        )
        stats['metrics'] = OrderedDict()
        for k in sorted(self.smoothed_metrics):
            stats['metrics'][k] = self.smoothed_metrics[k].GetMedianValue()

        head_losses = []
        rpn_losses = []
        rpn_fpn_cls_losses = []
        rpn_fpn_bbox_losses = []
        for k, v in self.smoothed_losses.items():
            toks = k.split('_')
            if len(toks) == 2:
                head_losses.append((k, v.GetMedianValue()))
            elif len(toks) == 3:
                rpn_losses.append((k, v.GetMedianValue()))
            elif len(toks) == 4 and toks[2] == 'cls':
                rpn_fpn_cls_losses.append((k, v.GetMedianValue()))
            elif len(toks) == 4 and toks[2] == 'bbox':
                rpn_fpn_bbox_losses.append((k, v.GetMedianValue()))
            else:
                raise ValueError("Unexpected loss key: %s" % k)
        stats['head_losses'] = OrderedDict(head_losses)
        stats['rpn_losses'] = OrderedDict(rpn_losses)
        stats['rpn_fpn_cls_losses'] = OrderedDict(rpn_fpn_cls_losses)
        stats['rpn_fpn_bbox_losses'] = OrderedDict(rpn_fpn_bbox_losses)

        return stats
class TrainingStats(object):
    """Track vital training statistics."""
    def __init__(self, metrics, losses, solver_max_iters):
        self.solver_max_iters = solver_max_iters
        # Window size for smoothing tracked values (with median filtering)
        self.win_sz = 20
        # Output logging period in SGD iterations
        self.log_period = 20
        self.smoothed_losses_and_metrics = {
            key: SmoothedValue(self.win_sz)
            for key in losses + metrics
        }
        self.losses_and_metrics = {key: 0 for key in losses + metrics}
        self.smoothed_total_loss = SmoothedValue(self.win_sz)
        self.smoothed_mb_qsize = SmoothedValue(self.win_sz)
        self.iter_total_loss = np.nan
        self.iter_timer = Timer()
        self.metrics = metrics
        self.losses = losses

    def IterTic(self):
        self.iter_timer.tic()

    def IterToc(self):
        return self.iter_timer.toc(average=False)

    def ResetIterTimer(self):
        self.iter_timer.reset()

    def UpdateIterStats(self, losses_dict, metrics_dict):
        """Update tracked iteration statistics."""
        for k in self.losses_and_metrics.keys():
            if k in self.losses:  # if loss
                self.losses_and_metrics[k] = losses_dict[k]
            else:  # if metric
                self.losses_and_metrics[k] = metrics_dict[k]

        for k, v in self.smoothed_losses_and_metrics.items():
            v.AddValue(self.losses_and_metrics[k])
        #import pdb; pdb.set_trace()
        self.iter_total_loss = np.sum(
            np.array([self.losses_and_metrics[k] for k in self.losses]))
        self.smoothed_total_loss.AddValue(self.iter_total_loss)
        self.smoothed_mb_qsize.AddValue(
            #self.model.roi_data_loader._minibatch_queue.qsize()
            64)

    def LogIterStats(self, cur_iter, lr):
        """Log the tracked statistics."""
        if (cur_iter % self.log_period == 0
                or cur_iter == self.solver_max_iters - 1):
            stats = self.GetStats(cur_iter, lr)
            log_json_stats(stats)

    def GetStats(self, cur_iter, lr):
        eta_seconds = self.iter_timer.average_time * (self.solver_max_iters -
                                                      cur_iter)
        eta = str(datetime.timedelta(seconds=int(eta_seconds)))
        #mem_stats = c2_py_utils.GetGPUMemoryUsageStats()
        #mem_usage = np.max(mem_stats['max_by_gpu'][:cfg.NUM_GPUS])
        stats = dict(
            iter=cur_iter,
            lr="{:.6f}".format(float(lr)),
            time="{:.6f}".format(self.iter_timer.average_time),
            loss="{:.6f}".format(self.smoothed_total_loss.GetMedianValue()),
            eta=eta,
            #mb_qsize=int(np.round(self.smoothed_mb_qsize.GetMedianValue())),
            #mem=int(np.ceil(mem_usage / 1024 / 1024))
        )
        for k, v in self.smoothed_losses_and_metrics.items():
            stats[k] = "{:.6f}".format(v.GetMedianValue())
        return stats
Ejemplo n.º 11
0
class TrainingStats(object):
    """Track vital training statistics."""
    def __init__(self, misc_args, tblogger):
        # Output logging period in SGD iterations
        self.misc_args = misc_args
        # pause()
        self.LOG_PERIOD = misc_args.disp_interval
        self.tblogger = tblogger
        self.tb_ignored_keys = ['iter', 'eta']
        self.iter_timer = Timer()
        # Window size for smoothing tracked values (with median filtering)
        self.WIN_SZ = 20

        def create_smoothed_value():
            return SmoothedValue(self.WIN_SZ)

        self.smoothed_losses = defaultdict(create_smoothed_value)
        self.smoothed_total_loss = SmoothedValue(self.WIN_SZ)
        # For the support of args.iter_size
        self.inner_total_loss = []
        self.inner_losses = defaultdict(list)

    def IterTic(self):
        self.iter_timer.tic()

    def IterToc(self):
        return self.iter_timer.toc(average=False)

    def ResetIterTimer(self):
        self.iter_timer.reset()

    def UpdateIterStats(self, model_out, inner_iter=None):

        if 'extras' in model_out:
            self.extras = model_out['extras']
        else:
            self.extras = None
        """Update tracked iteration statistics."""
        if inner_iter is not None and cfg.TRAIN.ITERATION_SIZE > 1:
            # For the case of using args.iter_size > 1
            return self._UpdateIterStats_inner(model_out, inner_iter)

        # Following code is saved for compatability of train_net.py and iter_size==1
        total_loss = 0

        for k, loss in model_out['losses'].items():
            assert loss.shape[0] == cfg.NUM_GPUS
            loss = loss.mean(dim=0, keepdim=True)
            total_loss += loss
            loss_data = loss.data[0]
            model_out['losses'][k] = loss
            self.smoothed_losses[k].AddValue(loss_data)

        model_out[
            'total_loss'] = total_loss  # Add the total loss for back propagation
        self.smoothed_total_loss.AddValue(total_loss.data[0])

    def _UpdateIterStats_inner(self, model_out, inner_iter):
        """Update tracked iteration statistics for the case of iter_size > 1"""
        assert inner_iter < cfg.TRAIN.ITERATION_SIZE

        total_loss = 0

        if inner_iter == 0:
            self.inner_total_loss = []
            for k in model_out['losses']:
                self.inner_losses[k] = []

        for k, loss in model_out['losses'].items():
            total_loss += loss
            loss_data = loss.item()

            model_out['losses'][k] = loss

            self.inner_losses[k].append(loss_data)
            if inner_iter == (cfg.TRAIN.ITERATION_SIZE - 1):
                loss_data = self._mean_and_reset_inner_list('inner_losses', k)
                self.smoothed_losses[k].AddValue(loss_data)

        model_out[
            'total_loss'] = total_loss  # Add the total loss for back propagation
        total_loss_data = total_loss.item()
        self.inner_total_loss.append(total_loss_data)
        if inner_iter == (cfg.TRAIN.ITERATION_SIZE - 1):
            total_loss_data = self._mean_and_reset_inner_list(
                'inner_total_loss')
            self.smoothed_total_loss.AddValue(total_loss_data)

    def _mean_and_reset_inner_list(self, attr_name, key=None):
        """Take the mean and reset list empty"""
        if key:
            mean_val = sum(getattr(self,
                                   attr_name)[key]) / cfg.TRAIN.ITERATION_SIZE
            getattr(self, attr_name)[key] = []
        else:
            mean_val = sum(getattr(self, attr_name)) / cfg.TRAIN.ITERATION_SIZE
            setattr(self, attr_name, [])
        return mean_val

    def LogIterStats(self, cur_iter, lr):
        """Log the tracked statistics."""
        if (cur_iter % self.LOG_PERIOD == 0
                or cur_iter == cfg.SOLVER.MAX_ITER - 1):
            stats = self.GetStats(cur_iter, lr)
            log_stats(stats, self.misc_args)
            if self.tblogger:
                self.tb_log_stats(stats, cur_iter)

    def tb_log_stats(self, stats, cur_iter):
        """Log the tracked statistics to tensorboard"""
        for k in stats:
            if k not in self.tb_ignored_keys:
                v = stats[k]
                if isinstance(v, dict):
                    self.tb_log_stats(v, cur_iter)
                elif v is not None:
                    self.tblogger.add_scalar(k, v, cur_iter)

    def GetStats(self, cur_iter, lr):
        eta_seconds = self.iter_timer.average_time * (cfg.SOLVER.MAX_ITER -
                                                      cur_iter)
        eta = str(datetime.timedelta(seconds=int(eta_seconds)))
        stats = OrderedDict(
            iter=cur_iter + 1,  # 1-indexed
            time=self.iter_timer.average_time,
            eta=eta,
            loss=self.smoothed_total_loss.GetMedianValue(),
            lr=lr,
        )

        head_losses = []
        for k, v in self.smoothed_losses.items():
            head_losses.append((k, v.GetMedianValue()))
        stats['head_losses'] = OrderedDict(head_losses)
        stats['extras'] = self.extras
        return stats
Ejemplo n.º 12
0
class TrainingStats(object):
    """Track vital training statistics."""

    def __init__(self, misc_args, log_period=20, max_iter=cfg.GAN.SOLVER.MAX_ITER, tensorboard_logger=None):
        # Output logging period in SGD iterations
        self.max_iter =max_iter
        self.misc_args = misc_args
        self.LOG_PERIOD = log_period
        self.tblogger = tensorboard_logger
        self.tb_ignored_keys = ['iter', 'eta']

        self.iter_timer = Timer()
        # Window size for smoothing tracked values (with median filtering)
        self.WIN_SZ = 10 #20

        def create_smoothed_value():
            return SmoothedValue(self.WIN_SZ)

        self.smoothed_losses = defaultdict(create_smoothed_value)
        self.smoothed_total_loss = SmoothedValue(self.WIN_SZ)
        self.smoothed_metrics = defaultdict(create_smoothed_value)

    def IterTic(self):
        self.iter_timer.tic()

    def IterToc(self):
        return self.iter_timer.toc(average=False)

    def ResetIterTimer(self):
        self.iter_timer.reset()

    def UpdateIterStats(self, out=None, out_add=None):
        """Update tracked iteration statistics."""
        if out is not None:  # first trained on either real/fake images (then set flag)
            if out_add is None:
                total_loss = 0

                for k, loss in out['losses'].items():
                    assert loss.shape[0] == cfg.NUM_GPUS
                    loss = loss.mean(dim=0, keepdim=True)
                    total_loss += loss
                    loss_data = loss.data[0]
                    out['losses'][k] = loss
                    self.smoothed_losses[k].AddValue(loss_data)

                out['total_loss'] = total_loss  # Add the total loss for back propagation
                self.smoothed_total_loss.AddValue(total_loss.data[0])

                for k, metric in out['metrics'].items():
                    metric = metric.mean(dim=0, keepdim=True)
                    self.smoothed_metrics[k].AddValue(metric.data[0])

            else:
                total_loss = 0

                for loss_key in out['losses']:

                    loss = out['losses'][loss_key]
                    loss_add = out_add['losses'][loss_key]

                    assert loss.shape[0] == cfg.NUM_GPUS
                    assert loss_add.shape[0] == cfg.NUM_GPUS

                    loss = loss.mean(dim=0, keepdim=True)
                    loss_add = loss_add.mean(dim=0, keepdim=True)

                    total_loss += loss
                    total_loss += loss_add

                    loss_data = loss.data[0]
                    loss_data_add = loss_add.data[0]

                    out['losses'][loss_key] = loss
                    out_add['losses'][loss_key] = loss_add

                    self.smoothed_losses[loss_key].AddValue(loss_data + loss_data_add)

                out['total_loss'] = total_loss  # Add the total loss for back propagation
                out_add['total_loss'] = total_loss

                self.smoothed_total_loss.AddValue(total_loss.data[0])

                for metric_key in out['metrics']:
                    metric = out['metrics'][metric_key].mean(dim=0, keepdim=True)
                    metric_add = out_add['metrics'][metric_key].mean(dim=0, keepdim=True)

                    self.smoothed_metrics[metric_key].AddValue(0.5*(metric.data[0] + metric_add.data[0]))

    def LogIterStatsReal(self, cur_iter, lr):
        """Log the tracked statistics."""
        if (cur_iter % self.LOG_PERIOD == 0 or
                cur_iter == self.max_iter - 1):
            stats = self.GetStats(cur_iter, lr)
            log_gan_stats(self.misc_args, self.max_iter, stats_dis=stats)
            if self.tblogger:
                self.tb_log_stats(stats, cur_iter)

    def tb_log_stats(self, stats, cur_iter):
        """Log the tracked statistics to tensorboard"""
        for k in stats:
            if k not in self.tb_ignored_keys:
                v = stats[k]
                if isinstance(v, dict):
                    self.tb_log_stats(v, cur_iter)
                else:
                    if self.tblogger:
                        self.tblogger.add_scalar(k, v, cur_iter)

    def GetStats(self, cur_iter, lr):
        eta_seconds = self.iter_timer.average_time * (
            self.max_iter - cur_iter
        )
        eta = str(datetime.timedelta(seconds=int(eta_seconds)))
        stats = OrderedDict(
            iter=cur_iter + 1,  # 1-indexed
            time=self.iter_timer.average_time,
            eta=eta,
            loss=self.smoothed_total_loss.GetMedianValue(),
            lr=lr,
        )
        stats['metrics'] = OrderedDict()
        for k in sorted(self.smoothed_metrics):
            stats['metrics'][k] = self.smoothed_metrics[k].GetMedianValue()

        head_losses = []
        adv_loss = []

        for k, v in self.smoothed_losses.items():
            toks = k.split('_')
            if len(toks) == 2 and toks[1] == 'adv':
                adv_loss.append((k, v.GetMedianValue()))
            elif len(toks) == 2:
                head_losses.append((k, v.GetMedianValue()))
            else:
                raise ValueError("Unexpected loss key: %s" % k)

        stats['head_losses'] = OrderedDict(head_losses)
        stats['adv_loss'] = OrderedDict(adv_loss)

        return stats