示例#1
0
def iterate(mode, args, loader, model, optimizer, logger, epoch):
    block_average_meter = AverageMeter()
    average_meter = AverageMeter()
    meters = [block_average_meter, average_meter]

    # switch to appropriate mode
    assert mode in ["train", "val", "eval", "test_prediction", "test_completion"], \
        "unsupported mode: {}".format(mode)
    if mode == 'train':
        model.train()
        lr = helper.adjust_learning_rate(args.lr, optimizer, epoch)
    else:
        model.eval()
        lr = 0

    for i, batch_data in enumerate(loader):
        start = time.time()
        batch_data = {
            key: val.to(device)
            for key, val in batch_data.items() if val is not None
        }
        gt = batch_data[
            'gt'] if mode != 'test_prediction' and mode != 'test_completion' else None
        data_time = time.time() - start

        start = time.time()
        with torch.no_grad():  # 自己加的
            pred = model(batch_data)
        depth_loss, photometric_loss, smooth_loss, mask = 0, 0, 0, None

        gpu_time = time.time() - start

        # measure accuracy and record loss
        with torch.no_grad():
            mini_batch_size = next(iter(batch_data.values())).size(0)
            result = Result()
            if mode != 'test_prediction' and mode != 'test_completion':
                #result.evaluate(pred.data, gt.data, photometric_loss)
                result.evaluate(pred.data.cpu(), gt.data.cpu(),
                                photometric_loss)
            [
                m.update(result, gpu_time, data_time, mini_batch_size)
                for m in meters
            ]
            logger.conditional_print(mode, i, epoch, lr, len(loader),
                                     block_average_meter, average_meter)
            logger.conditional_save_img_comparison(mode, i, batch_data, pred,
                                                   epoch)
            logger.conditional_save_pred(mode, i, pred, epoch)

    avg = logger.conditional_save_info(mode, average_meter, epoch)
    is_best = logger.rank_conditional_save_best(mode, avg, epoch)
    if is_best and not (mode == "train"):
        logger.save_img_comparison_as_best(mode, epoch)
    logger.conditional_summarize(mode, avg, is_best)

    return avg, is_best
示例#2
0
def iterate(mode, args, loader, model, optimizer, logger, epoch):
    block_average_meter = AverageMeter()
    average_meter = AverageMeter()
    meters = [block_average_meter, average_meter]

    # switch to appropriate mode
    assert mode in ["train", "val", "eval", "test_prediction", "test_completion"], \
        "unsupported mode: {}".format(mode)
    if mode == 'train':
        model.train()
        lr = helper.adjust_learning_rate(args.lr, optimizer, epoch)
    else:
        model.eval(
        )  # batchnorm or dropout layers will work in eval mode instead of training mode
        lr = 0

    for i, batch_data in enumerate(
            loader
    ):  # batch_data keys: 'd' (depth), 'gt' (ground truth), 'g' (gray)
        start = time.time()
        batch_data = {
            key: val.to(device)
            for key, val in batch_data.items() if val is not None
        }

        gt = batch_data[
            'gt'] if mode != 'test_prediction' and mode != 'test_completion' else None
        data_time = time.time() - start

        start = time.time()

        pred = model(batch_data)
        if args.save_images:  # save depth predictions
            pred_out_dir = max(glob.glob('../outputs/var_final_NN/var.test*'),
                               key=os.path.getmtime) + '/dense_depth_images'
            pred1 = pred.cpu().detach().numpy()[:, 0, :, :]
            for im_idx, pred_im in enumerate(pred1):
                pred_out_dir1 = os.path.abspath(pred_out_dir)
                cur_path = os.path.abspath((loader.dataset.paths)['d'][i])
                basename = os.path.basename(cur_path)
                cur_dir = os.path.abspath(os.path.dirname(cur_path))
                cur_dir = cur_dir.split('var_final_NN/')[1]
                new_dir = os.path.abspath(pred_out_dir1 + '/' + cur_dir)
                new_path = os.path.abspath(new_dir + '/' + basename)
                if os.path.isdir(new_dir) == False:
                    os.makedirs(new_dir)

                depth_write(new_path, pred_im)

        depth_loss, photometric_loss, smooth_loss, mask = 0, 0, 0, None
        if mode == 'train':
            # Loss 1: the direct depth supervision from ground truth label
            # mask=1 indicates that a pixel does not ground truth labels
            if 'sparse' in args.train_mode:
                depth_loss = depth_criterion(pred, batch_data['d'])
                mask = (batch_data['d'] < 1e-3).float()
            elif 'dense' in args.train_mode:
                depth_loss = depth_criterion(pred, gt)
                mask = (gt < 1e-3).float()

            # Loss 2: the self-supervised photometric loss
            if args.use_pose:
                # create multi-scale pyramids
                pred_array = helper.multiscale(pred)
                rgb_curr_array = helper.multiscale(batch_data['rgb'])
                rgb_near_array = helper.multiscale(batch_data['rgb_near'])
                if mask is not None:
                    mask_array = helper.multiscale(mask)
                num_scales = len(pred_array)

                # compute photometric loss at multiple scales
                for scale in range(len(pred_array)):
                    pred_ = pred_array[scale]
                    rgb_curr_ = rgb_curr_array[scale]
                    rgb_near_ = rgb_near_array[scale]
                    mask_ = None
                    if mask is not None:
                        mask_ = mask_array[scale]

                    # compute the corresponding intrinsic parameters
                    height_, width_ = pred_.size(2), pred_.size(3)
                    intrinsics_ = kitti_intrinsics.scale(height_, width_)

                    # inverse warp from a nearby frame to the current frame
                    warped_ = homography_from(rgb_near_, pred_,
                                              batch_data['r_mat'],
                                              batch_data['t_vec'], intrinsics_)
                    photometric_loss += photometric_criterion(
                        rgb_curr_, warped_, mask_) * (2**(scale - num_scales))

            # Loss 3: the depth smoothness loss
            smooth_loss = smoothness_criterion(pred) if args.w2 > 0 else 0

            # backprop
            loss = depth_loss + args.w1 * photometric_loss + args.w2 * smooth_loss
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        gpu_time = time.time() - start

        # measure accuracy and record loss of each batch
        with torch.no_grad(
        ):  # impacts the autograd engine and deactivate it (will reduce memory usage and speed up computations)
            mini_batch_size = next(iter(batch_data.values())).size(0)
            result = Result()  # metrics
            if mode != 'test_prediction' and mode != 'test_completion':
                result.evaluate(pred.data, gt.data, photometric_loss)
            [
                m.update(result, gpu_time, data_time, mini_batch_size)
                for m in meters
            ]
            logger.conditional_print(mode, i, epoch, args.epochs, lr,
                                     len(loader), block_average_meter,
                                     average_meter)
            logger.conditional_save_img_comparison(mode, i, batch_data, pred,
                                                   epoch)
            logger.conditional_save_pred(mode, i, pred, epoch)
        del pred

    avg = logger.conditional_save_info(
        mode, average_meter,
        epoch)  # take the avg of all the batches, to get the epoch metrics
    is_best = logger.rank_conditional_save_best(mode, avg, epoch, args.epochs)
    if is_best and not (mode == "train"):
        logger.save_img_comparison_as_best(mode, epoch)
    logger.conditional_summarize(mode, avg, is_best)

    return avg, is_best
def iterate(mode, args, loader, model, optimizer, logger, epoch):
    block_average_meter = AverageMeter()
    average_meter = AverageMeter()
    meters = [block_average_meter, average_meter]

    # switch to appropriate mode
    assert mode in ["train", "val", "eval", "test_prediction", "test_completion"], \
        "unsupported mode: {}".format(mode)
    if mode == 'train':
        model.train()
        lr = helper.adjust_learning_rate(args.lr, optimizer, epoch)
    else:
        model.eval()
        lr = 0

    for i, batch_data in enumerate(loader):
        print("The batch data keys are {}".format(batch_data.keys()))

        start = time.time()
        batch_data = {
            key: val.to(device)
            for key, val in batch_data.items() if val is not None
        }
        gt = batch_data[
            'gt'] if mode != 'test_prediction' and mode != 'test_completion' else None
        data_time = time.time() - start

        start = time.time()

        temp_d = batch_data['d']
        temp_gt = batch_data['gt']
        temp_g = batch_data['g']

        print("The depth min:{}, max:{}, shape:{}, dtype:{}".format(
            torch.min(temp_d), torch.max(temp_d), temp_d.shape, temp_d.dtype))
        print("The groundtruth min:{}, max:{}, shape:{}, dtype:{}".format(
            torch.min(temp_gt), torch.max(temp_gt), temp_gt.shape,
            temp_gt.dtype))
        print("The greyscale min:{}, max:{}, shape:{}, dtype:{}".format(
            torch.min(temp_g), torch.max(temp_g), temp_g.shape, temp_g.dtype))

        pred = model(batch_data)

        temp_out = pred.detach().cpu()
        print("The output min:{}, max:{}, shape:{}, dtype:{}".format(
            torch.min(temp_out), torch.max(temp_out), temp_out.shape,
            temp_out.dtype))

        depth_loss, photometric_loss, smooth_loss, mask = 0, 0, 0, None
        if mode == 'train':
            # Loss 1: the direct depth supervision from ground truth label
            # mask=1 indicates that a pixel does not ground truth labels
            if 'sparse' in args.train_mode:
                depth_loss = depth_criterion(pred, batch_data['d'])
                mask = (batch_data['d'] < 1e-3).float()
            elif 'dense' in args.train_mode:
                depth_loss = depth_criterion(pred, gt)
                mask = (gt < 1e-3).float()

            # Loss 2: the self-supervised photometric loss
            if args.use_pose:
                # create multi-scale pyramids
                pred_array = helper.multiscale(pred)
                rgb_curr_array = helper.multiscale(batch_data['rgb'])
                rgb_near_array = helper.multiscale(batch_data['rgb_near'])
                if mask is not None:
                    mask_array = helper.multiscale(mask)
                num_scales = len(pred_array)

                # compute photometric loss at multiple scales
                for scale in range(len(pred_array)):
                    pred_ = pred_array[scale]
                    rgb_curr_ = rgb_curr_array[scale]
                    rgb_near_ = rgb_near_array[scale]
                    mask_ = None
                    if mask is not None:
                        mask_ = mask_array[scale]

                    # compute the corresponding intrinsic parameters
                    height_, width_ = pred_.size(2), pred_.size(3)
                    intrinsics_ = kitti_intrinsics.scale(height_, width_)

                    # inverse warp from a nearby frame to the current frame
                    warped_ = homography_from(rgb_near_, pred_,
                                              batch_data['r_mat'],
                                              batch_data['t_vec'], intrinsics_)
                    photometric_loss += photometric_criterion(
                        rgb_curr_, warped_, mask_) * (2**(scale - num_scales))

            # Loss 3: the depth smoothness loss
            smooth_loss = smoothness_criterion(pred) if args.w2 > 0 else 0

            # backprop
            loss = depth_loss + args.w1 * photometric_loss + args.w2 * smooth_loss
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        gpu_time = time.time() - start

        # measure accuracy and record loss
        with torch.no_grad():
            mini_batch_size = next(iter(batch_data.values())).size(0)
            result = Result()
            if mode != 'test_prediction' and mode != 'test_completion':
                result.evaluate(pred.data, gt.data, photometric_loss)
            [
                m.update(result, gpu_time, data_time, mini_batch_size)
                for m in meters
            ]
            logger.conditional_print(mode, i, epoch, lr, len(loader),
                                     block_average_meter, average_meter)
            logger.conditional_save_img_comparison(mode, i, batch_data, pred,
                                                   epoch)
            logger.conditional_save_pred(mode, i, pred, epoch)

    avg = logger.conditional_save_info(mode, average_meter, epoch)
    is_best = logger.rank_conditional_save_best(mode, avg, epoch)
    if is_best and not (mode == "train"):
        logger.save_img_comparison_as_best(mode, epoch)
    logger.conditional_summarize(mode, avg, is_best)

    return avg, is_best
def iterate(mode, args, loader, model, optimizer, logger, epoch):
    block_average_meter = AverageMeter()
    average_meter = AverageMeter()
    meters = [block_average_meter, average_meter]

    # switch to appropriate mode
    assert mode in ["train", "val", "eval", "test_prediction", "test_completion"], \
        "unsupported mode: {}".format(mode)
    if mode == 'train':
        model.train()
        lr = helper.adjust_learning_rate(args.lr, optimizer, epoch)
    else:
        model.eval()
        lr = 0

    torch.set_printoptions(profile="full")
    table_is = np.zeros(400)
    for i, batch_data in enumerate(loader):

        sparse_depth_pathname = batch_data['d_path'][0]
        print(sparse_depth_pathname)
        del batch_data['d_path']
        print("i: ", i)
        start = time.time()
        batch_data = {
            key: val.to(device)
            for key, val in batch_data.items() if val is not None
        }
        gt = batch_data[
            'gt'] if mode != 'test_prediction' and mode != 'test_completion' else None

        # adjust depth for features
        depth_adjust = args.depth_adjust
        adjust_features = False

        if depth_adjust and args.use_d:
            if args.type_feature == "sq":
                if args.use_rgb:
                    depth_new, alg_mode, feat_mode, features, shape = depth_adjustment(
                        batch_data['d'], args.test_mode, args.feature_mode,
                        args.feature_num, args.rank_file_global_sq,
                        adjust_features, i, model_orig, args.seed,
                        batch_data['rgb'])
                else:
                    depth_new, alg_mode, feat_mode, features, shape = depth_adjustment(
                        batch_data['d'], args.test_mode, args.feature_mode,
                        args.feature_num, args.rank_file_global_sq,
                        adjust_features, i, model_orig, args.seed)
            elif args.type_feature == "lines":
                depth_new, alg_mode, feat_mode, features = depth_adjustment_lines(
                    batch_data['d'], args.test_mode, args.feature_mode,
                    args.feature_num, args.rank_file_global_sq, i, model_orig,
                    args.seed)

            batch_data['d'] = torch.Tensor(depth_new).unsqueeze(0).unsqueeze(
                1).to(device)
        data_time = time.time() - start
        start = time.time()
        if mode == "train":
            pred = model(batch_data)
        else:
            with torch.no_grad():
                pred = model(batch_data)
        # im = batch_data['d'].detach().cpu().numpy()
        # im_sq = im.squeeze()
        # plt.figure()
        # plt.imshow(im_sq)
        # plt.show()
        # for i in range(im_sq.shape[0]):
        #     print(f"{i} - {np.sum(im_sq[i])}")


#        pred = pred +0.155
#        gt = gt+0.155
# compute loss
        depth_loss, photometric_loss, smooth_loss, mask = 0, 0, 0, None
        if mode == 'train':
            # Loss 1: the direct depth supervision from ground truth label
            # mask=1 indicates that a pixel does not ground truth labels
            if 'sparse' in args.train_mode:
                depth_loss = depth_criterion(pred, batch_data['d'])
                mask = (batch_data['d'] < 1e-3).float()
            elif 'dense' in args.train_mode:
                depth_loss = depth_criterion(pred, gt)
                mask = (gt < 1e-3).float()
            # Loss 2: the self-supervised photometric loss
            if args.use_pose:
                # create multi-scale pyramids
                pred_array = helper.multiscale(pred)
                rgb_curr_array = helper.multiscale(batch_data['rgb'])
                rgb_near_array = helper.multiscale(batch_data['rgb_near'])
                if mask is not None:
                    mask_array = helper.multiscale(mask)
                num_scales = len(pred_array)
                # compute photometric loss at multiple scales
                for scale in range(len(pred_array)):
                    pred_ = pred_array[scale]
                    rgb_curr_ = rgb_curr_array[scale]
                    rgb_near_ = rgb_near_array[scale]
                    mask_ = None
                    if mask is not None:
                        mask_ = mask_array[scale]
                    # compute the corresponding intrinsic parameters
                    height_, width_ = pred_.size(2), pred_.size(3)
                    intrinsics_ = kitti_intrinsics.scale(height_, width_)
                    # inverse warp from a nearby frame to the current frame
                    warped_ = homography_from(rgb_near_, pred_,
                                              batch_data['r_mat'],
                                              batch_data['t_vec'], intrinsics_)
                    photometric_loss += photometric_criterion(
                        rgb_curr_, warped_, mask_) * (2**(scale - num_scales))
            # Loss 3: the depth smoothness loss
            smooth_loss = smoothness_criterion(pred) if args.w2 > 0 else 0

            # backprop
            loss = depth_loss + args.w1 * photometric_loss + args.w2 * smooth_loss
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        gpu_time = time.time() - start

        # measure accuracy and record loss
        with torch.no_grad():
            mini_batch_size = next(iter(batch_data.values())).size(0)
            result = Result()
            if mode != 'test_prediction' and mode != 'test_completion':
                result.evaluate(pred.data, gt.data, photometric_loss)
            [
                m.update(result, gpu_time, data_time, mini_batch_size)
                for m in meters
            ]
            print(f"rmse: {result.rmse:,}")
            if result.rmse < 6000:
                print("good rmse")
            elif result.rmse > 12000:
                print("bad rmse")
            logger.conditional_print(mode, i, epoch, lr, len(loader),
                                     block_average_meter, average_meter)
            logger.conditional_save_img_comparison(mode, i, batch_data, pred,
                                                   epoch)
            logger.conditional_save_pred(mode, i, pred, epoch)

        # save log and checkpoint
        every = 999 if mode == "val" else 200
        if i % every == 0 and i != 0:

            print(
                f"test settings (main_orig eval): {args.type_feature} {args.test_mode} {args.feature_mode} {args.feature_num}"
            )
            avg = logger.conditional_save_info(mode, average_meter, epoch)
            is_best = logger.rank_conditional_save_best(mode, avg, epoch)
            if is_best and not (mode == "train"):
                logger.save_img_comparison_as_best(mode, epoch)
            logger.conditional_summarize(mode, avg, is_best)

            if mode != "val":
                #if 1:
                helper.save_checkpoint({  # save checkpoint
                    'epoch': epoch,
                    'model': model.module.state_dict(),
                    'best_result': logger.best_result,
                    'optimizer': optimizer.state_dict(),
                    'args': args,
                }, is_best, epoch, logger.output_directory, args.type_feature, args.test_mode, args.feature_num, args.feature_mode, args.depth_adjust, i, every, "scratch")

        # draw features
        # run_info = [args.type_feature, alg_mode, feat_mode, model_orig]
        # if batch_data['rgb'] != None and 1 and (i % 1) == 0:
        #     draw("sq", batch_data['rgb'], batch_data['d'], features, shape[1], run_info, i, result)

    return avg, is_best
def iterate(mode, args, loader, model, optimizer, logger, epoch):
    block_average_meter = AverageMeter()
    average_meter = AverageMeter()
    meters = [block_average_meter, average_meter]

    # switch to appropriate mode
    assert mode in ["train", "val", "eval", "test_prediction", "test_completion"], \
        "unsupported mode: {}".format(mode)
    if mode == 'train':
        model.train()
        lr = helper.adjust_learning_rate(args.lr, optimizer, epoch)
    else:
        model.eval()
        lr = 0

    torch.set_printoptions(profile="full")
    for i, batch_data in enumerate(loader):

        name = batch_data['name'][0]
        print(name)
        del batch_data['name']
        print("i: ", i)
        # each batch data is 1 and has three keys d, gt, g and dim [1, 352, 1216]
        start = time.time()
        batch_data = {
            key: val.to(device)
            for key, val in batch_data.items() if val is not None
        }
        gt = batch_data[
            'gt'] if mode != 'test_prediction' and mode != 'test_completion' else None

        # if args.type_feature=="sq":
        #     depth_adjustment(gt, False)

        data_time = time.time() - start

        start = time.time()
        if mode == "train":
            pred = model(batch_data)
        else:
            with torch.no_grad():
                pred = model(batch_data)

        vis=False
        if vis:
            im = batch_data['gt'].detach().cpu().numpy()
            im_sq = im.squeeze()
            plt.figure()
            plt.imshow(im_sq)
            plt.show()
            # for i in range(im_sq.shape[0]):
            #     print(f"{i} - {np.sum(im_sq[i])}")

        depth_loss, photometric_loss, smooth_loss, mask = 0, 0, 0, None
        if mode == 'train':
            # Loss 1: the direct depth supervision from ground truth label
            # mask=1 indicates that a pixel does not ground truth labels
            if 'sparse' in args.train_mode:
                depth_loss = depth_criterion(pred, batch_data['d'])
                print("d pts: ", len(torch.where(batch_data['d']>0)[0]))
                mask = (batch_data['d'] < 1e-3).float()
            elif 'dense' in args.train_mode:
                depth_loss = depth_criterion(pred, gt)
                mask = (gt < 1e-3).float()

            # Loss 2: the self-supervised photometric loss
            if args.use_pose:
                # create multi-scale pyramids
                pred_array = helper.multiscale(pred)
                rgb_curr_array = helper.multiscale(batch_data['rgb'])
                rgb_near_array = helper.multiscale(batch_data['rgb_near'])
                if mask is not None:
                    mask_array = helper.multiscale(mask)
                num_scales = len(pred_array)

                # compute photometric loss at multiple scales
                for scale in range(len(pred_array)):
                    pred_ = pred_array[scale]
                    rgb_curr_ = rgb_curr_array[scale]
                    rgb_near_ = rgb_near_array[scale]
                    mask_ = None
                    if mask is not None:
                        mask_ = mask_array[scale]

                    # compute the corresponding intrinsic parameters
                    height_, width_ = pred_.size(2), pred_.size(3)
                    intrinsics_ = kitti_intrinsics.scale(height_, width_)

                    # inverse warp from a nearby frame to the current frame
                    warped_ = homography_from(rgb_near_, pred_,
                                              batch_data['r_mat'],
                                              batch_data['t_vec'], intrinsics_)
                    photometric_loss += photometric_criterion(
                        rgb_curr_, warped_, mask_) * (2**(scale - num_scales))

            # Loss 3: the depth smoothness loss
            smooth_loss = smoothness_criterion(pred) if args.w2 > 0 else 0

            # backprop



            loss = depth_loss + args.w1 * photometric_loss + args.w2 * smooth_loss
            optimizer.zero_grad()
            loss.backward()
            zero_params(model)
            optimizer.step()



        gpu_time = time.time() - start

        # counting pixels in each bin
        #binned_pixels = np.load("value.npy", allow_pickle=True)
        #print(len(binned_pixels))

        if (i % 1 == 0 and args.evaluate and args.instancewise) or\
                (i % args.every == 0 and not args.evaluate and not args.instancewise): # global training
            #    print(model.module.conv4[5].conv1.weight[0])
            # print(model.conv4.5.bn2.weight)
            # print(model.module.parameter.grad)
            #print("*************swiches:")
            torch.set_printoptions(precision=7, sci_mode=False)

            if model.module.phi is not None:
                mmp = 1000 * model.module.phi
                phi = F.softplus(mmp)

                S = phi / torch.sum(phi)
                #print("S", S[1, -10:])
                S_numpy= S.detach().cpu().numpy()

            if args.instancewise:

                global Ss
                if "Ss" not in globals():
                    Ss = []
                    Ss.append(S_numpy)
                else:
                    Ss.append(S_numpy)

            # GLOBAL
            if (i % args.every ==0  and not args.evaluate and not args.instancewise and model.module.phi is not None):

                np.set_printoptions(precision=4)

                switches_2d_argsort = np.argsort(S_numpy, None) # 2d to 1d sort torch.Size([9, 31])
                switches_2d_sort = np.sort(S_numpy, None)
                print("Switches: ")
                print(switches_2d_argsort[:10])
                print(switches_2d_sort[:10])
                print("and")
                print(switches_2d_argsort[-10:])
                print(switches_2d_sort[-10:])

                ##### saving global ranks
                global_ranks_path = lambda \
                    ii: f"ranks/{args.type_feature}/global/{folder_and_name[0]}/Ss_val_{folder_and_name[1]}_iter_{ii}.npy"
                global old_i
                if ("old_i" in globals()):
                    print("old_i")
                    if os.path.isfile(global_ranks_path(old_i)):
                        os.remove(global_ranks_path(old_i))

                folder_and_name = args.resume.split(os.sep)[-2:]
                os.makedirs(f"ranks/{args.type_feature}/global/{folder_and_name[0]}", exist_ok=True)
                np.save(global_ranks_path(i), S_numpy)
                old_i = i
                print("saving ranks")

                if args.type_feature == "sq":

                    hor = switches_2d_argsort % S_numpy.shape[1]
                    ver = np.floor(switches_2d_argsort // S_numpy.shape[1])
                    print(ver[:10],hor[:10])
                    print("and")
                    print(ver[-10:], hor[-10:])


        # measure accuracy and record loss
        with torch.no_grad():
            mini_batch_size = next(iter(batch_data.values())).size(0)
            result = Result()
            if mode != 'test_prediction' and mode != 'test_completion':
                result.evaluate(pred.data, gt.data, photometric_loss)
            [
                m.update(result, gpu_time, data_time, mini_batch_size)
                for m in meters
            ]
            logger.conditional_print(mode, i, epoch, lr, len(loader),
                                     block_average_meter, average_meter)
            logger.conditional_save_img_comparison(mode, i, batch_data, pred,
                                                   epoch)
            logger.conditional_save_pred(mode, i, pred, epoch)

        draw=False
        if draw:
            ma = batch_data['rgb'].detach().cpu().numpy().squeeze()
            ma  = np.transpose(ma, axes=[1, 2, 0])
           # ma = np.uint8(ma)
            #ma2 = Image.fromarray(ma)
            ma2 = Image.fromarray(np.uint8(ma)).convert('RGB')
            # create rectangle image
            img1 = ImageDraw.Draw(ma2)

            if args.type_feature == "sq":
                size=40
                print_square_num = 20
                for ii in range(print_square_num):
                    s_hor=hor[-ii].detach().cpu().numpy()
                    s_ver=ver[-ii].detach().cpu().numpy()
                    shape = [(s_hor * size, s_ver * size), ((s_hor + 1) * size, (s_ver + 1) * size)]
                    img1.rectangle(shape, outline="red")

                    tim = time.time()
                    lala = ma2.save(f"switches_photos/squares/squares_{tim}.jpg")
                    print("saving")
            elif args.type_feature == "lines":
                print_square_num = 20
                r=1
                parameter_mask = np.load("../kitti_pixels_to_lines.npy", allow_pickle=True)

                # for m in range(10,50):
                #     im = Image.fromarray(parameter_mask[m]*155)
                #     im = im.convert('1')  # convert image to black and white
                #     im.save(f"switches_photos/lala_{m}.jpg")


                for ii in range(print_square_num):
                     points = parameter_mask[ii]
                     y = np.where(points==1)[0]
                     x = np.where(points == 1)[1]

                     for p in range(len(x)):
                         img1.ellipse((x[p] - r, y[p] - r, x[p] + r, y[p] + r), fill=(255, 0, 0, 0))

                tim = time.time()
                lala = ma2.save(f"switches_photos/lines/lines_{tim}.jpg")
                print("saving")

        every = args.every
        if i % every ==0:

            print("saving")
            avg = logger.conditional_save_info(mode, average_meter, epoch)
            is_best = logger.rank_conditional_save_best(mode, avg, epoch)
            #is_best = True #saving all the checkpoints
            if is_best and not (mode == "train"):
                logger.save_img_comparison_as_best(mode, epoch)
            logger.conditional_summarize(mode, avg, is_best)

            if mode != "val":
                helper.save_checkpoint({  # save checkpoint
                    'epoch': epoch,
                    'model': model.module.state_dict(),
                    'best_result': logger.best_result,
                    'optimizer': optimizer.state_dict(),
                    'args': args,
                }, is_best, epoch, logger.output_directory, args.type_feature, i, every, qnet=True)

    if args.evaluate and args.instancewise:
        #filename = os.path.split(args.evaluate)[1]
        Ss_numpy = np.array(Ss)
        folder_and_name = args.evaluate.split(os.sep)[-3:]
        os.makedirs(f"ranks/instance/{folder_and_name[0]}", exist_ok=True)
        os.makedirs(f"ranks/instance/{folder_and_name[0]}/{folder_and_name[1]}", exist_ok=True)
        np.save(f"ranks/instance/{folder_and_name[0]}/{folder_and_name[1]}/Ss_val_{folder_and_name[2]}.npy", Ss)

    return avg, is_best
def iterate(mode, args, loader, model, optimizer, logger, epoch):
    block_average_meter = AverageMeter()
    average_meter = AverageMeter()
    meters = [block_average_meter, average_meter]

    # switch to appropriate mode
    assert mode in ["train", "val", "eval", "test_prediction", "test_completion"], \
        "unsupported mode: {}".format(mode)

    encoder = model[0]
    decoder = model[1]
    if mode == 'train':
        encoder.train()
        decoder.train()
        lr = helper.adjust_learning_rate(args.lr, optimizer, epoch)
    else:
        encoder.train()
        decoder.train()
        lr = 0

    torch.set_printoptions(profile="full")
    for i, batch_data in enumerate(loader):

        start = time.time()
        batch_data = {
            key: val.to(device)
            for key, val in batch_data.items() if val is not None
        }
        gt = batch_data[
            'gt'] if mode != 'test_prediction' and mode != 'test_completion' else None
        data_time = time.time() - start

        start = time.time()

        # torchvision.transforms.Resize
        # img = transform.resize(batch_data['rgb'], (192, 640))

        # transform = transforms.Compose([
        #     transforms.Resize((round(192), round(640))),
        #     # interpolation `BILINEAR` is applied by default
        #     transforms.ToTensor(),
        #     transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        # ])
        # image = transform(batch_data['rgb'])

        # Otherwise, we only feed the image with frame_id 0 through the depth encoder
        features = encoder(batch_data['rgb'])
        outputs = decoder(features)

        # if i % 20 == 0:
        #     print(outputs[('disp', 0)][1,0,:,40])

        # pred = model(batch_data)
        # im = batch_data['d'].detach().cpu().numpy()
        # im_sq = im.squeeze()
        # plt.figure()
        # plt.imshow(im_sq)
        # plt.show()
        # for i in range(im_sq.shape[0]):
        #     print(f"{i} - {np.sum(im_sq[i])}")
        pred = outputs[('disp', 0)]
        depth_loss, photometric_loss, smooth_loss, mask = 0, 0, 0, None
        if mode == 'train':
            # Loss 1: the direct depth supervision from ground truth label
            # mask=1 indicates that a pixel does not ground truth labels
            if 'sparse' in args.train_mode:
                depth_loss = depth_criterion(pred, batch_data['d'])
                mask = (batch_data['d'] < 1e-3).float()
            elif 'dense' in args.train_mode:
                depth_loss = depth_criterion(pred, gt)
                mask = (gt < 1e-3).float()
                # if i % 20 == 0:
                #     print("\n\n\n gt \n")
                #     print(gt[1,0,:,40])

            # # Loss 2: the self-supervised photometric loss
            # if args.use_pose:
            #     # create multi-scale pyramids
            #     pred_array = helper.multiscale(pred)
            #     rgb_curr_array = helper.multiscale(batch_data['rgb'])
            #     rgb_near_array = helper.multiscale(batch_data['rgb_near'])
            #     if mask is not None:
            #         mask_array = helper.multiscale(mask)
            #     num_scales = len(pred_array)
            #
            #     # compute photometric loss at multiple scales
            #     for scale in range(len(pred_array)):
            #         pred_ = pred_array[scale]
            #         rgb_curr_ = rgb_curr_array[scale]
            #         rgb_near_ = rgb_near_array[scale]
            #         mask_ = None
            #         if mask is not None:
            #             mask_ = mask_array[scale]
            #
            #         # compute the corresponding intrinsic parameters
            #         height_, width_ = pred_.size(2), pred_.size(3)
            #         intrinsics_ = kitti_intrinsics.scale(height_, width_)
            #
            #         # inverse warp from a nearby frame to the current frame
            #         warped_ = homography_from(rgb_near_, pred_,
            #                                   batch_data['r_mat'],
            #                                   batch_data['t_vec'], intrinsics_)
            #         photometric_loss += photometric_criterion(
            #             rgb_curr_, warped_, mask_) * (2**(scale - num_scales))

            # # Loss 3: the depth smoothness loss
            # smooth_loss = smoothness_criterion(pred) if args.w2 > 0 else 0

            # backprop
            #loss = depth_loss + args.w1 * photometric_loss + args.w2 * smooth_loss
            loss = depth_loss
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        gpu_time = time.time() - start

        if i % 50 == 0:
            print(i)

        # measure accuracy and record loss
        with torch.no_grad():
            mini_batch_size = next(iter(batch_data.values())).size(0)
            result = Result()
            if mode != 'test_prediction' and mode != 'test_completion':
                result.evaluate(pred.data, gt.data, photometric_loss)
            [
                m.update(result, gpu_time, data_time, mini_batch_size)
                for m in meters
            ]
            logger.conditional_print(mode, i, epoch, lr, len(loader),
                                     block_average_meter, average_meter)
            logger.conditional_save_img_comparison(mode, i, batch_data, pred,
                                                   epoch)
            logger.conditional_save_pred(mode, i, pred, epoch)

        if i % 200 == 0:

            print(gpu_time)

            print("saving")
            avg = logger.conditional_save_info(mode, average_meter, epoch)
            is_best = logger.rank_conditional_save_best(mode, avg, epoch)
            if is_best and not (mode == "train"):
                logger.save_img_comparison_as_best(mode, epoch)
            logger.conditional_summarize(mode, avg, is_best)

            # helper.save_checkpoint({  # save checkpoint
            #     'epoch': epoch,
            #     'model': model.module.state_dict(),
            #     'best_result': logger.best_result,
            #     'optimizer': optimizer.state_dict(),
            #     'args': args,
            # }, is_best, epoch, logger.output_directory)

    return avg, is_best
示例#7
0
def main():
    global args, best_prec1
    args = parser.parse_args()

    # create model
    if args.pretrained:
        print("=> using pre-trained model '{}'".format(args.arch))
    else:
        print("=> creating model '{}'".format(args.arch))

    if args.arch == 'alexnet':
        model = alexnet(pretrained=args.pretrained)
    elif args.arch == 'squeezenet1_0':
        model = squeezenet1_0(pretrained=args.pretrained)
    elif args.arch == 'squeezenet1_1':
        model = squeezenet1_1(pretrained=args.pretrained)
    elif args.arch == 'densenet121':
        model = densenet121(pretrained=args.pretrained)
    elif args.arch == 'densenet169':
        model = densenet169(pretrained=args.pretrained)
    elif args.arch == 'densenet201':
        model = densenet201(pretrained=args.pretrained)
    elif args.arch == 'densenet161':
        model = densenet161(pretrained=args.pretrained)
    elif args.arch == 'vgg11':
        model = vgg11(pretrained=args.pretrained)
    elif args.arch == 'vgg13':
        model = vgg13(pretrained=args.pretrained)
    elif args.arch == 'vgg16':
        model = vgg16(pretrained=args.pretrained)
    elif args.arch == 'vgg19':
        model = vgg19(pretrained=args.pretrained)
    elif args.arch == 'vgg11_bn':
        model = vgg11_bn(pretrained=args.pretrained)
    elif args.arch == 'vgg13_bn':
        model = vgg13_bn(pretrained=args.pretrained)
    elif args.arch == 'vgg16_bn':
        model = vgg16_bn(pretrained=args.pretrained)
    elif args.arch == 'vgg19_bn':
        model = vgg19_bn(pretrained=args.pretrained)
    elif args.arch == 'resnet18':
        model = resnet18(pretrained=args.pretrained)
    elif args.arch == 'resnet34':
        model = resnet34(pretrained=args.pretrained)
    elif args.arch == 'resnet50':
        model = resnet50(pretrained=args.pretrained)
    elif args.arch == 'resnet101':
        model = resnet101(pretrained=args.pretrained)
    elif args.arch == 'resnet152':
        model = resnet152(pretrained=args.pretrained)
    else:
        raise NotImplementedError

    # use cuda
    model.cuda()
    # model = torch.nn.parallel.DistributedDataParallel(model)

    # define loss and optimizer
    criterion = nn.CrossEntropyLoss().cuda()
    optimizer = optim.SGD(model.parameters(),
                          lr=args.lr,
                          momentum=args.momentum,
                          weight_decay=args.weight_decay)

    # optionlly resume from a checkpoint
    if args.resume:
        if os.path.isfile(args.resume):
            print("=> loading checkpoint '{}'".format(args.resume))
            checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint['epoch']
            best_prec1 = checkpoint['best_prec1']
            model.load_state_dict(checkpoint['state_dict'])
            optimizer.load_state_dict(checkpoint['optimizer'])
            print("=> loaded checkpoint '{}' (epoch {})".format(
                args.resume, checkpoint['epoch']))
    else:
        print("=> no checkpoint found at '{}'".format(args.resume))

    # cudnn.benchmark = True

    # Data loading
    train_loader, val_loader = data_loader(args.data, args.batch_size,
                                           args.workers, args.pin_memory)

    if args.evaluate:
        validate(val_loader, model, criterion, args.print_freq)
        return

    for epoch in range(args.start_epoch, args.epochs):
        adjust_learning_rate(optimizer, epoch, args.lr)

        # train for one epoch
        train(train_loader, model, criterion, optimizer, epoch,
              args.print_freq)

        # evaluate on validation set
        prec1, prec5 = validate(val_loader, model, criterion, args.print_freq)

        # remember the best prec@1 and save checkpoint
        is_best = prec1 > best_prec1
        best_prec1 = max(prec1, best_prec1)

        save_checkpoint(
            {
                'epoch': epoch + 1,
                'arch': args.arch,
                'state_dict': model.state_dict(),
                'best_prec1': best_prec1,
                'optimizer': optimizer.state_dict()
            }, is_best, args.arch + '.pth')
示例#8
0
def train_model(args,
                model,
                dataset,
                writer=None,
                n_rounds=1,
                lth_pruner=None):

    root = args.exp_name + '/checkpoints/'

    criterion = nn.CrossEntropyLoss().cuda()
    optimizer = torch.optim.SGD(model.parameters(),
                                args.lr,
                                momentum=args.momentum,
                                weight_decay=args.weight_decay)

    #Save initial weight files.
    init_weight_filename = args.exp_name + '/checkpoints/' + 'initial_state.pth.tar'
    helper.save_checkpoint(args, model, optimizer, init_weight_filename)

    for cur_round in range(n_rounds):

        best_acc = 0
        for epoch in range(args.start_epoch, args.epochs):

            helper.adjust_learning_rate(optimizer, epoch, args)

            train_top1, train_top5, train_loss,model = trainer.train(dataset.train_loader,model,criterion,\
             optimizer,epoch,args,lth_pruner,cur_round,mask_applied=args.mask_applied)

            val_top1, val_top5, val_loss = trainer.validate(
                dataset.test_loader, model, criterion, args)

            if writer is not None:
                writer.add_scalar("loss/train/" + str(cur_round), train_loss,
                                  epoch)
                writer.add_scalar("top1/train/" + str(cur_round), train_top1,
                                  epoch)
                writer.add_scalar("top5/train/" + str(cur_round), train_top5,
                                  epoch)

                writer.add_scalar("loss/val/" + str(cur_round), val_loss,
                                  epoch)
                writer.add_scalar("top1/val/" + str(cur_round), val_top1,
                                  epoch)
                writer.add_scalar("top5/val/" + str(cur_round), val_top5,
                                  epoch)

            if val_top1 >= best_acc:
                best_acc = val_top1
                is_best = True
                filename = root + str(cur_round) + '_model_best.pth'
                helper.save_checkpoint(args, model, optimizer, filename)

            filename = root + str(cur_round) + '_current.pth'
            helper.save_checkpoint(args,
                                   model,
                                   optimizer,
                                   filename,
                                   epoch=epoch)
            filename = root + str(cur_round) + '_mask.pkl'

            if epoch in [0, 1, 2, 3]:
                #Save early epochs for late resetting.
                filename = root + 'epoch_' + str(epoch) + '_model.pth'
                helper.save_checkpoint(args,
                                       model,
                                       optimizer,
                                       filename,
                                       epoch=epoch)
示例#9
0
def iterate(mode, args, loader, model, optimizer, logger, epoch):
    actual_epoch = epoch - args.start_epoch + args.start_epoch_bias

    block_average_meter = AverageMeter()
    block_average_meter.reset(False)
    average_meter = AverageMeter()
    meters = [block_average_meter, average_meter]

    # switch to appropriate mode
    assert mode in ["train", "val", "eval", "test_prediction", "test_completion"], \
        "unsupported mode: {}".format(mode)
    if mode == 'train':
        model.train()
        lr = helper.adjust_learning_rate(args.lr, optimizer, actual_epoch,
                                         args)
    else:
        model.eval()
        lr = 0

    torch.cuda.empty_cache()
    for i, batch_data in enumerate(loader):
        dstart = time.time()
        batch_data = {
            key: val.to(device)
            for key, val in batch_data.items() if val is not None
        }

        gt = batch_data[
            'gt'] if mode != 'test_prediction' and mode != 'test_completion' else None
        data_time = time.time() - dstart

        pred = None
        start = None
        gpu_time = 0

        #start = time.time()
        #pred = model(batch_data)
        #gpu_time = time.time() - start

        #'''
        if (args.network_model == 'e'):
            start = time.time()
            st1_pred, st2_pred, pred = model(batch_data)
        else:
            start = time.time()
            pred = model(batch_data)

        if (args.evaluate):
            gpu_time = time.time() - start
        #'''

        depth_loss, photometric_loss, smooth_loss, mask = 0, 0, 0, None

        # inter loss_param
        st1_loss, st2_loss, loss = 0, 0, 0
        w_st1, w_st2 = 0, 0
        round1, round2, round3 = 1, 3, None
        if (actual_epoch <= round1):
            w_st1, w_st2 = 0.2, 0.2
        elif (actual_epoch <= round2):
            w_st1, w_st2 = 0.05, 0.05
        else:
            w_st1, w_st2 = 0, 0

        if mode == 'train':
            # Loss 1: the direct depth supervision from ground truth label
            # mask=1 indicates that a pixel does not ground truth labels
            depth_loss = depth_criterion(pred, gt)

            if args.network_model == 'e':
                st1_loss = depth_criterion(st1_pred, gt)
                st2_loss = depth_criterion(st2_pred, gt)
                loss = (1 - w_st1 - w_st2
                        ) * depth_loss + w_st1 * st1_loss + w_st2 * st2_loss
            else:
                loss = depth_loss

            if i % multi_batch_size == 0:
                optimizer.zero_grad()
            loss.backward()

            if i % multi_batch_size == (multi_batch_size -
                                        1) or i == (len(loader) - 1):
                optimizer.step()
            print("loss:", loss, " epoch:", epoch, " ", i, "/", len(loader))

        if mode == "test_completion":
            str_i = str(i)
            path_i = str_i.zfill(10) + '.png'
            path = os.path.join(args.data_folder_save, path_i)
            vis_utils.save_depth_as_uint16png_upload(pred, path)

        if (not args.evaluate):
            gpu_time = time.time() - start
        # measure accuracy and record loss
        with torch.no_grad():
            mini_batch_size = next(iter(batch_data.values())).size(0)
            result = Result()
            if mode != 'test_prediction' and mode != 'test_completion':
                result.evaluate(pred.data, gt.data, photometric_loss)
                [
                    m.update(result, gpu_time, data_time, mini_batch_size)
                    for m in meters
                ]

                if mode != 'train':
                    logger.conditional_print(mode, i, epoch, lr, len(loader),
                                             block_average_meter,
                                             average_meter)
                logger.conditional_save_img_comparison(mode, i, batch_data,
                                                       pred, epoch)
                logger.conditional_save_pred(mode, i, pred, epoch)

    avg = logger.conditional_save_info(mode, average_meter, epoch)
    is_best = logger.rank_conditional_save_best(mode, avg, epoch)
    if is_best and not (mode == "train"):
        logger.save_img_comparison_as_best(mode, epoch)
    logger.conditional_summarize(mode, avg, is_best)

    return avg, is_best
示例#10
0
def iterate(mode, args, loader, model, optimizer, logger, epoch):
    block_average_meter = AverageMeter()
    average_meter = AverageMeter()
    meters = [block_average_meter, average_meter]

    # switch to appropriate mode
    assert mode in ["train", "val", "eval", "test_prediction", "test_completion"], \
        "unsupported mode: {}".format(mode)
    if mode == 'train':
        model.train()
        lr = helper.adjust_learning_rate(args.lr, optimizer, epoch)
    else:
        model.eval()
        lr = 0

    for i, batch_data in enumerate(loader):
        #print(i)

        start = time.time()
        batch_data = {
            key: val.to(device)
            for key, val in batch_data.items() if val is not None
        }
        gt = batch_data[
            'gt'] if mode != 'test_prediction' and mode != 'test_completion' else None
        data_time = time.time() - start

        start = time.time()
        pred = model(batch_data)

        #im = batch_data['d'].detach().cpu().numpy()
        #im_sq = im.squeeze()
        #plt.figure()
        #plt.imshow(im_sq)
        #plt.show()
        #for i in range(im_sq.shape[0]):
        #    print(f"{i} - {np.sum(im_sq[i])}")

        depth_loss, photometric_loss, smooth_loss, mask = 0, 0, 0, None
        if mode == 'train':
            # Loss 1: the direct depth supervision from ground truth label
            # mask=1 indicates that a pixel does not ground truth labels
            if 'sparse' in args.train_mode:
                depth_loss = depth_criterion(pred, batch_data['d'])
                mask = (batch_data['d'] < 1e-3).float()
            elif 'dense' in args.train_mode:
                depth_loss = depth_criterion(pred, gt)
                mask = (gt < 1e-3).float()

            # Loss 2: the self-supervised photometric loss
            if args.use_pose:
                # create multi-scale pyramids
                pred_array = helper.multiscale(pred)
                rgb_curr_array = helper.multiscale(batch_data['rgb'])
                rgb_near_array = helper.multiscale(batch_data['rgb_near'])
                if mask is not None:
                    mask_array = helper.multiscale(mask)
                num_scales = len(pred_array)

                # compute photometric loss at multiple scales
                for scale in range(len(pred_array)):
                    pred_ = pred_array[scale]
                    rgb_curr_ = rgb_curr_array[scale]
                    rgb_near_ = rgb_near_array[scale]
                    mask_ = None
                    if mask is not None:
                        mask_ = mask_array[scale]

                    # compute the corresponding intrinsic parameters
                    height_, width_ = pred_.size(2), pred_.size(3)
                    intrinsics_ = kitti_intrinsics.scale(height_, width_)

                    # inverse warp from a nearby frame to the current frame
                    warped_ = homography_from(rgb_near_, pred_,
                                              batch_data['r_mat'],
                                              batch_data['t_vec'], intrinsics_)
                    photometric_loss += photometric_criterion(
                        rgb_curr_, warped_, mask_) * (2**(scale - num_scales))

            # Loss 3: the depth smoothness loss
            smooth_loss = smoothness_criterion(pred) if args.w2 > 0 else 0

            # backprop
            loss = depth_loss + args.w1 * photometric_loss + args.w2 * smooth_loss
            optimizer.zero_grad()
            loss.backward()

            optimizer.step()


        gpu_time = time.time() - start



        if i % 50 ==0:
        #    print(model.module.conv4[5].conv1.weight[0])
            #print(model.conv4.5.bn2.weight)
            #print(model.module.parameter.grad)
            print(model.module.parameter)
            print(torch.argsort(model.module.parameter))

        # measure accuracy and record loss
        with torch.no_grad():
            mini_batch_size = next(iter(batch_data.values())).size(0)
            result = Result()
            if mode != 'test_prediction' and mode != 'test_completion':
                result.evaluate(pred.data, gt.data, photometric_loss)
            [
                m.update(result, gpu_time, data_time, mini_batch_size)
                for m in meters
            ]
            logger.conditional_print(mode, i, epoch, lr, len(loader),
                                     block_average_meter, average_meter)
            logger.conditional_save_img_comparison(mode, i, batch_data, pred,
                                                   epoch)
            logger.conditional_save_pred(mode, i, pred, epoch)

        if i % 20 == 0:
            avg = logger.conditional_save_info(mode, average_meter, epoch)
            is_best = logger.rank_conditional_save_best(mode, avg, epoch)
            if is_best and not (mode == "train"):
                logger.save_img_comparison_as_best(mode, epoch)
            logger.conditional_summarize(mode, avg, is_best)

            helper.save_checkpoint({  # save checkpoint
                'epoch': epoch,
                'model': model.module.state_dict(),
                'best_result': logger.best_result,
                'optimizer': optimizer.state_dict(),
                'args': args,
            }, is_best, epoch, logger.output_directory)

    return avg, is_best
示例#11
0
def iterate(mode, args, loader, model, optimizer, logger, epoch):
    block_average_meter = AverageMeter()
    average_meter = AverageMeter()
    meters = [block_average_meter, average_meter]

    block_average_meter_intensity = AverageIntensity()
    average_meter_intensity = AverageIntensity()
    meters_intensity = [block_average_meter_intensity, average_meter_intensity]

    block_average_meter_intensity_pure = AverageIntensity()
    average_meter_intensity_pure = AverageIntensity()
    meters_intensity_pure = [
        block_average_meter_intensity_pure, average_meter_intensity_pure
    ]

    # switch to appropriate mode
    assert mode in ["train", "val", "eval", "test_prediction", "test_completion"], \
        "unsupported mode: {}".format(mode)
    if mode == 'train':
        model.train()
        # lr = args.lr #helper.adjust_learning_rate(args.lr, optimizer, epoch)
        if args.lradj > 0:
            lr = helper.adjust_learning_rate(args.lr, optimizer, epoch,
                                             args.lradj)
        else:
            lr = args.lr
    else:
        model.eval()
        lr = 0

    global batch_num
    for i, batch_data in enumerate(loader):
        if mode == "train":
            batch_num += 1
        start = time.time()
        file_name = batch_data['filename']

        batch_data = {
            key: val.cuda()
            for key, val in batch_data.items()
            if (key != "filename" and val is not None)
        }
        #    batch_data = {key:val.cuda() for key,val in batch_data.items() if val is not None}
        gt = batch_data[
            'gt'] if mode != 'test_prediction' and mode != 'test_completion' else None
        #   Ireal as gt_intensity
        gt_intensity = batch_data[
            'gt_intensity'] if mode != 'test_prediction' and mode != 'test_completion' else None
        gt_intensity_pure = batch_data[
            'gt_intensity_pure'] if mode != 'test_prediction' and mode != 'test_completion' else None
        data_time = time.time() - start

        start = time.time()
        pred = model(batch_data)

        pred_intensity_pure = pred[0][:, 1, :, :].unsqueeze(1)
        pred_intensity = pred[1]
        pred = pred[0][:, 0, :, :].unsqueeze(1)

        depth_loss, intensity_loss, photometric_loss, smooth_loss, pure_loss, mask = 0, 0, 0, 0, 0, None
        if mode == 'train':
            # Loss 1: the direct depth supervision from ground truth label
            # mask=1 indicates that a pixel does not ground truth labels
            if 'sparse' in args.train_mode:
                depth_loss = depth_criterion(pred, batch_data['d'])
                mask = (batch_data['d'] < 1e-3).float()
            elif 'dense' in args.train_mode:
                depth_loss = depth_criterion(pred, gt)
                pure_loss = intensity_pure_criterion(pred_intensity_pure,
                                                     gt_intensity_pure)
                intensity_loss = intensity_criterion(pred_intensity,
                                                     gt_intensity)
                mask = (gt < 1e-3).float()

            # Loss 2: the self-supervised photometric loss
            if args.use_pose:
                # create multi-scale pyramids
                pred_array = helper.multiscale(pred)
                rgb_curr_array = helper.multiscale(batch_data['rgb'])
                rgb_near_array = helper.multiscale(batch_data['rgb_near'])
                if mask is not None:
                    mask_array = helper.multiscale(mask)
                num_scales = len(pred_array)

                # compute photometric loss at multiple scales
                for scale in range(len(pred_array)):
                    pred_ = pred_array[scale]
                    rgb_curr_ = rgb_curr_array[scale]
                    rgb_near_ = rgb_near_array[scale]
                    mask_ = None
                    if mask is not None:
                        mask_ = mask_array[scale]

                    # compute the corresponding intrinsic parameters
                    height_, width_ = pred_.size(2), pred_.size(3)
                    intrinsics_ = kitti_intrinsics.scale(height_, width_)

                    # inverse warp from a nearby frame to the current frame
                    warped_ = homography_from(rgb_near_, pred_,
                                              batch_data['r_mat'],
                                              batch_data['t_vec'], intrinsics_)
                    photometric_loss += photometric_criterion(
                        rgb_curr_, warped_, mask_) * (2**(scale - num_scales))

            # Loss 3: the depth smoothness loss
            smooth_loss = smoothness_criterion(pred) if args.w2 > 0 else 0

            # backprop
            loss = depth_loss + args.wi * intensity_loss + args.wpure * pure_loss
            # loss = depth_loss + wi * intensity_loss + args.w1*photometric_loss + args.w2*smooth_loss
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            #log train
            if batch_num % 25 == 0 and mode == "train":
                logger.writer.add_scalar('train/loss_total', loss, batch_num)
                logger.writer.add_scalar('train/loss_depth', depth_loss,
                                         batch_num)
                logger.writer.add_scalar('train/loss_intensity',
                                         intensity_loss, batch_num)
                logger.writer.add_scalar('train/loss_pure', pure_loss,
                                         batch_num)
                logger.writer.add_scalar('train/lr', lr, batch_num)
                #if batch_num % 200 == 0:
                #    logger.writer.add_image('train/Ipure', pred_intensity_pure[0], batch_num)

        gpu_time = time.time() - start

        # measure accuracy and record loss
        with torch.no_grad():
            mini_batch_size = next(iter(batch_data.values())).size(0)
            result = Result()
            # result_intensity = Result()
            result_intensity_pure = Result_intensity()
            result_intensity = Result_intensity()
            if mode != 'test_prediction' and mode != 'test_completion':
                # result.evaluate(pred.data, gt.data, photometric_loss)
                result.evaluate(pred.data, gt.data, photometric_loss)
                result_intensity_pure.evaluate(pred_intensity_pure.data,
                                               gt_intensity_pure.data)
                result_intensity.evaluate(pred_intensity.data,
                                          gt_intensity.data)

            [
                m.update(result, gpu_time, data_time, mini_batch_size)
                for m in meters
            ]
            [
                m_intensity.update(result_intensity, gpu_time, data_time,
                                   mini_batch_size)
                for m_intensity in meters_intensity
            ]
            [
                m_intensity_pure.update(result_intensity_pure, gpu_time,
                                        data_time, mini_batch_size)
                for m_intensity_pure in meters_intensity_pure
            ]

            logger.conditional_print(mode, i, epoch, lr, len(loader),
                                     block_average_meter, average_meter)
            # logger.conditional_save_img_comparison(mode, i, batch_data, pred, epoch)
            logger.conditional_save_pred_named(mode, file_name[0], pred, epoch)

            logger.conditional_print(mode, i, epoch, lr, len(loader),
                                     block_average_meter_intensity_pure,
                                     average_meter_intensity_pure, "Ipure")
            logger.conditional_print(mode, i, epoch, lr, len(loader),
                                     block_average_meter_intensity,
                                     average_meter_intensity, "Intensity")
            # logger.conditional_save_img_comparison_with_intensity(mode, i, batch_data, pred, pred_intensity, epoch)
            logger.conditional_save_img_comparison_with_intensity2(
                mode, i, batch_data, pred, pred_intensity_pure, pred_intensity,
                epoch)
            # run when eval, bt always = 1
            #logger.conditional_save_pred_named_with_intensity(mode, file_name[0], pred, pred_intensity, epoch)

    avg = logger.conditional_save_info(mode, average_meter, epoch)
    avg_intensity = logger.conditional_save_info_intensity(
        mode, average_meter_intensity, epoch)
    # is_best = logger.rank_conditional_save_best(mode, avg, epoch)
    is_best = logger.rank_conditional_save_best_with_intensity(
        mode, avg, avg_intensity, epoch)
    if is_best and not (mode == "train"):
        logger.save_img_comparison_as_best(mode, epoch)
    logger.conditional_summarize(mode, avg, is_best)

    logger.conditional_summarize_intensity(mode, avg_intensity)

    return avg, avg_intensity, is_best
def main():
    global args, best_prec1
    global cur_itrs
    args = parser.parse_args()
    print(args.mode)

    # STEP1: model
    if args.mode=='baseline_train':
        model = initialize_model(use_resnet=True, pretrained=False, nclasses=200)
    elif args.mode=='pretrain':
        model = deeplab_network.deeplabv3_resnet50(num_classes=args.num_classes, output_stride=args.output_stride, pretrained_backbone=False)
        set_bn_momentum(model.backbone, momentum=0.01)
    elif args.mode=='finetune':
        model = initialize_model(use_resnet=True, pretrained=False, nclasses=3)
        # load the pretrained model
        if args.pretrained_model:
            if os.path.isfile(args.pretrained_model):
                print("=> loading pretrained model '{}'".format(args.pretrained_model))
                checkpoint = torch.load(args.pretrained_model)
                args.start_epoch = checkpoint['epoch']
                best_prec1 = checkpoint['best_prec1']
                model.load_state_dict(checkpoint['state_dict'])
                optimizer.load_state_dict(checkpoint['optimizer'])
                print("=> loaded pretrained model '{}' (epoch {})".format(args.resume, checkpoint['epoch']))
    if torch.cuda.is_available:
        model = model.cuda()
    
    # STEP2: criterion and optimizer
    if args.mode in ['baseline_train', 'finetune']:
        criterion = nn.CrossEntropyLoss()
        optimizer = torch.optim.SGD(model.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay)
        # train_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=args.step_size, gamma=args.gamma) 
    elif args.mode=='pretrain':
        criterion = nn.MSELoss()
        optimizer = torch.optim.SGD(params=[
        {'params': model.backbone.parameters(), 'lr': 0.1*args.lr},
        {'params': model.classifier.parameters(), 'lr': args.lr},
    ], lr=args.lr, momentum=0.9, weight_decay=args.weight_decay)
        scheduler = PolyLR(optimizer, args.total_itrs, power=0.9)

    # STEP3: loss/prec record
    if args.mode in ['baseline_train', 'finetune']:
        train_losses = []
        train_top1s = []
        train_top5s = []

        test_losses = []
        test_top1s = []
        test_top5s = []
    elif args.mode == 'pretrain':
        train_losses = []
        test_losses = []

    # STEP4: optionlly resume from a checkpoint
    if args.resume:
        print('resume')
        if os.path.isfile(args.resume):
            print("=> loading checkpoint '{}'".format(args.resume))
            if args.mode in ['baseline_train', 'finetune']:
                checkpoint = torch.load(args.resume)
                args.start_epoch = checkpoint['epoch']
                best_prec1 = checkpoint['best_prec1']
                model.load_state_dict(checkpoint['state_dict'])
                optimizer.load_state_dict(checkpoint['optimizer'])
                datafile = args.resume.split('.pth')[0] + '.npz'
                load_data = np.load(datafile)
                train_losses = list(load_data['train_losses'])
                train_top1s = list(load_data['train_top1s'])
                train_top5s = list(load_data['train_top5s'])
                test_losses = list(load_data['test_losses'])
                test_top1s = list(load_data['test_top1s'])
                test_top5s = list(load_data['test_top5s'])
            elif args.mode=='pretrain':
                checkpoint = torch.load(args.resume)
                args.start_epoch = checkpoint['epoch']
                model.load_state_dict(checkpoint['state_dict'])
                optimizer.load_state_dict(checkpoint['optimizer'])
                scheduler.load_state_dict(checkpoint['scheduler'])
                cur_itrs = checkpoint['cur_itrs']
                datafile = args.resume.split('.pth')[0] + '.npz'
                load_data = np.load(datafile)
                train_losses = list(load_data['train_losses'])
                # test_losses = list(load_data['test_losses'])
            print("=> loaded checkpoint '{}' (epoch {})".format(args.resume, checkpoint['epoch']))
        else:
            print("=> no checkpoint found at '{}'".format(args.resume))

    # STEP5: train!
    if args.mode in ['baseline_train', 'finetune']:
        # data
        from utils import TinyImageNet_data_loader
        print('color_distortion:', color_distortion)
        train_loader, val_loader = TinyImageNet_data_loader(args.dataset, args.batch_size,color_distortion=args.color_distortion)
        
        # if evaluate the model
        if args.evaluate:
            print('evaluate this model on validation dataset')
            validate(val_loader, model, criterion, args.print_freq)
            return
        
        for epoch in range(args.start_epoch, args.epochs):
            adjust_learning_rate(optimizer, epoch, args.lr)
            time1 = time.time() #timekeeping

            # train for one epoch
            model.train()
            loss, top1, top5 = train(train_loader, model, criterion, optimizer, epoch, args.print_freq)
            train_losses.append(loss)
            train_top1s.append(top1)
            train_top5s.append(top5)

            # evaluate on validation set
            model.eval()
            loss, prec1, prec5 = validate(val_loader, model, criterion, args.print_freq)
            test_losses.append(loss)
            test_top1s.append(prec1)
            test_top5s.append(prec5)

            # remember the best prec@1 and save checkpoint
            is_best = prec1 > best_prec1
            best_prec1 = max(prec1, best_prec1)

            save_checkpoint({
                'epoch': epoch + 1,
                'mode': args.mode,
                'state_dict': model.state_dict(),
                'best_prec1': best_prec1,
                'optimizer': optimizer.state_dict()
            }, is_best, args.mode + '_' + args.dataset +'.pth')

            np.savez(args.mode + '_' + args.dataset +'.npz', train_losses=train_losses,train_top1s=train_top1s,train_top5s=train_top5s, test_losses=test_losses,test_top1s=test_top1s, test_top5s=test_top5s)
           # np.savez(args.mode + '_' + args.dataset +'.npz', train_losses=train_losses)
            time2 = time.time() #timekeeping
            print('Elapsed time for epoch:',time2 - time1,'s')
            print('ETA of completion:',(time2 - time1)*(args.epochs - epoch - 1)/60,'minutes')
            print()
    elif args.mode=='pretrain':
        #data
        from utils import TinyImageNet_data_loader
        # args.dataset = 'tiny-imagenet-200'
        args.batch_size = 16
        train_loader, val_loader = TinyImageNet_data_loader(args.dataset, args.batch_size, col=True)
        
        # if evaluate the model, show some results
        if args.evaluate:
            print('evaluate this model on validation dataset')
            visulization(val_loader, model, args.start_epoch)
            return

        # for epoch in range(args.start_epoch, args.epochs):
        epoch = 0
        while True:
            if cur_itrs >=  args.total_itrs:
                return
            # adjust_learning_rate(optimizer, epoch, args.lr)
            time1 = time.time() #timekeeping

            model.train()
            # train for one epoch
            # loss, _, _ = train(train_loader, model, criterion, optimizer, epoch, args.print_freq, colorization=True,scheduler=scheduler)
            # train_losses.append(loss)
            

            # model.eval()
            # # evaluate on validation set
            # loss, _, _ = validate(val_loader, model, criterion, args.print_freq, colorization=True)
            # test_losses.append(loss)

            save_checkpoint({
                'epoch': epoch + 1,
                'mode': args.mode,
                'state_dict': model.state_dict(),
                'optimizer': optimizer.state_dict(),
                'scheduler':scheduler.state_dict(),
                "cur_itrs": cur_itrs
            }, True, args.mode + '_' + args.dataset +'.pth')

            np.savez(args.mode + '_' + args.dataset +'.npz', train_losses=train_losses)
            # scheduler.step()
            time2 = time.time() #timekeeping
            print('Elapsed time for epoch:',time2 - time1,'s')
            print('ETA of completion:',(time2 - time1)*(args.total_itrs - cur_itrs - 1)/60,'minutes')
            print()
            epoch += 1
示例#13
0
def main(args):

    if args.debug:
        import pdb;
        pdb.set_trace();

    tb_dir = args.exp_name+'/tb_logs/'
    ckpt_dir = args.exp_name + '/checkpoints/'

    if not os.path.exists(args.exp_name):
        os.mkdir(args.exp_name)
        os.mkdir(tb_dir)
        os.mkdir(ckpt_dir)

    #writer = SummaryWriter(tb_dir+'{}'.format(args.exp_name), flush_secs=10)
    writer = SummaryWriter(tb_dir, flush_secs=10)

    # create model
    print("=> creating model: ")
    os.system('nvidia-smi')
    #model = models.__dict__[args.arch]()

    #model = resnet_dilated.Resnet18_32s(num_classes=21)
    print(args.no_pre_train,' pretrain')
    #model = resnet18_fcn.Resnet18_fcn(num_classes=args.n_classes,pre_train=args.no_pre_train)

    model_map = {
        'deeplabv3_resnet18': arma_network.deeplabv3_resnet18,
        'deeplabv3_resnet50': arma_network.deeplabv3_resnet50,
        'fcn_resnet18': arma_network.fcn_resnet18,
        #'deeplabv3_resnet101': network.deeplabv3_resnet101,
        # 'deeplabv3plus_resnet18': network.deeplabv3plus_resnet18,
        # 'deeplabv3plus_resnet50': network.deeplabv3plus_resnet50,
        # 'deeplabv3plus_resnet101': network.deeplabv3plus_resnet101
    }
    
    model = model_map['deeplabv3_resnet50'](arma=False,num_classes=args.n_classes)

    model = model.cuda()
    model = nn.DataParallel(model)


    optimizer = torch.optim.SGD(model.parameters(), args.lr,
                                momentum=args.momentum,
                                weight_decay=args.weight_decay)

    # optionally resume from a checkpoint
    if args.resume:
        if os.path.isfile(args.resume):
            model,optimizer,args = helper.load_checkpoint(args,model,optimizer)
        else:
            print("=> no checkpoint found at '{}'".format(args.resume))

    #USE this only when batch size is fixed. 
    #This takes time, but optimizes to crazy speeds once input is fixed. 
    cudnn.benchmark = True

    #Load dataloaders
    augmentations = aug.Compose([aug.RandomCrop(512),aug.RandomHorizontallyFlip(5),\
        aug.RandomRotate(30),aug.RandomSizedCrop(512)])

    my_dataset = pascalVOCLoader(args=args,root=args.data,sbd_path=args.data,\
        augmentations=augmentations)

    my_dataset.get_loaders()

    init_weight_filename ='initial_state.pth.tar'
    helper.save_checkpoint(args,model,optimizer,custom_name=init_weight_filename)

    with open(args.exp_name+'/'+'args.pkl','wb') as fout:
        pickle.dump(args,fout)


    best_iou = -100.0
    for epoch in range(args.start_epoch, args.epochs):

        helper.adjust_learning_rate(optimizer, epoch, args)

        train_loss = trainer.train(my_dataset.train_loader,model,optimizer,epoch,args,writer)
        val_loss,scores,class_iou,running_metrics_val = trainer.validate(my_dataset.val_loader, model,epoch,args,writer)
        
        if scores["Mean IoU : \t"] >= best_iou:
            best_iou = scores["Mean IoU : \t"]
            is_best = True

        if not args.multiprocessing_distributed or (args.multiprocessing_distributed
                and args.rank % ngpus_per_node == 0):

            if epoch in [0,1,2,3,4,5,6,7,8]:
                helper.save_checkpoint(args,model,optimizer,epoch,custom_name=str(epoch)+'.pth')

            if args.save_freq is None:
                helper.save_checkpoint(args,model,optimizer,epoch,is_best=is_best,periodic=False)
            else:
                helper.save_checkpoint(args,model,optimizer,epoch,is_best=is_best,periodic=True)

    with open(args.exp_name+'/running_metric.pkl','wb') as fout:
        pickle.dump(running_metrics_val,fout)
示例#14
0
def main_worker(gpu, ngpus_per_node, args):
    args.gpu = gpu

    if args.debug:
        import pdb;
        pdb.set_trace();

    if not args.multiprocessing_distributed or (args.multiprocessing_distributed
                and args.rank % ngpus_per_node == 0):
        
        tb_dir = args.exp_name+'/tb_logs/'
        ckpt_dir = args.exp_name + '/checkpoints/'

        if not os.path.exists(args.exp_name):
            os.mkdir(args.exp_name)
            os.mkdir(tb_dir)
            os.mkdir(ckpt_dir)

        print("writing to : ",tb_dir+'{}'.format(args.exp_name),args.rank,ngpus_per_node)

        #writer = SummaryWriter(tb_dir+'{}'.format(args.exp_name), flush_secs=10)
        writer = SummaryWriter(tb_dir, flush_secs=10)

    # suppress printing if not master
    if args.multiprocessing_distributed and args.gpu != 0:
        def print_pass(*args):
            pass
        builtins.print = print_pass

    if args.gpu is not None:
        print("Use GPU: {} for training".format(args.gpu))

    if args.distributed:
        if args.dist_url == "env://" and args.rank == -1:
            args.rank = int(os.environ["RANK"])
        if args.multiprocessing_distributed:
            args.rank = args.rank * ngpus_per_node + gpu
        dist.init_process_group(backend=args.dist_backend, init_method=args.dist_url,
                                world_size=args.world_size, rank=args.rank)
    # create model
    print("=> creating model: ")
    #model = models.__dict__[args.arch]()

    model = resnet_dilated.Resnet18_32s(num_classes=21)

    if args.distributed:
        print("distributed")
        if args.gpu is not None:
            torch.cuda.set_device(args.gpu)
            model.cuda(args.gpu)
            args.batch_size = int(args.batch_size / ngpus_per_node)
            args.workers = int((args.workers + ngpus_per_node - 1) / ngpus_per_node)
            model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu])
        else:
            model.cuda()
            model = torch.nn.parallel.DistributedDataParallel(model)
    elif args.gpu is not None:
        torch.cuda.set_device(args.gpu)
        model = model.cuda(args.gpu)
        # comment out the following line for debugging
        raise NotImplementedError("Only DistributedDataParallel is supported.")
    else:
        raise NotImplementedError("Only DistributedDataParallel is supported.")



    optimizer = torch.optim.SGD(model.parameters(), args.lr,
                                momentum=args.momentum,
                                weight_decay=args.weight_decay)

    # optionally resume from a checkpoint
    if args.resume:
        if os.path.isfile(args.resume):
            model,optimizer,args = helper.load_checkpoint(args,model,optimizer)
        else:
            print("=> no checkpoint found at '{}'".format(args.resume))

    #USE this only when batch size is fixed. 
    #This takes time, but optimizes to crazy speeds once input is fixed. 
    cudnn.benchmark = True

    #Load dataloaders
    augmentations = aug.Compose([aug.RandomCrop(512),aug.RandomHorizontallyFlip(5),aug.RandomRotate(30),aug.RandomSizedCrop(512)])
    my_dataset = pascalVOCLoader(args=args,root='/scratch0/shishira/pascal_voc/',sbd_path='/scratch0/shishira/pascal_voc/',\
        augmentations=augmentations)
    my_dataset.get_loaders()

    init_weight_filename ='initial_state.pth.tar'
    helper.save_checkpoint(args,model,optimizer,custom_name=init_weight_filename)

    with open(args.exp_name+'/'+'args.pkl','wb') as fout:
        pickle.dump(args,fout)


    best_iou = -100.0
    for epoch in range(args.start_epoch, args.epochs):

        if args.distributed:
            my_dataset.train_sampler.set_epoch(epoch)

        helper.adjust_learning_rate(optimizer, epoch, args)

        train_loss = trainer.train(my_dataset.train_loader,model,optimizer,epoch,args,writer)
        val_loss,scores,class_iou = trainer.validate(my_dataset.val_loader, model,epoch,args,writer)
        
        if scores["Mean IoU : \t"] >= best_iou:
            best_iou = scores["Mean IoU : \t"]
            is_best = True

        if not args.multiprocessing_distributed or (args.multiprocessing_distributed
                and args.rank % ngpus_per_node == 0):

            if epoch in [0,1,2,3,4,5,6,7,8]:
                helper.save_checkpoint(args,model,optimizer,epoch,custom_name=str(epoch)+'.pth')

            if args.save_freq is None:
                helper.save_checkpoint(args,model,optimizer,epoch,is_best=is_best,periodic=False)
            else:
                helper.save_checkpoint(args,model,optimizer,epoch,is_best=is_best,periodic=True)
示例#15
0
def main():

    global args, best_prec1
    args = parser.parse_args()

    os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu

    #Log
    log_format = '[%(asctime)s] %(message)s'
    logging.basicConfig(stream=sys.stdout,
                        level=logging.INFO,
                        format=log_format,
                        datefmt='%d %I:%M:%S')
    t = time.time()
    local_time = time.localtime(t)
    if not os.path.exists('./log'):
        os.mkdir('./log')
    fh = logging.FileHandler(
        os.path.join('log/train-{}{}{:02}{}'.format(args.arch,
                                                    local_time.tm_year % 2000,
                                                    local_time.tm_mon, t)))
    fh.setFormatter(logging.Formatter(log_format))
    logging.getLogger().addHandler(fh)

    # create model
    if args.pretrained:
        print("=> using pre-trained model '{}'".format(args.arch))
    else:
        print("=> creating model '{}'".format(args.arch))

    if args.arch == 'alexnet':
        model = alexnet(pretrained=args.pretrained)
    elif args.arch == 'alexnet_s':
        model = alexnet_s(pretrained=args.pretrained)
    elif args.arch == 'alexnet_s_addchannel':
        model = alexnet_s_addchannel(pretrained=args.pretrained)
    elif args.arch == 'alexnet_s_addchannel_fullshuffle':
        model = alexnet_s_addchannel_fullshuffle(pretrained=args.pretrained)
    elif args.arch == 'alexnet_s_addchannelx2_fullshuffle':
        model = alexnet_s_addchannelx2_fullshuffle(pretrained=args.pretrained)
    elif args.arch == 'alexnet_s1_2x':
        model = alexnet_s1_2x(pretrained=args.pretrained)
    elif args.arch == 'squeezenet1_0':
        model = squeezenet1_0(pretrained=args.pretrained)
    elif args.arch == 'squeezenet1_1':
        model = squeezenet1_1(pretrained=args.pretrained)
    elif args.arch == 'densenet121':
        model = densenet121(pretrained=args.pretrained)
    elif args.arch == 'densenet169':
        model = densenet169(pretrained=args.pretrained)
    elif args.arch == 'densenet201':
        model = densenet201(pretrained=args.pretrained)
    elif args.arch == 'densenet161':
        model = densenet161(pretrained=args.pretrained)
    elif args.arch == 'vgg11':
        model = vgg11(pretrained=args.pretrained)
    elif args.arch == 'vgg13':
        model = vgg13(pretrained=args.pretrained)
    elif args.arch == 'vgg16':
        model = vgg16(pretrained=args.pretrained)
    elif args.arch == 'vgg19':
        model = vgg19(pretrained=args.pretrained)
    elif args.arch == 'vgg11_bn':
        model = vgg11_bn(pretrained=args.pretrained)
    elif args.arch == 'vgg13_bn':
        model = vgg13_bn(pretrained=args.pretrained)
    elif args.arch == 'vgg16_bn':
        model = vgg16_bn(pretrained=args.pretrained)
    elif args.arch == 'vgg19_bn':
        model = vgg19_bn(pretrained=args.pretrained)
    elif args.arch == 'vgg16_bn_s3_addchannel':
        model = vgg16_bn_s3_addchannel(pretrained=args.pretrained)
    elif args.arch == 'vgg16_bn_s3_addchannel_v2':
        model = vgg16_bn_s3_addchannel_v2(pretrained=args.pretrained)
    elif args.arch == 'resnet18':
        model = resnet18(pretrained=args.pretrained)
    elif args.arch == 'resnet34':
        model = resnet34(pretrained=args.pretrained)
    elif args.arch == 'resnet34_s2_addchannel':
        model = resnet34_s2_addchannel(pretrained=args.pretrained)
    elif args.arch == 'resnet34_s3_x15':
        model = resnet34_s3_x15(pretrained=args.pretrained)
    elif args.arch == 'resnet34_s3_x175':
        model = resnet34_s3_x175(pretrained=args.pretrained)
    elif args.arch == 'resnet50':
        model = resnet50(pretrained=args.pretrained)
    elif args.arch == 'resnet50_s':
        model = resnet50_s(pretrained=args.pretrained)
    elif args.arch == 'resnet50_block_s':
        model = resnet50_block_s(pretrained=args.pretrained)
    elif args.arch == 'resnet50_block_s_addchannel':
        model = resnet50_block_s_addchannel(pretrained=args.pretrained)
    elif args.arch == 'resnet50_s3_addchannel':
        model = resnet50_s3_addchannel(pretrained=args.pretrained)
    elif args.arch == 'resnet50_s3_addchannel175':
        model = resnet50_s3_addchannel175(pretrained=args.pretrained)
    elif args.arch == 'resnet50_justaddgroup':
        model = resnet50_justaddgroup(pretrained=args.pretrained)
    elif args.arch == 'resnet101':
        model = resnet101(pretrained=args.pretrained)
    elif args.arch == 'resnet152':
        model = resnet152(pretrained=args.pretrained)
    else:
        raise NotImplementedError

    # use cuda
    model.cuda()
    # model = torch.nn.parallel.DistributedDataParallel(model)

    # define loss and optimizer
    criterion = nn.CrossEntropyLoss().cuda()
    optimizer = optim.SGD(model.parameters(),
                          lr=args.lr,
                          momentum=args.momentum,
                          weight_decay=args.weight_decay)

    # optionlly resume from a checkpoint
    if args.resume:
        if os.path.isfile(args.resume):
            print("=> loading checkpoint '{}'".format(args.resume))
            checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint['epoch']
            best_prec1 = checkpoint['best_prec1']
            model.load_state_dict(checkpoint['state_dict'])
            optimizer.load_state_dict(checkpoint['optimizer'])
            print("=> loaded checkpoint '{}' (epoch {})".format(
                args.resume, checkpoint['epoch']))
    else:
        print("=> no checkpoint found at '{}'".format(args.resume))

    # cudnn.benchmark = True

    # Data loading
    train_loader, val_loader = data_loader(args.data, args.batch_size,
                                           args.workers, args.pin_memory)

    if args.evaluate:
        validate(val_loader, model, criterion, args.print_freq)
        return

    for epoch in range(args.start_epoch, args.epochs):
        adjust_learning_rate(optimizer, epoch, args.lr)

        # train for one epoch
        train(train_loader, model, criterion, optimizer, epoch,
              args.print_freq)

        # evaluate on validation set
        prec1, prec5 = validate(val_loader, model, criterion, args.print_freq)

        # remember the best prec@1 and save checkpoint
        is_best = prec1 > best_prec1
        best_prec1 = max(prec1, best_prec1)

        save_checkpoint(
            {
                'epoch': epoch + 1,
                'arch': args.arch,
                'state_dict': model.state_dict(),
                'best_prec1': best_prec1,
                'optimizer': optimizer.state_dict()
            }, is_best, args.arch + '.pth')
def iterate(mode, args, loader, model, optimizer, logger, epoch):
    block_average_meter = AverageMeter()
    average_meter = AverageMeter()
    meters = [block_average_meter, average_meter]

    # switch to appropriate mode
    assert mode in ["train", "val", "eval", "test_prediction", "test_completion"], \
        "unsupported mode: {}".format(mode)
    if mode == 'train':
        model.train()
        lr = helper.adjust_learning_rate(args.lr, optimizer, epoch)
    else:
        model.eval()
        lr = 0

    print("\nTraining")
    prune_type = "sq"  # sq, vlines, nothing
    square_choice = args.training_sparse_opt
    if prune_type == "sq":
        print(f"Features: squares\n Square choice: {square_choice}")

    for i, batch_data in enumerate(loader):
        start = time.time()
        batch_data = {
            key: val.to(device)
            for key, val in batch_data.items() if val is not None
        }
        gt = batch_data[
            'gt'] if mode != 'test_prediction' and mode != 'test_completion' else None
        data_time = time.time() - start
        start = time.time()

        if prune_type == "vlines":
            np.random.seed(10)
            lines_unmasked = np.random.choice(352, 20, replace=False)
            lines_unmasked = np.arange(352)
            lines_all = np.arange(352)
            lines_masked = [x for x in lines_all if x not in lines_unmasked]
            batch_data['d'][:, :, lines_masked] = 0
            print(batch_data['d'].shape)
            print("lines unmasked", lines_unmasked)

        elif prune_type == "sq":
            A = np.load("ranks/switches_2D_equal_iter_390.npy",
                        allow_pickle=True)
            # with np.printoptions(precision=5):
            #     print("switches", A)
            # get the ver and hor coordinates of the most important squares
            A_2d_argsort = np.argsort(A, None)[::-1]
            if square_choice == "most":
                squares_top_file = "ranks/sq/global/squares_most.npy"
                A_2d_argsort = np.load(squares_top_file)[::-1]
            ver = np.floor(A_2d_argsort // A.shape[1])
            hor = A_2d_argsort % A.shape[1]
            A_list = np.stack([ver, hor]).transpose()
            square_size = 40
            squares_top_num = 50

            if square_choice == "full":
                squares_top = A_list

            if square_choice == "most":
                squares_top = A_list[:squares_top_num]

            if square_choice == "best_sw":
                squares_top = A_list[:squares_top_num]

            if square_choice == "latin_sw":
                # creating latin grid (with big squares/blocks)
                hor_large = np.linspace(0, 30, 7)
                ver_larger = np.arange(10)
                all_squares = np.arange(len(A_list))
                bins_2d_latin = binned_statistic_2d(
                    ver, hor, all_squares, 'min', bins=[ver_larger, hor_large])
                bins_2d_latin.statistic
                best_latin = bins_2d_latin.statistic[-3:].flatten().astype(int)
                best_latin_coors = list(A_list[best_latin])
                for i1 in A_list:
                    elem_in = False  # check if the block already contains a small square
                    for i2 in best_latin_coors:
                        if i1[0] == i2[0] and i1[1] == i2[1]:
                            elem_in = True
                    if not elem_in:
                        best_latin_coors.append(i1)
                    if len(best_latin_coors) == squares_top_num:
                        break
                squares_top = np.array(best_latin_coors)

            elif square_choice == "latin":
                np.random.seed(12)
                squares_latin_evenlyspaced = []
                # create blocks, choose k random blocks and have fixed first block
                hor_large = np.linspace(0, 30, 7)
                ver_large = np.arange(10)
                # random sample from positive blocks
                hor_large_rand = np.random.choice(len(hor_large),
                                                  squares_top_num)
                ver_large_rand = np.random.choice([6, 7, 8], squares_top_num)
                # selecting a small square from A_list with given corrdinates within a block
                for j in range(len(hor_large_rand)):
                    elem = \
                    np.where((A_list[:, 0] == ver_large_rand[j]) & (A_list[:, 1] == hor_large[hor_large_rand[j]]))[0][0]
                    squares_latin_evenlyspaced.append(elem)
                squares_top = A_list[squares_latin_evenlyspaced]

            elif square_choice == "random_all":
                np.random.seed(12)
                rand_idx = np.random.choice(len(A_list), squares_top_num)
                print(rand_idx)
                squares_top = A_list[rand_idx]

            elif square_choice == "random_pos":  # from squares which include depth points
                np.random.seed(12)
                # choose from the squares which have roughly positive number of depth points
                rand_idx = np.random.choice(len(A_list[:93]), squares_top_num)
                print(rand_idx)
                squares_top = A_list[rand_idx]

            # after selecting indices of the squares save in squares_top
            squares_top_scaled = np.array(squares_top) * square_size
            mask = np.zeros((352, 1216))
            bin_ver = np.arange(0, 352, square_size)
            bin_ver = np.append(bin_ver, oheight)
            bin_hor = np.arange(0, 1216, square_size)
            bin_hor = np.append(bin_hor, owidth)
            # filling in the mask with selected squares up to squares_top_num (e.g. 20)
            # print("Number of squares selected: ", len(squares_top))
            # print(squares_top)
            for it in range(
                    len(squares_top)
            ):  # in all but full should be equal to squares_top_num
                ver = int(squares_top[it][0])
                hor = int(squares_top[it][1])
                # print("ver", bin_ver[ver], bin_ver[ver+1], "hor", bin_hor[hor], bin_hor[hor+1] )
                mask[bin_ver[ver]:bin_ver[ver + 1],
                     bin_hor[hor]:bin_hor[hor + 1]] = 1

            aaa1 = batch_data['d'].detach().cpu().numpy()
            batch_data['d'] = torch.einsum(
                "abcd, cd->abcd",
                [batch_data['d'],
                 torch.Tensor(mask).to(device)])
            aaa2 = batch_data['d'].detach().cpu().numpy()
            #
            # # from PIL import Image
            # # img = Image.fromarray(aaa[0, :, :, :], 'RGB')
            # # #img.save('my.png')
            # # img.show()

        pred = model(batch_data)
        # im = batch_data['d'].detach().cpu().numpy()
        # im_sq = im.squeeze()
        # plt.figure()
        # plt.imshow(im_sq)
        # plt.show()
        # for i in range(im_sq.shape[0]):
        #    print(f"{i} - {np.sum(im_sq[i])}")

        depth_loss, photometric_loss, smooth_loss, mask = 0, 0, 0, None
        if mode == 'train':
            # Loss 1: the direct depth supervision from ground truth label
            # mask=1 indicates that a pixel does not ground truth labels
            if 'sparse' in args.train_mode:
                depth_loss = depth_criterion(pred, batch_data['d'])
                mask = (batch_data['d'] < 1e-3).float()
            elif 'dense' in args.train_mode:
                depth_loss = depth_criterion(pred, gt)
                mask = (gt < 1e-3).float()

            # Loss 2: the self-supervised photometric loss
            if args.use_pose:
                # create multi-scale pyramids
                pred_array = helper.multiscale(pred)
                rgb_curr_array = helper.multiscale(batch_data['rgb'])
                rgb_near_array = helper.multiscale(batch_data['rgb_near'])
                if mask is not None:
                    mask_array = helper.multiscale(mask)
                num_scales = len(pred_array)

                # compute photometric loss at multiple scales
                for scale in range(len(pred_array)):
                    pred_ = pred_array[scale]
                    rgb_curr_ = rgb_curr_array[scale]
                    rgb_near_ = rgb_near_array[scale]
                    mask_ = None
                    if mask is not None:
                        mask_ = mask_array[scale]

                    # compute the corresponding intrinsic parameters
                    height_, width_ = pred_.size(2), pred_.size(3)
                    intrinsics_ = kitti_intrinsics.scale(height_, width_)

                    # inverse warp from a nearby frame to the current frame
                    warped_ = homography_from(rgb_near_, pred_,
                                              batch_data['r_mat'],
                                              batch_data['t_vec'], intrinsics_)
                    photometric_loss += photometric_criterion(
                        rgb_curr_, warped_, mask_) * (2**(scale - num_scales))

            # Loss 3: the depth smoothness loss
            smooth_loss = smoothness_criterion(pred) if args.w2 > 0 else 0

            # backprop
            loss = depth_loss + args.w1 * photometric_loss + args.w2 * smooth_loss
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        gpu_time = time.time() - start

        # measure accuracy and record loss
        with torch.no_grad():
            mini_batch_size = next(iter(batch_data.values())).size(0)
            result = Result()
            if mode != 'test_prediction' and mode != 'test_completion':
                result.evaluate(pred.data, gt.data, photometric_loss)
            [
                m.update(result, gpu_time, data_time, mini_batch_size)
                for m in meters
            ]
            logger.conditional_print(mode, i, epoch, lr, len(loader),
                                     block_average_meter, average_meter)
            logger.conditional_save_img_comparison(mode, i, batch_data, pred,
                                                   epoch)
            logger.conditional_save_pred(mode, i, pred, epoch)

        every = 100
        if i % 500 == 0:  # every 100 batches/images (before it was after the entire dataset - two tabs/on if statement)
            avg = logger.conditional_save_info(mode, average_meter, epoch)
            is_best = logger.rank_conditional_save_best(mode, avg, epoch)
            if is_best and not (mode == "train"):
                logger.save_img_comparison_as_best(mode, epoch)
            logger.conditional_summarize(mode, avg, is_best)

            helper.save_checkpoint({  # save checkpoint
                'epoch': epoch,
                'model': model.module.state_dict(),
                'best_result': logger.best_result,
                'optimizer': optimizer.state_dict(),
                'args': args,
            }, is_best, epoch, logger.output_directory, args.type_feature, i, every, "scratch")

    return avg, is_best