Exemple #1
0
    def _kmeans_finetune(self,
                         train_loader,
                         model,
                         idx,
                         centroid_label_dict,
                         epochs=1,
                         verbose=True):
        batch_time = AverageMeter()
        data_time = AverageMeter()
        losses = AverageMeter()
        top1 = AverageMeter()
        top5 = AverageMeter()
        best_acc = 0.

        # switch to train mode
        model.train()
        end = time.time()
        t1 = time.time()
        bar = Bar('train:', max=len(train_loader))
        for epoch in range(epochs):
            for i, (inputs, targets) in enumerate(train_loader):
                input_var, target_var = inputs.cuda(), targets.cuda()

                # measure data loading time
                data_time.update(time.time() - end)

                # compute output
                output = model(input_var)
                loss = self.criterion(output, target_var)

                # measure accuracy and record loss
                prec1, prec5 = accuracy(output.data, target_var, topk=(1, 5))
                losses.update(loss.item(), inputs.size(0))
                top1.update(prec1.item(), inputs.size(0))
                top5.update(prec5.item(), inputs.size(0))

                # compute gradient
                self.optimizer.zero_grad()
                loss.backward()

                # do SGD step
                self.optimizer.step()

                kmeans_update_model(model,
                                    self.quantizable_idx,
                                    centroid_label_dict,
                                    free_high_bit=True)

                # measure elapsed time
                batch_time.update(time.time() - end)
                end = time.time()

                # plot progress
                if i % 1 == 0:
                    bar.suffix = \
                        '({batch}/{size}) Data: {data:.3f}s | Batch: {bt:.3f}s | Total: {total:} | ETA: {eta:} | ' \
                        'Loss: {loss:.4f} | top1: {top1: .4f} | top5: {top5: .4f}'.format(
                            batch=i + 1,
                            size=len(train_loader),
                            data=data_time.val,
                            bt=batch_time.val,
                            total=bar.elapsed_td,
                            eta=bar.eta_td,
                            loss=losses.avg,
                            top1=top1.avg,
                            top5=top5.avg,
                        )
                    bar.next()
            bar.finish()

            if self.use_top5:
                if top5.avg > best_acc:
                    best_acc = top5.avg
            else:
                if top1.avg > best_acc:
                    best_acc = top1.avg
            self.adjust_learning_rate()
        t2 = time.time()
        if verbose:
            print('* Test loss: %.3f  top1: %.3f  top5: %.3f  time: %.3f' %
                  (losses.avg, top1.avg, top5.avg, t2 - t1))
        return best_acc
Exemple #2
0
def train(train_loader, model, criterion, optimizer, epoch, use_cuda):
    # switch to train mode
    model.train()

    batch_time = AverageMeter()
    data_time = AverageMeter()
    losses = AverageMeter()
    top1 = AverageMeter()
    top5 = AverageMeter()
    end = time.time()

    bar = Bar('Processing', max=len(train_loader))
    for batch_idx, (inputs, targets) in enumerate(train_loader):
        # measure data loading time
        data_time.update(time.time() - end)

        if use_cuda:
            inputs, targets = inputs.cuda(), targets.cuda()
        inputs, targets = torch.autograd.Variable(
            inputs), torch.autograd.Variable(targets)

        # compute output
        outputs = model(inputs)
        loss = criterion(outputs, targets)

        # measure accuracy and record loss
        prec1, prec5 = accuracy(outputs.data, targets.data, topk=(1, 5))
        losses.update(loss.item(), inputs.size(0))
        top1.update(prec1.item(), inputs.size(0))
        top5.update(prec5.item(), inputs.size(0))

        # compute gradient
        optimizer.zero_grad()
        if args.half:
            with apex.amp.scale_loss(loss, optimizer) as scaled_loss:
                scaled_loss.backward()
            # with amp_handle.scale_loss(loss, optimizer) as scaled_loss:
            #     scaled_loss.backward()
        else:
            loss.backward()
        # do SGD step
        optimizer.step()

        if not args.linear_quantization:
            kmeans_update_model(model,
                                quantizable_idx,
                                centroid_label_dict,
                                free_high_bit=args.free_high_bit)

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        # plot progress
        if batch_idx % 1 == 0:
            bar.suffix = \
                '({batch}/{size}) Data: {data:.3f}s | Batch: {bt:.3f}s | Total: {total:} | ETA: {eta:} | ' \
                'Loss: {loss:.4f} | top1: {top1: .4f} | top5: {top5: .4f}'.format(
                    batch=batch_idx + 1,
                    size=len(train_loader),
                    data=data_time.val,
                    bt=batch_time.val,
                    total=bar.elapsed_td,
                    eta=bar.eta_td,
                    loss=losses.avg,
                    top1=top1.avg,
                    top5=top5.avg,
                )
            bar.next()
    bar.finish()
    return losses.avg, top1.avg
Exemple #3
0
def test(val_loader, model, criterion, epoch, use_cuda):
    global best_acc

    batch_time = AverageMeter()
    data_time = AverageMeter()
    losses = AverageMeter()
    top1 = AverageMeter()
    top5 = AverageMeter()

    with torch.no_grad():
        # switch to evaluate mode
        model.eval()

        end = time.time()
        bar = Bar('Processing', max=len(val_loader))
        for batch_idx, (inputs, targets) in enumerate(val_loader):
            # measure data loading time
            data_time.update(time.time() - end)

            if use_cuda:
                inputs, targets = inputs.cuda(), targets.cuda()
            inputs, targets = torch.autograd.Variable(
                inputs, volatile=True), torch.autograd.Variable(targets)

            # compute output
            outputs = model(inputs)
            loss = criterion(outputs, targets)

            # measure accuracy and record loss
            prec1, prec5 = accuracy(outputs.data, targets.data, topk=(1, 5))
            losses.update(loss.item(), inputs.size(0))
            top1.update(prec1.item(), inputs.size(0))
            top5.update(prec5.item(), inputs.size(0))

            # measure elapsed time
            batch_time.update(time.time() - end)
            end = time.time()

            # plot progress
            if batch_idx % 1 == 0:
                bar.suffix  = \
                    '({batch}/{size}) Data: {data:.3f}s | Batch: {bt:.3f}s | Total: {total:} | ETA: {eta:} | ' \
                    'Loss: {loss:.4f} | top1: {top1: .4f} | top5: {top5: .4f}'.format(
                        batch=batch_idx + 1,
                        size=len(val_loader),
                        data=data_time.avg,
                        bt=batch_time.avg,
                        total=bar.elapsed_td,
                        eta=bar.eta_td,
                        loss=losses.avg,
                        top1=top1.avg,
                        top5=top5.avg,
                        )
                bar.next()
        bar.finish()
    return losses.avg, top1.avg
Exemple #4
0
    def train(self):
        # Single epoch training routine

        losses = AverageMeter()

        timer = {
            'data': 0,
            'forward': 0,
            'loss': 0,
            'backward': 0,
            'batch': 0,
        }

        self.generator.train()
        self.motion_discriminator.train()

        start = time.time()

        summary_string = ''

        bar = Bar(f'Epoch {self.epoch + 1}/{self.end_epoch}',
                  fill='#',
                  max=self.num_iters_per_epoch)

        for i in range(self.num_iters_per_epoch):
            # Dirty solution to reset an iterator
            target_2d = target_3d = None
            if self.train_2d_iter:
                try:
                    target_2d = next(self.train_2d_iter)
                except StopIteration:
                    self.train_2d_iter = iter(self.train_2d_loader)
                    target_2d = next(self.train_2d_iter)

                move_dict_to_device(target_2d, self.device)

            if self.train_3d_iter:
                try:
                    target_3d = next(self.train_3d_iter)
                except StopIteration:
                    self.train_3d_iter = iter(self.train_3d_loader)
                    target_3d = next(self.train_3d_iter)

                move_dict_to_device(target_3d, self.device)

            real_body_samples = real_motion_samples = None

            try:
                real_motion_samples = next(self.disc_motion_iter)
            except StopIteration:
                self.disc_motion_iter = iter(self.disc_motion_loader)
                real_motion_samples = next(self.disc_motion_iter)

            move_dict_to_device(real_motion_samples, self.device)

            # <======= Feedforward generator and discriminator
            if target_2d and target_3d:
                inp = torch.cat((target_2d['features'], target_3d['features']),
                                dim=0).to(self.device)
            elif target_3d:
                inp = target_3d['features'].to(self.device)
            else:
                inp = target_2d['features'].to(self.device)

            timer['data'] = time.time() - start
            start = time.time()

            preds = self.generator(inp)

            timer['forward'] = time.time() - start
            start = time.time()

            gen_loss, motion_dis_loss, loss_dict = self.criterion(
                generator_outputs=preds,
                data_2d=target_2d,
                data_3d=target_3d,
                data_body_mosh=real_body_samples,
                data_motion_mosh=real_motion_samples,
                motion_discriminator=self.motion_discriminator,
            )
            # =======>

            timer['loss'] = time.time() - start
            start = time.time()

            # <======= Backprop generator and discriminator
            self.gen_optimizer.zero_grad()
            gen_loss.backward()
            self.gen_optimizer.step()

            if self.train_global_step % self.dis_motion_update_steps == 0:
                self.dis_motion_optimizer.zero_grad()
                motion_dis_loss.backward()
                self.dis_motion_optimizer.step()
            # =======>

            # <======= Log training info
            total_loss = gen_loss + motion_dis_loss

            losses.update(total_loss.item(), inp.size(0))

            timer['backward'] = time.time() - start
            timer['batch'] = timer['data'] + timer['forward'] + timer[
                'loss'] + timer['backward']
            start = time.time()

            summary_string = f'({i + 1}/{self.num_iters_per_epoch}) | Total: {bar.elapsed_td} | ' \
                             f'ETA: {bar.eta_td:} | loss: {losses.avg:.4f}'

            for k, v in loss_dict.items():
                summary_string += f' | {k}: {v:.2f}'
                self.writer.add_scalar('train_loss/' + k,
                                       v,
                                       global_step=self.train_global_step)

            for k, v in timer.items():
                summary_string += f' | {k}: {v:.2f}'

            self.writer.add_scalar('train_loss/loss',
                                   total_loss.item(),
                                   global_step=self.train_global_step)

            if self.debug:
                print('==== Visualize ====')
                from lib.utils.vis import batch_visualize_vid_preds
                video = target_3d['video']
                dataset = 'spin'
                vid_tensor = batch_visualize_vid_preds(video,
                                                       preds[-1],
                                                       target_3d.copy(),
                                                       vis_hmr=False,
                                                       dataset=dataset)
                self.writer.add_video('train-video',
                                      vid_tensor,
                                      global_step=self.train_global_step,
                                      fps=10)

            self.train_global_step += 1
            bar.suffix = summary_string
            bar.next()

            if torch.isnan(total_loss):
                exit('Nan value in loss, exiting!...')
            # =======>

        bar.finish()

        logger.info(summary_string)
Exemple #5
0
    def run_epoch(self, phase, epoch, data_loader):
        model_with_loss = self.model_with_loss
        if phase == 'train':
            model_with_loss.train()
        else:
            if len(self.opt.gpus) > 1:
                model_with_loss = self.model_with_loss.module
            model_with_loss.eval()
            torch.cuda.empty_cache()

        opt = self.opt
        results = {}
        data_time, batch_time = AverageMeter(), AverageMeter()
        avg_loss_stats = {l: AverageMeter() for l in self.loss_stats}
        num_iters = len(data_loader) if opt.num_iters < 0 else opt.num_iters
        bar = Bar('{}/{}'.format(opt.task, opt.exp_id), max=num_iters)
        end = time.time()
        loss_temp = 0
        for iter_id, batch in enumerate(data_loader):
            if iter_id >= num_iters:
                break
            data_time.update(time.time() - end)

            for k in batch:
                if k != 'meta':
                    batch[k] = batch[k].to(device=opt.device,
                                           non_blocking=True)

            output, loss, loss_stats = model_with_loss(batch)
            loss = loss.mean()
            loss_temp += loss
            if phase == 'train' and iter_id % 1 == 0:
                self.optimizer.zero_grad()
                loss_temp.backward()
                self.optimizer.step()
                loss_temp = 0
            batch_time.update(time.time() - end)
            end = time.time()

            Bar.suffix = '\r{phase}: [{0}][{1}/{2}]|Tot: {total:} |ETA: {eta:} '.format(
                epoch,
                iter_id,
                num_iters,
                phase=phase,
                total=bar.elapsed_td,
                eta=bar.eta_td)
            for l in avg_loss_stats:
                avg_loss_stats[l].update(loss_stats[l].mean().item(),
                                         batch['input'].size(0))
                Bar.suffix = Bar.suffix + '|{} {:.4f} '.format(
                    l, avg_loss_stats[l].avg)
            if not opt.hide_data_time:
                Bar.suffix = Bar.suffix + '|Data {dt.val:.3f}s({dt.avg:.3f}s) ' \
                                          '|Net {bt.avg:.3f}s'.format(dt=data_time, bt=batch_time)
            if opt.print_iter > 0:
                if iter_id % opt.print_iter == 0:
                    print('\r{}/{}| {}'.format(opt.task, opt.exp_id,
                                               Bar.suffix),
                          end="",
                          flush=True)
            else:
                bar.next()

            if opt.test:
                self.save_result(output, batch, results)
            del output, loss, loss_stats, batch

        bar.finish()
        ret = {k: v.avg for k, v in avg_loss_stats.items()}
        ret['time'] = bar.elapsed_td.total_seconds() / 60.
        return ret, results
    def run_epoch(self, phase, epoch, data_loader):
        """
        :param phase:
        :param epoch:
        :param data_loader:
        :return:
        """
        model_with_loss = self.model_with_loss

        if phase == 'train':
            model_with_loss.train()  # train phase
        else:
            if len(self.opt.gpus) > 1:
                model_with_loss = self.model_with_loss.module

            model_with_loss.eval()  # test phase
            torch.cuda.empty_cache()

        opt = self.opt
        results = {}
        data_time, batch_time = AverageMeter(), AverageMeter()
        avg_loss_stats = {l: AverageMeter() for l in self.loss_stats}
        num_iters = len(data_loader) if opt.num_iters < 0 else opt.num_iters
        bar = Bar('{}/{}'.format(opt.task, opt.exp_id), max=num_iters)
        end = time.time()

        # train each batch
        # print('Total {} batches in en epoch.'.format(len(data_loader) + 1))
        for batch_i, batch in enumerate(data_loader):
            if batch_i >= num_iters:
                break

            data_time.update(time.time() - end)

            for k in batch:
                if k != 'meta':
                    batch[k] = batch[k].to(device=opt.device,
                                           non_blocking=True)

            # Forward
            output, loss, loss_stats = model_with_loss.forward(batch)

            # Backwards
            loss = loss.mean()
            if phase == 'train':
                self.optimizer.zero_grad()  # 优化器梯度清零
                loss.backward()  # 梯度反传
                self.optimizer.step()  # 优化器依据反传的梯度, 更新网络权重

            batch_time.update(time.time() - end)
            end = time.time()

            Bar.suffix = '{phase}: [{0}][{1}/{2}]|Tot: {total:} |ETA: {eta:} '.format(
                epoch,
                batch_i,
                num_iters,
                phase=phase,
                total=bar.elapsed_td,
                eta=bar.eta_td)
            for l in avg_loss_stats:
                try:
                    avg_loss_stats[l].update(loss_stats[l].mean().item(),
                                             batch['input'].size(0))
                except:
                    print(
                        "\n>>BUG loss_stats base_traimer.py float instead of narray NC UPDATE: {} \n"
                        .format(loss_stats[l]))
                    pass
                Bar.suffix = Bar.suffix + '|{} {:.4f} '.format(
                    l, avg_loss_stats[l].avg)

            # multi-scale img_size display
            scale_idx = data_loader.dataset.batch_i_to_scale_i[batch_i]
            if data_loader.dataset.input_multi_scales is None:
                img_size = Input_WHs[scale_idx]
            else:
                img_size = data_loader.dataset.input_multi_scales[scale_idx]
            Bar.suffix = Bar.suffix + '|Img_size(wh) {:d}×{:d}'.format(
                img_size[0], img_size[1])

            if not opt.hide_data_time:
                Bar.suffix = Bar.suffix + '|Data {dt.val:.3f}s({dt.avg:.3f}s) ' \
                                          '|Net {bt.avg:.3f}s'.format(dt=data_time, bt=batch_time)
            if opt.print_iter > 0:
                if batch_i % opt.print_iter == 0:
                    print('{}/{}| {}'.format(opt.task, opt.exp_id, Bar.suffix))
            else:
                bar.next()

            if opt.test:
                self.save_result(output, batch, results)
            del output, loss, loss_stats, batch

        # randomly do multi-scaling for dataset every epoch
        data_loader.dataset.rand_scale()  # re-assign scale for each batch

        # shuffule the dataset every epoch
        data_loader.dataset.shuffle()  # re-assign file id for each idx

        bar.finish()
        ret = {k: v.avg for k, v in avg_loss_stats.items()}
        ret['time'] = bar.elapsed_td.total_seconds() / 60.0

        return ret, results
Exemple #7
0
def train(cfg, model, train_loader, epoch):
    # set the AverageMeter
    Batch_time = AverageMeter()
    Eval_time = AverageMeter()
    dice = AverageMeter()
    loss_1 = AverageMeter()
    loss_2 = AverageMeter()
    loss_3 = AverageMeter()
    losses = AverageMeter()
    dice = AverageMeter()
    torch.set_grad_enabled(True)
    # switch to train mode
    model.train()
    start = time.time()
    # loop in dataloader
    for iter, (img, gt_mid, gt_mid_3, gt_x_coords, pos_weight, wce_weight,
               heatmap) in enumerate(train_loader):
        curr_iter = iter + 1 + epoch * len(train_loader)
        total_iter = cfg.training_setting_param.epochs * len(train_loader)
        lr = adjust_learning_rate(cfg, cfg.training_setting_param, epoch,
                                  curr_iter, total_iter)

        input_var = Variable(img).cuda()
        gt_mid_3 = Variable(gt_mid_3).cuda()
        gt_x_coords = Variable(gt_x_coords).cuda()
        pos_weight = Variable(pos_weight).cuda()
        wce_weight = Variable(wce_weight).cuda()
        heatmap = Variable(heatmap).cuda()
        gt_mid_3[gt_mid_3 != 0] = 1
        # forward
        pred = model(input_var)
        # loss
        targets = (gt_x_coords, gt_mid_3, pos_weight, wce_weight, heatmap)
        loss, loss1, loss2, loss3 = get_total_loss(cfg, pred, targets, epoch)
        loss_1.update(loss1.item(), input_var.size(0))
        loss_2.update(loss2.item(), input_var.size(0))
        loss_3.update(loss3.item(), input_var.size(0))
        losses.update(loss.item(), input_var.size(0))
        # mid
        end = time.time()
        # metrics: dice for ven
        Eval_time.update(time.time() - end)
        cfg.optimizer.zero_grad()
        loss.backward()
        cfg.optimizer.step()
        # measure elapsed time
        Batch_time.update(time.time() - start)
        start = time.time()

        if iter % cfg.print_freq == 0:
            logger_vis.info(
                'Epoch-Iter: [{0}/{1}]-[{2}/{3}]\t'
                'LR: {4:.6f}\t'
                'Batch_Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
                'Eval_Time {eval_time.val:.3f} ({eval_time.avg:.3f})\t'
                'Loss1 {loss_1.val:.3f} ({loss_1.avg:.3f})\t'
                'Loss2 {loss_2.val:.3f} ({loss_2.avg:.3f})\t'
                'Loss3 {loss_3.val:.3f} ({loss_3.avg:.3f})\t'
                'Loss {loss.val:.3f} ({loss.avg:.3f})\t'.format(
                    epoch,
                    cfg.training_setting_param.epochs,
                    iter,
                    len(train_loader),
                    lr,
                    batch_time=Batch_time,
                    eval_time=Eval_time,
                    loss_1=loss_1,
                    loss_2=loss_2,
                    loss_3=loss_3,
                    loss=losses))

    return losses.avg, dice.avg
Exemple #8
0
def test(cfg, model, test_data_loader, logger):
    model.eval()
    torch.set_grad_enabled(False)
    # Indicator to log
    batch_time = AverageMeter()
    shift_list = cfg._get_shift_list()
    metric_class = Metric()

    # The method to predict midline: Left, Right, Max, DP(dynamic programming)

    end = time.time()
    print('# ===== TEST ===== #')
    for step, (input, target, _, gt_curves, ori_img,
               path) in enumerate(test_data_loader):
        pid = path[0].split('/')[-2]
        target[target != 0] = 1
        gt_midline = target.numpy().astype(np.uint8)
        img = ori_img.numpy().astype(np.uint8)
        # Variable
        input_var = Variable(input).cuda()
        target_var = target.cuda()

        with torch.no_grad():
            # forward
            pred = model(input_var)
            pred_midline_real_limit, pred_midline_gt_limit = pred2midline(
                pred, gt_curves, model_name=cfg.model_param.model_name)
            metric_class.process(pred_midline_real_limit,
                                 pred_midline_gt_limit, gt_midline)
            batch_time.update(time.time() - end)

        # save vis png
        if cfg.test_setting_param.save_vis:
            pid_save_name = pid + '.shift' if pid in shift_list else pid
            save_dir_width_1 = osp.join('{}.{}'.format(cfg.vis_dir, '1'),
                                        pid_save_name)
            save_dir_width_3 = osp.join('{}.{}'.format(cfg.vis_dir, '3'),
                                        pid_save_name)
            ensure_dir(save_dir_width_1)
            ensure_dir(save_dir_width_3)
            vis_boundary(img,
                         gt_midline,
                         pred_midline_real_limit,
                         save_dir_width_1,
                         is_background=True)
            vis_boundary(img,
                         midline_expand(gt_midline),
                         midline_expand(pred_midline_real_limit),
                         save_dir_width_3,
                         is_background=True)
            print('save vis, finished!')

        end = time.time()
        logger_vis.info(
            'Eval: [{0}/{1}]\t'
            'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'.format(
                step, len(test_data_loader), batch_time=batch_time))

    final_dict = metric_class.get_final_dict()
    assd_mean = final_dict['assd'][0]
    logger.write('\n# ===== final stat ===== #')
    for key, value in final_dict.items():
        logger.write('{}: {}({})'.format(key, value[0], value[1]))

    return -assd_mean
Exemple #9
0
    def _validate(self, val_loader, model, verbose=False):
        batch_time = AverageMeter()
        data_time = AverageMeter()
        losses = AverageMeter()
        top1 = AverageMeter()
        top5 = AverageMeter()

        t1 = time.time()
        with torch.no_grad():
            # switch to evaluate mode
            model.eval()

            end = time.time()
            bar = Bar('valid:', max=len(val_loader))
            for i, (inputs, targets) in enumerate(val_loader):
                # measure data loading time
                data_time.update(time.time() - end)

                input_var, target_var = inputs.cuda(), targets.cuda()

                # compute output
                output = model(input_var)
                loss = self.criterion(output, target_var)

                # measure accuracy and record loss
                prec1, prec5 = accuracy(output.data, target_var, topk=(1, 5))
                losses.update(loss.item(), inputs.size(0))
                top1.update(prec1.item(), inputs.size(0))
                top5.update(prec5.item(), inputs.size(0))

                # measure elapsed time
                batch_time.update(time.time() - end)
                end = time.time()
                # plot progress
                if i % 1 == 0:
                    bar.suffix = \
                        '({batch}/{size}) Data: {data:.3f}s | Batch: {bt:.3f}s | Total: {total:} | ETA: {eta:} | ' \
                        'Loss: {loss:.4f} | top1: {top1: .4f} | top5: {top5: .4f}'.format(
                            batch=i + 1,
                            size=len(val_loader),
                            data=data_time.avg,
                            bt=batch_time.avg,
                            total=bar.elapsed_td,
                            eta=bar.eta_td,
                            loss=losses.avg,
                            top1=top1.avg,
                            top5=top5.avg,
                        )
                    bar.next()
            bar.finish()
        t2 = time.time()
        if verbose:
            print('* Test loss: %.3f  top1: %.3f  top5: %.3f  time: %.3f' %
                  (losses.avg, top1.avg, top5.avg, t2 - t1))

        # if self.use_top5:
        #     return top5.avg
        # else:
        #     return top1.avg

        return {"top1": top1.avg, "top5": top5.avg}
Exemple #10
0
def do_test(test_loader, model, cfg, visualize, writer_dict, final_output_dir):
    model.eval()

    batch_time = AverageMeter()
    data_time = AverageMeter()
    losses = AverageMeter()
    SSIM = AverageMeter()
    RMSE = AverageMeter()
    end = time.time()

    writer = writer_dict['writer']

    is_volume_visualizer = cfg.TEST.IS_VOLUME_VISUALIZER
    if is_volume_visualizer:
        from lib.analyze.utilis import VolumeVisualization
        from lib.analyze.utilis import visualize as volume_visualizer_function
        volume_visualizer = VolumeVisualization(
            data_range=[
                cfg.DATASET.NORMALIZATION.AFTER_MIN,
                cfg.DATASET.NORMALIZATION.AFTER_MAX
            ],
            before_min=cfg.DATASET.NORMALIZATION.BEFORE_MIN,
            before_max=cfg.DATASET.NORMALIZATION.BEFORE_MAX,
            after_min=cfg.DATASET.NORMALIZATION.AFTER_MIN,
            after_max=cfg.DATASET.NORMALIZATION.AFTER_MAX,
            threshold=200)

    for i, current_data in enumerate(test_loader):
        data_time.update(time.time() - end)

        model.set_dataset(current_data)
        with torch.no_grad():
            model.forward()
            #model.output = model.output[0]
            model.target = model.target[0]

        current_loss = model.criterion_pixel_wise_loss(model.output,
                                                       model.target)
        losses.update(current_loss.item())

        if is_volume_visualizer:
            current_rmse, current_ssim = volume_visualizer.update(
                current_data, model)
            RMSE.update(current_rmse)
            SSIM.update(current_ssim)

        batch_time.update(time.time() - end)
        end = time.time()

        if i % cfg.VAL.PRINT_FREQUENCY == 0:
            msg = 'Test: [{0}/{1}]\t' \
                  'Loss {losses.val:.5f} ({losses.avg:.5f})\t' \
                  'RMSE {RMSE.val:.5f}({RMSE.avg:.5f})\t' \
                  'SSIM {SSIM.val:.5f}({SSIM.avg:.5f})'.format(
                i, len(test_loader),
                losses=losses,
                RMSE=RMSE,
                SSIM=SSIM)

            logger.info(msg)

        model.output = [model.output]
        model.target = [model.target]
        visualize(model, writer_dict['test_global_steps'],
                  os.path.join(final_output_dir, "test"), 1,
                  cfg.DATASET.NORMALIZATION.BEFORE_MIN,
                  cfg.DATASET.NORMALIZATION.BEFORE_MAX,
                  cfg.DATASET.NORMALIZATION.AFTER_MIN,
                  cfg.DATASET.NORMALIZATION.AFTER_MAX)
        writer.add_scalar('test_loss', losses.val,
                          writer_dict['test_global_steps'])
        writer.add_scalar('RMSE', RMSE.val, writer_dict['test_global_steps'])
        writer.add_scalar('SSIM', SSIM.val, writer_dict['test_global_steps'])
        writer_dict['test_global_steps'] += 1

    # log the slice-wise ssim and rmse
    for current_data_path, current_rmse, current_ssim in zip(
            volume_visualizer.data_path, volume_visualizer.rmse,
            volume_visualizer.ssim):
        logger.info('{}\tSSIM:{}\tRMSE:{}'.format(
            os.path.basename(current_data_path), current_rmse, current_ssim))
    logger.info('* 2D Test: \t Average RMSE {}\t SSIM {}\n'.format(
        RMSE.avg, SSIM.avg))

    if is_volume_visualizer:
        rmse, ssim = volume_visualizer_function(
            volume_visualizer, model, writer_dict['test_global_steps'],
            os.path.join(final_output_dir, "volume"), 1)
        msg = '* 3D Test: \t RMSE {}\t SSIM {}'.format(rmse, ssim)
        logger.info(msg)

    return losses.val