Exemple #1
0
def main_worker(gpu, ngpus_per_node, args):
    args.gpu = gpu

    if args.gpu is not None:
        print("Use GPU: {} for training".format(args.gpu))

    if args.distributed:
        if args.dist_url == "env://" and args.rank == -1:
            args.rank = int(os.environ["RANK"])
        if args.multiprocessing_distributed:
            args.rank = args.rank * ngpus_per_node + gpu
        dist.init_process_group(backend=args.dist_backend,
                                init_method=args.dist_url,
                                world_size=args.world_size,
                                rank=args.rank)

    # Create model
    model = BtsModel(args)
    model.train()
    model.decoder.apply(weights_init_xavier)
    set_misc(model)

    num_params = sum([np.prod(p.size()) for p in model.parameters()])
    print("Total number of parameters: {}".format(num_params))

    num_params_update = sum(
        [np.prod(p.shape) for p in model.parameters() if p.requires_grad])
    print("Total number of learning parameters: {}".format(num_params_update))

    if args.distributed:
        if args.gpu is not None:
            torch.cuda.set_device(args.gpu)
            model.cuda(args.gpu)
            args.batch_size = int(args.batch_size / ngpus_per_node)
            model = torch.nn.parallel.DistributedDataParallel(
                model, device_ids=[args.gpu], find_unused_parameters=True)
        else:
            model.cuda()
            model = torch.nn.parallel.DistributedDataParallel(
                model, find_unused_parameters=True)
    else:
        model = torch.nn.DataParallel(model)
        model.cuda()

    if args.distributed:
        print("Model Initialized on GPU: {}".format(args.gpu))
    else:
        print("Model Initialized")

    global_step = 0
    best_eval_measures_lower_better = torch.zeros(6).cpu() + 1e3
    best_eval_measures_higher_better = torch.zeros(3).cpu()
    best_eval_steps = np.zeros(9, dtype=np.int32)

    # Training parameters
    optimizer = torch.optim.AdamW([{
        'params': model.module.encoder.parameters(),
        'weight_decay': args.weight_decay
    }, {
        'params': model.module.decoder.parameters(),
        'weight_decay': 0
    }],
                                  lr=args.learning_rate,
                                  eps=args.adam_eps)

    model_just_loaded = False
    if args.checkpoint_path != '':
        if os.path.isfile(args.checkpoint_path):
            print("Loading checkpoint '{}'".format(args.checkpoint_path))
            if args.gpu is None:
                checkpoint = torch.load(args.checkpoint_path)
            else:
                loc = 'cuda:{}'.format(args.gpu)
                checkpoint = torch.load(args.checkpoint_path, map_location=loc)
            global_step = checkpoint['global_step']
            model.load_state_dict(checkpoint['model'])
            optimizer.load_state_dict(checkpoint['optimizer'])
            try:
                best_eval_measures_higher_better = checkpoint[
                    'best_eval_measures_higher_better'].cpu()
                best_eval_measures_lower_better = checkpoint[
                    'best_eval_measures_lower_better'].cpu()
                best_eval_steps = checkpoint['best_eval_steps']
            except KeyError:
                print("Could not load values for online evaluation")

            print("Loaded checkpoint '{}' (global_step {})".format(
                args.checkpoint_path, checkpoint['global_step']))
        else:
            print("No checkpoint found at '{}'".format(args.checkpoint_path))
        model_just_loaded = True

    if args.retrain:
        global_step = 0

    cudnn.benchmark = True

    dataloader = BtsDataLoader(args)
    dataloader_eval = BtsDataLoader(args)

    # Logging
    if not args.multiprocessing_distributed or (
            args.multiprocessing_distributed
            and args.rank % ngpus_per_node == 0):
        writer = SummaryWriter(os.path.join(args.log_directory,
                                            args.model_name, 'summaries'),
                               flush_secs=30)
        if args.do_online_eval:
            if args.eval_summary_directory != '':
                eval_summary_path = os.path.join(args.eval_summary_directory,
                                                 args.model_name)
            else:
                eval_summary_path = os.path.join(args.log_directory, 'eval')
            eval_summary_writer = SummaryWriter(eval_summary_path,
                                                flush_secs=30)

    silog_criterion = silog_loss(variance_focus=args.variance_focus)

    start_time = time.time()
    duration = 0

    num_log_images = args.batch_size
    end_learning_rate = args.end_learning_rate if args.end_learning_rate != -1 else 0.1 * args.learning_rate

    var_sum = [var.sum() for var in model.parameters() if var.requires_grad]
    var_cnt = len(var_sum)
    var_sum = np.sum(var_sum)

    print("Initial variables' sum: {:.3f}, avg: {:.3f}".format(
        var_sum, var_sum / var_cnt))

    steps_per_epoch = len(dataloader.data)
    num_total_steps = args.num_epochs * steps_per_epoch
    epoch = global_step // steps_per_epoch

    while epoch < args.num_epochs:
        if args.distributed:
            dataloader.train_sampler.set_epoch(epoch)

        for step, sample_batched in enumerate(dataloader.data):
            optimizer.zero_grad()
            before_op_time = time.time()

            mask = torch.autograd.Variable(sample_batched['mask'].cuda(
                args.gpu, non_blocking=True))
            image = torch.autograd.Variable(sample_batched['image'].cuda(
                args.gpu, non_blocking=True))
            focal = torch.autograd.Variable(sample_batched['focal'].cuda(
                args.gpu, non_blocking=True))
            depth_gt = torch.autograd.Variable(sample_batched['depth'].cuda(
                args.gpu, non_blocking=True))

            lpg8x8, lpg4x4, lpg2x2, reduc1x1, depth_est = model(image, focal)

            if args.dataset == 'nyu':
                mask = depth_gt > 0.1
            elif args.dataset == 'kitti':
                mask = depth_gt > 1.0

            loss = silog_criterion.forward(depth_est, depth_gt,
                                           mask.to('cuda:0'))
            loss.backward()
            for param_group in optimizer.param_groups:
                current_lr = (args.learning_rate - end_learning_rate) * (
                    1 - global_step / num_total_steps)**0.9 + end_learning_rate
                param_group['lr'] = current_lr

            optimizer.step()

            if not args.multiprocessing_distributed or (
                    args.multiprocessing_distributed
                    and args.rank % ngpus_per_node == 0):
                print(
                    '[epoch][s/s_per_e/gs]: [{}][{}/{}/{}], lr: {:.12f}, loss: {:.12f}'
                    .format(epoch, step, steps_per_epoch, global_step,
                            current_lr, loss))
                if np.isnan(loss.cpu().item()):
                    print('NaN in loss occurred. Aborting training.')
                    return -1

            duration += time.time() - before_op_time
            if global_step and global_step % args.log_freq == 0 and not model_just_loaded:
                var_sum = [
                    var.sum() for var in model.parameters()
                    if var.requires_grad
                ]
                var_cnt = len(var_sum)
                var_sum = np.sum(var_sum)
                examples_per_sec = args.batch_size / duration * args.log_freq
                duration = 0
                time_sofar = (time.time() - start_time) / 3600
                training_time_left = (num_total_steps / global_step -
                                      1.0) * time_sofar
                if not args.multiprocessing_distributed or (
                        args.multiprocessing_distributed
                        and args.rank % ngpus_per_node == 0):
                    print("{}".format(args.model_name))
                print_string = 'GPU: {} | examples/s: {:4.2f} | loss: {:.5f} | var sum: {:.3f} avg: {:.3f} | time elapsed: {:.2f}h | time left: {:.2f}h'
                print(
                    print_string.format(args.gpu, examples_per_sec, loss,
                                        var_sum.item(),
                                        var_sum.item() / var_cnt, time_sofar,
                                        training_time_left))

                if not args.multiprocessing_distributed or (
                        args.multiprocessing_distributed
                        and args.rank % ngpus_per_node == 0):
                    writer.add_scalar('silog_loss', loss, global_step)
                    writer.add_scalar('learning_rate', current_lr, global_step)
                    writer.add_scalar('var average',
                                      var_sum.item() / var_cnt, global_step)
                    depth_gt = torch.where(depth_gt < 1e-3, depth_gt * 0 + 1e3,
                                           depth_gt)
                    for i in range(num_log_images):
                        writer.add_image(
                            'depth_gt/image/{}'.format(i),
                            normalize_result(1 / depth_gt[i, :, :, :].data),
                            global_step)
                        writer.add_image(
                            'depth_est/image/{}'.format(i),
                            normalize_result(1 / depth_est[i, :, :, :].data),
                            global_step)
                        writer.add_image(
                            'reduc1x1/image/{}'.format(i),
                            normalize_result(1 / reduc1x1[i, :, :, :].data),
                            global_step)
                        writer.add_image(
                            'lpg2x2/image/{}'.format(i),
                            normalize_result(1 / lpg2x2[i, :, :, :].data),
                            global_step)
                        writer.add_image(
                            'lpg4x4/image/{}'.format(i),
                            normalize_result(1 / lpg4x4[i, :, :, :].data),
                            global_step)
                        writer.add_image(
                            'lpg8x8/image/{}'.format(i),
                            normalize_result(1 / lpg8x8[i, :, :, :].data),
                            global_step)
                        writer.add_image('image/image/{}'.format(i),
                                         inv_normalize(image[i, :, :, :]).data,
                                         global_step)
                    writer.flush()

            if not args.do_online_eval and global_step and global_step % args.save_freq == 0:
                if not args.multiprocessing_distributed or (
                        args.multiprocessing_distributed
                        and args.rank % ngpus_per_node == 0):
                    checkpoint = {
                        'global_step': global_step,
                        'model': model.state_dict(),
                        'optimizer': optimizer.state_dict()
                    }
                    torch.save(
                        checkpoint, args.log_directory + '/' +
                        args.model_name + '/model-{}'.format(global_step))

            if args.do_online_eval and global_step and global_step % args.eval_freq == 0 and not model_just_loaded:
                time.sleep(0.1)
                model.eval()
                eval_measures = online_eval(model, dataloader_eval, gpu,
                                            ngpus_per_node)
                if eval_measures is not None:
                    for i in range(9):
                        eval_summary_writer.add_scalar(eval_metrics[i],
                                                       eval_measures[i].cpu(),
                                                       int(global_step))
                        measure = eval_measures[i]
                        is_best = False
                        if i < 6 and measure < best_eval_measures_lower_better[
                                i]:
                            old_best = best_eval_measures_lower_better[i].item(
                            )
                            best_eval_measures_lower_better[i] = measure.item()
                            is_best = True
                        elif i >= 6 and measure > best_eval_measures_higher_better[
                                i - 6]:
                            old_best = best_eval_measures_higher_better[
                                i - 6].item()
                            best_eval_measures_higher_better[
                                i - 6] = measure.item()
                            is_best = True
                        if is_best:
                            old_best_step = best_eval_steps[i]
                            old_best_name = '/model-{}-best_{}_{:.5f}'.format(
                                old_best_step, eval_metrics[i], old_best)
                            model_path = args.log_directory + '/' + args.model_name + old_best_name
                            if os.path.exists(model_path):
                                command = 'rm {}'.format(model_path)
                                os.system(command)
                            best_eval_steps[i] = global_step
                            model_save_name = '/model-{}-best_{}_{:.5f}'.format(
                                global_step, eval_metrics[i], measure)
                            print('New best for {}. Saving model: {}'.format(
                                eval_metrics[i], model_save_name))
                            checkpoint = {
                                'global_step': global_step,
                                'model': model.state_dict(),
                                'optimizer': optimizer.state_dict(),
                                'best_eval_measures_higher_better':
                                best_eval_measures_higher_better,
                                'best_eval_measures_lower_better':
                                best_eval_measures_lower_better,
                                'best_eval_steps': best_eval_steps
                            }
                            torch.save(
                                checkpoint, args.log_directory + '/' +
                                args.model_name + model_save_name)
                    eval_summary_writer.flush()
                model.train()
                block_print()
                set_misc(model)
                enable_print()
            model_just_loaded = False
            global_step += 1

        epoch += 1
Exemple #2
0
def test_images(params):
    """Test function."""

    model = BtsModel(params=args)
    model = torch.nn.DataParallel(model)

    checkpoint = torch.load(args.checkpoint_path)
    model.load_state_dict(checkpoint['model'])
    model.eval()
    model.cuda()

    # apply transformations
    loader_transforms = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ])

    pred_depths = []

    save_name = args.save_name
    if not os.path.exists(save_name + 'raw'):
        os.mkdir(save_name + 'raw')

    with torch.no_grad():
        gt_file = open(args.media_path, "r")
        num_test_samples = len(gt_file.readlines())
        print(num_test_samples)
        max_distance_list = [10, 25, 50, 100, 150, 200]
        for max_dist in max_distance_list:
            globals()['silog_%s' % max_dist] = np.zeros(
                num_test_samples, np.float32)
            globals()['log10_%s' % max_dist] = np.zeros(
                num_test_samples, np.float32)
            globals()['rms_%s' % max_dist] = np.zeros(num_test_samples,
                                                      np.float32)
            globals()['log_rms_%s' % max_dist] = np.zeros(
                num_test_samples, np.float32)
            globals()['abs_rel_%s' % max_dist] = np.zeros(
                num_test_samples, np.float32)
            globals()['sq_rel_%s' % max_dist] = np.zeros(
                num_test_samples, np.float32)
            globals()['d1_%s' % max_dist] = np.zeros(num_test_samples,
                                                     np.float32)
            globals()['d2_%s' % max_dist] = np.zeros(num_test_samples,
                                                     np.float32)
            globals()['d3_%s' % max_dist] = np.zeros(num_test_samples,
                                                     np.float32)

        i = 0
        gt_file = open(args.media_path, "r")
        for x in gt_file:
            output = x.split()
            file_sub_path = output[0]
            focal = output[2]
            print(file_sub_path, focal)
            media_path = os.path.join("../train_val_split/val/rgb/",
                                      file_sub_path)

            save_dir = os.path.join(args.save_name, "raw",
                                    os.path.dirname(file_sub_path))
            if not os.path.exists(save_dir):
                os.makedirs(save_dir)

            image = cv2.imread(media_path)
            height, width, _ = image.shape
            #check divisibility be 32
            adjusted_height = lambda height: 32 * (math.ceil(
                height / 32)) if height % 32 != 0 else height
            adjusted_width = lambda height: 32 * (math.ceil(
                width / 32)) if width % 32 != 0 else width
            image_original = cv2.resize(
                image, (adjusted_width(width), adjusted_height(height)),
                interpolation=cv2.INTER_LANCZOS4)

            image = loader_transforms(image_original).float().cuda()
            image = image.unsqueeze(0)
            _, _, _, _, depth_est = model(image, focal)

            depth_est = depth_est.cpu().numpy().squeeze()
            depth_est_scaled = cv2.resize(depth_est, (1920, 1200),
                                          interpolation=cv2.INTER_LANCZOS4)

            filename_pred_png = os.path.join(
                save_dir,
                os.path.splitext(os.path.basename(media_path))[0] + ".png")
            pred_depth_scaled = depth_est_scaled * 256.0

            pred_depth_scaled = pred_depth_scaled.astype(np.uint16)
            cv2.imwrite(filename_pred_png, pred_depth_scaled,
                        [cv2.IMWRITE_PNG_COMPRESSION, 0])

            gt_depth_path = os.path.join("../train_val_split/val/depth/",
                                         file_sub_path)
            depth = cv2.imread(gt_depth_path, -1)

            if depth is None:
                print('Missing: %s ' % gt_depth_path)
                missing_ids.add(i)
                continue
            gt_depth = depth.astype(np.float32) / 256.0
            gt_depth_copy = np.copy(gt_depth)
            gt_depth[gt_depth > 0] = 1  #mask

            pred_depth_maskd = np.where(gt_depth, depth_est_scaled, 0)

            for max_dist in max_distance_list:
                pred_depth_maskd[pred_depth_maskd <
                                 args.min_depth_eval] = args.min_depth_eval
                pred_depth_maskd[pred_depth_maskd > max_dist] = max_dist
                pred_depth_maskd[np.isinf(pred_depth_maskd)] = max_dist

                gt_depth_copy[np.isinf(gt_depth_copy)] = 0
                gt_depth_copy[np.isnan(gt_depth_copy)] = 0

                valid_mask = np.logical_and(
                    gt_depth_copy > args.min_depth_eval,
                    gt_depth_copy < max_dist)
                if max_dist == 10:
                    silog_10[i], log10_10[i], abs_rel_10[i], sq_rel_10[
                        i], rms_10[i], log_rms_10[i], d1_10[i], d2_10[
                            i], d3_10[i] = compute_errors(
                                gt_depth_copy[valid_mask],
                                pred_depth_maskd[valid_mask])
                elif max_dist == 25:
                    silog_25[i], log10_25[i], abs_rel_25[i], sq_rel_25[
                        i], rms_25[i], log_rms_25[i], d1_25[i], d2_25[
                            i], d3_25[i] = compute_errors(
                                gt_depth_copy[valid_mask],
                                pred_depth_maskd[valid_mask])
                elif max_dist == 50:
                    silog_50[i], log10_50[i], abs_rel_50[i], sq_rel_50[
                        i], rms_50[i], log_rms_50[i], d1_50[i], d2_50[
                            i], d3_50[i] = compute_errors(
                                gt_depth_copy[valid_mask],
                                pred_depth_maskd[valid_mask])
                elif max_dist == 100:
                    silog_100[i], log10_100[i], abs_rel_100[i], sq_rel_100[
                        i], rms_100[i], log_rms_100[i], d1_100[i], d2_100[
                            i], d3_100[i] = compute_errors(
                                gt_depth_copy[valid_mask],
                                pred_depth_maskd[valid_mask])
                elif max_dist == 150:
                    silog_150[i], log10_150[i], abs_rel_150[i], sq_rel_150[
                        i], rms_150[i], log_rms_150[i], d1_150[i], d2_150[
                            i], d3_150[i] = compute_errors(
                                gt_depth_copy[valid_mask],
                                pred_depth_maskd[valid_mask])
                elif max_dist == 200:
                    silog_200[i], log10_200[i], abs_rel_200[i], sq_rel_200[
                        i], rms_200[i], log_rms_200[i], d1_200[i], d2_200[
                            i], d3_200[i] = compute_errors(
                                gt_depth_copy[valid_mask],
                                pred_depth_maskd[valid_mask])
                    print(silog_200[i], log10_200[i], abs_rel_200[i],
                          sq_rel_200[i], rms_200[i], log_rms_200[i], d1_200[i],
                          d2_200[i], d3_200[i])
            i += 1

        eval_dump_dist = {}
        eval_dump_dist[max_distance_list[0]] = {
            'd1': float(d1_10.mean()),
            'd2': float(d2_10.mean()),
            'd3': float(d3_10.mean()),
            'abs_rel': float(abs_rel_10.mean()),
            'sq_rel': float(sq_rel_10.mean()),
            'rms': float(rms_10.mean()),
            'log_rms': float(log_rms_10.mean()),
            'silog': float(silog_10.mean()),
            'log10': float(log10_10.mean())
        }
        eval_dump_dist[max_distance_list[1]] = {
            'd1': float(d1_25.mean()),
            'd2': float(d2_25.mean()),
            'd3': float(d3_25.mean()),
            'abs_rel': float(abs_rel_25.mean()),
            'sq_rel': float(sq_rel_25.mean()),
            'rms': float(rms_25.mean()),
            'log_rms': float(log_rms_25.mean()),
            'silog': float(silog_25.mean()),
            'log10': float(log10_25.mean())
        }
        eval_dump_dist[max_distance_list[2]] = {
            'd1': float(d1_50.mean()),
            'd2': float(d2_50.mean()),
            'd3': float(d3_50.mean()),
            'abs_rel': float(abs_rel_50.mean()),
            'sq_rel': float(sq_rel_50.mean()),
            'rms': float(rms_50.mean()),
            'log_rms': float(log_rms_50.mean()),
            'silog': float(silog_50.mean()),
            'log10': float(log10_50.mean())
        }
        eval_dump_dist[max_distance_list[3]] = {
            'd1': float(d1_100.mean()),
            'd2': float(d2_100.mean()),
            'd3': float(d3_100.mean()),
            'abs_rel': float(abs_rel_100.mean()),
            'sq_rel': float(sq_rel_100.mean()),
            'rms': float(rms_100.mean()),
            'log_rms': float(log_rms_100.mean()),
            'silog': float(silog_100.mean()),
            'log10': float(log10_100.mean())
        }
        eval_dump_dist[max_distance_list[4]] = {
            'd1': float(d1_150.mean()),
            'd2': float(d2_150.mean()),
            'd3': float(d3_150.mean()),
            'abs_rel': float(abs_rel_150.mean()),
            'sq_rel': float(sq_rel_150.mean()),
            'rms': float(rms_150.mean()),
            'log_rms': float(log_rms_150.mean()),
            'silog': float(silog_150.mean()),
            'log10': float(log10_150.mean())
        }
        eval_dump_dist[max_distance_list[5]] = {
            'd1': float(d1_200.mean()),
            'd2': float(d2_200.mean()),
            'd3': float(d3_200.mean()),
            'abs_rel': float(abs_rel_200.mean()),
            'sq_rel': float(sq_rel_200.mean()),
            'rms': float(rms_200.mean()),
            'log_rms': float(log_rms_200.mean()),
            'silog': float(silog_200.mean()),
            'log10': float(log10_200.mean())
        }

    return eval_dump_dist
Exemple #3
0
def test(params):
    """Test function."""
    args.mode = 'test'
    dataloader = BtsDataLoader(args, 'test')

    model = BtsModel(params=args)
    model = torch.nn.DataParallel(model)

    checkpoint = torch.load(args.checkpoint_path)
    model.load_state_dict(checkpoint['model'])
    model.eval()
    model.cuda()

    num_test_samples = get_num_lines(args.filenames_file)

    with open(args.filenames_file) as f:
        lines = f.readlines()

    import glob
    lines = glob.glob("/home/pebert/dataset/wireframe/train/*_label.npz")
    #lines.sort()

    print('now testing {} files with {}'.format(num_test_samples,
                                                args.checkpoint_path))

    pred_depths = []
    pred_8x8s = []
    pred_4x4s = []
    pred_2x2s = []
    pred_1x1s = []

    start_time = time.time()
    with torch.no_grad():
        for _, sample in enumerate(tqdm(dataloader.data)):
            #if _ > 300:
            #    break
            image = Variable(sample['image'].cuda())
            focal = []  #Variable(sample['focal'].cuda())
            # Predict
            lpg8x8, lpg4x4, lpg2x2, reduc1x1, depth_est = model(image, focal)
            pred_depths.append(depth_est.cpu().numpy().squeeze())
            pred_8x8s.append(lpg8x8[0].cpu().numpy().squeeze())
            pred_4x4s.append(lpg4x4[0].cpu().numpy().squeeze())
            pred_2x2s.append(lpg2x2[0].cpu().numpy().squeeze())
            pred_1x1s.append(reduc1x1[0].cpu().numpy().squeeze())

    elapsed_time = time.time() - start_time
    print('Elapesed time: %s' % str(elapsed_time))
    print('Done.')

    save_name = 'result_' + args.model_name

    print('Saving result pngs..')
    if not os.path.exists(os.path.dirname(save_name)):
        try:
            os.mkdir(save_name)
            os.mkdir(save_name + '/raw')
            os.mkdir(save_name + '/cmap')
            os.mkdir(save_name + '/rgb')
            os.mkdir(save_name + '/gt')
        except OSError as e:
            if e.errno != errno.EEXIST:
                raise

    for s in tqdm(range(len(lines))):
        if args.dataset == 'kitti':
            date_drive = lines[s].split('/')[1]
            filename_pred_png = save_name + '/raw/' + date_drive + '_' + lines[
                s].split()[0].split('/')[-1].replace('.jpg', '.png')
            filename_cmap_png = save_name + '/cmap/' + date_drive + '_' + lines[
                s].split()[0].split('/')[-1].replace('.jpg', '.png')
            filename_image_png = save_name + '/rgb/' + date_drive + '_' + lines[
                s].split()[0].split('/')[-1]
        elif args.dataset == 'kitti_benchmark':
            filename_pred_png = save_name + '/raw/' + lines[s].split(
            )[0].split('/')[-1].replace('.jpg', '.png')
            filename_cmap_png = save_name + '/cmap/' + lines[s].split(
            )[0].split('/')[-1].replace('.jpg', '.png')
            filename_image_png = save_name + '/rgb/' + lines[s].split(
            )[0].split('/')[-1]
        else:
            scene_name = lines[s].split('/')[-1].split('_label')[
                0]  #lines[s].split()[0].split('/')[0]
            filename_pred_png = save_name + '/raw/' + scene_name + '_' + 'depth.png'
            filename_cmap_png = save_name + '/cmap/' + scene_name + '_' + 'depth.png'
            filename_gt_png = save_name + '/gt/' + scene_name + '_' + 'depth.png'
            filename_image_png = save_name + '/rgb/' + scene_name + '_' + 'depth'

        rgb_path = os.path.join(
            lines[s][:-10].replace("_a0", "").replace("_a1", "") +
            ".png")  #os.path.join(args.data_path, './' + lines[s].split()[0])
        image = cv2.imread(rgb_path)
        #if args.dataset == 'nyu':
        #    gt_path = os.path.join(args.data_path, './' + lines[s].split()[1])
        #    gt = cv2.imread(gt_path, -1).astype(np.float32) #/ 1000.0  # Visualization purpose only
        #    gt[gt == 0] = np.amax(gt)

        pred_depth = pred_depths[s]
        pred_8x8 = pred_8x8s[s]
        pred_4x4 = pred_4x4s[s]
        pred_2x2 = pred_2x2s[s]
        pred_1x1 = pred_1x1s[s]

        #if args.dataset == 'kitti' or args.dataset == 'kitti_benchmark':
        #    pred_depth_scaled = pred_depth * 256.0
        #else:

        pred_depth_scaled = pred_depth / pred_depth.max() * 300.0

        pred_depth_scaled[pred_depth_scaled < 1e-3] = 1e-3
        pred_depth_scaled[pred_depth_scaled > 300] = 300
        pred_depth_scaled[np.isinf(pred_depth_scaled)] = 300
        pred_depth_scaled[np.isnan(pred_depth_scaled)] = 1e-3

        pred_depth_scaled = pred_depth_scaled.astype(np.float32)
        cv2.imwrite(filename_pred_png, pred_depth_scaled,
                    [cv2.IMWRITE_PNG_COMPRESSION, 0])
        #'/home/pebert/bts/tmp2.png'

        if args.save_lpg:
            cv2.imwrite(filename_image_png, image[10:-1 - 9, 10:-1 - 9, :])
            if args.dataset == 'nyu':
                plt.imsave(filename_gt_png,
                           np.log10(gt[10:-1 - 9, 10:-1 - 9]),
                           cmap='Greys')
                pred_depth_cropped = pred_depth[10:-1 - 9, 10:-1 - 9]
                plt.imsave(filename_cmap_png,
                           np.log10(pred_depth_cropped),
                           cmap='Greys')
                pred_8x8_cropped = pred_8x8[10:-1 - 9, 10:-1 - 9]
                filename_lpg_cmap_png = filename_cmap_png.replace(
                    '.png', '_8x8.png')
                plt.imsave(filename_lpg_cmap_png,
                           np.log10(pred_8x8_cropped),
                           cmap='Greys')
                pred_4x4_cropped = pred_4x4[10:-1 - 9, 10:-1 - 9]
                filename_lpg_cmap_png = filename_cmap_png.replace(
                    '.png', '_4x4.png')
                plt.imsave(filename_lpg_cmap_png,
                           np.log10(pred_4x4_cropped),
                           cmap='Greys')
                pred_2x2_cropped = pred_2x2[10:-1 - 9, 10:-1 - 9]
                filename_lpg_cmap_png = filename_cmap_png.replace(
                    '.png', '_2x2.png')
                plt.imsave(filename_lpg_cmap_png,
                           np.log10(pred_2x2_cropped),
                           cmap='Greys')
                pred_1x1_cropped = pred_1x1[10:-1 - 9, 10:-1 - 9]
                filename_lpg_cmap_png = filename_cmap_png.replace(
                    '.png', '_1x1.png')
                plt.imsave(filename_lpg_cmap_png,
                           np.log10(pred_1x1_cropped),
                           cmap='Greys')
            else:
                plt.imsave(filename_cmap_png,
                           np.log10(pred_depth),
                           cmap='Greys')
                filename_lpg_cmap_png = filename_cmap_png.replace(
                    '.png', '_8x8.png')
                plt.imsave(filename_lpg_cmap_png,
                           np.log10(pred_8x8),
                           cmap='Greys')
                filename_lpg_cmap_png = filename_cmap_png.replace(
                    '.png', '_4x4.png')
                plt.imsave(filename_lpg_cmap_png,
                           np.log10(pred_4x4),
                           cmap='Greys')
                filename_lpg_cmap_png = filename_cmap_png.replace(
                    '.png', '_2x2.png')
                plt.imsave(filename_lpg_cmap_png,
                           np.log10(pred_2x2),
                           cmap='Greys')
                filename_lpg_cmap_png = filename_cmap_png.replace(
                    '.png', '_1x1.png')
                plt.imsave(filename_lpg_cmap_png,
                           np.log10(pred_1x1),
                           cmap='Greys')

    return
Exemple #4
0
def main_worker(gpu, ngpus_per_node, args):
    args.gpu = gpu

    if args.gpu is not None:
        print("Use GPU: {} for training".format(args.gpu))

    if args.distributed:
        if args.dist_url == "env://" and args.rank == -1:
            args.rank = int(os.environ["RANK"])
        if args.multiprocessing_distributed:
            args.rank = args.rank * ngpus_per_node + gpu
        dist.init_process_group(backend=args.dist_backend,
                                init_method=args.dist_url,
                                world_size=args.world_size,
                                rank=args.rank)

    # Create model
    model = BtsModel(args)
    model.train()
    model.decoder.apply(weights_init_xavier)
    if args.bn_no_track_stats:
        print("Disabling tracking running stats in batch norm layers")
        model.apply(bn_init_as_tf)

    if args.fix_first_conv_blocks:
        if 'resne' in args.encoder:
            fixing_layers = [
                'base_model.conv1', 'base_model.layer1.0',
                'base_model.layer1.1', '.bn'
            ]
        else:
            fixing_layers = [
                'conv0', 'denseblock1.denselayer1', 'denseblock1.denselayer2',
                'norm'
            ]
        print("Fixing first two conv blocks")
    elif args.fix_first_conv_block:
        if 'resne' in args.encoder:
            fixing_layers = ['base_model.conv1', 'base_model.layer1.0', '.bn']
        else:
            fixing_layers = ['conv0', 'denseblock1.denselayer1', 'norm']
        print("Fixing first conv block")
    else:
        if 'resne' in args.encoder:
            fixing_layers = ['base_model.conv1', '.bn']
        else:
            fixing_layers = ['conv0', 'norm']
        print("Fixing first conv layer")

    for name, child in model.named_children():
        if not 'encoder' in name:
            continue
        for name2, parameters in child.named_parameters():
            # print(name, name2)
            if any(x in name2 for x in fixing_layers):
                parameters.requires_grad = False

    num_params = sum([np.prod(p.size()) for p in model.parameters()])
    print("Total number of parameters: {}".format(num_params))

    num_params_update = sum(
        [np.prod(p.shape) for p in model.parameters() if p.requires_grad])
    print("Total number of learning parameters: {}".format(num_params_update))

    if args.distributed:
        if args.gpu is not None:
            torch.cuda.set_device(args.gpu)
            model.cuda(args.gpu)
            args.batch_size = int(args.batch_size / ngpus_per_node)
            model = torch.nn.parallel.DistributedDataParallel(
                model, device_ids=[args.gpu], find_unused_parameters=True)
        else:
            model.cuda()
            model = torch.nn.parallel.DistributedDataParallel(
                model, find_unused_parameters=True)
    else:
        model = torch.nn.DataParallel(model)
        model.cuda()

    if args.distributed:
        print("Model Initialized on GPU: {}".format(args.gpu))
    else:
        print("Model Initialized")

    global_step = 0

    # Training parameters
    optimizer = torch.optim.AdamW([{
        'params': model.module.encoder.parameters(),
        'weight_decay': args.weight_decay
    }, {
        'params': model.module.decoder.parameters(),
        'weight_decay': 0
    }],
                                  lr=args.learning_rate,
                                  eps=1e-6)
    # optimizer = torch.optim.AdamW(model.parameters(), weight_decay=args.weight_decay, lr=args.learning_rate, eps=1e-3)

    if args.checkpoint_path != '':
        if os.path.isfile(args.checkpoint_path):
            print("Loading checkpoint '{}'".format(args.checkpoint_path))
            if args.gpu is None:
                checkpoint = torch.load(args.checkpoint_path)
            else:
                loc = 'cuda:{}'.format(args.gpu)
                checkpoint = torch.load(args.checkpoint_path, map_location=loc)
            global_step = checkpoint['global_step']
            model.load_state_dict(checkpoint['model'])
            optimizer.load_state_dict(checkpoint['optimizer'])
            print("Loaded checkpoint '{}' (global_step {})".format(
                args.checkpoint_path, checkpoint['global_step']))
        else:
            print("No checkpoint found at '{}'".format(args.checkpoint_path))

    if args.retrain:
        global_step = 0

    cudnn.benchmark = True

    dataloader = BtsDataLoader(args)

    # Logging
    if not args.multiprocessing_distributed or (
            args.multiprocessing_distributed
            and args.rank % ngpus_per_node == 0):
        writer = SummaryWriter(args.log_directory + '/' + args.model_name +
                               '/summaries',
                               flush_secs=30)

    silog_criterion = silog_loss(variance_focus=args.variance_focus)

    start_time = time.time()
    duration = 0
    log_freq = 100
    save_freq = 500

    num_log_images = args.batch_size
    end_learning_rate = args.end_learning_rate if args.end_learning_rate != -1 else 0.1 * args.learning_rate

    var_sum = [var.sum() for var in model.parameters() if var.requires_grad]
    var_cnt = len(var_sum)
    var_sum = np.sum(var_sum)

    print("Initial variables' sum: {:.3f}, avg: {:.3f}".format(
        var_sum, var_sum / var_cnt))

    steps_per_epoch = len(dataloader.data)
    num_total_steps = args.num_epochs * steps_per_epoch

    epoch = global_step // steps_per_epoch

    while epoch < args.num_epochs:
        if args.distributed:
            dataloader.train_sampler.set_epoch(epoch)

        for step, sample_batched in enumerate(dataloader.data):
            optimizer.zero_grad()
            before_op_time = time.time()

            image = torch.autograd.Variable(sample_batched['image'].cuda(
                args.gpu, non_blocking=True))
            focal = torch.autograd.Variable(sample_batched['focal'].cuda(
                args.gpu, non_blocking=True))
            depth_gt = torch.autograd.Variable(sample_batched['depth'].cuda(
                args.gpu, non_blocking=True))

            lpg8x8, lpg4x4, lpg2x2, reduc1x1, depth_est = model(image, focal)

            if args.dataset == 'nyu':
                mask = depth_gt > 0.1
            else:
                mask = depth_gt > 1.0

            loss = silog_criterion.forward(depth_est, depth_gt,
                                           mask.to(torch.bool))
            loss.backward()
            for param_group in optimizer.param_groups:
                current_lr = (args.learning_rate - end_learning_rate) * (
                    1 - global_step / num_total_steps)**0.9 + end_learning_rate
                param_group['lr'] = current_lr

            optimizer.step()

            if not args.multiprocessing_distributed or (
                    args.multiprocessing_distributed
                    and args.rank % ngpus_per_node == 0):
                print(
                    '[epoch][s/s_per_e/gs]: [{}][{}/{}/{}], lr: {:.12f}, loss: {:.12f}'
                    .format(epoch, step, steps_per_epoch, global_step,
                            current_lr, loss))

            duration += time.time() - before_op_time
            if global_step and global_step % log_freq == 0:
                var_sum = [
                    var.sum() for var in model.parameters()
                    if var.requires_grad
                ]
                var_cnt = len(var_sum)
                var_sum = np.sum(var_sum)
                examples_per_sec = args.batch_size / duration * log_freq
                duration = 0
                time_sofar = (time.time() - start_time) / 3600
                training_time_left = (num_total_steps / global_step -
                                      1.0) * time_sofar
                if not args.multiprocessing_distributed or (
                        args.multiprocessing_distributed
                        and args.rank % ngpus_per_node == 0):
                    print("{}".format(args.model_name))
                print_string = 'GPU: {} | examples/s: {:4.2f} | loss: {:.5f} | var sum: {:.3f} avg: {:.3f} | time elapsed: {:.2f}h | time left: {:.2f}h'
                print(
                    print_string.format(args.gpu, examples_per_sec, loss,
                                        var_sum.item(),
                                        var_sum.item() / var_cnt, time_sofar,
                                        training_time_left))

                if not args.multiprocessing_distributed or (
                        args.multiprocessing_distributed
                        and args.rank % ngpus_per_node == 0):
                    writer.add_scalar('silog_loss', loss, global_step)
                    writer.add_scalar('learning_rate', current_lr, global_step)
                    writer.add_scalar('var average',
                                      var_sum.item() / var_cnt, global_step)
                    depth_gt = torch.where(depth_gt < 1e-3, depth_gt * 0 + 1e3,
                                           depth_gt)
                    for i in range(num_log_images):
                        writer.add_image(
                            'depth_gt/image/{}'.format(i),
                            normalize_result(1 / depth_gt[i, :, :, :].data),
                            global_step)
                        writer.add_image(
                            'depth_est/image/{}'.format(i),
                            normalize_result(1 / depth_est[i, :, :, :].data),
                            global_step)
                        writer.add_image(
                            'reduc1x1/image/{}'.format(i),
                            normalize_result(1 / reduc1x1[i, :, :, :].data),
                            global_step)
                        writer.add_image(
                            'lpg2x2/image/{}'.format(i),
                            normalize_result(1 / lpg2x2[i, :, :, :].data),
                            global_step)
                        writer.add_image(
                            'lpg4x4/image/{}'.format(i),
                            normalize_result(1 / lpg4x4[i, :, :, :].data),
                            global_step)
                        writer.add_image(
                            'lpg8x8/image/{}'.format(i),
                            normalize_result(1 / lpg8x8[i, :, :, :].data),
                            global_step)
                        writer.add_image('image/image/{}'.format(i),
                                         inv_normalize(image[i, :, :, :]).data,
                                         global_step)
                    writer.flush()

            if global_step and global_step % save_freq == 0:
                if not args.multiprocessing_distributed or (
                        args.multiprocessing_distributed
                        and args.rank % ngpus_per_node == 0):
                    checkpoint = {
                        'global_step': global_step,
                        'model': model.state_dict(),
                        'optimizer': optimizer.state_dict()
                    }
                    torch.save(
                        checkpoint, args.log_directory + '/' +
                        args.model_name + '/model-{}'.format(global_step))

            global_step += 1

        epoch += 1