Exemple #1
0
def test_myimagefloder():
    show_images = True
    scale_factor = 1.0
    dataset_folder = '/datasets/hrvs/carla-highres/trainingF'
    all_left_img, all_right_img, all_left_disp, all_right_disp = ls.dataloader(
        dataset_folder)
    loader_eth3 = DA.myImageFloder(all_left_img,
                                   all_right_img,
                                   all_left_disp,
                                   right_disparity=all_right_disp,
                                   rand_scale=[0.225, 0.6 * scale_factor],
                                   rand_bright=[0.8, 1.2],
                                   order=2)
    inv_t = get_inv_transform()
    for left_img, right_img, disp in loader_eth3:
        if show_images:
            left_img_np = np.array(inv_t(left_img)).astype(np.uint8)
            right_img_np = np.array(inv_t(right_img)).astype(np.uint8)
            disp_img = cv2.normalize(disp,
                                     None,
                                     alpha=0,
                                     beta=1,
                                     norm_type=cv2.NORM_MINMAX,
                                     dtype=cv2.CV_32F)
            cv2.imshow('left_img', left_img_np[:, :, ::-1])
            cv2.imshow('right_img', right_img_np[:, :, ::-1])
            cv2.imshow('disp_img', disp_img)
            cv2.waitKey(0)
        assert left_img.shape == (3, 512, 768)
        assert right_img.shape == (3, 512, 768)
        assert disp.shape == (512, 768)
        break
Exemple #2
0
def test_with_torch_dataloader():
    show_images = True
    all_left_img, all_right_img, all_left_disp = dataloader(
        '/datasets/lidar_dataset/')
    lidar_loader = DA.myImageFloder(all_left_img,
                                    all_right_img,
                                    all_left_disp,
                                    rand_scale=[0.9, 1.1],
                                    rand_bright=[0.8, 1.2],
                                    order=2)
    inv_t = get_inv_transform()
    for left_img, right_img, disp in lidar_loader:
        if show_images:
            left_img_np = np.array(inv_t(left_img)).astype(np.uint8)
            right_img_np = np.array(inv_t(right_img)).astype(np.uint8)
            disp_img = cv2.normalize(disp,
                                     None,
                                     alpha=0,
                                     beta=1,
                                     norm_type=cv2.NORM_MINMAX,
                                     dtype=cv2.CV_32F)

            rectified_pair = np.concatenate((left_img_np, right_img_np),
                                            axis=1)
            h, w, _ = rectified_pair.shape
            for i in range(10, h, 30):
                rectified_pair = cv2.line(rectified_pair, (0, i), (w, i),
                                          (0, 0, 255))
            cv2.imshow('rectified', rectified_pair[:, :, ::-1])
            cv2.imshow('disp_img', disp_img)
            cv2.waitKey(0)
        assert left_img.shape == (3, 512, 768)
        assert right_img.shape == (3, 512, 768)
        assert disp.shape == (512, 768)
        break
Exemple #3
0
def init_dataloader(input_args):
    batch_size = input_args.batchsize
    scale_factor = input_args.maxdisp / 384.  # controls training resolution

    hrvs_folder = '%s/hrvs/carla-highres/trainingF' % input_args.database
    all_left_img, all_right_img, all_left_disp, all_right_disp = ls.dataloader(
        hrvs_folder)
    loader_carla = DA.myImageFloder(all_left_img,
                                    all_right_img,
                                    all_left_disp,
                                    right_disparity=all_right_disp,
                                    rand_scale=[0.225, 0.6 * scale_factor],
                                    rand_bright=[0.8, 1.2],
                                    order=2)

    middlebury_folder = '%s/middlebury/mb-ex-training/trainingF' % input_args.database
    all_left_img, all_right_img, all_left_disp, all_right_disp = ls.dataloader(
        middlebury_folder)
    loader_mb = DA.myImageFloder(all_left_img,
                                 all_right_img,
                                 all_left_disp,
                                 right_disparity=all_right_disp,
                                 rand_scale=[0.225, 0.6 * scale_factor],
                                 rand_bright=[0.8, 1.2],
                                 order=0)

    rand_scale = [0.9, 2.4 * scale_factor]
    all_left_img, all_right_img, all_left_disp, all_right_disp = lt.dataloader(
        '%s/sceneflow/' % input_args.database)
    loader_scene = DA.myImageFloder(all_left_img,
                                    all_right_img,
                                    all_left_disp,
                                    right_disparity=all_right_disp,
                                    rand_scale=rand_scale,
                                    order=2)

    # change to trainval when finetuning on KITTI
    all_left_img, all_right_img, all_left_disp, _, _, _ = lk15.dataloader(
        '%s/kitti15/training/' % input_args.database, split='train')
    loader_kitti15 = DA.myImageFloder(all_left_img,
                                      all_right_img,
                                      all_left_disp,
                                      rand_scale=rand_scale,
                                      order=0)

    all_left_img, all_right_img, all_left_disp = lk12.dataloader(
        '%s/kitti12/training/' % input_args.database)
    loader_kitti12 = DA.myImageFloder(all_left_img,
                                      all_right_img,
                                      all_left_disp,
                                      rand_scale=rand_scale,
                                      order=0)

    all_left_img, all_right_img, all_left_disp, _ = ls.dataloader(
        '%s/eth3d/' % input_args.database)
    loader_eth3d = DA.myImageFloder(all_left_img,
                                    all_right_img,
                                    all_left_disp,
                                    rand_scale=rand_scale,
                                    order=0)

    all_left_img, all_right_img, all_left_disp, all_right_disp = lidar_dataloader(
        '%s/lidar-hdsm-dataset/' % input_args.database)
    loader_lidar = DA.myImageFloder(all_left_img,
                                    all_right_img,
                                    all_left_disp,
                                    right_disparity=all_right_disp,
                                    rand_scale=[0.5, 1.1 * scale_factor],
                                    rand_bright=[0.8, 1.2],
                                    order=2,
                                    flip_disp_ud=True,
                                    occlusion_size=[10, 25])

    data_inuse = torch.utils.data.ConcatDataset(
        [loader_carla] * 10 + [loader_mb] * 150 +  # 71 pairs
        [loader_scene] +  # 39K pairs 960x540
        [loader_kitti15] + [loader_kitti12] * 24 + [loader_eth3d] * 300 +
        [loader_lidar])  # 25K pairs
    # airsim ~750
    train_dataloader = torch.utils.data.DataLoader(data_inuse,
                                                   batch_size=batch_size,
                                                   shuffle=True,
                                                   num_workers=batch_size,
                                                   drop_last=True,
                                                   worker_init_fn=_init_fn)
    print('%d batches per epoch' % (len(data_inuse) // batch_size))
    return train_dataloader
                    help='enables CUDA training')
parser.add_argument('--seed', type=int, default=1, metavar='S',
                    help='random seed (default: 1)')
args = parser.parse_args()
args.cuda = not args.no_cuda and torch.cuda.is_available()

torch.manual_seed(args.seed)
if args.cuda:
    torch.cuda.manual_seed(args.seed)

if args.dataset == 'MiddleburyLoader':
    from dataloader import MiddleburyLoader as DA
else:
    from dataloader import LibraryLoader as DA

test_left_img, test_right_img = DA.dataloader(args.datapath)

if args.model == 'basic2':
    model = basic2(args.maxdisp)
elif args.model == 'basic':
    model = basic(args.maxdisp)
else:
    print('no model')

model.cuda()

if args.loadmodel is not None:
    state_dict = torch.load(args.loadmodel)
    keys = list(state_dict['state_dict'].keys())
    values = list(state_dict['state_dict'].values())
    state_dict2 ={}
Exemple #5
0
    np.random.seed()
    random.seed()
torch.manual_seed(args.seed)  # set again
torch.cuda.manual_seed(args.seed)


from dataloader import listfiles as ls
from dataloader import listsceneflow as lt
from dataloader import KITTIloader2015 as lk15
from dataloader import KITTIloader2012 as lk12
from dataloader import MiddleburyLoader as DA

batch_size = args.batchsize
scale_factor = args.maxdisp / 384. # controls training resolution
all_left_img, all_right_img, all_left_disp, all_right_disp = ls.dataloader('%s/carla-highres/trainingF'%args.database)
loader_carla = DA.myImageFloder(all_left_img,all_right_img,all_left_disp,right_disparity=all_right_disp, rand_scale=[0.225,0.6*scale_factor], order=2)

all_left_img, all_right_img, all_left_disp, all_right_disp = ls.dataloader('%s/mb-ex-training/trainingF'%args.database)  # mb-ex
loader_mb = DA.myImageFloder(all_left_img,all_right_img,all_left_disp,right_disparity=all_right_disp, rand_scale=[0.225,0.6*scale_factor], order=0)

all_left_img, all_right_img, all_left_disp, all_right_disp = lt.dataloader('%s/sceneflow/'%args.database)
loader_scene = DA.myImageFloder(all_left_img,all_right_img,all_left_disp,right_disparity=all_right_disp, rand_scale=[0.9,2.4*scale_factor], order=2)

all_left_img, all_right_img, all_left_disp,_,_,_ = lk15.dataloader('%s/kitti_scene/training/'%args.database,typ='train') # trainval
loader_kitti15 = DA.myImageFloder(all_left_img,all_right_img,all_left_disp, rand_scale=[0.9,2.4*scale_factor], order=0)
all_left_img, all_right_img, all_left_disp = lk12.dataloader('%s/data_stereo_flow/training/'%args.database)
loader_kitti12 = DA.myImageFloder(all_left_img,all_right_img,all_left_disp, rand_scale=[0.9,2.4*scale_factor], order=0)

all_left_img, all_right_img, all_left_disp, _ = ls.dataloader('%s/eth3d/'%args.database)
loader_eth3d = DA.myImageFloder(all_left_img,all_right_img,all_left_disp, rand_scale=[0.9,2.4*scale_factor],order=0)
def main():
    parser = argparse.ArgumentParser(description='HSM-Net')
    parser.add_argument('--maxdisp',
                        type=int,
                        default=384,
                        help='maxium disparity')
    parser.add_argument('--name', default='name')
    parser.add_argument('--database',
                        default='/data/private',
                        help='data path')
    parser.add_argument('--epochs',
                        type=int,
                        default=10,
                        help='number of epochs to train')
    parser.add_argument(
        '--batch_size',
        type=int,
        default=16,
        # when maxdisp is 768, 18 is the most you can fit in 2 V100s (with syncBN on)
        help='samples per batch')
    parser.add_argument(
        '--val_batch_size',
        type=int,
        default=4,
        # when maxdisp is 768, 18 is the most you can fit in 2 V100s (with syncBN on)
        help='samples per batch')
    parser.add_argument('--loadmodel', default=None, help='weights path')
    parser.add_argument('--log_dir',
                        default="/data/private/logs/high-res-stereo")
    # parser.add_argument('--savemodel', default=os.path.join(os.getcwd(),'/trained_model'),
    #                     help='save path')
    parser.add_argument('--seed',
                        type=int,
                        default=1,
                        metavar='S',
                        help='random seed (default: 1)')
    parser.add_argument('--val_epoch', type=int, default=4)
    parser.add_argument('--save_epoch', type=int, default=10)
    parser.add_argument("--val", action="store_true", default=False)
    parser.add_argument("--save_numpy", action="store_true", default=False)
    parser.add_argument("--testres", type=float, default=1.8)
    parser.add_argument("--threshold", type=float, default=0.7)
    parser.add_argument("--use_pseudoGT", default=False, action="store_true")
    parser.add_argument("--lr", default=1e-3, type=float)
    parser.add_argument("--lr_decay", default=2, type=int)
    parser.add_argument("--gpu", default=[0], nargs="+")
    parser.add_argument("--no_aug", default=False, action="store_true")

    args = parser.parse_args()
    torch.manual_seed(args.seed)
    torch.manual_seed(args.seed)  # set again
    torch.cuda.manual_seed(args.seed)
    batch_size = args.batch_size
    scale_factor = args.maxdisp / 384.  # controls training resolution
    args.name = args.name + "_" + time.strftime('%l:%M%p_%Y%b%d').strip(" ")
    gpu = []
    for i in args.gpu:
        gpu.append(int(i))
    args.gpu = gpu

    root_dir = "/data/private/KITTI_raw/2011_09_26/2011_09_26_drive_0013_sync"
    disp_dir = "final-768px_testres-3.3/disp"
    entp_dir = "final-768px_testres-3.3/entropy"
    mode = "image"
    image_name = "0000000040.npy"  #* this is the 4th image in the validation set
    train_left, train_right, train_disp, train_entp = kitti_raw_loader(
        root_dir, disp_dir, entp_dir, mode=mode, image_name=image_name)
    train_left = train_left * args.batch_size * 16
    train_right = train_right * args.batch_size * 16
    train_disp = train_disp * args.batch_size * 16
    train_entp = train_entp * args.batch_size * 16

    all_left_img, all_right_img, all_left_disp, left_val, right_val, disp_val_L = lk15.dataloader(
        '%s/KITTI2015/data_scene_flow/training/' % args.database, val=args.val)

    left_val = [left_val[3]]
    right_val = [right_val[3]]
    disp_val_L = [disp_val_L[3]]

    loader_kitti15 = DA.myImageFloder(train_left,
                                      train_right,
                                      train_disp,
                                      rand_scale=[0.9, 2.4 * scale_factor],
                                      order=0,
                                      use_pseudoGT=args.use_pseudoGT,
                                      entropy_threshold=args.threshold,
                                      left_entropy=train_entp,
                                      no_aug=args.no_aug)
    val_loader_kitti15 = DA.myImageFloder(left_val,
                                          right_val,
                                          disp_val_L,
                                          is_validation=True,
                                          testres=args.testres)

    train_data_inuse = loader_kitti15
    val_data_inuse = val_loader_kitti15

    # ! For internal bug in Pytorch, if you are going to set num_workers >0 in one dataloader, it must also be set to
    # ! n >0 for the other data loader as well (ex. 1 for valLoader and 10 for trainLoader)
    ValImgLoader = torch.utils.data.DataLoader(
        val_data_inuse,
        drop_last=False,
        batch_size=args.val_batch_size,
        shuffle=False,
        worker_init_fn=_init_fn,
        num_workers=args.val_batch_size)  #

    TrainImgLoader = torch.utils.data.DataLoader(
        train_data_inuse,
        batch_size=batch_size,
        shuffle=True,
        drop_last=True,
        worker_init_fn=_init_fn,
        num_workers=args.batch_size)  # , , worker_init_fn=_init_fn
    print('%d batches per epoch' % (len(train_data_inuse) // batch_size))

    model = hsm(args.maxdisp, clean=False, level=1)

    if len(args.gpu) > 1:
        from sync_batchnorm.sync_batchnorm import convert_model
        model = nn.DataParallel(model, device_ids=args.gpu)
        model = convert_model(model)
    else:
        model = nn.DataParallel(model, device_ids=args.gpu)

    model.cuda()

    # load model
    if args.loadmodel is not None:
        print("loading pretrained model: " + str(args.loadmodel))
        pretrained_dict = torch.load(args.loadmodel)
        pretrained_dict['state_dict'] = {
            k: v
            for k, v in pretrained_dict['state_dict'].items()
            if ('disp' not in k)
        }
        model.load_state_dict(pretrained_dict['state_dict'], strict=False)

    print('Number of model parameters: {}'.format(
        sum([p.data.nelement() for p in model.parameters()])))

    optimizer = optim.Adam(model.parameters(), lr=args.lr, betas=(0.9, 0.999))

    log = logger.Logger(args.log_dir, args.name, save_numpy=args.save_numpy)
    total_iters = 0
    val_sample_count = 0
    val_batch_count = 0
    save_path = os.path.join(args.log_dir,
                             os.path.join(args.name, "saved_model"))
    os.makedirs(save_path, exist_ok=True)

    for epoch in range(1, args.epochs + 1):
        total_train_loss = 0
        train_score_accum_dict = {
        }  # accumulates scores throughout a batch to get average score
        train_score_accum_dict["num_scored"] = 0
        adjust_learning_rate(optimizer,
                             args.lr,
                             args.lr_decay,
                             epoch,
                             args.epochs,
                             decay_rate=0.1)

        print('Epoch %d / %d' % (epoch, args.epochs))

        # SAVE
        if epoch != 1 and epoch % args.save_epoch == 0:
            print("saving weights at epoch: " + str(epoch))
            savefilename = os.path.join(save_path,
                                        'ckpt_' + str(total_iters) + '.tar')

            torch.save(
                {
                    'iters': total_iters,
                    'state_dict': model.state_dict(),
                    'train_loss': total_train_loss / len(TrainImgLoader),
                    "optimizer": optimizer.state_dict()
                }, savefilename)

        ## val ##

        if epoch == 1 or epoch % args.val_epoch == 0:
            print("validating at epoch: " + str(epoch))
            val_score_accum_dict = {}
            val_img_idx = 0
            for batch_idx, (imgL_crop, imgR_crop,
                            disp_crop_L) in enumerate(ValImgLoader):

                vis, scores_list, err_map_list = val_step(
                    model, imgL_crop, imgR_crop, disp_crop_L, args.maxdisp,
                    args.testres)

                for score, err_map in zip(scores_list, err_map_list):
                    for (score_tag,
                         score_val), (map_tag,
                                      map_val) in zip(score.items(),
                                                      err_map.items()):
                        log.scalar_summary(
                            "val/im_" + str(val_img_idx) + "/" + score_tag,
                            score_val, val_sample_count)
                        log.image_summary("val/" + map_tag, map_val,
                                          val_sample_count)

                        if score_tag not in val_score_accum_dict.keys():
                            val_score_accum_dict[score_tag] = 0
                        val_score_accum_dict[score_tag] += score_val
                    val_img_idx += 1
                    val_sample_count += 1

                log.image_summary('val/left', imgL_crop[0:1], val_sample_count)
                log.image_summary('val/right', imgR_crop[0:1],
                                  val_sample_count)
                log.disp_summary('val/gt0', disp_crop_L[0:1],
                                 val_sample_count)  # <-- GT disp
                log.entp_summary('val/entropy', vis['entropy'],
                                 val_sample_count)
                log.disp_summary('val/output3', vis['output3'][0],
                                 val_sample_count)

            for score_tag, score_val in val_score_accum_dict.items():
                log.scalar_summary("val/" + score_tag + "_batch_avg",
                                   score_val, epoch)

        ## training ##
        for batch_idx, (imgL_crop, imgR_crop,
                        disp_crop_L) in enumerate(TrainImgLoader):
            print("training at epoch: " + str(epoch))

            is_scoring = total_iters % 10 == 0

            loss, vis, scores_list, maps = train_step(model,
                                                      optimizer,
                                                      imgL_crop,
                                                      imgR_crop,
                                                      disp_crop_L,
                                                      args.maxdisp,
                                                      is_scoring=is_scoring)

            total_train_loss += loss

            if is_scoring:
                log.scalar_summary('train/loss_batch', loss, total_iters)
                for score in scores_list:
                    for tag, val in score.items():
                        log.scalar_summary("train/" + tag + "_batch", val,
                                           total_iters)

                        if tag not in train_score_accum_dict.keys():
                            train_score_accum_dict[tag] = 0
                        train_score_accum_dict[tag] += val
                        train_score_accum_dict[
                            "num_scored"] += imgL_crop.shape[0]

                for tag, err_map in maps[0].items():
                    log.image_summary("train/" + tag, err_map, total_iters)

            if total_iters % 10 == 0:
                log.image_summary('train/left', imgL_crop[0:1], total_iters)
                log.image_summary('train/right', imgR_crop[0:1], total_iters)
                log.disp_summary('train/gt0', disp_crop_L[0:1],
                                 total_iters)  # <-- GT disp
                log.entp_summary('train/entropy', vis['entropy'][0:1],
                                 total_iters)
                log.disp_summary('train/output3', vis['output3'][0:1],
                                 total_iters)

            total_iters += 1

        log.scalar_summary('train/loss',
                           total_train_loss / len(TrainImgLoader), epoch)
        for tag, val in train_score_accum_dict.items():
            log.scalar_summary("train/" + tag + "_avg",
                               val / train_score_accum_dict["num_scored"],
                               epoch)

        torch.cuda.empty_cache()
    # Save final checkpoint
    print("Finished training!\n Saving the last checkpoint...")
    savefilename = os.path.join(save_path, 'final' + '.tar')

    torch.save(
        {
            'iters': total_iters,
            'state_dict': model.state_dict(),
            'train_loss': total_train_loss / len(TrainImgLoader),
            "optimizer": optimizer.state_dict()
        }, savefilename)
Exemple #7
0
def main():
    parser = argparse.ArgumentParser(description='HSM-Net')
    parser.add_argument('--maxdisp', type=int, default=384,
                        help='maxium disparity')
    parser.add_argument('--name', default='name')
    parser.add_argument('--database', default='/data/private',
                        help='data path')
    parser.add_argument('--seed', type=int, default=1, metavar='S',
                        help='random seed (default: 1)')
    parser.add_argument('--val_batch_size', type=int, default=1,
                        help='samples per batch')
    parser.add_argument('--loadmodel', default=None,
                        help='weights path')
    parser.add_argument('--log_dir', default="/data/private/logs/high-res-stereo")
    parser.add_argument("--testres", default=[0], nargs="+")
    parser.add_argument("--no_aug",default=False, action="store_true")

    args = parser.parse_args()
    torch.manual_seed(args.seed)
    torch.manual_seed(args.seed)  # set again
    torch.cuda.manual_seed(args.seed)
    args.name = args.name + "_" + time.strftime('%l:%M%p_%Y%b%d').strip(" ")
    testres = []
    for i in args.testres:
        testres.append(float(i))
    args.testres=testres

    all_left_img, all_right_img, all_left_disp, left_val, right_val, disp_val_L = lk15.dataloader(
        '%s/KITTI2015/data_scene_flow/training/' % args.database, val=True)

    left_val = [left_val[3]]
    right_val = [right_val[3]]
    disp_val_L = [disp_val_L[3]]

    # all_l = all_left_disp + left_val
    # all_r = all_right_img + right_val
    # all_d = all_left_disp + disp_val_L

    # correct_shape = (1242, 375)
    # for i in range(len(all_l)):
    #     l = np.array(Image.open(all_l[i]).convert("RGB"))
    #     r = np.array(Image.open(all_r[i]).convert("RGB"))
    #     d = Image.open(all_d[i])
    #     if l.shape != (375, 1242, 3):
    #
    #         l2 = cv2.resize(l, correct_shape, interpolation=cv2.INTER_CUBIC)
    #         r2 = cv2.resize(r, correct_shape, interpolation=cv2.INTER_CUBIC)
    #         d2 = np.array(torchvision.transforms.functional.resize(d, [375, 1242]))
    #         # d = np.stack([d, d, d], axis=-1)
    #         # d2 = cv2.resize(d.astype("uint16"), correct_shape)
    #
    #         cv2.imwrite(all_l[i], cv2.cvtColor(l2, cv2.COLOR_RGB2BGR))
    #         cv2.imwrite(all_r[i], cv2.cvtColor(r2, cv2.COLOR_RGB2BGR))
    #         cv2.imwrite(all_d[i], d2)


        # cv2.resize(l,())
    model = hsm(args.maxdisp, clean=False, level=1)
    model.cuda()

    # load model
    print("loading pretrained model: " + str(args.loadmodel))
    pretrained_dict = torch.load(args.loadmodel)
    pretrained_dict['state_dict'] = {k: v for k, v in pretrained_dict['state_dict'].items() if ('disp' not in k)}
    model = nn.DataParallel(model, device_ids=[0])
    model.load_state_dict(pretrained_dict['state_dict'], strict=False)

    name = "val_at_many_res" + "_" + datetime.now().strftime("%Y-%m-%d_%H:%M:%S")
    log = logger.Logger(args.log_dir, name)
    val_sample_count = 0
    for res in args.testres:

        val_loader_kitti15 = DA.myImageFloder(left_val, right_val, disp_val_L, is_validation=True, testres=res)
        ValImgLoader = torch.utils.data.DataLoader(val_loader_kitti15, drop_last=False, batch_size=args.val_batch_size,
                                                   shuffle=False, worker_init_fn=_init_fn,
                                                   num_workers=0)
        print("================ res: " + str(res) + " ============================")
        ## val ##
        val_score_accum_dict = {}
        val_img_idx = 0
        for batch_idx, (imgL_crop, imgR_crop, disp_crop_L) in enumerate(ValImgLoader):
            vis, scores_list, err_map_list = val_step(model, imgL_crop, imgR_crop, disp_crop_L, args.maxdisp, res)

            for score, err_map in zip(scores_list, err_map_list):
                for (score_tag, score_val), (map_tag, map_val) in zip(score.items(), err_map.items()):
                    log.scalar_summary("val/im_" + str(val_img_idx) + "/" + str(res) + "/"+ score_tag, score_val, val_sample_count)
                    log.image_summary("val/" + str(res) + "/"+ map_tag, map_val, val_sample_count)

                    if score_tag not in val_score_accum_dict.keys():
                        val_score_accum_dict[score_tag] = 0
                    val_score_accum_dict[score_tag]+=score_val
                    print("res: " + str(res) + " " + score_tag + ": " + str(score_val))

                val_img_idx+=1
                val_sample_count += 1

                log.image_summary('val/left', imgL_crop[0:1], val_sample_count)
                # log.image_summary('val/right', imgR_crop[0:1], val_sample_count)
                log.disp_summary('val/gt0', disp_crop_L[0:1], val_sample_count)  # <-- GT disp
                log.entp_summary('val/entropy', vis['entropy'], val_sample_count)
                log.disp_summary('val/output3', vis['output3'][0], val_sample_count)
Exemple #8
0
def get_training_dataloader(maxdisp, dataset_folder):
    scale_factor = maxdisp / 384.

    all_left_img, all_right_img, all_left_disp, all_right_disp = ls.dataloader(
        '%s/hrvs/carla-highres/trainingF' % dataset_folder)
    loader_carla = DA.myImageFloder(all_left_img,
                                    all_right_img,
                                    all_left_disp,
                                    right_disparity=all_right_disp,
                                    rand_scale=[0.225, 0.6 * scale_factor],
                                    rand_bright=[0.8, 1.2],
                                    order=2)

    all_left_img, all_right_img, all_left_disp, all_right_disp = ls.dataloader(
        '%s/middlebury/mb-ex-training/trainingF' % dataset_folder)  # mb-ex
    loader_mb = DA.myImageFloder(all_left_img,
                                 all_right_img,
                                 all_left_disp,
                                 right_disparity=all_right_disp,
                                 rand_scale=[0.225, 0.6 * scale_factor],
                                 rand_bright=[0.8, 1.2],
                                 order=0)

    all_left_img, all_right_img, all_left_disp, all_right_disp = lt.dataloader(
        '%s/sceneflow/' % dataset_folder)
    loader_scene = DA.myImageFloder(all_left_img,
                                    all_right_img,
                                    all_left_disp,
                                    right_disparity=all_right_disp,
                                    rand_scale=[0.9, 2.4 * scale_factor],
                                    order=2)

    all_left_img, all_right_img, all_left_disp, _, _, _ = lk15.dataloader(
        '%s/kitti15/training/' % dataset_folder,
        typ='train')  # change to trainval when finetuning on KITTI
    loader_kitti15 = DA.myImageFloder(all_left_img,
                                      all_right_img,
                                      all_left_disp,
                                      rand_scale=[0.9, 2.4 * scale_factor],
                                      order=0)
    all_left_img, all_right_img, all_left_disp = lk12.dataloader(
        '%s/kitti12/training/' % dataset_folder)
    loader_kitti12 = DA.myImageFloder(all_left_img,
                                      all_right_img,
                                      all_left_disp,
                                      rand_scale=[0.9, 2.4 * scale_factor],
                                      order=0)

    all_left_img, all_right_img, all_left_disp, _ = ls.dataloader(
        '%s/eth3d/' % dataset_folder)
    loader_eth3d = DA.myImageFloder(all_left_img,
                                    all_right_img,
                                    all_left_disp,
                                    rand_scale=[0.9, 2.4 * scale_factor],
                                    order=0)

    all_left_img, all_right_img, all_left_disp = lld.dataloader(
        '%s/lidar_dataset/train' % dataset_folder)
    loader_lidar = DA.myImageFloder(all_left_img,
                                    all_right_img,
                                    all_left_disp,
                                    rand_scale=[0.5, 1.25 * scale_factor],
                                    rand_bright=[0.8, 1.2],
                                    order=0)
    all_dataloaders = [{
        'name': 'lidar',
        'dl': loader_lidar,
        'count': 1
    }, {
        'name': 'hrvs',
        'dl': loader_carla,
        'count': 1
    }, {
        'name': 'middlebury',
        'dl': loader_mb,
        'count': 1
    }, {
        'name': 'sceneflow',
        'dl': loader_scene,
        'count': 1
    }, {
        'name': 'kitti12',
        'dl': loader_kitti12,
        'count': 1
    }, {
        'name': 'kitti15',
        'dl': loader_kitti15,
        'count': 1
    }, {
        'name': 'eth3d',
        'dl': loader_eth3d,
        'count': 1
    }]
    max_count = 0
    for dataloader in all_dataloaders:
        max_count = max(max_count, len(dataloader['dl']))

    print('=' * 80)
    concat_dataloaders = []
    for dataloader in all_dataloaders:
        dataloader['count'] = max(1, max_count // len(dataloader['dl']))
        concat_dataloaders += [dataloader['dl']] * dataloader['count']
        print('{name}: {size} (x{count})'.format(name=dataloader['name'],
                                                 size=len(dataloader['dl']),
                                                 count=dataloader['count']))
    data_inuse = torch.utils.data.ConcatDataset(concat_dataloaders)
    print('Total dataset size: {}'.format(len(data_inuse)))
    print('=' * 80)
    return data_inuse
def main():
    parser = argparse.ArgumentParser(description='HSM')
    parser.add_argument(
        '--datapath',
        default="/home/isaac/rvc_devkit/stereo/datasets_middlebury2014",
        help='test data path')
    parser.add_argument('--loadmodel', default=None, help='model path')
    parser.add_argument('--name',
                        default='rvc_highres_output',
                        help='output dir')
    parser.add_argument('--clean',
                        type=float,
                        default=-1,
                        help='clean up output using entropy estimation')
    parser.add_argument(
        '--testres',
        type=float,
        default=0.5,  #default used to be 0.5
        help='test time resolution ratio 0-x')
    parser.add_argument('--max_disp',
                        type=float,
                        default=-1,
                        help='maximum disparity to search for')
    parser.add_argument(
        '--level',
        type=int,
        default=1,
        help='output level of output, default is level 1 (stage 3),\
                              can also use level 2 (stage 2) or level 3 (stage 1)'
    )
    parser.add_argument('--debug_image', type=str, default=None)
    parser.add_argument("--eth_testres", type=float, default=3.5)
    parser.add_argument("--score_results", action="store_true", default=False)
    parser.add_argument("--save_weights", action="store_true", default=False)
    parser.add_argument("--kitti", action="store_true", default=False)
    parser.add_argument("--eth", action="store_true", default=False)
    parser.add_argument("--mb", action="store_true", default=False)
    parser.add_argument("--all_data", action="store_true", default=False)
    parser.add_argument("--eval_train_only",
                        action="store_true",
                        default=False)
    parser.add_argument("--debug", action="store_true", default=False)
    parser.add_argument("--batchsize", type=int, default=16)
    parser.add_argument("--prepare_kitti", action="store_true", default=False)

    args = parser.parse_args()

    # wandb.init(name=args.name, project="high-res-stereo", save_code=True, magic=True, config=args)

    if not os.path.exists("output"):
        os.mkdir("output")

    kitti_merics = {}
    eth_metrics = {}
    mb_metrics = {}

    # construct model
    model = hsm(128, args.clean, level=args.level)
    model = convert_model(model)
    # wandb.watch(model)
    model = nn.DataParallel(model, device_ids=[0])
    model.cuda()

    if args.loadmodel is not None:
        pretrained_dict = torch.load(args.loadmodel)
        pretrained_dict['state_dict'] = {
            k: v
            for k, v in pretrained_dict['state_dict'].items()
            if 'disp' not in k
        }
        model.load_state_dict(pretrained_dict['state_dict'], strict=False)
    else:
        print('run with random init')
    print('Number of model parameters: {}'.format(
        sum([p.data.nelement() for p in model.parameters()])))

    model.eval()

    if not args.prepare_kitti:
        dataset = RVCDataset(args)
    if args.prepare_kitti:
        _, _, _, left_val, right_val, disp_val_L = lk15.dataloader(
            '/data/private/KITTI2015/data_scene_flow/training/',
            val=True)  # change to trainval when finetuning on KITTI

        dataset = DA.myImageFloder(left_val,
                                   right_val,
                                   disp_val_L,
                                   rand_scale=[1, 1],
                                   order=0)

    dataloader = DataLoader(dataset,
                            batch_size=args.batchsize,
                            shuffle=False,
                            num_workers=0)

    steps = 0
    max_disp = None
    origianl_image_size = None
    top_pad = None
    left_pad = None
    testres = [args.testres]
    dataset_type = None
    data_path = [args.datapath]
    # for (imgL, imgR, gt_disp_raw, max_disp, origianl_image_size, top_pad, left_pad, testres, dataset_type , data_path) in dataloader:
    for (imgL, imgR, gt_disp_raw) in dataloader:
        # Todo: this is a hot fix. Must be fixed to handle batchsize greater than 1
        data_path = data_path[0]
        img_name = os.path.basename(os.path.normpath(data_path))
        testres = float(testres[0])
        gt_disp_raw = gt_disp_raw[0]

        cum_metrics = None
        if dataset_type == 0:
            cum_metrics = mb_metrics

        elif dataset_type == 1:
            cum_metrics = eth_metrics

        elif dataset_type == 2:
            cum_metrics = kitti_merics

        print(img_name)

        if args.max_disp > 0:
            max_disp = int(args.max_disp)

        ## change max disp
        tmpdisp = int(max_disp * testres // 64 * 64)
        if (max_disp * testres / 64 * 64) > tmpdisp:
            model.module.maxdisp = tmpdisp + 64
        else:
            model.module.maxdisp = tmpdisp
        if model.module.maxdisp == 64: model.module.maxdisp = 128
        model.module.disp_reg8 = disparityregression(model.module.maxdisp,
                                                     16).cuda()
        model.module.disp_reg16 = disparityregression(model.module.maxdisp,
                                                      16).cuda()
        model.module.disp_reg32 = disparityregression(model.module.maxdisp,
                                                      32).cuda()
        model.module.disp_reg64 = disparityregression(model.module.maxdisp,
                                                      64).cuda()
        print("    max disparity = " + str(model.module.maxdisp))

        # wandb.log({"imgL": wandb.Image(imgL, caption=img_name + ", " + str(tuple(imgL.shape))),
        #            "imgR": wandb.Image(imgR, caption=img_name + ", " + str(tuple(imgR.shape)))}, step=steps)

        with torch.no_grad():
            torch.cuda.synchronize()
            start_time = time.time()

            # * output dimensions same as input dimensions
            # * (ex: imgL[1, 3, 704, 2240] then pred_disp[1, 704, 2240])
            pred_disp, entropy = model(imgL, imgR)

            torch.cuda.synchronize()
            ttime = (time.time() - start_time)

            print('    time = %.2f' % (ttime * 1000))

        # * squeeze (remove dimensions with size 1) (ex: pred_disp[1, 704, 2240] ->[704, 2240])
        pred_disp = torch.squeeze(pred_disp).data.cpu().numpy()

        top_pad = int(top_pad[0])
        left_pad = int(left_pad[0])
        entropy = entropy[top_pad:, :pred_disp.shape[1] -
                          left_pad].cpu().numpy()
        pred_disp = pred_disp[top_pad:, :pred_disp.shape[1] - left_pad]

        # save predictions
        idxname = img_name

        if not os.path.exists('output/%s/%s' % (args.name, idxname)):
            os.makedirs('output/%s/%s' % (args.name, idxname))

        idxname = '%s/disp0%s' % (idxname, args.name)

        # * shrink image back to the GT size (ex: pred_disp[675, 2236] -> [375, 1242])
        # ! we element-wise divide pred_disp by testres becasue the image is shrinking,
        # ! so the distance between pixels should also shrink by the same factor
        pred_disp_raw = cv2.resize(
            pred_disp / testres,
            (origianl_image_size[1], origianl_image_size[0]),
            interpolation=cv2.INTER_LINEAR)
        pred_disp = pred_disp_raw  # raw is to use for scoring

        gt_disp = gt_disp_raw.numpy()

        # * clip while keep inf
        # ? `pred_disp != pred_disp` is always true, right??
        # ? `pred_disp[pred_invalid] = np.inf` why do this?
        pred_invalid = np.logical_or(pred_disp == np.inf,
                                     pred_disp != pred_disp)
        pred_disp[pred_invalid] = np.inf

        pred_disp_png = (pred_disp * 256).astype("uint16")

        gt_invalid = np.logical_or(gt_disp == np.inf, gt_disp != gt_disp)
        gt_disp[gt_invalid] = 0
        gt_disp_png = (gt_disp * 256).astype("uint16")
        entorpy_png = (entropy * 256).astype('uint16')

        # ! raw output to png
        pred_disp_path = 'output/%s/%s/disp.png' % (args.name,
                                                    idxname.split('/')[0])
        gt_disp_path = 'output/%s/%s/gt_disp.png' % (args.name,
                                                     idxname.split('/')[0])
        assert (cv2.imwrite(pred_disp_path, pred_disp_png))
        assert (cv2.imwrite(gt_disp_path, gt_disp_png))
        assert (cv2.imwrite(
            'output/%s/%s/ent.png' % (args.name, idxname.split('/')[0]),
            entorpy_png))

        # ! Experimental color maps
        gt_disp_color_path = 'output/%s/%s/gt_disp_color.png' % (
            args.name, idxname.split('/')[0])
        pred_disp_color_path = 'output/%s/%s/disp_color.png' % (
            args.name, idxname.split('/')[0])

        gt_colormap = convert_to_colormap(gt_disp_png)
        pred_colormap = convert_to_colormap(pred_disp_png)
        entropy_colormap = convert_to_colormap(entorpy_png)
        assert (cv2.imwrite(gt_disp_color_path, gt_colormap))
        assert (cv2.imwrite(pred_disp_color_path, pred_colormap))

        # ! diff colormaps
        diff_colormap_path = 'output/%s/%s/diff_color.png' % (
            args.name, idxname.split('/')[0])
        false_positive_path = 'output/%s/%s/false_positive_color.png' % (
            args.name, idxname.split('/')[0])
        false_negative_path = 'output/%s/%s/false_negative_color.png' % (
            args.name, idxname.split('/')[0])
        gt_disp_png[gt_invalid] = pred_disp_png[gt_invalid]
        gt_disp_png = gt_disp_png.astype("int32")
        pred_disp_png = pred_disp_png.astype("int32")

        diff_colormap = convert_to_colormap(np.abs(gt_disp_png -
                                                   pred_disp_png))
        false_positive_colormap = convert_to_colormap(
            np.abs(np.clip(gt_disp_png - pred_disp_png, None, 0)))
        false_negative_colormap = convert_to_colormap(
            np.abs(np.clip(gt_disp_png - pred_disp_png, 0, None)))
        assert (cv2.imwrite(diff_colormap_path, diff_colormap))
        assert (cv2.imwrite(false_positive_path, false_positive_colormap))
        assert (cv2.imwrite(false_negative_path, false_negative_colormap))

        out_pfm_path = 'output/%s/%s.pfm' % (args.name, idxname)
        with open(out_pfm_path, 'w') as f:
            save_pfm(f, pred_disp[::-1, :])
        with open(
                'output/%s/%s/time_%s.txt' %
            (args.name, idxname.split('/')[0], args.name), 'w') as f:
            f.write(str(ttime))
        print("    output = " + out_pfm_path)

        caption = img_name + ", " + str(
            tuple(pred_disp_png.shape)) + ", max disparity = " + str(
                int(max_disp[0])) + ", time = " + str(ttime)

        # read GT depthmap and upload as jpg

        # wandb.log({"disparity": wandb.Image(pred_colormap, caption=caption) , "gt": wandb.Image(gt_colormap), "entropy": wandb.Image(entropy_colormap, caption= str(entorpy_png.shape)),
        #            "diff":wandb.Image(diff_colormap), "false_positive":wandb.Image(false_positive_colormap), "false_negative":wandb.Image(false_negative_colormap)}, step=steps)

        torch.cuda.empty_cache()
        steps += 1

        # Todo: find out what mask0nocc does. It's probably not the same as KITTI's object map
        if dataset_type == 2:
            obj_map_path = os.path.join(data_path, "obj_map.png")
        else:
            obj_map_path = None

        if args.score_results:
            if pred_disp_raw.shape != gt_disp_raw.shape:  # pred_disp_raw[375 x 1242] gt_disp_raw[675 x 2236]
                ratio = float(gt_disp_raw.shape[1]) / pred_disp_raw.shape[1]
                disp_resized = cv2.resize(
                    pred_disp_raw,
                    (gt_disp_raw.shape[1], gt_disp_raw.shape[0])) * ratio
                pred_disp_raw = disp_resized  # [675 x 2236]
            # if args.debug:
            #     out_resized_pfm_path = 'output/%s/%s/pred_scored.pfm' % (args.name, img_name)
            #     with open(out_resized_pfm_path, 'w') as f:
            #         save_pfm(f, pred_disp_raw)

            #     out_resized_gt_path = 'output/%s/%s/gt_scored.pfm' % (args.name, img_name)
            #     with open(out_resized_gt_path, 'w') as f:
            #         save_pfm(f, gt_disp_raw.numpy())

            metrics = score_rvc.get_metrics(
                pred_disp_raw,
                gt_disp_raw,
                int(max_disp[0]),
                dataset_type,
                ('output/%s/%s' % (args.name, idxname.split('/')[0])),
                disp_path=pred_disp_path,
                gt_path=gt_disp_path,
                obj_map_path=obj_map_path,
                debug=args.debug)

            avg_metrics = {}
            for (key, val) in metrics.items():
                if cum_metrics.get(key) == None:
                    cum_metrics[key] = []
                cum_metrics[key].append(val)
                avg_metrics["avg_" + key] = sum(cum_metrics[key]) / len(
                    cum_metrics[key])

            # wandb.log(metrics, step=steps)
            # wandb.log(avg_metrics, step=steps)

    # if args.save_weights and os.path.exists(args.loadmodel):
    #     wandb.save(args.loadmodel)

    if args.prepare_kitti and (args.all_data or args.kitti):
        in_path = 'output/%s' % (args.name)
        out_path = "/home/isaac/high-res-stereo/kitti_submission_output"
        out_path = prepare_kitti(in_path, out_path)
        subprocess.run(
            ["/home/isaac/KITTI2015_devkit/cpp/eval_scene_flow", out_path])
        print("KITTI submission evaluation saved to: " + out_path)
from dataloader import listfiles as ls
from dataloader import listsceneflow as lt
from dataloader import KITTIloader2015 as lk15
from dataloader import KITTIloader2012 as lk12
from dataloader import MiddleburyLoader as DA

batch_size = args.batchsize
scale_factor = args.maxdisp / 384.  # controls training resolution

# * Gengshan told me to set the scale s.t. the mean of the scale would be same as testres (kitti = 1.8, eth = 3.5 (?) , MB = 1)
kitti_scale_range = [1.4 , 2.2]  # ? multiply scale_factor or not? Since default maxdisp is 384, scale_factor is 1 by default. (Asked gengshan via FB messenger)

all_left_img, all_right_img, all_left_disp, all_left_entropy = unlabeled_loader()


loader = DA.myImageFloder(all_left_img, all_right_img, all_left_disp, left_entropy=all_left_entropy, rand_scale=kitti_scale_range,
                                  order=0, entropy_threshold=args.threshold) # or 0.95


TrainImgLoader = torch.utils.data.DataLoader(
    loader,
    batch_size=batch_size, shuffle=False, num_workers=batch_size, drop_last=True, worker_init_fn=_init_fn)

print('%d batches per epoch' % (len(loader) // batch_size))


def train(imgL, imgR, disp_L):
    model.train()
    imgL = torch.FloatTensor(imgL)
    imgR = torch.FloatTensor(imgR)
    disp_L = torch.FloatTensor(disp_L)
Exemple #11
0
def main():
    parser = argparse.ArgumentParser(description='HSM-Net')
    parser.add_argument('--maxdisp',
                        type=int,
                        default=384,
                        help='maxium disparity')
    parser.add_argument('--name', default='name')
    parser.add_argument('--database',
                        default='/data/private',
                        help='data path')
    parser.add_argument('--epochs',
                        type=int,
                        default=10,
                        help='number of epochs to train')
    parser.add_argument(
        '--batch_size',
        type=int,
        default=18,
        # when maxdisp is 768, 18 is the most you can fit in 2 V100s (with syncBN on)
        help='samples per batch')
    parser.add_argument('--val_batch_size',
                        type=int,
                        default=2,
                        help='validation samples per batch')
    parser.add_argument('--loadmodel', default=None, help='weights path')
    parser.add_argument('--log_dir',
                        default="/data/private/logs/high-res-stereo")
    # parser.add_argument('--savemodel', default=os.path.join(os.getcwd(),'/trained_model'),
    #                     help='save path')
    parser.add_argument('--seed',
                        type=int,
                        default=1,
                        metavar='S',
                        help='random seed (default: 1)')
    parser.add_argument('--val_epoch', type=int, default=2)
    parser.add_argument('--save_epoch', type=int, default=1)
    parser.add_argument("--val", action="store_true", default=False)
    parser.add_argument("--save_numpy", action="store_true", default=False)
    parser.add_argument("--testres", type=float, default=1.8)
    parser.add_argument("--threshold", type=float, default=0.7)
    parser.add_argument("--use_pseudoGT", default=False, action="store_true")
    parser.add_argument("--lr", default=1e-3, type=float)
    parser.add_argument("--lr_decay", default=2, type=int)
    parser.add_argument("--gpu", default=[0], nargs="+")

    args = parser.parse_args()
    torch.manual_seed(args.seed)
    torch.manual_seed(args.seed)  # set again
    torch.cuda.manual_seed(args.seed)
    scale_factor = args.maxdisp / 384.  # controls training resolution
    args.name = args.name + "_" + time.strftime('%l:%M%p_%Y%b%d').strip(" ")
    gpu = []
    for i in args.gpu:
        gpu.append(int(i))
    args.gpu = gpu

    all_left_img = [
        "/data/private/Middlebury/mb-ex/trainingF/Cable-perfect/im0.png"
    ] * args.batch_size * 16
    all_right_img = [
        "/data/private/Middlebury/mb-ex/trainingF/Cable-perfect/im1.png"
    ] * args.batch_size * 16
    all_left_disp = [
        "/data/private/Middlebury/kitti_testres1.15_maxdisp384/disp/Cable-perfect.npy"
    ] * args.batch_size * 16
    all_left_entp = [
        "/data/private/Middlebury/kitti_testres1.15_maxdisp384/entropy/Cable-perfect.npy"
    ] * args.batch_size * 16

    loader_mb = DA.myImageFloder(all_left_img,
                                 all_right_img,
                                 all_left_disp,
                                 rand_scale=[0.225, 0.6 * scale_factor],
                                 order=0,
                                 use_pseudoGT=args.use_pseudoGT,
                                 entropy_threshold=args.threshold,
                                 left_entropy=all_left_entp)

    val_left_img = [
        "/data/private/Middlebury/mb-ex/trainingF/Cable-perfect/im0.png"
    ]
    val_right_img = [
        "/data/private/Middlebury/mb-ex/trainingF/Cable-perfect/im1.png"
    ]
    val_disp = [
        "/data/private/Middlebury/mb-ex/trainingF/Cable-perfect/disp0GT.pfm"
    ]
    val_loader_mb = DA.myImageFloder(val_left_img,
                                     val_right_img,
                                     val_disp,
                                     is_validation=True,
                                     testres=args.testres)

    TrainImgLoader = torch.utils.data.DataLoader(
        loader_mb,
        batch_size=args.batch_size,
        shuffle=True,
        drop_last=True,
        worker_init_fn=_init_fn,
        num_workers=args.batch_size)  # , , worker_init_fn=_init_fn

    ValImgLoader = torch.utils.data.DataLoader(val_loader_mb,
                                               batch_size=1,
                                               shuffle=False,
                                               drop_last=False,
                                               worker_init_fn=_init_fn,
                                               num_workers=1)

    print('%d batches per epoch' % (len(loader_mb) // args.batch_size))

    model = hsm(args.maxdisp, clean=False, level=1)

    gpus = [0, 1]
    if len(gpus) > 1:
        from sync_batchnorm.sync_batchnorm import convert_model
        model = nn.DataParallel(model, device_ids=gpus)
        model = convert_model(model)
    else:
        model = nn.DataParallel(model, device_ids=gpus)

    model.cuda()

    # load model
    if args.loadmodel is not None:
        print("loading pretrained model: " + str(args.loadmodel))
        pretrained_dict = torch.load(args.loadmodel)
        pretrained_dict['state_dict'] = {
            k: v
            for k, v in pretrained_dict['state_dict'].items()
            if ('disp' not in k)
        }
        model.load_state_dict(pretrained_dict['state_dict'], strict=False)

    print('Number of model parameters: {}'.format(
        sum([p.data.nelement() for p in model.parameters()])))

    optimizer = optim.Adam(model.parameters(), lr=args.lr, betas=(0.9, 0.999))

    log = logger.Logger(args.log_dir, args.name, save_numpy=args.save_numpy)
    total_iters = 0
    val_sample_count = 0
    val_batch_count = 0

    save_path = os.path.join(args.log_dir,
                             os.path.join(args.name, "saved_model"))
    os.makedirs(save_path, exist_ok=True)

    for epoch in range(1, args.epochs + 1):
        total_train_loss = 0
        train_score_accum_dict = {
        }  # accumulates scores throughout a batch to get average score
        train_score_accum_dict["num_scored"] = 0
        adjust_learning_rate(optimizer,
                             args.lr,
                             args.lr_decay,
                             epoch,
                             args.epochs,
                             decay_rate=0.1)

        print('Epoch %d / %d' % (epoch, args.epochs))

        # SAVE
        if epoch != 1 and epoch % args.save_epoch == 0:
            print("saving weights at epoch: " + str(epoch))
            savefilename = os.path.join(save_path,
                                        'ckpt_' + str(total_iters) + '.tar')

            torch.save(
                {
                    'iters': total_iters,
                    'state_dict': model.state_dict(),
                    'train_loss': total_train_loss / len(TrainImgLoader),
                    "optimizer": optimizer.state_dict()
                }, savefilename)

        ## val ##
        if epoch % args.val_epoch == 0:
            print("validating at epoch: " + str(epoch))
            val_score_accum_dict = {
            }  # accumulates scores throughout a batch to get average score
            for batch_idx, (imgL_crop, imgR_crop,
                            disp_crop_L) in enumerate(ValImgLoader):

                vis, scores_list, err_map_list = val_step(
                    model, imgL_crop, imgR_crop, disp_crop_L, args.maxdisp,
                    args.testres)

                for score, err_map in zip(scores_list, err_map_list):
                    for (score_tag,
                         score_val), (map_tag,
                                      map_val) in zip(score.items(),
                                                      err_map.items()):
                        log.scalar_summary("val/" + score_tag, score_val,
                                           val_sample_count)
                        log.image_summary("val/" + map_tag, map_val,
                                          val_sample_count)

                        if score_tag not in val_score_accum_dict.keys():
                            val_score_accum_dict[score_tag] = 0
                        val_score_accum_dict[score_tag] += score_val
                    val_sample_count += 1

                log.image_summary('val/left', imgL_crop[0:1], val_sample_count)
                log.image_summary('val/right', imgR_crop[0:1],
                                  val_sample_count)
                log.disp_summary('val/gt0', disp_crop_L[0:1],
                                 val_sample_count)  # <-- GT disp
                log.entp_summary('val/entropy', vis['entropy'],
                                 val_sample_count)
                log.disp_summary('val/output3', vis['output3'][0],
                                 val_sample_count)

                for score_tag, score_val in val_score_accum_dict.items():
                    log.scalar_summary("val/" + score_tag + "_batch_avg",
                                       score_val, val_batch_count)
                val_batch_count += 1

        ## training ##
        for batch_idx, (imgL_crop, imgR_crop,
                        disp_crop_L) in enumerate(TrainImgLoader):
            print("training at epoch: " + str(epoch))

            is_scoring = total_iters % 10 == 0

            loss, vis, scores_list, maps = train_step(model,
                                                      optimizer,
                                                      imgL_crop,
                                                      imgR_crop,
                                                      disp_crop_L,
                                                      args.maxdisp,
                                                      is_scoring=is_scoring)

            total_train_loss += loss

            if is_scoring:
                log.scalar_summary('train/loss_batch', loss, total_iters)
                for score in scores_list:
                    for tag, val in score.items():
                        log.scalar_summary("train/" + tag + "_batch", val,
                                           total_iters)

                        if tag not in train_score_accum_dict.keys():
                            train_score_accum_dict[tag] = 0
                        train_score_accum_dict[tag] += val
                        train_score_accum_dict[
                            "num_scored"] += imgL_crop.shape[0]

                for tag, err_map in maps[0].items():
                    log.image_summary("train/" + tag, err_map, total_iters)

            if total_iters % 10 == 0:
                log.image_summary('train/left', imgL_crop[0:1], total_iters)
                log.image_summary('train/right', imgR_crop[0:1], total_iters)
                log.disp_summary('train/gt0', disp_crop_L[0:1],
                                 total_iters)  # <-- GT disp
                log.entp_summary('train/entropy', vis['entropy'][0:1],
                                 total_iters)
                log.disp_summary('train/output3', vis['output3'][0:1],
                                 total_iters)

            total_iters += 1

        log.scalar_summary('train/loss',
                           total_train_loss / len(TrainImgLoader), epoch)
        for tag, val in train_score_accum_dict.items():
            log.scalar_summary("train/" + tag + "_avg",
                               val / train_score_accum_dict["num_scored"],
                               epoch)

        torch.cuda.empty_cache()
    # Save final checkpoint
    print("Finished training!\n Saving the last checkpoint...")
    savefilename = os.path.join(save_path, 'final' + '.tar')

    torch.save(
        {
            'iters': total_iters,
            'state_dict': model.state_dict(),
            'train_loss': total_train_loss / len(TrainImgLoader),
            "optimizer": optimizer.state_dict()
        }, savefilename)