Exemplo n.º 1
0
def test(params):
    """Test function."""
    args.mode = 'online_eval'
    dataloader = BtsDataLoader(args, 'online_eval')

    model = BtsModel(params=args)
    model = torch.nn.DataParallel(model)

    checkpoint = torch.load(args.checkpoint_path)
    model.load_state_dict(checkpoint['model'])
    model.eval()
    model.cuda()

    num_test_samples = get_num_lines(args.filenames_file)

    with open(args.filenames_file) as f:
        lines = f.readlines()

    print('now testing {} files with {}'.format(num_test_samples,
                                                args.checkpoint_path))

    pred_depths = []
    pred_8x8s = []
    pred_4x4s = []
    pred_2x2s = []
    pred_1x1s = []

    save_name = 'result_' + args.model_name

    if not os.path.exists(os.path.dirname(save_name)):
        try:
            os.mkdir(save_name)
            os.mkdir(save_name + '/raw')
            os.mkdir(save_name + '/cmap')
            os.mkdir(save_name + '/rgb')
            os.mkdir(save_name + '/gt')
            os.mkdir(save_name + '/nd')
        except OSError as e:
            if e.errno != errno.EEXIST:
                raise

    start_time = time.time()
    with torch.no_grad():
        for s, sample in enumerate(tqdm(dataloader.data)):
            image = Variable(sample['image'].cuda())
            focal = Variable(sample['focal'].cuda())
            depth_gt = Variable(sample['depth'].cuda())

            depth_gt = depth_gt.transpose(2, 3).transpose(1, 2)

            # Predict
            lpg8x8, lpg4x4, lpg2x2, reduc1x1, depth_est = model(image, focal)
            pred_depths.append(depth_est.cpu().numpy().squeeze())
            pred_8x8s.append(lpg8x8[0].cpu().numpy().squeeze())
            pred_4x4s.append(lpg4x4[0].cpu().numpy().squeeze())
            pred_2x2s.append(lpg2x2[0].cpu().numpy().squeeze())
            pred_1x1s.append(reduc1x1[0].cpu().numpy().squeeze())

            depth_est[:, :, 0, 0] = 0.
            nd_gt, diff_gt, invd_bmask = nd_model(depth_gt)
            nd_est, diff_est, _ = nd_model(depth_est)

            ### with nd paint
            scene_name = lines[s].split()[0].split('/')[0]
            filename_nd_png = save_name + '/nd/' + scene_name + '_' + lines[
                s].split()[0].split('/')[1].replace('.jpg', '.png')

            paint_multiple(image[0].cpu().detach(),
                           depth_est[0].cpu().detach(),
                           depth_gt[0].cpu().detach(),
                           None,
                           nd_est[0].cpu().detach(),
                           nd_gt[0].cpu().detach(),
                           None,
                           diff_est[0].cpu().detach(),
                           diff_gt[0].cpu().detach(),
                           images_per_row=3,
                           to_screen=False,
                           to_file=filename_nd_png)

    elapsed_time = time.time() - start_time
    print('Elapesed time: %s' % str(elapsed_time))
    print('Done.')

    print('Saving result pngs..')

    for s in tqdm(range(num_test_samples)):
        if args.dataset == 'kitti':
            date_drive = lines[s].split('/')[1]
            filename_pred_png = save_name + '/raw/' + date_drive + '_' + lines[
                s].split()[0].split('/')[-1].replace('.jpg', '.png')
            filename_cmap_png = save_name + '/cmap/' + date_drive + '_' + lines[
                s].split()[0].split('/')[-1].replace('.jpg', '.png')
            filename_image_png = save_name + '/rgb/' + date_drive + '_' + lines[
                s].split()[0].split('/')[-1]
        elif args.dataset == 'kitti_benchmark':
            filename_pred_png = save_name + '/raw/' + lines[s].split(
            )[0].split('/')[-1].replace('.jpg', '.png')
            filename_cmap_png = save_name + '/cmap/' + lines[s].split(
            )[0].split('/')[-1].replace('.jpg', '.png')
            filename_image_png = save_name + '/rgb/' + lines[s].split(
            )[0].split('/')[-1]
        else:  # nyu
            scene_name = lines[s].split()[0].split('/')[0]
            filename_pred_png = save_name + '/raw/' + scene_name + '_' + lines[
                s].split()[0].split('/')[1].replace('.jpg', '.png')
            filename_cmap_png = save_name + '/cmap/' + scene_name + '_' + lines[
                s].split()[0].split('/')[1].replace('.jpg', '.png')
            filename_gt_png = save_name + '/gt/' + scene_name + '_' + lines[
                s].split()[0].split('/')[1].replace('.jpg', '.png')
            filename_image_png = save_name + '/rgb/' + scene_name + '_' + lines[
                s].split()[0].split('/')[1]

        rgb_path = os.path.join(args.data_path, './' + lines[s].split()[0])
        image = cv2.imread(rgb_path)
        if args.dataset == 'nyu':
            gt_path = os.path.join(args.data_path, './' + lines[s].split()[1])
            gt = cv2.imread(gt_path, -1).astype(
                np.float32) / 1000.0  # Visualization purpose only
            gt[gt == 0] = np.amax(gt)

        pred_depth = pred_depths[s]
        pred_8x8 = pred_8x8s[s]
        pred_4x4 = pred_4x4s[s]
        pred_2x2 = pred_2x2s[s]
        pred_1x1 = pred_1x1s[s]

        if args.dataset == 'kitti' or args.dataset == 'kitti_benchmark':
            pred_depth_scaled = pred_depth * 256.0
        else:
            pred_depth_scaled = pred_depth * 1000.0

        #print(pred_depth_scaled.shape)
        pred_depth_scaled = pred_depth_scaled.astype(np.uint16)
        cv2.imwrite(filename_pred_png, pred_depth_scaled,
                    [cv2.IMWRITE_PNG_COMPRESSION, 0])

        if args.save_lpg:
            cv2.imwrite(filename_image_png, image[10:-1 - 9, 10:-1 - 9, :])
            if args.dataset == 'nyu':
                plt.imsave(filename_gt_png,
                           np.log10(gt[10:-1 - 9, 10:-1 - 9]),
                           cmap='Greys')
                pred_depth_cropped = pred_depth[10:-1 - 9, 10:-1 - 9]
                plt.imsave(filename_cmap_png,
                           np.log10(pred_depth_cropped),
                           cmap='Greys')
                pred_8x8_cropped = pred_8x8[10:-1 - 9, 10:-1 - 9]
                filename_lpg_cmap_png = filename_cmap_png.replace(
                    '.png', '_8x8.png')
                plt.imsave(filename_lpg_cmap_png,
                           np.log10(pred_8x8_cropped),
                           cmap='Greys')
                pred_4x4_cropped = pred_4x4[10:-1 - 9, 10:-1 - 9]
                filename_lpg_cmap_png = filename_cmap_png.replace(
                    '.png', '_4x4.png')
                plt.imsave(filename_lpg_cmap_png,
                           np.log10(pred_4x4_cropped),
                           cmap='Greys')
                pred_2x2_cropped = pred_2x2[10:-1 - 9, 10:-1 - 9]
                filename_lpg_cmap_png = filename_cmap_png.replace(
                    '.png', '_2x2.png')
                plt.imsave(filename_lpg_cmap_png,
                           np.log10(pred_2x2_cropped),
                           cmap='Greys')
                pred_1x1_cropped = pred_1x1[10:-1 - 9, 10:-1 - 9]
                filename_lpg_cmap_png = filename_cmap_png.replace(
                    '.png', '_1x1.png')
                plt.imsave(filename_lpg_cmap_png,
                           np.log10(pred_1x1_cropped),
                           cmap='Greys')
            else:
                plt.imsave(filename_cmap_png,
                           np.log10(pred_depth),
                           cmap='Greys')
                filename_lpg_cmap_png = filename_cmap_png.replace(
                    '.png', '_8x8.png')
                plt.imsave(filename_lpg_cmap_png,
                           np.log10(pred_8x8),
                           cmap='Greys')
                filename_lpg_cmap_png = filename_cmap_png.replace(
                    '.png', '_4x4.png')
                plt.imsave(filename_lpg_cmap_png,
                           np.log10(pred_4x4),
                           cmap='Greys')
                filename_lpg_cmap_png = filename_cmap_png.replace(
                    '.png', '_2x2.png')
                plt.imsave(filename_lpg_cmap_png,
                           np.log10(pred_2x2),
                           cmap='Greys')
                filename_lpg_cmap_png = filename_cmap_png.replace(
                    '.png', '_1x1.png')
                plt.imsave(filename_lpg_cmap_png,
                           np.log10(pred_1x1),
                           cmap='Greys')

    return
Exemplo n.º 2
0
def main_worker(gpu, ngpus_per_node, args):
    args.gpu = gpu

    if args.gpu is not None:
        print("Use GPU: {} for training".format(args.gpu))

    if args.distributed:
        if args.dist_url == "env://" and args.rank == -1:
            args.rank = int(os.environ["RANK"])
        if args.multiprocessing_distributed:
            args.rank = args.rank * ngpus_per_node + gpu
        dist.init_process_group(backend=args.dist_backend,
                                init_method=args.dist_url,
                                world_size=args.world_size,
                                rank=args.rank)

    # Create model
    model = BtsModel(args)
    nd_model = NormDiff(5e-4)
    model.train()
    model.decoder.apply(weights_init_xavier)
    set_misc(model)

    num_params = sum([np.prod(p.size()) for p in model.parameters()])
    print("Total number of parameters: {}".format(num_params))

    num_params_update = sum(
        [np.prod(p.shape) for p in model.parameters() if p.requires_grad])
    print("Total number of learning parameters: {}".format(num_params_update))

    if args.distributed:
        if args.gpu is not None:
            torch.cuda.set_device(args.gpu)

            model.cuda(args.gpu)
            nd_model.cuda(args.gpu)

            args.batch_size = int(args.batch_size / ngpus_per_node)
            model = torch.nn.parallel.DistributedDataParallel(
                model, device_ids=[args.gpu], find_unused_parameters=True)
            nd_model = torch.nn.parallel.DistributedDataParallel(
                nd_model, device_ids=[args.gpu], find_unused_parameters=True)
        else:
            model.cuda()
            model = torch.nn.parallel.DistributedDataParallel(
                model, find_unused_parameters=True)
    else:
        model = torch.nn.DataParallel(model)
        model.cuda()

        nd_model = torch.nn.DataParallel(nd_model)
        nd_model.cuda()

    if args.distributed:
        print("Model Initialized on GPU: {}".format(args.gpu))
    else:
        print("Model Initialized")

    global_step = 0
    best_eval_measures_lower_better = torch.zeros(6).cpu() + 1e3
    best_eval_measures_higher_better = torch.zeros(3).cpu()
    best_eval_steps = np.zeros(9, dtype=np.int32)

    # Training parameters
    optimizer = torch.optim.AdamW([{
        'params': model.module.encoder.parameters(),
        'weight_decay': args.weight_decay
    }, {
        'params': model.module.decoder.parameters(),
        'weight_decay': 0
    }],
                                  lr=args.learning_rate,
                                  eps=args.adam_eps)

    model_just_loaded = False
    if args.checkpoint_path != '':
        if os.path.isfile(args.checkpoint_path):
            print("Loading checkpoint '{}'".format(args.checkpoint_path))
            if args.gpu is None:
                checkpoint = torch.load(args.checkpoint_path)
            else:
                loc = 'cuda:{}'.format(args.gpu)
                checkpoint = torch.load(args.checkpoint_path, map_location=loc)
            global_step = checkpoint['global_step']
            model.load_state_dict(checkpoint['model'])
            optimizer.load_state_dict(checkpoint['optimizer'])
            try:
                best_eval_measures_higher_better = checkpoint[
                    'best_eval_measures_higher_better'].cpu()
                best_eval_measures_lower_better = checkpoint[
                    'best_eval_measures_lower_better'].cpu()
                best_eval_steps = checkpoint['best_eval_steps']
            except KeyError:
                print("Could not load values for online evaluation")

            print("Loaded checkpoint '{}' (global_step {})".format(
                args.checkpoint_path, checkpoint['global_step']))
        else:
            print("No checkpoint found at '{}'".format(args.checkpoint_path))
        model_just_loaded = True

    if args.retrain:
        global_step = 0

    cudnn.benchmark = True

    dataloader = BtsDataLoader(args, 'train')
    dataloader_eval = BtsDataLoader(args, 'online_eval')

    # Logging
    if not args.multiprocessing_distributed or (
            args.multiprocessing_distributed
            and args.rank % ngpus_per_node == 0):
        writer = SummaryWriter(args.log_directory + '/' + args.model_name +
                               '/summaries',
                               flush_secs=30)
        if args.do_online_eval:
            if args.eval_summary_directory != '':
                eval_summary_path = os.path.join(args.eval_summary_directory,
                                                 args.model_name)
            else:
                eval_summary_path = os.path.join(args.log_directory, 'eval')
            eval_summary_writer = SummaryWriter(eval_summary_path,
                                                flush_secs=30)

    silog_criterion = silog_loss(variance_focus=args.variance_focus)
    mse_criterion = nn.MSELoss()

    start_time = time.time()
    duration = 0

    num_log_images = args.batch_size
    end_learning_rate = args.end_learning_rate if args.end_learning_rate != -1 else 0.1 * args.learning_rate

    var_sum = [var.sum() for var in model.parameters() if var.requires_grad]
    var_cnt = len(var_sum)
    var_sum = np.sum(var_sum)

    print("Initial variables' sum: {:.3f}, avg: {:.3f}".format(
        var_sum, var_sum / var_cnt))

    steps_per_epoch = len(dataloader.data)
    num_total_steps = args.num_epochs * steps_per_epoch
    epoch = global_step // steps_per_epoch

    while epoch < args.num_epochs:
        if args.distributed:
            dataloader.train_sampler.set_epoch(epoch)

        for step, sample_batched in enumerate(dataloader.data):
            optimizer.zero_grad()
            before_op_time = time.time()

            image = torch.autograd.Variable(sample_batched['image'].cuda(
                args.gpu, non_blocking=True))
            focal = torch.autograd.Variable(sample_batched['focal'].cuda(
                args.gpu, non_blocking=True))
            depth_gt = torch.autograd.Variable(sample_batched['depth'].cuda(
                args.gpu, non_blocking=True))

            lpg8x8, lpg4x4, lpg2x2, reduc1x1, depth_est = model(image, focal)

            if args.dataset == 'nyu':
                mask = depth_gt > 0.1
            else:
                mask = depth_gt > 1.0

            valid_bmask_gt = mask & (depth_gt != 0)

            disp_gt = 1 / depth_gt
            disp_gt[~valid_bmask_gt] = 0.

            disp_est = torch.zeros_like(depth_est, device=depth_est.device)
            disp_est[depth_est > 0] = (1 / depth_est)[depth_est > 0]

            loss_silog = silog_criterion.forward(depth_est, depth_gt,
                                                 mask.to(torch.bool))

            nd_gt, diff_gt, invd_bmask = nd_model(disp_gt)
            nd_est, diff_est, _ = nd_model(disp_est)

            current_coef = (1 - .6) * (1 -
                                       global_step / num_total_steps)**.9 + .5

            loss_nd = current_coef * 10 * mse_criterion(
                nd_est[~invd_bmask.expand(-1, 3, -1, -1)],
                nd_gt[~invd_bmask.expand(-1, 3, -1, -1)])
            loss_diff = current_coef * 1e3 * mse_criterion(
                diff_gt[~invd_bmask.expand(-1, 2, -1, -1)],
                diff_est[~invd_bmask.expand(-1, 2, -1, -1)])

            loss = loss_silog + loss_nd + loss_diff
            loss.backward()

            if global_step % 200 == 0:
                paint_multiple(
                    image[0].cpu().detach(),
                    depth_est[0].cpu().detach(),
                    depth_gt[0].cpu().detach(),
                    None,
                    nd_est[0].cpu().detach(),
                    nd_gt[0].cpu().detach(),
                    None,
                    diff_est[0].cpu().detach(),
                    diff_gt[0].cpu().detach(),
                    images_per_row=3,
                    to_screen=False,
                    to_file=(args.log_directory + '/' + args.model_name +
                             '/images_save/' + 'img_%d.png' % global_step))
            for param_group in optimizer.param_groups:
                current_lr = (args.learning_rate - end_learning_rate) * (
                    1 - global_step / num_total_steps)**0.9 + end_learning_rate
                param_group['lr'] = current_lr

            optimizer.step()

            if not args.multiprocessing_distributed or (
                    args.multiprocessing_distributed
                    and args.rank % ngpus_per_node == 0):
                print(
                    '[epoch][s/s_per_e/gs]: [{}][{}/{}/{}], lr: {:.12f}, loss_silog: {:.8f}, loss_nd: {:.8f}, loss_diff: {:.8f}, cur_coef: {:.8f}'
                    .format(epoch, step, steps_per_epoch, global_step,
                            current_lr, loss_silog, loss_nd, loss_diff,
                            current_coef))
                if np.isnan(loss.cpu().item()):
                    print('NaN in loss occurred. Aborting training.')
                    return -1

            duration += time.time() - before_op_time
            if global_step and global_step % args.log_freq == 0 and not model_just_loaded:
                var_sum = [
                    var.sum() for var in model.parameters()
                    if var.requires_grad
                ]
                var_cnt = len(var_sum)
                var_sum = np.sum(var_sum)
                examples_per_sec = args.batch_size / duration * args.log_freq
                duration = 0
                time_sofar = (time.time() - start_time) / 3600
                training_time_left = (num_total_steps / global_step -
                                      1.0) * time_sofar
                if not args.multiprocessing_distributed or (
                        args.multiprocessing_distributed
                        and args.rank % ngpus_per_node == 0):
                    print("{}".format(args.model_name))
                print_string = 'GPU: {} | examples/s: {:4.2f} | loss_silog: {:.5f} | loss_nd: {:.5f} | loss_diff: {:.5f} | var sum: {:.3f} avg: {:.3f} | time elapsed: {:.2f}h | time left: {:.2f}h'
                print(
                    print_string.format(args.gpu, examples_per_sec, loss_silog,
                                        loss_nd, loss_diff, var_sum.item(),
                                        var_sum.item() / var_cnt, time_sofar,
                                        training_time_left))

                if not args.multiprocessing_distributed or (
                        args.multiprocessing_distributed
                        and args.rank % ngpus_per_node == 0):
                    writer.add_scalar('silog_loss', loss_silog, global_step)
                    writer.add_scalar('nd_loss', loss_nd, global_step)
                    writer.add_scalar('diff_loss', loss_diff, global_step)
                    writer.add_scalar('learning_rate', current_lr, global_step)
                    writer.add_scalar('var average',
                                      var_sum.item() / var_cnt, global_step)
                    depth_gt = torch.where(depth_gt < 1e-3, depth_gt * 0 + 1e3,
                                           depth_gt)
                    for i in range(num_log_images):
                        writer.add_image(
                            'depth_gt/image/{}'.format(i),
                            normalize_result(1 / depth_gt[i, :, :, :].data),
                            global_step)
                        writer.add_image(
                            'depth_est/image/{}'.format(i),
                            normalize_result(1 / depth_est[i, :, :, :].data),
                            global_step)
                        writer.add_image(
                            'reduc1x1/image/{}'.format(i),
                            normalize_result(1 / reduc1x1[i, :, :, :].data),
                            global_step)
                        writer.add_image(
                            'lpg2x2/image/{}'.format(i),
                            normalize_result(1 / lpg2x2[i, :, :, :].data),
                            global_step)
                        writer.add_image(
                            'lpg4x4/image/{}'.format(i),
                            normalize_result(1 / lpg4x4[i, :, :, :].data),
                            global_step)
                        writer.add_image(
                            'lpg8x8/image/{}'.format(i),
                            normalize_result(1 / lpg8x8[i, :, :, :].data),
                            global_step)
                        writer.add_image('image/image/{}'.format(i),
                                         inv_normalize(image[i, :, :, :]).data,
                                         global_step)
                    writer.flush()

            if not args.do_online_eval and global_step and global_step % args.save_freq == 0:
                if not args.multiprocessing_distributed or (
                        args.multiprocessing_distributed
                        and args.rank % ngpus_per_node == 0):
                    checkpoint = {
                        'global_step': global_step,
                        'model': model.state_dict(),
                        'optimizer': optimizer.state_dict()
                    }
                    torch.save(
                        checkpoint, args.log_directory + '/' +
                        args.model_name + '/model-{}'.format(global_step))

            if args.do_online_eval and global_step and global_step % args.eval_freq == 0 and not model_just_loaded:
                time.sleep(0.1)
                model.eval()
                eval_measures = online_eval(model, dataloader_eval, gpu,
                                            ngpus_per_node)
                if eval_measures is not None:
                    for i in range(9):
                        eval_summary_writer.add_scalar(eval_metrics[i],
                                                       eval_measures[i].cpu(),
                                                       int(global_step))
                        measure = eval_measures[i]
                        is_best = False
                        if i < 6 and measure < best_eval_measures_lower_better[
                                i]:
                            old_best = best_eval_measures_lower_better[i].item(
                            )
                            best_eval_measures_lower_better[i] = measure.item()
                            is_best = True
                        elif i >= 6 and measure > best_eval_measures_higher_better[
                                i - 6]:
                            old_best = best_eval_measures_higher_better[
                                i - 6].item()
                            best_eval_measures_higher_better[
                                i - 6] = measure.item()
                            is_best = True
                        if is_best:
                            old_best_step = best_eval_steps[i]
                            old_best_name = '/model-{}-best_{}_{:.5f}'.format(
                                old_best_step, eval_metrics[i], old_best)
                            model_path = args.log_directory + '/' + args.model_name + old_best_name
                            if os.path.exists(model_path):
                                command = 'rm {}'.format(model_path)
                                os.system(command)
                            best_eval_steps[i] = global_step
                            model_save_name = '/model-{}-best_{}_{:.5f}'.format(
                                global_step, eval_metrics[i], measure)
                            print('New best for {}. Saving model: {}'.format(
                                eval_metrics[i], model_save_name))
                            checkpoint = {
                                'global_step': global_step,
                                'model': model.state_dict(),
                                'optimizer': optimizer.state_dict(),
                                'best_eval_measures_higher_better':
                                best_eval_measures_higher_better,
                                'best_eval_measures_lower_better':
                                best_eval_measures_lower_better,
                                'best_eval_steps': best_eval_steps
                            }
                            torch.save(
                                checkpoint, args.log_directory + '/' +
                                args.model_name + model_save_name)
                    eval_summary_writer.flush()
                model.train()
                block_print()
                set_misc(model)
                enable_print()

            model_just_loaded = False
            global_step += 1

        epoch += 1
Exemplo n.º 3
0
def main_worker(gpu, ngpus_per_node, args):
    args.gpu = gpu

    if args.gpu is not None:
        print("Use GPU: {} for training".format(args.gpu))

    # Create model
    nd_model = NormDiff(2e-4).cuda()
    cs_model = ChannelWiseSoftmax().cuda()

    global_step = 0

    cudnn.benchmark = True

    dataloader = BtsDataLoader(args, 'online_eval')

    steps_per_epoch = len(dataloader.data)
    epoch = global_step // steps_per_epoch

    while epoch < args.num_epochs:

        for step, sample_batched in enumerate(dataloader.data):

            image = torch.autograd.Variable(sample_batched['image'].cuda(
                args.gpu, non_blocking=True))
            focal = torch.autograd.Variable(sample_batched['focal'].cuda(
                args.gpu, non_blocking=True))
            depth_gt = torch.autograd.Variable(sample_batched['depth'].cuda(
                args.gpu, non_blocking=True))

            depth_gt = depth_gt.transpose(3, 2).transpose(2, 1)

            print('depth_gt shape =', depth_gt.shape)

            #lpg8x8, lpg4x4, lpg2x2, reduc1x1, depth_est = model(image, focal)

            if args.dataset == 'nyu':
                mask = depth_gt > 0.1
            else:
                mask = depth_gt > 1.0

            real_bmask = mask & (depth_gt != 0)
            disp_gt = 1 / depth_gt
            disp_gt[~real_bmask] = 0.

            #loss_silog = silog_criterion.forward(depth_est, depth_gt, mask.to(torch.bool))

            nd_gt, diff_gt, invd_bmask = nd_model(disp_gt)
            #nd_gt = F.avg_pool2d(nd_gt, kernel_size=5, stride=1, padding=2)
            #nd_gt = F.avg_pool2d(nd_gt, kernel_size=5, stride=1, padding=2)
            #nd_est, diff_est, _ = nd_model(depth_est)

            paint_multiple(image[0].cpu().detach(),
                           depth_gt[0].cpu().detach(),
                           nd_gt[0].cpu().detach(),
                           nd_gt[0, 0:2].cpu().detach(),
                           nd_gt[0, 1:3].cpu().detach(),
                           torch.cat((nd_gt[0, 0:1].cpu().detach(),
                                      nd_gt[0, 2:3].cpu().detach()),
                                     dim=0),
                           images_per_row=2,
                           to_screen=True)

            diff_xy_len = (diff_gt[:, 0:1, :, :].cpu().detach()**2 +
                           diff_gt[:, 1:2, :, :].cpu().detach()**2).sqrt()
            diff_xy_len[invd_bmask[:, 0:1]] = 0.
            mean_val = diff_xy_len[~invd_bmask[:, 0:1]].mean()

            diff_xy_len_list = diff_xy_len.reshape(-1)

            N, _, H, W = diff_xy_len.shape

            print('max = %.8f' % diff_xy_len_list.max())
            print('mean = %.8f' % mean_val)

            #mean_plane = mean_val.reshape(1, 1, 1, 1).expand(N, 1, H, W)

            nd_cls = cs_model(nd_gt, dim=1, scaling=10)

            N, _, H, W = nd_cls.shape
            cls_map = torch.zeros(N, 3, H, W, device=nd_cls.device)

            up_thresh = .98

            # all classes tegether
            cls_map[:, 0:1][nd_cls[:, 0:1] > up_thresh] = 1.
            cls_map[:, 1:2][nd_cls[:, 0:1] > up_thresh] = 1.

            cls_map[:, 0:1][nd_cls[:, 1:2] > up_thresh] = 1.
            cls_map[:, 2:3][nd_cls[:, 1:2] > up_thresh] = 1.

            cls_map[:, 1:2][nd_cls[:, 2:3] > up_thresh] = 1.
            cls_map[:, 2:3][nd_cls[:, 2:3] > up_thresh] = 1.

            cls_map[:, 0:1][nd_cls[:, 3:4] > up_thresh] = 1.

            cls_map[:, 1:2][nd_cls[:, 4:5] > up_thresh] = 1.

            # different planes
            cls_l = torch.zeros(N, 3, H, W, device=nd_cls.device)
            cls_r = cls_l.clone()
            cls_d = cls_l.clone()
            cls_u = cls_l.clone()
            cls_b = cls_l.clone()

            cls_r[:,
                  0:1][nd_cls[:,
                              0:1] > up_thresh] = 1.  # channel 0 is for 'sure'
            cls_r[:, 2:3][nd_cls[:, 0:1] < 1 -
                          up_thresh] = 1.  # channel 2 is for 'surely not'
            # channel 1 is for 'not sure'

            cls_l[:, 0:1][nd_cls[:, 1:2] > up_thresh] = 1.
            cls_l[:, 2:3][nd_cls[:, 1:2] < 1 - up_thresh] = 1.

            cls_u[:, 0:1][nd_cls[:, 2:3] > up_thresh] = 1.  # channel 1
            cls_u[:, 2:3][nd_cls[:, 2:3] < 1 - up_thresh] = 1.  # channel 1

            cls_d[:, 0:1][nd_cls[:, 3:4] > up_thresh] = 1.  # channel 1
            cls_d[:, 2:3][nd_cls[:, 3:4] < 1 - up_thresh] = 1.  # channel 1

            cls_b[:, 0:1][nd_cls[:, 4:5] > up_thresh] = 1.  # channel 1
            cls_b[:, 2:3][nd_cls[:, 4:5] < 1 - up_thresh] = 1.  # channel 1

            paint_multiple(nd_gt[0],
                           cls_map[0],
                           cls_l[0],
                           cls_r[0],
                           depth_gt[0],
                           cls_u[0],
                           cls_d[0],
                           cls_b[0],
                           images_per_row=4)

            #img2pcd(nd_gt[0,0:1])
            #img2pcd(nd_gt[0,1:2])
            #img2pcd(nd_gt[0,2:3])
            #img2pcd(nd_gt[0,3:4])
            #img2pcd(nd_gt[0,4:5])
            #img2pcd(nd_gt[0,5:6])

            #img2pcd(diff_xy_len[0], mean_plane[0])

            #plt.scatter(range(len(diff_xy_len_list)), diff_xy_len_list)
            #plt.scatter(range(len(diff_xy_len_list)), mean_plane.reshape(-1))
            #plt.show()
            #plt.clf()

            global_step += 1

        epoch += 1