Exemple #1
0
    def _validate(self, val_loader, model, verbose=False):
        batch_time = AverageMeter()
        data_time = AverageMeter()
        losses = AverageMeter()
        top1 = AverageMeter()
        top5 = AverageMeter()

        t1 = time.time()
        with torch.no_grad():
            # switch to evaluate mode
            model.eval()

            end = time.time()
            bar = Bar('valid:', max=len(val_loader))
            for i, (inputs, targets) in enumerate(val_loader):
                # measure data loading time
                data_time.update(time.time() - end)

                input_var, target_var = inputs.cuda(), targets.cuda()

                # compute output
                output = model(input_var)
                loss = self.criterion(output, target_var)

                # measure accuracy and record loss
                prec1, prec5 = accuracy(output.data, target_var, topk=(1, 5))
                losses.update(loss.item(), inputs.size(0))
                top1.update(prec1.item(), inputs.size(0))
                top5.update(prec5.item(), inputs.size(0))

                # measure elapsed time
                batch_time.update(time.time() - end)
                end = time.time()
                # plot progress
                if i % 1 == 0:
                    bar.suffix = \
                        '({batch}/{size}) Data: {data:.3f}s | Batch: {bt:.3f}s | Total: {total:} | ETA: {eta:} | ' \
                        'Loss: {loss:.4f} | top1: {top1: .4f} | top5: {top5: .4f}'.format(
                            batch=i + 1,
                            size=len(val_loader),
                            data=data_time.avg,
                            bt=batch_time.avg,
                            total=bar.elapsed_td,
                            eta=bar.eta_td,
                            loss=losses.avg,
                            top1=top1.avg,
                            top5=top5.avg,
                        )
                    bar.next()
            bar.finish()
        t2 = time.time()
        if verbose:
            print('* Test loss: %.3f  top1: %.3f  top5: %.3f  time: %.3f' %
                  (losses.avg, top1.avg, top5.avg, t2 - t1))
        if self.use_top5:
            return top5.avg
        else:
            return top1.avg
Exemple #2
0
def test(val_loader, model, criterion, epoch, use_cuda):
    global best_acc

    batch_time = AverageMeter()
    data_time = AverageMeter()
    losses = AverageMeter()
    top1 = AverageMeter()
    top5 = AverageMeter()

    with torch.no_grad():
        # switch to evaluate mode
        model.eval()

        end = time.time()
        bar = Bar('Processing', max=len(val_loader))
        for batch_idx, (inputs, targets) in enumerate(val_loader):
            # measure data loading time
            data_time.update(time.time() - end)

            if use_cuda:
                inputs, targets = inputs.cuda(), targets.cuda()
            inputs, targets = torch.autograd.Variable(
                inputs, volatile=True), torch.autograd.Variable(targets)

            # compute output
            outputs = model(inputs)
            loss = criterion(outputs, targets)

            # measure accuracy and record loss
            prec1, prec5 = accuracy(outputs.data, targets.data, topk=(1, 5))
            losses.update(loss.item(), inputs.size(0))
            top1.update(prec1.item(), inputs.size(0))
            top5.update(prec5.item(), inputs.size(0))

            # measure elapsed time
            batch_time.update(time.time() - end)
            end = time.time()

            # plot progress
            if batch_idx % 1 == 0:
                bar.suffix  = \
                    '({batch}/{size}) Data: {data:.3f}s | Batch: {bt:.3f}s | Total: {total:} | ETA: {eta:} | ' \
                    'Loss: {loss:.4f} | top1: {top1: .4f} | top5: {top5: .4f}'.format(
                        batch=batch_idx + 1,
                        size=len(val_loader),
                        data=data_time.avg,
                        bt=batch_time.avg,
                        total=bar.elapsed_td,
                        eta=bar.eta_td,
                        loss=losses.avg,
                        top1=top1.avg,
                        top5=top5.avg,
                        )
                bar.next()
        bar.finish()
    return losses.avg, top1.avg
Exemple #3
0
def train_integral(config, train_loader, model, criterion, optimizer, epoch):
    batch_time = AverageMeter()
    data_time = AverageMeter()
    losses = AverageMeter()

    # switch to train mode
    model.train()
    end = time.time()

    for i, data in enumerate(train_loader):
        # measure data loading time
        data_time.update(time.time() - end)

        batch_data, batch_label, batch_label_weight, meta = data

        optimizer.zero_grad()

        batch_data = batch_data.cuda()
        batch_label = batch_label.cuda()
        batch_label_weight = batch_label_weight.cuda()

        batch_size = batch_data.size(0)
        # compute output
        preds = model(batch_data)

        loss = criterion(preds, batch_label, batch_label_weight)
        del batch_data, batch_label, batch_label_weight, preds

        # compute gradient and do update step
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # record loss
        losses.update(loss.item(), batch_size)
        del loss
        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        if i % config.PRINT_FREQ == 0:
            msg = 'Epoch: [{0}][{1}/{2}]\t' \
                  'Time {batch_time.val:.3f}s ({batch_time.avg:.3f}s)\t' \
                  'Speed {speed:.1f} samples/s\t' \
                  'Data {data_time.val:.3f}s ({data_time.avg:.3f}s)\t' \
                  'Loss {loss.val:.5f} ({loss.avg:.5f})'.format(
                      epoch, i, len(train_loader), batch_time=batch_time,
                      speed=batch_size/batch_time.val,
                      data_time=data_time, loss=losses)
            logger.info(msg)
Exemple #4
0
def validate(config, testloader, model, writer_dict):
    model.eval()
    ave_loss = AverageMeter()
    confusion_matrix = np.zeros(
        (config.DATASET.NUM_CLASSES, config.DATASET.NUM_CLASSES))
    with torch.no_grad():
        for _, batch in enumerate(testloader):
            image, label, _, _ = batch
            size = label.size()
            label = label.long().cuda()

            losses, pred = model(image, label)
            pred = F.upsample(input=pred,
                              size=(size[-2], size[-1]),
                              mode='bilinear')
            loss = losses.mean()
            ave_loss.update(loss.item())

            confusion_matrix += get_confusion_matrix(
                label, pred, size, config.DATASET.NUM_CLASSES,
                config.TRAIN.IGNORE_LABEL)

    pos = confusion_matrix.sum(1)
    res = confusion_matrix.sum(0)
    tp = np.diag(confusion_matrix)
    IoU_array = (tp / np.maximum(1.0, pos + res - tp))
    mean_IoU = IoU_array.mean()

    writer = writer_dict['writer']
    global_steps = writer_dict['valid_global_steps']
    writer.add_scalar('valid_loss', ave_loss.average(), global_steps)
    writer.add_scalar('valid_mIoU', mean_IoU, global_steps)
    writer_dict['valid_global_steps'] = global_steps + 1
    return ave_loss.average(), mean_IoU, IoU_array
Exemple #5
0
def do_validate(val_loader, model, cfg, visualize, writer_dict,
                final_output_dir):
    batch_time = AverageMeter()
    end = time.time()
    model.eval()

    selected_visualized_data = random.randint(0, len(val_loader) - 1)
    for i, current_data in enumerate(val_loader):
        model.set_dataset(current_data)

        with torch.no_grad():
            model.forward()
            model.loss_calculation()

        batch_time.update(time.time() - end)
        end = time.time()

        performance = model.record_information(
            current_iteration=i,
            data_loader_size=len(val_loader),
            writer_dict=writer_dict,
            phase='val')
        if i == selected_visualized_data and cfg.IS_VISUALIZE:
            visualize(model, writer_dict['val_global_steps'],
                      os.path.join(final_output_dir, "val"), 1)

    return performance
Exemple #6
0
def train(args, train_loader, model, criterion, optimizer, epoch):
    losses = AverageMeter()
    ac_scores = AverageMeter()

    model.train()

    for i, (input, target) in tqdm(enumerate(train_loader),
                                   total=len(train_loader)):
        input = input.cuda()
        target = target.cuda()

        output = model(input)
        if args.mode == 'baseline':
            output = output
        elif args.mode == 'gcn':
            output, adj = output
        else:
            output = output

        if args.pred_type == 'classification':
            loss = criterion(output, target)
        elif args.pred_type == 'regression':
            loss = criterion(output.view(-1), target.float())
        elif args.pred_type == 'multitask':
            loss = args.reg_coef * criterion['regression'](output[:, 0], target.float()) + \
                args.cls_coef * \
                criterion['classification'](output[:, 1:], target)
            output = output[:, 0].unsqueeze(1)

        # compute gradient and do optimizing step
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        ac_score = compute_accuracy(output, target)

        losses.update(loss.item(), input.size(0))
        ac_scores.update(ac_score, input.size(0))
    if args.mode == 'gcn':
        print(torch.max(adj))
    return losses.avg, ac_scores.avg
Exemple #7
0
def train(config, epoch, num_epoch, epoch_iters, base_lr, num_iters,
          trainloader, optimizer, model, writer_dict):
    # Training
    model.train()
    batch_time = AverageMeter()
    ave_loss = AverageMeter()
    tic = time.time()
    cur_iters = epoch * epoch_iters
    writer = writer_dict['writer']
    global_steps = writer_dict['train_global_steps']
    for i_iter, batch in enumerate(trainloader, 0):
        images, labels = batch
        images = images.cuda()
        labels = labels.long().cuda()

        losses, _ = model(images, labels)
        loss = losses.mean()

        model.zero_grad()
        loss.backward()
        optimizer.step()

        # measure elapsed time
        batch_time.update(time.time() - tic)
        tic = time.time()

        # update average loss
        ave_loss.update(loss.item())

        lr = adjust_learning_rate(optimizer, base_lr, num_iters,
                                  i_iter + cur_iters)

        if i_iter % config.PRINT_FREQ == 0:
            msg = 'Epoch: [{}/{}] Iter:[{}/{}], Time: {:.2f}, ' \
                  'lr: {:.6f}, Loss: {:.6f}' .format(
                      epoch, num_epoch, i_iter, epoch_iters,
                      batch_time.average(), lr, ave_loss.average())
            logging.info(msg)

    writer.add_scalar('train_loss', ave_loss.average(), global_steps)
    writer_dict['train_global_steps'] = global_steps + 1
Exemple #8
0
def validate(args, val_loader, model, criterion):
    losses = AverageMeter()
    ac_scores = AverageMeter()

    # switch to evaluate mode
    model.eval()

    with torch.no_grad():
        for i, (input, target) in tqdm(enumerate(val_loader),
                                       total=len(val_loader)):
            input = input.cuda()
            target = target.cuda()

            output = model(input)
            if args.mode == 'baseline':
                output = output
            elif args.mode == 'gcn':
                output, adj = output
            else:
                output = output

            if args.pred_type == 'classification':
                loss = criterion(output, target)
            elif args.pred_type == 'regression':
                loss = criterion(output.view(-1), target.float())
            elif args.pred_type == 'multitask':
                loss = args.reg_coef * criterion['regression'](output[:, 0], target.float()) + \
                    args.cls_coef * \
                    criterion['classification'](output[:, 1:], target)
                output = output[:, 0].unsqueeze(1)

            ac_score = compute_accuracy(output, target)

            losses.update(loss.item(), input.size(0))
            ac_scores.update(ac_score, input.size(0))

    return losses.avg, ac_scores.avg
Exemple #9
0
def test(opt):
    os.environ['CUDA_VISIBLE_DEVICES'] = opt.gpus_str

    Dataset = dataset_factory[opt.dataset]
    opt = opts().update_dataset_info_and_set_heads(opt, Dataset)
    print(opt)
    Logger(opt)
    opt.debug = max(opt.debug, 1)
    Detector = detector_factory[opt.task]

    split = 'val' if opt.trainval else 'test'
    dataset = Dataset(opt, split)
    detector = Detector(opt)

    results = {}
    num_iters = len(dataset)
    bar = Bar('{}'.format(opt.exp_id), max=num_iters)
    time_stats = ['tot', 'load', 'pre', 'net', 'dec', 'post', 'merge']
    avg_time_stats = {t: AverageMeter() for t in time_stats}
    for ind in range(num_iters):
        img_id = dataset.images[ind]
        img_info = dataset.coco.loadImgs(ids=[img_id])[0]
        img_path = os.path.join(dataset.img_dir, img_info['file_name'])

        if opt.task == 'ddd':
            ret = detector.run(img_path, img_info['calib'])
        else:
            ret = detector.run(img_path)

        results[img_id] = ret['results']

        Bar.suffix = '[{0}/{1}]|Tot: {total:} |ETA: {eta:} '.format(
            ind, num_iters, total=bar.elapsed_td, eta=bar.eta_td)
        for t in avg_time_stats:
            avg_time_stats[t].update(ret[t])
            Bar.suffix = Bar.suffix + '|{} {:.3f} '.format(
                t, avg_time_stats[t].avg)
        bar.next()
    bar.finish()
    dataset.run_eval(results, opt.save_dir)
Exemple #10
0
    def record_information(self, current_iteration=None, data_loader_size=None, batch_time=None, data_time=None,
                           indicator_dict=None, writer_dict=None, phase='train'):
        writer = writer_dict['writer']
        if phase == 'train':
            self.losses_train.update(self.loss.item())
            indicator_dict['current_iteration'] += 1
            global_steps = writer_dict['train_global_steps']
            writer.add_scalar('train_loss', self.loss.item(), global_steps)
            writer_dict['train_global_steps'] = global_steps + 1
            if current_iteration % self.cfg.TRAIN.PRINT_FREQUENCY == 0:
                msg = 'Iteration: [{0}/{1}]\t' \
                      'Time {batch_time.val:.3f}s ({batch_time.avg:.3f}s)\t' \
                      'Data {data_time.val:.3f}s ({data_time.avg:.3f}s)\t' \
                      'LR: {LR:.6f}\t' \
                      'Loss {losses.val:.5f} ({losses.avg:.5f})'.format(
                    current_iteration, data_loader_size,
                    batch_time=batch_time,
                    data_time=data_time,
                    LR=self.schedulers[0].get_last_lr()[0],
                    losses=self.losses_train)
        elif phase == 'val':
            if current_iteration == 0:
                self.losses_val = AverageMeter()
            self.losses_val.update(self.loss.item())

            if current_iteration == data_loader_size - 1:
                global_steps = writer_dict['val_global_steps']
                writer.add_scalar('val_loss', self.loss, global_steps)
                writer_dict['val_global_steps'] = global_steps + 1

            if current_iteration % self.cfg.VAL.PRINT_FREQUENCY == 0:
                msg = 'Val: [{0}/{1}]\t' \
                      'Loss {losses.val:.5f} ({losses.avg:.5f})'.format(
                    current_iteration, data_loader_size,
                    losses=self.losses_val)
        else:
            raise ValueError('Unknown operation in information recording!')
        logger.info(msg)
        return self.losses_val.avg
Exemple #11
0
def prefetch_test(opt):
    os.environ['CUDA_VISIBLE_DEVICES'] = opt.gpus_str

    Dataset = dataset_factory[opt.dataset]
    opt = opts().update_dataset_info_and_set_heads(opt, Dataset)
    print(opt)
    Logger(opt)
    Detector = detector_factory[opt.task]

    split = 'val' if not opt.trainval else 'test'
    dataset = Dataset(opt, split)
    detector = Detector(opt)

    data_loader = torch.utils.data.DataLoader(PrefetchDataset(
        opt, dataset, detector.pre_process),
                                              batch_size=1,
                                              shuffle=False,
                                              num_workers=1,
                                              pin_memory=True)

    results = {}
    num_iters = len(dataset)
    bar = Bar('{}'.format(opt.exp_id), max=num_iters)
    time_stats = ['tot', 'load', 'pre', 'net', 'dec', 'post', 'merge']
    avg_time_stats = {t: AverageMeter() for t in time_stats}
    for ind, (img_id, pre_processed_images) in enumerate(data_loader):
        ret = detector.run(pre_processed_images)
        img_id = np.array(img_id)
        results[img_id.astype(np.int32)[0]] = ret['results']
        # results[img_id.numpy().astype(np.int32)[0]] = ret['results']
        Bar.suffix = '[{0}/{1}]|Tot: {total:} |ETA: {eta:} '.format(
            ind, num_iters, total=bar.elapsed_td, eta=bar.eta_td)
        for t in avg_time_stats:
            avg_time_stats[t].update(ret[t])
            Bar.suffix = Bar.suffix + '|{} {tm.val:.3f}s ({tm.avg:.3f}s) '.format(
                t, tm=avg_time_stats[t])
        bar.next()
    bar.finish()
    dataset.run_eval(results, opt.save_dir)
Exemple #12
0
def val(cfg, model, val_data_loader):
    model.eval()
    torch.set_grad_enabled(False)
    # Indicator to log
    batch_time = AverageMeter()
    HD = []

    # The method to predict midline: Left, Right, Max, DP(dynamic programming)

    end = time.time()
    print('# ===== Validation ===== #')
    for step, (input, target, _, gt_curves, _,
               _) in enumerate(val_data_loader):
        target[target != 0] = 1
        label = target.numpy().astype('uint8')
        # Variable
        input_var = Variable(input).cuda()
        target_var = target.cuda()

        with torch.no_grad():
            # forward
            pred = model(input_var)
            midline = pred2midline(pred,
                                   gt_curves,
                                   model_name=cfg.model_param.model_name)[0]
            hd_info = compute_assd(midline, label)
            HD.extend(hd_info)
            batch_time.update(time.time() - end)

        end = time.time()
        logger_vis.info(
            'Eval: [{0}/{1}]\t'
            'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'.format(
                step, len(val_data_loader), batch_time=batch_time))
        #breaK
    hd_mean = np.mean(HD)

    return -hd_mean
Exemple #13
0
    def train(self):
        # Single epoch training routine

        losses = AverageMeter()

        timer = {
            'data': 0,
            'forward': 0,
            'loss': 0,
            'backward': 0,
            'batch': 0,
        }

        self.generator.train()
        self.motion_discriminator.train()

        start = time.time()

        summary_string = ''

        bar = Bar(f'Epoch {self.epoch + 1}/{self.end_epoch}',
                  fill='#',
                  max=self.num_iters_per_epoch)

        for i in range(self.num_iters_per_epoch):
            # Dirty solution to reset an iterator
            target_2d = target_3d = None
            if self.train_2d_iter:
                try:
                    target_2d = next(self.train_2d_iter)
                except StopIteration:
                    self.train_2d_iter = iter(self.train_2d_loader)
                    target_2d = next(self.train_2d_iter)

                move_dict_to_device(target_2d, self.device)

            if self.train_3d_iter:
                try:
                    target_3d = next(self.train_3d_iter)
                except StopIteration:
                    self.train_3d_iter = iter(self.train_3d_loader)
                    target_3d = next(self.train_3d_iter)

                move_dict_to_device(target_3d, self.device)

            real_body_samples = real_motion_samples = None

            try:
                real_motion_samples = next(self.disc_motion_iter)
            except StopIteration:
                self.disc_motion_iter = iter(self.disc_motion_loader)
                real_motion_samples = next(self.disc_motion_iter)

            move_dict_to_device(real_motion_samples, self.device)

            # <======= Feedforward generator and discriminator
            if target_2d and target_3d:
                inp = torch.cat((target_2d['features'], target_3d['features']),
                                dim=0).to(self.device)
            elif target_3d:
                inp = target_3d['features'].to(self.device)
            else:
                inp = target_2d['features'].to(self.device)

            timer['data'] = time.time() - start
            start = time.time()

            preds = self.generator(inp)

            timer['forward'] = time.time() - start
            start = time.time()

            gen_loss, motion_dis_loss, loss_dict = self.criterion(
                generator_outputs=preds,
                data_2d=target_2d,
                data_3d=target_3d,
                data_body_mosh=real_body_samples,
                data_motion_mosh=real_motion_samples,
                motion_discriminator=self.motion_discriminator,
            )
            # =======>

            timer['loss'] = time.time() - start
            start = time.time()

            # <======= Backprop generator and discriminator
            self.gen_optimizer.zero_grad()
            gen_loss.backward()
            self.gen_optimizer.step()

            if self.train_global_step % self.dis_motion_update_steps == 0:
                self.dis_motion_optimizer.zero_grad()
                motion_dis_loss.backward()
                self.dis_motion_optimizer.step()
            # =======>

            # <======= Log training info
            total_loss = gen_loss + motion_dis_loss

            losses.update(total_loss.item(), inp.size(0))

            timer['backward'] = time.time() - start
            timer['batch'] = timer['data'] + timer['forward'] + timer[
                'loss'] + timer['backward']
            start = time.time()

            summary_string = f'({i + 1}/{self.num_iters_per_epoch}) | Total: {bar.elapsed_td} | ' \
                             f'ETA: {bar.eta_td:} | loss: {losses.avg:.4f}'

            for k, v in loss_dict.items():
                summary_string += f' | {k}: {v:.2f}'
                self.writer.add_scalar('train_loss/' + k,
                                       v,
                                       global_step=self.train_global_step)

            for k, v in timer.items():
                summary_string += f' | {k}: {v:.2f}'

            self.writer.add_scalar('train_loss/loss',
                                   total_loss.item(),
                                   global_step=self.train_global_step)

            if self.debug:
                print('==== Visualize ====')
                from lib.utils.vis import batch_visualize_vid_preds
                video = target_3d['video']
                dataset = 'spin'
                vid_tensor = batch_visualize_vid_preds(video,
                                                       preds[-1],
                                                       target_3d.copy(),
                                                       vis_hmr=False,
                                                       dataset=dataset)
                self.writer.add_video('train-video',
                                      vid_tensor,
                                      global_step=self.train_global_step,
                                      fps=10)

            self.train_global_step += 1
            bar.suffix = summary_string
            bar.next()

            if torch.isnan(total_loss):
                exit('Nan value in loss, exiting!...')
            # =======>

        bar.finish()

        logger.info(summary_string)
Exemple #14
0
    def run_epoch(self, phase, epoch, data_loader):
        model_with_loss = self.model_with_loss
        if phase == 'train':
            model_with_loss.train()
        else:
            if len(self.opt.gpus) > 1:
                model_with_loss = self.model_with_loss.module
            model_with_loss.eval()
            torch.cuda.empty_cache()

        opt = self.opt
        results = {}
        data_time, batch_time = AverageMeter(), AverageMeter()
        avg_loss_stats = {l: AverageMeter() for l in self.loss_stats}
        num_iters = len(data_loader) if opt.num_iters < 0 else opt.num_iters
        bar = Bar('{}/{}'.format(opt.task, opt.exp_id), max=num_iters)
        end = time.time()
        for iter_id, batch in enumerate(data_loader):
            if iter_id >= num_iters:
                break
            data_time.update(time.time() - end)

            for k in batch:
                if k != 'meta':
                    batch[k] = batch[k].to(device=opt.device,
                                           non_blocking=True)
            output, loss, loss_stats = model_with_loss(batch)
            loss = loss.mean()
            if phase == 'train':
                self.optimizer.zero_grad()
                loss.backward()
                self.optimizer.step()
            batch_time.update(time.time() - end)
            end = time.time()

            Bar.suffix = '{phase}: [{0}][{1}/{2}]|Tot: {total:} |ETA: {eta:} '.format(
                epoch,
                iter_id,
                num_iters,
                phase=phase,
                total=bar.elapsed_td,
                eta=bar.eta_td)
            for l in avg_loss_stats:
                avg_loss_stats[l].update(loss_stats[l].mean().item(),
                                         batch['input'].size(0))
                Bar.suffix = Bar.suffix + '|{} {:.4f} '.format(
                    l, avg_loss_stats[l].avg)
            if not opt.hide_data_time:
                Bar.suffix = Bar.suffix + '|Data {dt.val:.3f}s({dt.avg:.3f}s) ' \
                  '|Net {bt.avg:.3f}s'.format(dt=data_time, bt=batch_time)
            if opt.print_iter > 0:
                if iter_id % opt.print_iter == 0:
                    print('{}/{}| {}'.format(opt.task, opt.exp_id, Bar.suffix))
            else:
                bar.next()

            if opt.debug > 0:
                self.debug(batch, output, iter_id)

            if opt.test:
                self.save_result(output, batch, results)
            del output, loss, loss_stats

        bar.finish()
        ret = {k: v.avg for k, v in avg_loss_stats.items()}
        ret['time'] = bar.elapsed_td.total_seconds() / 60.
        return ret, results
Exemple #15
0
def test(cfg, model, test_data_loader, logger):
    model.eval()
    torch.set_grad_enabled(False)
    # Indicator to log
    batch_time = AverageMeter()
    shift_list = cfg._get_shift_list()
    metric_class = Metric()

    # The method to predict midline: Left, Right, Max, DP(dynamic programming)

    end = time.time()
    print('# ===== TEST ===== #')
    for step, (input, target, _, gt_curves, ori_img,
               path) in enumerate(test_data_loader):
        pid = path[0].split('/')[-2]
        target[target != 0] = 1
        gt_midline = target.numpy().astype(np.uint8)
        img = ori_img.numpy().astype(np.uint8)
        # Variable
        input_var = Variable(input).cuda()
        target_var = target.cuda()

        with torch.no_grad():
            # forward
            pred = model(input_var)
            pred_midline_real_limit, pred_midline_gt_limit = pred2midline(
                pred, gt_curves, model_name=cfg.model_param.model_name)
            metric_class.process(pred_midline_real_limit,
                                 pred_midline_gt_limit, gt_midline)
            batch_time.update(time.time() - end)

        # save vis png
        if cfg.test_setting_param.save_vis:
            pid_save_name = pid + '.shift' if pid in shift_list else pid
            save_dir_width_1 = osp.join('{}.{}'.format(cfg.vis_dir, '1'),
                                        pid_save_name)
            save_dir_width_3 = osp.join('{}.{}'.format(cfg.vis_dir, '3'),
                                        pid_save_name)
            ensure_dir(save_dir_width_1)
            ensure_dir(save_dir_width_3)
            vis_boundary(img,
                         gt_midline,
                         pred_midline_real_limit,
                         save_dir_width_1,
                         is_background=True)
            vis_boundary(img,
                         midline_expand(gt_midline),
                         midline_expand(pred_midline_real_limit),
                         save_dir_width_3,
                         is_background=True)
            print('save vis, finished!')

        end = time.time()
        logger_vis.info(
            'Eval: [{0}/{1}]\t'
            'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'.format(
                step, len(test_data_loader), batch_time=batch_time))

    final_dict = metric_class.get_final_dict()
    assd_mean = final_dict['assd'][0]
    logger.write('\n# ===== final stat ===== #')
    for key, value in final_dict.items():
        logger.write('{}: {}({})'.format(key, value[0], value[1]))

    return -assd_mean
Exemple #16
0
    def _validate(self, val_loader, model, verbose=False):
        batch_time = AverageMeter()
        data_time = AverageMeter()
        losses = AverageMeter()
        top1 = AverageMeter()
        top5 = AverageMeter()

        t1 = time.time()
        with torch.no_grad():
            # switch to evaluate mode
            model.eval()

            end = time.time()
            bar = Bar('valid:', max=len(val_loader))
            for i, batch in enumerate(val_loader):
                data_time.update(time.time() - end)
                batch = tuple(t.cuda() for t in batch)

                input_ids, input_mask, segment_ids, lm_label_ids, is_next = batch
                outputs = model(input_ids=input_ids,
                                token_type_ids=segment_ids,
                                attention_mask=input_mask,
                                masked_lm_labels=lm_label_ids,
                                next_sentence_label=is_next)
                loss = outputs[0]
                if self.n_gpu > 1:
                    loss = loss.mean()

                # measure data loading time
                # input_var, target_var = inputs.cuda(), targets.cuda()

                # compute output
                # output = model(input_var)
                # loss = self.criterion(output, target_var)

                # measure accuracy and record loss
                # prec1, prec5 = accuracy(output.data, target_var, topk=(1, 5))
                losses.update(loss.item(), input_ids.size(0))

                # top1.update(prec1.item(), inputs.size(0))
                # top5.update(prec5.item(), inputs.size(0))

                # measure elapsed time
                batch_time.update(time.time() - end)
                end = time.time()
                # plot progress

                if i % 1 == 0:
                    bar.suffix = \
                        '({batch}/{size}) Data: {data:.3f}s | Batch: {bt:.3f}s | Total: {total:} | ETA: {eta:} | ' \
                        'Loss: {loss:.4f}'.format(
                            batch=i + 1,
                            size=len(val_loader),
                            data=data_time.avg,
                            bt=batch_time.avg,
                            total=bar.elapsed_td,
                            eta=bar.eta_td,
                            loss=losses.avg,
                        )
                    bar.next()
            bar.finish()
        t2 = time.time()
        if verbose:
            print('* Test loss: %.3f  time: %.3f' % (losses.avg, t2 - t1))

        return losses.avg
Exemple #17
0
def do_train(train_loader, val_loader, model, indicator_dict, cfg, writer_dict,
             final_output_dir, log_dir, visualize):
    batch_time = AverageMeter()
    data_time = AverageMeter()

    end = time.time()
    for i, current_data in enumerate(
            train_loader, start=indicator_dict['current_iteration']):
        data_time.update(time.time() - end)

        if i > indicator_dict['total_iteration']:
            return

        # validation
        if indicator_dict[
                'current_iteration'] % cfg.VAL.EVALUATION_FREQUENCY == 0:
            indicator_dict['current_performance'] = do_validate(
                val_loader, model, cfg, visualize, writer_dict,
                final_output_dir)
            indicator_dict['is_best'] = False
            if indicator_dict['current_performance'] < indicator_dict[
                    'best_performance']:
                indicator_dict['best_performance'] = indicator_dict[
                    'current_performance']
                indicator_dict['is_best'] = True

            # save checkpoint
            output_dictionary = {
                'indicator_dict': indicator_dict,
                'writer_dict_train_global_steps':
                writer_dict['train_global_steps'],
                'writer_dict_val_global_steps':
                writer_dict['val_global_steps'],
                'tb_log_dir': log_dir
            }

            if hasattr(model, 'generator'):
                output_dictionary['generator'] = model.generator.state_dict()
                output_dictionary[
                    'optimizer_generator'] = model.optimizer_generator.state_dict(
                    )

            if hasattr(model, 'discriminator'):
                output_dictionary[
                    'discriminator'] = model.discriminator.state_dict()
                output_dictionary[
                    'optimizer_discriminator'] = model.optimizer_discriminator.state_dict(
                    )

            save_checkpoint(output_dictionary, indicator_dict,
                            final_output_dir)
            model.train()

        # train
        model.set_dataset(current_data)
        model.optimize_parameters()

        # visualize
        if indicator_dict[
                'current_iteration'] % cfg.TRAIN.DISPLAY_FREQUENCY == 0 and cfg.IS_VISUALIZE:
            visualize(model, indicator_dict['current_iteration'],
                      os.path.join(final_output_dir, "train"),
                      cfg.TRAIN.DISPLAY_FREQUENCY)

        # update learning rate
        for current_scheduler in model.schedulers:
            current_scheduler.step()

        batch_time.update(time.time() - end)
        end = time.time()
        model.record_information(i,
                                 len(train_loader),
                                 batch_time,
                                 data_time,
                                 indicator_dict,
                                 writer_dict,
                                 phase='train')
Exemple #18
0
def train(args, train_loader, index_loader, val_loader, model, optimizer,
          scheduler, criterion, hash_center):

    print("\n\n\n Start Train! \n\n\n")

    valid_result = {}
    for epc in range(args.epoch):
        model.train()

        batch_time = AverageMeter()
        data_time = AverageMeter()
        losses = AverageMeter()

        start = time.time()
        pbar = tqdm.tqdm(enumerate(train_loader), desc="Epoch : %d" % epc)

        for batch_i, (data, label) in pbar:
            center = hash_center_multilables(label, hash_center)

            data = data.float().cuda()
            label = label.cuda()
            center = center.cuda()

            if model.training:
                input_var = torch.autograd.Variable(data, requires_grad=False)
                center_var = torch.autograd.Variable(center,
                                                     requires_grad=False)
            else:
                input_var = torch.autograd.Variable(data, volatile=True)
                center_var = torch.autograd.Variable(center, volatile=True)

            data_time.update(time.time() - start)
            start = time.time()

            output = model(input_var)

            loss = criterion(0.5 * (output + 1), 0.5 * (center_var + 1))

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            losses.update(loss.item())
            batch_time.update(time.time() - start)
            start = time.time()
            if args.wandb:
                wandb.log({"Train/Batch Loss": losses.avg})

            state_msg = (
                'Epoch: {:4d}; Loss: {:0.5f}; Data time: {:0.5f}; Batch time: {:0.5f};'
                .format(epc, losses.avg, data_time.avg, batch_time.avg))

            pbar.set_description(state_msg)

        scheduler.step()
        if args.wandb:
            wandb.log({"Train/Epoch Loss": losses.avg})

        if epc % args.save_frequency == 0:
            mAP = evaluation(args, index_loader, val_loader, model)
            valid_result.update({epc: mAP})
            if args.wandb:
                wandb.log({"Valid/Epoch mAP@{}".format(args.R): mAP})

        if epc % args.save_frequency == 0:
            ckpt_path = os.path.join(args.model_dir, "ckpt",
                                     "ckpt_{:05d}.pth".format(epc))
            torch.save(
                {
                    'epoch': epc,
                    'state_dict': model.state_dict(),
                    'optimizer': optimizer.state_dict(),
                    'valid_result': valid_result
                }, ckpt_path)
Exemple #19
0
def train(cfg, model, train_loader, epoch):
    # set the AverageMeter
    Batch_time = AverageMeter()
    Eval_time = AverageMeter()
    dice = AverageMeter()
    loss_1 = AverageMeter()
    loss_2 = AverageMeter()
    loss_3 = AverageMeter()
    losses = AverageMeter()
    dice = AverageMeter()
    torch.set_grad_enabled(True)
    # switch to train mode
    model.train()
    start = time.time()
    # loop in dataloader
    for iter, (img, gt_mid, gt_mid_3, gt_x_coords, pos_weight, wce_weight,
               heatmap) in enumerate(train_loader):
        curr_iter = iter + 1 + epoch * len(train_loader)
        total_iter = cfg.training_setting_param.epochs * len(train_loader)
        lr = adjust_learning_rate(cfg, cfg.training_setting_param, epoch,
                                  curr_iter, total_iter)

        input_var = Variable(img).cuda()
        gt_mid_3 = Variable(gt_mid_3).cuda()
        gt_x_coords = Variable(gt_x_coords).cuda()
        pos_weight = Variable(pos_weight).cuda()
        wce_weight = Variable(wce_weight).cuda()
        heatmap = Variable(heatmap).cuda()
        gt_mid_3[gt_mid_3 != 0] = 1
        # forward
        pred = model(input_var)
        # loss
        targets = (gt_x_coords, gt_mid_3, pos_weight, wce_weight, heatmap)
        loss, loss1, loss2, loss3 = get_total_loss(cfg, pred, targets, epoch)
        loss_1.update(loss1.item(), input_var.size(0))
        loss_2.update(loss2.item(), input_var.size(0))
        loss_3.update(loss3.item(), input_var.size(0))
        losses.update(loss.item(), input_var.size(0))
        # mid
        end = time.time()
        # metrics: dice for ven
        Eval_time.update(time.time() - end)
        cfg.optimizer.zero_grad()
        loss.backward()
        cfg.optimizer.step()
        # measure elapsed time
        Batch_time.update(time.time() - start)
        start = time.time()

        if iter % cfg.print_freq == 0:
            logger_vis.info(
                'Epoch-Iter: [{0}/{1}]-[{2}/{3}]\t'
                'LR: {4:.6f}\t'
                'Batch_Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
                'Eval_Time {eval_time.val:.3f} ({eval_time.avg:.3f})\t'
                'Loss1 {loss_1.val:.3f} ({loss_1.avg:.3f})\t'
                'Loss2 {loss_2.val:.3f} ({loss_2.avg:.3f})\t'
                'Loss3 {loss_3.val:.3f} ({loss_3.avg:.3f})\t'
                'Loss {loss.val:.3f} ({loss.avg:.3f})\t'.format(
                    epoch,
                    cfg.training_setting_param.epochs,
                    iter,
                    len(train_loader),
                    lr,
                    batch_time=Batch_time,
                    eval_time=Eval_time,
                    loss_1=loss_1,
                    loss_2=loss_2,
                    loss_3=loss_3,
                    loss=losses))

    return losses.avg, dice.avg
Exemple #20
0
 def __init__(self, is_train=True):
     super(BaseModel, self).__init__()
     self.is_train = is_train
     self.losses_train = AverageMeter()
Exemple #21
0
def train(train_loader, model, criterion, optimizer, epoch, use_cuda):
    # switch to train mode
    model.train()

    batch_time = AverageMeter()
    data_time = AverageMeter()
    losses = AverageMeter()
    top1 = AverageMeter()
    top5 = AverageMeter()
    end = time.time()

    bar = Bar('Processing', max=len(train_loader))
    for batch_idx, (inputs, targets) in enumerate(train_loader):
        # measure data loading time
        data_time.update(time.time() - end)

        if use_cuda:
            inputs, targets = inputs.cuda(), targets.cuda()
        inputs, targets = torch.autograd.Variable(
            inputs), torch.autograd.Variable(targets)

        # compute output
        outputs = model(inputs)
        loss = criterion(outputs, targets)

        # measure accuracy and record loss
        prec1, prec5 = accuracy(outputs.data, targets.data, topk=(1, 5))
        losses.update(loss.item(), inputs.size(0))
        top1.update(prec1.item(), inputs.size(0))
        top5.update(prec5.item(), inputs.size(0))

        # compute gradient
        optimizer.zero_grad()
        loss.backward()
        # do SGD step
        optimizer.step()

        kmeans_update_model(model,
                            quantizable_idx,
                            centroid_label_dict,
                            free_high_bit=args.free_high_bit)

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        # plot progress
        if batch_idx % 1 == 0:
            bar.suffix = \
                '({batch}/{size}) Data: {data:.3f}s | Batch: {bt:.3f}s | Total: {total:} | ETA: {eta:} | ' \
                'Loss: {loss:.4f} | top1: {top1: .4f} | top5: {top5: .4f}'.format(
                    batch=batch_idx + 1,
                    size=len(train_loader),
                    data=data_time.val,
                    bt=batch_time.val,
                    total=bar.elapsed_td,
                    eta=bar.eta_td,
                    loss=losses.avg,
                    top1=top1.avg,
                    top5=top5.avg,
                )
            bar.next()
    bar.finish()
    return losses.avg, top1.avg
Exemple #22
0
    def _finetune(self, train_loader, model, epochs=1, verbose=True):
        batch_time = AverageMeter()
        data_time = AverageMeter()
        losses = AverageMeter()
        top1 = AverageMeter()
        top5 = AverageMeter()
        best_acc = 0.

        # switch to train mode
        model.train()
        end = time.time()
        t1 = time.time()
        bar = Bar('train:', max=len(train_loader))
        for epoch in range(epochs):
            for i, (inputs, targets) in enumerate(train_loader):
                input_var, target_var = inputs.cuda(), targets.cuda()

                # measure data loading time
                data_time.update(time.time() - end)

                # compute output
                output = model(input_var)
                loss = self.criterion(output, target_var)

                # measure accuracy and record loss
                prec1, prec5 = accuracy(output.data, target_var, topk=(1, 5))
                losses.update(loss.item(), inputs.size(0))
                top1.update(prec1.item(), inputs.size(0))
                top5.update(prec5.item(), inputs.size(0))

                # compute gradient
                self.optimizer.zero_grad()
                loss.backward()

                # do SGD step
                self.optimizer.step()

                # measure elapsed time
                batch_time.update(time.time() - end)
                end = time.time()

                # plot progress
                if i % 1 == 0:
                    bar.suffix = \
                        '({batch}/{size}) Data: {data:.3f}s | Batch: {bt:.3f}s | Total: {total:} | ETA: {eta:} | ' \
                        'Loss: {loss:.4f} | top1: {top1: .4f} | top5: {top5: .4f}'.format(
                            batch=i + 1,
                            size=len(train_loader),
                            data=data_time.val,
                            bt=batch_time.val,
                            total=bar.elapsed_td,
                            eta=bar.eta_td,
                            loss=losses.avg,
                            top1=top1.avg,
                            top5=top5.avg,
                        )
                    bar.next()
            bar.finish()

            if self.use_top5:
                if top5.avg > best_acc:
                    best_acc = top5.avg
            else:
                if top1.avg > best_acc:
                    best_acc = top1.avg
            self.adjust_learning_rate()
        t2 = time.time()
        if verbose:
            print('* Test loss: %.3f  top1: %.3f  top5: %.3f  time: %.3f' %
                  (losses.avg, top1.avg, top5.avg, t2 - t1))
        return best_acc
    def run_epoch(self, phase, epoch, data_loader):
        """
        :param phase:
        :param epoch:
        :param data_loader:
        :return:
        """
        model_with_loss = self.model_with_loss

        if phase == 'train':
            model_with_loss.train()  # train phase
        else:
            if len(self.opt.gpus) > 1:
                model_with_loss = self.model_with_loss.module

            model_with_loss.eval()  # test phase
            torch.cuda.empty_cache()

        opt = self.opt
        results = {}
        data_time, batch_time = AverageMeter(), AverageMeter()
        avg_loss_stats = {l: AverageMeter() for l in self.loss_stats}
        num_iters = len(data_loader) if opt.num_iters < 0 else opt.num_iters
        bar = Bar('{}/{}'.format(opt.task, opt.exp_id), max=num_iters)
        end = time.time()

        # train each batch
        # print('Total {} batches in en epoch.'.format(len(data_loader) + 1))
        for batch_i, batch in enumerate(data_loader):
            if batch_i >= num_iters:
                break

            data_time.update(time.time() - end)

            for k in batch:
                if k != 'meta':
                    batch[k] = batch[k].to(device=opt.device,
                                           non_blocking=True)

            # Forward
            output, loss, loss_stats = model_with_loss.forward(batch)

            # Backwards
            loss = loss.mean()
            if phase == 'train':
                self.optimizer.zero_grad()  # 优化器梯度清零
                loss.backward()  # 梯度反传
                self.optimizer.step()  # 优化器依据反传的梯度, 更新网络权重

            batch_time.update(time.time() - end)
            end = time.time()

            Bar.suffix = '{phase}: [{0}][{1}/{2}]|Tot: {total:} |ETA: {eta:} '.format(
                epoch,
                batch_i,
                num_iters,
                phase=phase,
                total=bar.elapsed_td,
                eta=bar.eta_td)
            for l in avg_loss_stats:
                try:
                    avg_loss_stats[l].update(loss_stats[l].mean().item(),
                                             batch['input'].size(0))
                except:
                    print(
                        "\n>>BUG loss_stats base_traimer.py float instead of narray NC UPDATE: {} \n"
                        .format(loss_stats[l]))
                    pass
                Bar.suffix = Bar.suffix + '|{} {:.4f} '.format(
                    l, avg_loss_stats[l].avg)

            # multi-scale img_size display
            scale_idx = data_loader.dataset.batch_i_to_scale_i[batch_i]
            if data_loader.dataset.input_multi_scales is None:
                img_size = Input_WHs[scale_idx]
            else:
                img_size = data_loader.dataset.input_multi_scales[scale_idx]
            Bar.suffix = Bar.suffix + '|Img_size(wh) {:d}×{:d}'.format(
                img_size[0], img_size[1])

            if not opt.hide_data_time:
                Bar.suffix = Bar.suffix + '|Data {dt.val:.3f}s({dt.avg:.3f}s) ' \
                                          '|Net {bt.avg:.3f}s'.format(dt=data_time, bt=batch_time)
            if opt.print_iter > 0:
                if batch_i % opt.print_iter == 0:
                    print('{}/{}| {}'.format(opt.task, opt.exp_id, Bar.suffix))
            else:
                bar.next()

            if opt.test:
                self.save_result(output, batch, results)
            del output, loss, loss_stats, batch

        # randomly do multi-scaling for dataset every epoch
        data_loader.dataset.rand_scale()  # re-assign scale for each batch

        # shuffule the dataset every epoch
        data_loader.dataset.shuffle()  # re-assign file id for each idx

        bar.finish()
        ret = {k: v.avg for k, v in avg_loss_stats.items()}
        ret['time'] = bar.elapsed_td.total_seconds() / 60.0

        return ret, results
Exemple #24
0
    def train(self):
        # Single epoch training routine

        losses = AverageMeter()

        timer = {
            'data': 0,
            'forward': 0,
            'loss': 0,
            'backward': 0,
            'batch': 0,
        }

        self.generator.train()
        start = time.time()
        summary_string = ''
        bar = Bar(f'Epoch {self.epoch + 1}/{self.end_epoch}',
                  fill='#',
                  max=self.num_iters_per_epoch)

        for i in range(self.num_iters_per_epoch):
            # Dirty solution to reset an iterator
            target_3dpw = None
            if self.train_3dpw_iter:
                try:
                    target_3dpw = next(self.train_3dpw_iter)
                except StopIteration:
                    self.train_3dpw_iter = iter(self.valid_loader)
                    target_3dpw = next(self.train_3dpw_iter)
            move_dict_to_device(target_3dpw, self.device)
            timer['data'] = time.time() - start
            start = time.time()
            inp = target_3dpw['features']

            preds = self.generator(inp)  #dict = {’pose‘:}

            timer['forward'] = time.time() - start
            start = time.time()

            gen_loss = self.criterion(generator_outputs=preds,
                                      real_pose=target_3dpw['pose'],
                                      real_shape=target_3dpw['shape'],
                                      real_joints3D=target_3dpw['joints3D'])
            # =======>

            timer['loss'] = time.time() - start
            start = time.time()

            # <======= Backprop generator and discriminator
            self.gen_optimizer.zero_grad()
            gen_loss.backward()
            self.gen_optimizer.step()

            # <======= Log training info
            total_loss = gen_loss

            losses.update(total_loss.item(), inp.size(0))

            timer['backward'] = time.time() - start

            timer['batch'] = timer['data'] + timer['forward'] + timer[
                'loss'] + timer['backward']
            start = time.time()

            summary_string = f'({i + 1}/{self.num_iters_per_epoch}) | Total: {bar.elapsed_td} | ' \
                             f'ETA: {bar.eta_td:} | loss: {losses.avg:.4f}'

            for k, v in timer.items():
                summary_string += f' | {k}: {v:.2f}'

            bar.suffix = summary_string
            bar.next()

            if torch.isnan(total_loss):
                exit('Nan value in loss, exiting!...')
            # =======>
        bar.finish()
        logger.info(summary_string)
Exemple #25
0
def do_test(test_loader, model, cfg, visualize, writer_dict, final_output_dir):
    model.eval()

    batch_time = AverageMeter()
    data_time = AverageMeter()
    losses = AverageMeter()
    SSIM = AverageMeter()
    RMSE = AverageMeter()
    end = time.time()

    writer = writer_dict['writer']

    is_volume_visualizer = cfg.TEST.IS_VOLUME_VISUALIZER
    if is_volume_visualizer:
        from lib.analyze.utilis import VolumeVisualization
        from lib.analyze.utilis import visualize as volume_visualizer_function
        volume_visualizer = VolumeVisualization(
            data_range=[
                cfg.DATASET.NORMALIZATION.AFTER_MIN,
                cfg.DATASET.NORMALIZATION.AFTER_MAX
            ],
            before_min=cfg.DATASET.NORMALIZATION.BEFORE_MIN,
            before_max=cfg.DATASET.NORMALIZATION.BEFORE_MAX,
            after_min=cfg.DATASET.NORMALIZATION.AFTER_MIN,
            after_max=cfg.DATASET.NORMALIZATION.AFTER_MAX,
            threshold=200)

    for i, current_data in enumerate(test_loader):
        data_time.update(time.time() - end)

        model.set_dataset(current_data)
        with torch.no_grad():
            model.forward()
            #model.output = model.output[0]
            model.target = model.target[0]

        current_loss = model.criterion_pixel_wise_loss(model.output,
                                                       model.target)
        losses.update(current_loss.item())

        if is_volume_visualizer:
            current_rmse, current_ssim = volume_visualizer.update(
                current_data, model)
            RMSE.update(current_rmse)
            SSIM.update(current_ssim)

        batch_time.update(time.time() - end)
        end = time.time()

        if i % cfg.VAL.PRINT_FREQUENCY == 0:
            msg = 'Test: [{0}/{1}]\t' \
                  'Loss {losses.val:.5f} ({losses.avg:.5f})\t' \
                  'RMSE {RMSE.val:.5f}({RMSE.avg:.5f})\t' \
                  'SSIM {SSIM.val:.5f}({SSIM.avg:.5f})'.format(
                i, len(test_loader),
                losses=losses,
                RMSE=RMSE,
                SSIM=SSIM)

            logger.info(msg)

        model.output = [model.output]
        model.target = [model.target]
        visualize(model, writer_dict['test_global_steps'],
                  os.path.join(final_output_dir, "test"), 1,
                  cfg.DATASET.NORMALIZATION.BEFORE_MIN,
                  cfg.DATASET.NORMALIZATION.BEFORE_MAX,
                  cfg.DATASET.NORMALIZATION.AFTER_MIN,
                  cfg.DATASET.NORMALIZATION.AFTER_MAX)
        writer.add_scalar('test_loss', losses.val,
                          writer_dict['test_global_steps'])
        writer.add_scalar('RMSE', RMSE.val, writer_dict['test_global_steps'])
        writer.add_scalar('SSIM', SSIM.val, writer_dict['test_global_steps'])
        writer_dict['test_global_steps'] += 1

    # log the slice-wise ssim and rmse
    for current_data_path, current_rmse, current_ssim in zip(
            volume_visualizer.data_path, volume_visualizer.rmse,
            volume_visualizer.ssim):
        logger.info('{}\tSSIM:{}\tRMSE:{}'.format(
            os.path.basename(current_data_path), current_rmse, current_ssim))
    logger.info('* 2D Test: \t Average RMSE {}\t SSIM {}\n'.format(
        RMSE.avg, SSIM.avg))

    if is_volume_visualizer:
        rmse, ssim = volume_visualizer_function(
            volume_visualizer, model, writer_dict['test_global_steps'],
            os.path.join(final_output_dir, "volume"), 1)
        msg = '* 3D Test: \t RMSE {}\t SSIM {}'.format(rmse, ssim)
        logger.info(msg)

    return losses.val