Ejemplo n.º 1
0
    def _train_one_batch(self, x, y, optimizer, lr_scheduler, meters, criterions, end):
        top1_meter, top5_meter, loss_meter, data_time = meters
        criterion = criterions[0]
        world_size = dist.get_world_size()

        lr_scheduler.step(self.cur_step)
        self.cur_step += 1
        data_time.update(time.time() - end)

        self.model.zero_grad()
        out = self.model(x)
        loss = criterion(out, y)
        loss /= world_size

        top1, top5 = accuracy(out, y, top_k=(1, 5))
        reduced_loss = dist.all_reduce(loss.clone())
        reduced_top1 = dist.all_reduce(top1.clone(), div=True)
        reduced_top5 = dist.all_reduce(top5.clone(), div=True)

        loss_meter.update(reduced_loss.item())
        top1_meter.update(reduced_top1.item())
        top5_meter.update(reduced_top5.item())

        loss.backward()
        dist.average_gradient(self.model.parameters())
        optimizer.step()
Ejemplo n.º 2
0
    def validate(self, val_loader, tb_logger=None):
        batch_time = AverageMeter(0)
        loss_meter = AverageMeter(0)
        top1_meter = AverageMeter(0)
        top5_meter = AverageMeter(0)

        self.model.eval()
        criterion = nn.CrossEntropyLoss()
        end = time.time()

        with torch.no_grad():
            for batch_idx, (x, y) in enumerate(val_loader):
                x, y = x.cuda(), y.cuda()
                num = x.size(0)

                out = self.model(x)
                loss = criterion(out, y)
                top1, top5 = accuracy(out, y, top_k=(1, 5))

                loss_meter.update(loss.item(), num)
                top1_meter.update(top1.item(), num)
                top5_meter.update(top5.item(), num)

                batch_time.update(time.time() - end)
                end = time.time()

                if batch_idx % self.config.logging.print_freq == 0:
                    self._info(
                        'Test: [{0}/{1}]\tTime {batch_time.val:.3f} ({batch_time.avg:.3f})'
                        .format(batch_idx,
                                len(val_loader),
                                batch_time=batch_time))

        total_num = torch.tensor([loss_meter.count]).cuda()
        loss_sum = torch.tensor([loss_meter.avg * loss_meter.count]).cuda()
        top1_sum = torch.tensor([top1_meter.avg * top1_meter.count]).cuda()
        top5_sum = torch.tensor([top5_meter.avg * top5_meter.count]).cuda()

        dist.all_reduce(total_num)
        dist.all_reduce(loss_sum)
        dist.all_reduce(top1_sum)
        dist.all_reduce(top5_sum)

        val_loss = loss_sum.item() / total_num.item()
        val_top1 = top1_sum.item() / total_num.item()
        val_top5 = top5_sum.item() / total_num.item()

        self._info(
            'Prec@1 {:.3f}\tPrec@5 {:.3f}\tLoss {:.3f}\ttotal_num={}'.format(
                val_top1, val_top5, val_loss, loss_meter.count))

        if dist.is_master():
            if val_top1 > self.best_top1:
                self.best_top1 = val_top1

            if tb_logger is not None:
                tb_logger.add_scalar('loss_val', val_loss, self.cur_step)
                tb_logger.add_scalar('acc1_val', val_top1, self.cur_step)
                tb_logger.add_scalar('acc5_val', val_top5, self.cur_step)
Ejemplo n.º 3
0
def eval_epoch(val_loader, model, epoch, cfg):
    '''Evaluate the model on the val set.

    Args:
      val_loader (loader): data loader to provide validation data.
      model (model): model to evaluate the performance.
      epoch (int): number of the current epoch of training.
      cfg (CfgNode): configs. Details can be found in config/defaults.py
    '''
    if is_master_proc():
        log.info('Testing..')

    model.eval()
    test_loss = 0.0
    correct = total = 0.0
    for batch_idx, (inputs, labels) in enumerate(val_loader):
        inputs, labels = inputs.cuda(non_blocking=True), labels.cuda()
        outputs = model(inputs)
        loss = F.cross_entropy(outputs, labels, reduction='mean')

        # Gather all predictions across all devices.
        if cfg.NUM_GPUS > 1:
            loss = all_reduce([loss])[0]
            outputs, labels = all_gather([outputs, labels])

        # Accuracy.
        batch_correct = topks_correct(outputs, labels, (1, ))[0]
        correct += batch_correct.item()
        total += labels.size(0)

        if is_master_proc():
            test_loss += loss.item()
            test_acc = correct / total
            log.info('Loss: %.3f | Acc: %.3f' % (test_loss /
                                                 (batch_idx + 1), test_acc))
Ejemplo n.º 4
0
def train(args, model, device, train_loader, optimizer, epoch):
    model.train()
    lossMeter = ScalarMeter(args.log_interval)
    for batch_idx, (data, target) in enumerate(train_loader):
        #data, target = data.to(device), target.to(device)
        data = data.cuda()
        target = target.cuda()

        optimizer.zero_grad()
        output = model(data)
        loss = F.nll_loss(output, target)
        loss.backward()
        optimizer.step()
        if args.gpus > 1:
            [loss] = du.all_reduce([loss])

        if dist.get_rank() == 0:
            lossMeter.add_value(loss.item())

        if batch_idx % args.log_interval == 0 and dist.get_rank() == 0:
            if args.gpus > 1:
                loss = lossMeter.get_win_median()
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch,
                batch_idx * len(data) * args.gpus, len(train_loader.dataset),
                100. * batch_idx / len(train_loader), loss.item()))
            if args.dry_run:
                break
Ejemplo n.º 5
0
    def _train_one_batch(self, x, y, optimizer, lr_scheduler, meters,
                         criterions, end):
        top1_meter, top5_meter, loss_meter, data_time = meters
        criterion, distill_loss = criterions
        world_size = dist.get_world_size()
        max_width = self.config.training.sandwich.max_width

        lr_scheduler.step(self.cur_step)
        self.cur_step += 1
        data_time.update(time.time() - end)

        self.model.zero_grad()

        max_pred = None
        for idx in range(self.config.training.sandwich.num_sample):
            # sandwich rule
            top1_m, top5_m, loss_m = self._set_width(idx, top1_meter,
                                                     top5_meter, loss_meter)

            out = self.model(x)
            if self.config.training.distillation.enable:
                if idx == 0:
                    max_pred = out.detach()
                    loss = criterion(out, y)
                else:
                    loss = self.config.training.distillation.loss_weight * \
                           distill_loss(out, max_pred)
                    if self.config.training.distillation.hard_label:
                        loss += criterion(out, y)
            else:
                loss = criterion(out, y)
            loss /= world_size

            top1, top5 = accuracy(out, y, top_k=(1, 5))
            reduced_loss = dist.all_reduce(loss.clone())
            reduced_top1 = dist.all_reduce(top1.clone(), div=True)
            reduced_top5 = dist.all_reduce(top5.clone(), div=True)

            loss_m.update(reduced_loss.item())
            top1_m.update(reduced_top1.item())
            top5_m.update(reduced_top5.item())

            loss.backward()

        dist.average_gradient(self.model.parameters())
        optimizer.step()
Ejemplo n.º 6
0
    def _train_one_batch(self, x, y, optimizer, lr_scheduler, meters,
                         criterions, end):
        lr_scheduler, arch_lr_scheduler = lr_scheduler
        optimizer, arch_optimizer = optimizer
        top1_meter, top5_meter, loss_meter, arch_loss_meter, \
            floss_meter, eflops_meter, arch_top1_meter, data_time = meters
        criterion, _ = criterions

        self.model.module.set_alpha_training(False)
        super(DMCPRunner, self)._train_one_batch(
            x, y, optimizer, lr_scheduler,
            [top1_meter, top5_meter, loss_meter, data_time], criterions, end)

        arch_lr_scheduler.step(self.cur_step)
        world_size = dist.get_world_size()

        # train architecture params
        if self.cur_step >= self.config.arch.start_train \
                and self.cur_step % self.config.arch.train_freq == 0:
            self._set_width(0, top1_meter, top5_meter, loss_meter)
            self.model.module.set_alpha_training(True)

            self.model.zero_grad()
            arch_out = self.model(x)
            arch_loss = criterion(arch_out, y)
            arch_loss /= world_size
            floss, eflops = flop_loss(self.config, self.model)
            floss /= world_size

            arch_top1 = accuracy(arch_out, y, top_k=(1, ))[0]
            reduced_arch_loss = dist.all_reduce(arch_loss.clone())
            reduced_floss = dist.all_reduce(floss.clone())
            reduced_eflops = dist.all_reduce(eflops.clone(), div=True)
            reduced_arch_top1 = dist.all_reduce(arch_top1.clone(), div=True)

            arch_loss_meter.update(reduced_arch_loss.item())
            floss_meter.update(reduced_floss.item())
            eflops_meter.update(reduced_eflops.item())
            arch_top1_meter.update(reduced_arch_top1.item())

            floss.backward()
            arch_loss.backward()
            dist.average_gradient(self.model.module.arch_parameters())
            arch_optimizer.step()
Ejemplo n.º 7
0
def train_epoch(train_loader, model, optimizer, epoch, cfg):
    '''Epoch training.

    Args:
      train_loader (DataLoader): training data loader.
      model (model): the video model to train.
      optimizer (optim): the optimizer to perform optimization on the model's parameters.
      epoch (int): current epoch of training.
      cfg (CfgNode): configs. Details can be found in config/defaults.py
    '''
    if is_master_proc():
        log.info('Epoch: %d' % epoch)

    model.train()
    num_batches = len(train_loader)
    train_loss = 0.0
    correct = total = 0.0
    for batch_idx, (inputs, labels) in enumerate(train_loader):
        inputs, labels = inputs.cuda(non_blocking=True), labels.cuda()

        # Update lr.
        lr = get_epoch_lr(cfg, epoch + float(batch_idx) / num_batches)
        set_lr(optimizer, lr)

        # Forward.
        outputs = model(inputs)
        loss = F.cross_entropy(outputs, labels, reduction='mean')

        # Backward.
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # Gather all predictions across all devices.
        if cfg.NUM_GPUS > 1:
            loss = all_reduce([loss])[0]
            outputs, labels = all_gather([outputs, labels])

        # Accuracy.
        batch_correct = topks_correct(outputs, labels, (1, ))[0]
        correct += batch_correct.item()
        total += labels.size(0)

        if is_master_proc():
            train_loss += loss.item()
            train_acc = correct / total
            log.info('Loss: %.3f | Acc: %.3f | LR: %.3f' %
                     (train_loss / (batch_idx + 1), train_acc, lr))
Ejemplo n.º 8
0
 def update_stats(self, info_dict):
     """
     Update the current stats.
     Args:
         psnr (float): psnr
         loss (float): loss value.
         lr (float): learning rate.
         mb_size (int): mini batch size.
     """
     # Current minibatch stats
     if self.infos is None:
         self.init(info_dict.keys())
     # reduce from all gpus
     if self._cfg.NUM_GPUS > 1:
         for k, v in info_dict.items():
             info_dict[k] = du.all_reduce([v])
     # syncronize from gpu to cpu
     info_dict = {k: v.item() for k, v in info_dict.items()}
     # log value into scalar meter
     for k, v in info_dict.items():
         self.infos[k].add_value(v)
Ejemplo n.º 9
0
def eval_epoch(val_loader, model, val_meter, cur_epoch, cfg):
    model.eval()
    val_meter.iter_tic()

    for cur_step, (inputs, labels, _) in enumerate(val_loader):
        # Transfer the data to the current GPU device.
        if isinstance(inputs, (list,)):
            for i in range(len(inputs)):
                inputs[i] = inputs[i].cuda(non_blocking=True)
        else:
            inputs = inputs.cuda(non_blocking=True)
        labels = labels.cuda()

        preds = model(inputs)
        if cfg.DATA.MULTI_LABEL:
            if cfg.NUM_GPUS > 1:
                preds, labels = du.all_gather([preds, labels])
            val_meter.iter_toc()
            val_meter.update_predictions(preds, labels)
        else:
            top1_err, top5_err = metrics.topk_errors(preds, labels, (1, 5))
            if cfg.NUM_GPUS > 1:
                top1_err, top5_err = du.all_reduce([top1_err, top5_err])
            top1_err, top5_err = top1_err.item(), top5_err.item()

            val_meter.iter_toc()
            val_meter.update_stats(
                top1_err, top5_err, labels.size(0) * cfg.NUM_GPUS
            )
        val_meter.log_iter_stats(cur_epoch, cur_step)
        val_meter.iter_tic()

    stats = val_meter.log_epoch_stats(cur_epoch)
    val_meter.reset()

    return stats
Ejemplo n.º 10
0
def train_epoch(
    train_loader, model, optimizer, train_meter, cur_epoch, cfg, writer=None
):
    """
    Perform the video training for one epoch.
    Args:
        train_loader (loader): video training loader.
        model (model): the video model to train.
        optimizer (optim): the optimizer to perform optimization on the model's
            parameters.
        train_meter (TrainMeter): training meters to log the training performance.
        cur_epoch (int): current epoch of training.
        cfg (CfgNode): configs. Details can be found in
            slowfast/config/defaults.py
        writer (TensorboardWriter, optional): TensorboardWriter object
            to writer Tensorboard log.
    """
    # Enable train mode.
    model.train()
    train_meter.iter_tic()
    data_size = len(train_loader)

    for cur_iter, (inputs, labels, _, meta) in enumerate(train_loader):
        # Transfer the data to the current GPU device.
        if isinstance(inputs, (list,)):
            for i in range(len(inputs)):
                inputs[i] = inputs[i].cuda(non_blocking=True)
        else:
            inputs = inputs.cuda(non_blocking=True)
        labels = labels.cuda()
        for key, val in meta.items():
            if isinstance(val, (list,)):
                for i in range(len(val)):
                    val[i] = val[i].cuda(non_blocking=True)
            else:
                meta[key] = val.cuda(non_blocking=True)

        # Update the learning rate.
        lr = optim.get_epoch_lr(cur_epoch + float(cur_iter) / data_size, cfg)
        optim.set_lr(optimizer, lr)

        if cfg.DETECTION.ENABLE:
            # Compute the predictions.
            preds = model(inputs, meta["boxes"])

        else:
            # Perform the forward pass.
            preds = model(inputs)
        # Explicitly declare reduction to mean.
        loss_fun = losses.get_loss_func(cfg.MODEL.LOSS_FUNC)(reduction="mean")

        # Compute the loss.
        loss = loss_fun(preds, labels)

        # check Nan Loss.
        misc.check_nan_losses(loss)

        # Perform the backward pass.
        optimizer.zero_grad()
        loss.backward()
        # Update the parameters.
        optimizer.step()

        if cfg.DETECTION.ENABLE:
            if cfg.NUM_GPUS > 1:
                loss = du.all_reduce([loss])[0]
            loss = loss.item()

            train_meter.iter_toc()
            # Update and log stats.
            train_meter.update_stats(None, None, None, loss, lr)
            # write to tensorboard format if available.
            if writer is not None:
                writer.add_scalars(
                    {"Train/loss": loss, "Train/lr": lr},
                    global_step=data_size * cur_epoch + cur_iter,
                )

        else:
            top1_err, top5_err = None, None
            if cfg.DATA.MULTI_LABEL:
                # Gather all the predictions across all the devices.
                if cfg.NUM_GPUS > 1:
                    [loss] = du.all_reduce([loss])
                loss = loss.item()
            else:
                # Compute the errors.
                num_topks_correct = metrics.topks_correct(
                    preds, labels, (1, 5))
                top1_err, top5_err = [
                    (1.0 - x / preds.size(0)) * 100.0 for x in num_topks_correct
                ]

                # Gather all the predictions across all the devices.
                if cfg.NUM_GPUS > 1:
                    loss, top1_err, top5_err = du.all_reduce(
                        [loss, top1_err, top5_err]
                    )

                # Copy the stats from GPU to CPU (sync point).
                loss, top1_err, top5_err = (
                    loss.item(),
                    top1_err.item(),
                    top5_err.item(),
                )

            train_meter.iter_toc()
            # Update and log stats.
            train_meter.update_stats(
                top1_err, top5_err, loss, lr, inputs[0].size(0) * cfg.NUM_GPUS
            )
            # write to tensorboard format if available.
            if writer is not None:
                writer.add_scalars(
                    {
                        "Train/loss": loss,
                        "Train/lr": lr,
                        "Train/Top1_err": top1_err,
                        "Train/Top5_err": top5_err,
                    },
                    global_step=data_size * cur_epoch + cur_iter,
                )

        train_meter.log_iter_stats(cur_epoch, cur_iter)
        train_meter.iter_tic()

    # Log epoch stats.
    train_meter.log_epoch_stats(cur_epoch)
    train_meter.reset()
Ejemplo n.º 11
0
def eval_epoch(val_loader, model, val_meter, cur_epoch, cfg, writer=None):
    """
    Evaluate the model on the val set.
    Args:
        val_loader (loader): data loader to provide validation data.
        model (model): model to evaluate the performance.
        val_meter (ValMeter): meter instance to record and calculate the metrics.
        cur_epoch (int): number of the current epoch of training.
        cfg (CfgNode): configs. Details can be found in
            slowfast/config/defaults.py
        writer (TensorboardWriter, optional): TensorboardWriter object
            to writer Tensorboard log.
    """

    # Evaluation mode enabled. The running stats would not be updated.
    model.eval()
    val_meter.iter_tic()

    for cur_iter, (inputs, labels, _, meta) in enumerate(val_loader):
        # Transferthe data to the current GPU device.
        if isinstance(inputs, (list,)):
            for i in range(len(inputs)):
                inputs[i] = inputs[i].cuda(non_blocking=True)
        else:
            inputs = inputs.cuda(non_blocking=True)
        labels = labels.cuda()
        for key, val in meta.items():
            if isinstance(val, (list,)):
                for i in range(len(val)):
                    val[i] = val[i].cuda(non_blocking=True)
            else:
                meta[key] = val.cuda(non_blocking=True)

        if cfg.DETECTION.ENABLE:
            # Compute the predictions.
            preds = model(inputs, meta["boxes"])

            preds = preds.cpu()
            ori_boxes = meta["ori_boxes"].cpu()
            metadata = meta["metadata"].cpu()

            if cfg.NUM_GPUS > 1:
                preds = torch.cat(du.all_gather_unaligned(preds), dim=0)
                ori_boxes = torch.cat(
                    du.all_gather_unaligned(ori_boxes), dim=0)
                metadata = torch.cat(du.all_gather_unaligned(metadata), dim=0)

            val_meter.iter_toc()
            # Update and log stats.
            val_meter.update_stats(
                preds.cpu(), ori_boxes.cpu(), metadata.cpu())

        else:
            preds = model(inputs)

            if cfg.DATA.MULTI_LABEL:
                if cfg.NUM_GPUS > 1:
                    preds, labels = du.all_gather([preds, labels])
            else:
                # Compute the errors.
                num_topks_correct = metrics.topks_correct(
                    preds, labels, (1, 5))

                # Combine the errors across the GPUs.
                top1_err, top5_err = [
                    (1.0 - x / preds.size(0)) * 100.0 for x in num_topks_correct
                ]
                if cfg.NUM_GPUS > 1:
                    top1_err, top5_err = du.all_reduce([top1_err, top5_err])

                # Copy the errors from GPU to CPU (sync point).
                top1_err, top5_err = top1_err.item(), top5_err.item()

                val_meter.iter_toc()
                # Update and log stats.
                val_meter.update_stats(
                    top1_err, top5_err, inputs[0].size(0) * cfg.NUM_GPUS
                )
                # write to tensorboard format if available.
                if writer is not None:
                    writer.add_scalars(
                        {"Val/Top1_err": top1_err, "Val/Top5_err": top5_err},
                        global_step=len(val_loader) * cur_epoch + cur_iter,
                    )

            val_meter.update_predictions(preds, labels)

        val_meter.log_iter_stats(cur_epoch, cur_iter)
        val_meter.iter_tic()

    # Log epoch stats.
    val_meter.log_epoch_stats(cur_epoch)
    # write to tensorboard format if available.
    if writer is not None:
        if cfg.DETECTION.ENABLE:
            writer.add_scalars(
                {"Val/mAP": val_meter.full_map}, global_step=cur_epoch
            )
        all_preds_cpu = [pred.clone().detach().cpu()
                         for pred in val_meter.all_preds]
        all_labels_cpu = [label.clone().detach().cpu()
                          for label in val_meter.all_labels]
        writer.plot_eval(
            preds=all_preds_cpu,
            labels=all_labels_cpu,
            global_step=cur_epoch,
        )

    val_meter.reset()
Ejemplo n.º 12
0
def train_epoch(
    train_loader,
    model,
    optimizer,
    train_meter,
    cur_epoch,
    global_step,
    num_steps,
    cfg
):
    model.train()
    train_meter.iter_tic()

    data_size = len(train_loader) / cfg.SOLVER.GRADIENT_ACCUMULATION_STEPS

    epoch_step = 0
    _global_step = global_step // cfg.SOLVER.GRADIENT_ACCUMULATION_STEPS
    lr = optim.get_epoch_lr(cur_epoch + float(epoch_step) / data_size, _global_step, cfg)
    for cur_step, (inputs, labels, _) in enumerate(train_loader):
        global_step += 1
        # Transfer the data to the current GPU device.
        if isinstance(inputs, (list,)):
            for i in range(len(inputs)):
                inputs[i] = inputs[i].cuda(non_blocking=True)
        else:
            inputs = inputs.cuda(non_blocking=True)
        labels = labels.cuda()

        preds = model(inputs)

        loss_fun = losses.get_loss_func(cfg.MODEL.LOSS_FUNC)(reduction='mean')

        loss = loss_fun(preds, labels)

        if cfg.SOLVER.GRADIENT_ACCUMULATION_STEPS > 1:
            loss = loss / cfg.SOLVER.GRADIENT_ACCUMULATION_STEPS

        # Check Nan Loss.
        misc.check_nan_losses(loss)

        loss.backward()

        if global_step % cfg.SOLVER.GRADIENT_ACCUMULATION_STEPS == 0:
            epoch_step += 1
            _global_step = global_step // cfg.SOLVER.GRADIENT_ACCUMULATION_STEPS
            lr = optim.get_epoch_lr(cur_epoch + float(epoch_step) / data_size, _global_step, cfg)
            optim.set_lr(optimizer, lr)
            if cfg.SOLVER.GRADIENT_CLIPPING:
                torch.nn.utils.clip_grad_norm_(
                    model.parameters(), cfg.SOLVER.MAX_GRAD_NORM
                )
            optimizer.step()
            optimizer.zero_grad()

        if cfg.DATA.MULTI_LABEL:
            if cfg.NUM_GPUS > 1:
                [loss] = du.all_reduce([loss])

            loss = loss.item()
            top1_err, top5_err = None, None
        else:
            top1_err, top5_err = metrics.topk_errors(preds, labels, (1, 5))

            # Gather all the predictions across all the devices.
            if cfg.NUM_GPUS > 1:
                loss, top1_err, top5_err = du.all_reduce([loss, top1_err, top5_err])

            loss, top1_err, top5_err = (
                loss.item(),
                top1_err.item(),
                top5_err.item(),
            )

        train_meter.iter_toc()
        # Update and log stats.
        train_meter.update_stats(
            top1_err,
            top5_err,
            loss,
            lr,
            labels.size(0) * cfg.NUM_GPUS
        )

        train_meter.log_iter_stats(cur_epoch, cur_step, global_step)

        if global_step == num_steps and (cur_step + 1) != len(train_loader):
            return global_step

        train_meter.iter_tic()

    # Log epoch stats.
    train_meter.log_epoch_stats(cur_epoch)
    train_meter.reset()

    return global_step
Ejemplo n.º 13
0
    def validate(self,
                 val_loader,
                 train_loader=None,
                 val_width=None,
                 tb_logger=None):
        assert train_loader is not None
        assert val_width is not None

        batch_time = AverageMeter(0)
        loss_meter = [AverageMeter(0) for _ in range(len(val_width))]
        top1_meter = [AverageMeter(0) for _ in range(len(val_width))]
        top5_meter = [AverageMeter(0) for _ in range(len(val_width))]
        val_loss, val_top1, val_top5 = [], [], []

        # switch to evaluate mode
        self.model.eval()

        criterion = nn.CrossEntropyLoss()
        end = time.time()

        with torch.no_grad():
            for idx, width in enumerate(val_width):
                top1_m, top5_m, loss_m = self._set_width(idx,
                                                         top1_meter,
                                                         top5_meter,
                                                         loss_meter,
                                                         width=width)

                self._info('-' * 80)
                self._info('Evaluating [{}/{}]@{}'.format(
                    idx + 1, len(val_width), width))

                self.calibrate(train_loader)
                for j, (x, y) in enumerate(val_loader):
                    x, y = x.cuda(), y.cuda()
                    num = x.size(0)

                    out = self.model(x)
                    loss = criterion(out, y)
                    top1, top5 = accuracy(out.data, y, top_k=(1, 5))

                    loss_m.update(loss.item(), num)
                    top1_m.update(top1.item(), num)
                    top5_m.update(top5.item(), num)

                    # measure elapsed time
                    batch_time.update(time.time() - end)
                    end = time.time()

                    if j % self.config.logging.print_freq == 0:
                        self._info(
                            'Test: [{0}/{1}]\tTime {batch_time.val:.3f} ({batch_time.avg:.3f})'
                            .format(j, len(val_loader), batch_time=batch_time))

                total_num = torch.tensor([loss_m.count]).cuda()
                loss_sum = torch.tensor([loss_m.avg * loss_m.count]).cuda()
                top1_sum = torch.tensor([top1_m.avg * top1_m.count]).cuda()
                top5_sum = torch.tensor([top5_m.avg * top5_m.count]).cuda()

                dist.all_reduce(total_num)
                dist.all_reduce(loss_sum)
                dist.all_reduce(top1_sum)
                dist.all_reduce(top5_sum)

                val_loss.append(loss_sum.item() / total_num.item())
                val_top1.append(top1_sum.item() / total_num.item())
                val_top5.append(top5_sum.item() / total_num.item())

                self._info(
                    'Prec@1 {:.3f}\tPrec@5 {:.3f}\tLoss {:.3f}\ttotal_num={}'.
                    format(val_top1[-1], val_top5[-1], val_loss[-1],
                           loss_m.count))

            if dist.is_master() and tb_logger is not None:
                for i in range(len(val_loss)):
                    tb_logger.add_scalar('loss_val@{}'.format(val_width[i]),
                                         val_loss[i], self.cur_step)
                    tb_logger.add_scalar('acc1_val@{}'.format(val_width[i]),
                                         val_top1[i], self.cur_step)
                    tb_logger.add_scalar('acc5_val@{}'.format(val_width[i]),
                                         val_top5[i], self.cur_step)