예제 #1
0
    def source_finder(self, image=None, thresh=None, prefix=None,
                      noise=None, output=None, savemask=None, **kw):
        
        #kw.update(kwards)
        tpos = None
        naxis = self.header["NAXIS1"] 
        boundary = numpy.array([self.locstep, self.cfstep])
        #trim_box = (boundary.max(), naxis - boundary.max(),
        #          boundary.max(), naxis - boundary.max())
        trim_box = None
        # data smoothing
        if self.smoothing:

            ext = utils.fits_ext(image)
            tpos = tempfile.NamedTemporaryFile(suffix="."+ext, dir=".")
            tpos.flush()

            mask, noise = utils.thresh_mask(image, tpos.name,
                          thresh=thresh, noise=self.noise, 
                          sigma=True, smooth=True, prefix=prefix, 
                          savemask=savemask)

            # using the masked image for forming islands
            kw["detection_image"] = tpos.name
            kw["blank_limit"] = self.noise/1.0e5

        # source extraction
        utils.sources_extraction(
             image=image, output=output, 
             sourcefinder_name=self.sourcefinder_name,
             trim_box=trim_box,
             prefix=self.prefix, **kw)

        if tpos:
            tpos.close()
예제 #2
0
    def source_finder(self, image=None, thresh=None,
                      noise=None, lsmname=None, **kw):
        
        #TODO look for other source finders and how they operate
         

        thresh = thresh or self.pos_smooth
        image = image or self.imagename

        ext = utils.fits_ext(image)
        tpos = tempfile.NamedTemporaryFile(suffix="."+ext, dir=".")
        tpos.flush()
        
        # data smoothing
        mask, noise = utils.thresh_mask(
                          image, tpos.name,
                          thresh=thresh, noise=self.noise, 
                          sigma=True, smooth=True)

        lsmname = lsmname or self.poslsm
        # source extraction
        utils.sources_extraction(
             image=tpos.name, output=lsmname, 
             sourcefinder_name=self.sourcefinder_name, 
             blank_limit=self.noise/100.0, prefix=self.prefix,
             **kw)
예제 #3
0
    def source_finder(self,
                      image=None,
                      thresh=None,
                      prefix=None,
                      noise=None,
                      output=None,
                      savemask=None,
                      **kw):

        #kw.update(kwards)
        tpos = None
        naxis = self.header["NAXIS1"]
        boundary = numpy.array([self.locstep, self.cfstep])
        #trim_box = (boundary.max(), naxis - boundary.max(),
        #          boundary.max(), naxis - boundary.max())
        trim_box = None
        # data smoothing
        if self.smoothing:

            ext = utils.fits_ext(image)
            tpos = tempfile.NamedTemporaryFile(suffix="." + ext, dir=".")
            tpos.flush()

            mask, noise = utils.thresh_mask(image,
                                            tpos.name,
                                            thresh=thresh,
                                            noise=self.noise,
                                            sigma=True,
                                            smooth=True,
                                            prefix=prefix,
                                            savemask=savemask)

            # using the masked image for forming islands
            kw["detection_image"] = tpos.name
            kw["blank_limit"] = self.noise / 1.0e5

        # source extraction
        utils.sources_extraction(image=image,
                                 output=output,
                                 sourcefinder_name=self.sourcefinder_name,
                                 trim_box=trim_box,
                                 prefix=self.prefix,
                                 **kw)

        if tpos:
            tpos.close()
예제 #4
0
def test_model(save_dir, save_img=False, evaluate=True):

    if not os.path.exists('%s/testimg' % save_dir):
        os.makedirs('%s/testimg' % save_dir)

    # load saved model
    model = Model_rgbd().cuda()
    model.load_state_dict(
        torch.load(os.path.join(save_dir, 'model_1up_TXrefixed/epoch-5.pth')))
    # model.load_state_dict(torch.load(os.path.join(save_dir, 'epoch-19.pth')))
    model.eval()
    print('model loaded for evaluation.')

    # Load data
    test_loader = getTestingDataOnly(batch_size=1)
    train_loader_l, test_loader_l = getTranslucentData(batch_size=1)

    with torch.cuda.device(0):
        model.eval()

        tot_len = len(
            test_loader_l)  # min(len(test_loader), len(test_loader_l))
        testiter = iter(test_loader)
        testiter_l = iter(test_loader_l)

        for i in range(tot_len):
            # print("Iteration "+str(i)+". loop start:")
            try:
                sample_batched = next(testiter)
                sample_batched_l = next(testiter_l)
            except StopIteration:
                print('  (almost) end of iteration: %d.' % i)
                break
            print('/=/=/=/=/=/ iter %02d /=/=/=/=/' % i)

            # (1) Pretext task : test and save
            image_nyu = torch.autograd.Variable(sample_batched['image'].cuda())
            depth_nyu = torch.autograd.Variable(
                sample_batched['depth'].cuda(non_blocking=True))

            mask_raw = torch.autograd.Variable(sample_batched_l['mask'].cuda())

            depth_nyu_n = DepthNorm(depth_nyu)

            # # Apply random mask to it
            ordered_index = list(
                range(depth_nyu.shape[0])
            )  # NOTE: NYU test batch size shouldn't be bigger than lucent's.
            mask_new = mask_raw[ordered_index, :, :, :]
            depth_nyu_masked = resize2d(depth_nyu_n, (480, 640)) * mask_new

            # if i <= 1:
            #     print('====/ %02d /====' % i)
            #     print(image_nyu.shape)
            #     print(" " + str(torch.max(image_nyu)) + " " + str(torch.min(image_nyu)))
            #     print(depth_nyu.shape)
            #     print(" " + str(torch.max(depth_nyu)) + " " + str(torch.min(depth_nyu)))
            #     print(mask_new.shape)
            #     print(" " + str(torch.max(mask_new)) + " " + str(torch.min(mask_new)))

            # Predict
            (htped_out_t1, _) = model(image_nyu, depth_nyu_masked)
            depth_out_t1 = DepthNorm(htped_out_t1)

            dn_resized = resize2d(depth_nyu, (240, 320))

            if save_img:
                # Save image
                vutils.save_image(depth_out_t1,
                                  '%s/testimg/1out_%02d.png' % (save_dir, i),
                                  normalize=True,
                                  range=(0, 1000))
                if not os.path.exists('%s/testimg/1in_000000_%02d.png' %
                                      (save_dir, i)):
                    vutils.save_image(depth_nyu_masked,
                                      '%s/testimg/1in_%02d.png' %
                                      (save_dir, i),
                                      normalize=True,
                                      range=(0, 1000))
                save_error_image(depth_out_t1 - dn_resized,
                                 '%s/testimg/1diff_%02d.png' % (save_dir, i),
                                 normalize=True,
                                 range=(-300, 300))

            del image_nyu, depth_nyu, htped_out_t1, depth_out_t1, dn_resized

            # (2) Main task : test and save
            image = torch.autograd.Variable(sample_batched_l['image'].cuda())
            depth_in = torch.autograd.Variable(
                sample_batched_l['depth_raw'].cuda())
            htped_in = DepthNorm(depth_in)

            depth_gt = torch.autograd.Variable(
                sample_batched_l['depth_truth'].cuda(non_blocking=True))

            (_, htped_out_t2) = model(image, htped_in)
            depth_out_t2 = DepthNorm(htped_out_t2)

            mask_small = resize2dmask(mask_raw, (240, 320))
            obj_mask = thresh_mask(depth_gt, resize2d(depth_in, (240, 320)))
            # print(" " + str(torch.max(depth_out_t2)) + " " + str(torch.min(depth_out_t2)))
            # print(" " + str(torch.max(depth_gt)) + " " + str(torch.min(depth_gt)))
            # print(" " + str(torch.max(depth_in)) + " " + str(torch.min(depth_in)))
            if i == 0:
                (s0, s1, s2, s3) = depth_out_t2.size()
                # https://stackoverflow.com/questions/22392497/how-to-add-a-new-row-to-an-empty-numpy-array
                true_y = np.empty((0, s1, s2, s3), float)
                raw_y = np.empty((0, s1, s2, s3), float)
                pred_y = np.empty((0, s1, s2, s3), float)
                mask_y = np.empty((0, s1, s2, s3), float)
                objmask_y = np.empty((0, s1, s2, s3), float)
            if evaluate:
                true_y = np.append(true_y, depth_gt.cpu().numpy(), axis=0)
                raw_y = np.append(raw_y,
                                  resize2d(depth_in, (240, 320)).cpu().numpy(),
                                  axis=0)
                pred_y = np.append(pred_y,
                                   depth_out_t2.detach().cpu().numpy(),
                                   axis=0)
                mask_y = np.append(mask_y, mask_small.cpu().numpy(), axis=0)
                objmask_y = np.append(objmask_y,
                                      obj_mask.cpu().numpy(),
                                      axis=0)

            # dl = depth_in.cpu().numpy()
            # hl = htped_in.cpu().numpy()
            # dr = resize2d(depth_in, (240, 320)).cpu().numpy()
            # hr = resize2d(htped_in, (240, 320)).cpu().numpy()
            # do = depth_out_t2.cpu().detach().numpy()
            # gr = depth_gt.cpu().numpy()
            #
            # print("  Depth input (original size):" + str(np.min(dl)) + "~" + str(np.max(dl)) + " (" + str(np.mean(dl)) + ")")
            # print("  Depth Normed (original size):" + str(np.min(hl)) + "~" + str(np.max(hl)) + " (" + str(np.mean(hl)) + ")")
            #
            # print("  Depth input (resized):" + str(np.min(dr)) + "~" + str(np.max(dr)) + " (" + str(np.mean(dr)) + ")")
            # print("  Depth Normed (resized):" + str(np.min(hr)) + "~" + str(np.max(hr)) + " (" + str(np.mean(hr)) + ")")
            #
            # print("  Output converted to depth:" + str(np.min(do)) + "~" + str(np.max(do)) + " (" + str(np.mean(do)) + ")")
            # print("  GT depth (original size):" + str(np.min(gr)) + "~" + str(np.max(gr)) + " (" + str(np.mean(gr)) + ")")

            if save_img:
                if not os.path.exists('%s/testimg/2truth_000000_%02d.png' %
                                      (save_dir, i)):
                    vutils.save_image(depth_in,
                                      '%s/testimg/2in_%02d.png' %
                                      (save_dir, i),
                                      normalize=True,
                                      range=(0, 500))
                    vutils.save_image(resize2d(depth_in, (240, 320)),
                                      '%s/testimg/2in_s_%02d.png' %
                                      (save_dir, i),
                                      normalize=True,
                                      range=(0, 500))
                    vutils.save_image(depth_gt,
                                      '%s/testimg/2truth_%02d.png' %
                                      (save_dir, i),
                                      normalize=True,
                                      range=(0, 500))
                vutils.save_image(depth_out_t2,
                                  '%s/testimg/2out_%02d.png' % (save_dir, i),
                                  normalize=True,
                                  range=(0, 500))
                save_error_image(resize2d(depth_out_t2, (480, 640)) - depth_in,
                                 '%s/testimg/2corr_%02d.png' % (save_dir, i),
                                 normalize=True,
                                 range=(-50, 50),
                                 mask=mask_raw)
                save_error_image(depth_out_t2 - depth_gt,
                                 '%s/testimg/2diff_%02d.png' % (save_dir, i),
                                 normalize=True,
                                 range=(-50, 50),
                                 mask=mask_small)
                vutils.save_image(mask_small,
                                  '%s/testimg/2_mask_%02d.png' % (save_dir, i),
                                  normalize=True,
                                  range=(-0.5, 1.5))
                vutils.save_image(obj_mask,
                                  '%s/testimg/2_objmask_%02d.png' %
                                  (save_dir, i),
                                  normalize=True,
                                  range=(-0.5, 1.5))
            del image, htped_in, depth_in, depth_gt, depth_out_t2, mask_raw, mask_small

    if evaluate:

        eo = eo_r = 0
        print(
            '#    \ta1    \ta2    \ta3    \tabsrel\trmse  \tlog10 | \timprovements--> '
        )
        for j in range(len(true_y)):
            # errors = compute_errors(true_y[j], pred_y[j], mask_y[j])
            errors_object = compute_errors(true_y[j], pred_y[j],
                                           mask_y[j] * objmask_y[j])
            # errors_r = compute_errors(true_y[j], raw_y[j], mask_y[j])
            errors_object_r = compute_errors(true_y[j], raw_y[j],
                                             mask_y[j] * objmask_y[j])

            eo = eo + errors_object
            eo_r = eo_r + errors_object_r

            print('{j:2d} | \t'
                  '{e[1]:.4f}\t'
                  '{e[2]:.4f}\t'
                  '{e[3]:.4f}\t'
                  '{e[4]:.4f}\t'
                  '{e[5]:.3f}\t'
                  '{e[6]:.4f} | \t'
                  '{f1[1]:+.3f}\t'
                  '{f1[2]:+.3f}\t'
                  '{f1[3]:+.3f}\t'
                  '{f2[4]:+.3f}\t'
                  '{f2[5]:+.3f}\t'
                  '{f2[6]:+.3f}'.format(
                      j=j,
                      e=errors_object,
                      f1=(1 - errors_object_r) / (1 - errors_object) - 1,
                      f2=errors_object_r / errors_object - 1))

        eo = eo / len(true_y)
        eo_r = eo_r / len(true_y)
        print('\ntotal \t'
              '{e[1]:.4f}\t'
              '{e[2]:.4f}\t'
              '{e[3]:.4f}\t'
              '{e[4]:.4f}\t'
              '{e[5]:.3f}\t'
              '{e[6]:.4f} | \t'
              '{f1[1]:+.3f}\t'
              '{f1[2]:+.3f}\t'
              '{f1[3]:+.3f}\t'
              '{f2[4]:+.3f}\t'
              '{f2[5]:+.3f}\t'
              '{f2[6]:+.3f}'.format(e=eo,
                                    f1=(1 - eo_r) / (1 - eo) - 1,
                                    f2=eo_r / eo - 1))
예제 #5
0
파일: train.py 프로젝트: wtre/lucents
def main():
    # Arguments
    parser = argparse.ArgumentParser(description='High Quality Monocular Depth Estimation via Transfer Learning')
    parser.add_argument('--epochs', default=20, type=int, help='number of total epochs to run')
    parser.add_argument('--lr', '--learning-rate', default=0.0001, type=float, help='initial learning rate')
    parser.add_argument('--bs', default=4, type=int, help='batch size')
    args = parser.parse_args()
    SAVE_DIR = 'models/191107_mod15'
    ifcrop = True

    if ifcrop:
        HEIGHT = 256
        HEIGHT_WITH_RATIO = 240
        WIDTH = 320
    else:
        HEIGHT = 480
        WIDTH = 640

    with torch.cuda.device(0):

        # Create model
    #    model = Model().cuda()
    #    print('Model created.')
    # =============================================================================
        # load saved model
        # model = Model_rgbd().cuda()
        # model.load_state_dict(torch.load(os.path.join(SAVE_DIR, 'model_overtraining.pth')))
        # model.eval()
        # print('model loaded for evaluation.')
    # =============================================================================
        # Create RGB-D model
        model = Model_rgbd().cuda()
        print('Model created.')
    # =============================================================================

        # Training parameters
        optimizer = torch.optim.Adam( model.parameters(), args.lr, amsgrad=True )
        batch_size = args.bs
        prefix = 'densenet_' + str(batch_size)

        # Load data
        train_loader, test_loader = getTrainingTestingData(batch_size=1, crop_halfsize=ifcrop)
        train_loader_l, test_loader_l = getTranslucentData(batch_size=1, crop_halfsize=ifcrop)
        # Test batch is manually enlarged! See getTranslucentData's return.

        # Logging
        writer = SummaryWriter(comment='{}-lr{}-e{}-bs{}'.format(prefix, args.lr, args.epochs, args.bs), flush_secs=30)

        # Loss
        l1_criterion = nn.L1Loss()
        l1_criterion_masked = MaskedL1()
        grad_l1_criterion_masked = MaskedL1Grad()

        # Hand-craft loss weight of main task
        interval1 = 1
        interval2 = 2
        weight_t1loss = [1] * (10*interval1) + [0] * interval2
        weight_txloss = [.0317] * interval1 + [.1] * interval1 + \
                        [.316] * interval1 + [1] * interval1 + \
                        [3.16] * interval1 + [10] * interval1 + \
                        [10] * interval1 + [5.62] * interval1 + \
                        [3.16] * interval1 + [1.78] * interval1 + \
                        [0] * interval2
        weight_t2loss = [.001] * interval1 + [.00316] * interval1 + \
                        [.01] * interval1 + [.0316] * interval1 + \
                        [.1] * interval1 + [.316] * interval1 + \
                        [1.0] * interval1 + [3.16] * interval1 + \
                        [10.0] * interval1 + [31.6] * interval1 + \
                        [100.0] * interval2

        if not os.path.exists('%s/img' % SAVE_DIR):
            os.makedirs('%s/img' % SAVE_DIR)

        # Start training...
        for epoch in range(0, 10*interval1 + interval2):
            batch_time = AverageMeter()
            losses_nyu = AverageMeter()
            losses_lucent = AverageMeter()
            losses_hole = AverageMeter()
            losses = AverageMeter()
            N = len(train_loader)

            # Switch to train mode
            model.train()

            end = time.time()

            # decide #(iter)
            tot_len = min(len(train_loader), len(train_loader_l))
            # print(tot_len)

            trainiter = iter(train_loader)
            trainiter_l = iter(train_loader_l)

            for i in range(tot_len):
                # print("Iteration "+str(i)+". loop start:")
                try:
                    sample_batched = next(trainiter)
                    sample_batched_l = next(trainiter_l)
                except StopIteration:
                    print('  (almost) end of iteration.')
                    continue
                # print('in loop.')

                # Prepare sample and target
                image_nyu = torch.autograd.Variable(sample_batched['image'].cuda())
                depth_nyu = torch.autograd.Variable(sample_batched['depth'].cuda(non_blocking=True))

                image_raw = torch.autograd.Variable(sample_batched_l['image'].cuda())
                mask_raw = torch.autograd.Variable(sample_batched_l['mask'].cuda())
                depth_raw = torch.autograd.Variable(sample_batched_l['depth_raw'].cuda())
                depth_gt = torch.autograd.Variable(sample_batched_l['depth_truth'].cuda(non_blocking=True))

                # if i < 10:
                #     print('========-=-=')
                #     print(image_nyu.shape)
                #     print(depth_nyu.shape)
                #     print(image_raw.shape)
                #     print(" " + str(torch.max(image_raw)) + " " + str(torch.min(image_raw)))
                #     print(mask_raw.shape)
                #     print(" " + str(torch.max(mask_raw)) + " " + str(torch.min(mask_raw)))
                #     print(depth_raw.shape)
                #     print(" " + str(torch.max(depth_raw)) + " " + str(torch.min(depth_raw)))
                #     print(depth_gt.shape)
                #     print(" " + str(torch.max(depth_gt)) + " " + str(torch.min(depth_gt)))

                N1 = image_nyu.shape[0]
                N2 = image_raw.shape[0]

                ###########################
                # (1) Pretext task: depth completion

                # if weight_t1loss[epoch] > 0:
                # Normalize depth
                depth_nyu_n = DepthNorm(depth_nyu)

                # Apply random mask to it
                rand_index = [random.randint(0, N2-1) for k in range(N1)]
                mask_new = mask_raw[rand_index, :, :, :]
                depth_nyu_masked = resize2d(depth_nyu_n, (HEIGHT, WIDTH)) * mask_new

                # if i < 1:
                #     print('========')
                #     print(image_nyu.shape)
                #     print(" " + str(torch.max(image_raw)) + " " + str(torch.min(image_raw)))
                #     print(depth_nyu_masked.shape)
                #     print(" " + str(torch.max(depth_nyu_masked)) + " " + str(torch.min(depth_nyu_masked)))

                # Predict
                (output_t1, _) = model(image_nyu, depth_nyu_masked)
                # print("  (1): " + str(output_task1.shape))

                # Calculate Loss and backprop
                l_depth_t1 = l1_criterion(output_t1, depth_nyu_n)
                l_grad_t1 = l1_criterion(grad_x(output_t1), grad_x(depth_nyu_n)) + l1_criterion(grad_y(output_t1), grad_y(depth_nyu_n))
                l_ssim_t1 = torch.clamp((1 - ssim(output_t1, depth_nyu_n, val_range=1000.0 / 10.0)) * 0.5, 0, 1)
                loss_nyu = (0.1 * l_depth_t1) + (1.0 * l_grad_t1) + (1.0 * l_ssim_t1)
                # loss_nyu_weighted = weight_t1loss[epoch] * loss_nyu

                # https://discuss.pytorch.org/t/freeze-the-learnable-parameters-of-resnet-and-attach-it-to-a-new-network/949
                # freeze_weight(model, e_stay=False, e=False, d1_stay=False, d1=True)
                # optimizer.zero_grad()  # moved to its new position
                # loss_nyu_weighted.backward(retain_graph=True)
                # optimizer.step()

                if i % 150 == 0 or i < 2:
                    vutils.save_image(DepthNorm(depth_nyu_masked), '%s/img/A_masked_%06d.png' % (SAVE_DIR, epoch*10000+i), normalize=True)
                    vutils.save_image(DepthNorm(output_t1), '%s/img/A_out_%06d.png' % (SAVE_DIR, epoch*10000+i), normalize=True)
                    save_error_image(DepthNorm(output_t1) - depth_nyu, '%s/img/A_diff_%06d.png' % (SAVE_DIR, epoch * 10000 + i), normalize=True, range=(-500, 500))

                torch.cuda.empty_cache()

                ###########################
                # (x) Transfer task: /*Fill*/ reconstruct sudo-translucent object

                depth_gt_n = DepthNorm(depth_gt)
                depth_raw_n = DepthNorm(depth_raw)
                # if weight_txloss[epoch] > 0:

                # Normalize depth
                depth_gt_large = resize2d(depth_gt, (HEIGHT, WIDTH))
                object_mask = thresh_mask(depth_gt_large, depth_raw)
                # depth_holed = depth_raw * object_mask
                depth_holed = blend_depth(depth_raw, depth_gt_large, object_mask)

                # print('========')
                # print(object_mask.shape)
                # print(" " + str(torch.max(object_mask)) + " " + str(torch.min(object_mask)))
                # print(depth_holed.shape)
                # print(" " + str(torch.max(depth_holed)) + " " + str(torch.min(depth_holed)))
                # print(image_raw.shape)
                # print(" " + str(torch.max(image_raw)) + " " + str(torch.min(image_raw)))

                (output_tx, _) = model(image_raw, DepthNorm(depth_holed))
                output_tx_n = DepthNorm(output_tx)

                # Calculate Loss and backprop
                mask_post = resize2dmask(mask_raw, (int(HEIGHT/2), int(WIDTH/2)))
                l_depth_tx = l1_criterion_masked(output_tx, depth_gt_n, mask_post)
                l_grad_tx = grad_l1_criterion_masked(output_tx, depth_gt_n, mask_post)
                # l_ssim_tx = torch.clamp((1 - ssim(output_tx, depth_nyu_n, val_range=1000.0 / 10.0)) * 0.5, 0, 1)
                loss_hole = (0.1 * l_depth_tx) + (1.0 * l_grad_tx) #+ (0 * l_ssim_tx) ####
                # loss_hole_weighted = weight_txloss[epoch] * loss_hole

                # for param in model.decoder1.parameters():
                #     param.requires_grad = False
                # for param in model.decoder2.parameters():
                #     param.requires_grad = True

                # freeze_weight(model, d1_stay=False, d1=False, d2_stay=False, d2=True)
                # optimizer.zero_grad()
                # loss_hole_weighted.backward(retain_graph=True)  # https://pytorch.org/docs/stable/autograd.html
                # optimizer.step()

                if i % 150 == 0 or i < 2:
                    vutils.save_image(DepthNorm(depth_holed), '%s/img/C_in_%06d.png' % (SAVE_DIR, epoch * 10000 + i),
                                      normalize=True, range=(0, 500))
                    vutils.save_image(object_mask, '%s/img/C_mask_%06d.png' % (SAVE_DIR, epoch * 10000 + i),
                                      normalize=True, range=(0, 1.5))
                    vutils.save_image(output_tx_n, '%s/img/C_out_%06d.png' % (SAVE_DIR, epoch * 10000 + i),
                                      normalize=True, range=(0, 500))
                    save_error_image(output_tx_n - depth_gt, '%s/img/C_zdiff_%06d.png' % (SAVE_DIR, epoch * 10000 + i),
                                     normalize=True, range=(-500, 500))
                torch.cuda.empty_cache()

                ###########################
                # (2) Main task: Undistort translucent object

                # Predict
                (_, output_t2) = model(image_raw, depth_raw_n)
                output_t2_n = DepthNorm(output_t2)
                # print("  (2): " + str(output.shape))

                # Calculate Loss and backprop
                l_depth_t2 = l1_criterion_masked(output_t2, depth_gt_n, mask_post)
                l_grad_t2 = grad_l1_criterion_masked(output_t2, depth_gt_n, mask_post)
                # l_ssim_t2 = torch.clamp((1 - ssim(output_t2, depth_gt_n, val_range=1000.0/10.0)) * 0.5, 0, 1)
                loss_lucent = (0.1 * l_depth_t2) + (1.0 * l_grad_t2) # + (0 * l_ssim_t2)
                # loss_lucent_weighted = weight_t2loss[epoch] * loss_lucent

                # optimizer.zero_grad()  # moved to its new position
                # loss_lucent_weighted.backward(retain_graph=True)
                # optimizer.step()

                if i % 150 == 0 or i < 2:
                    vutils.save_image(depth_raw, '%s/img/B_ln_%06d.png' % (SAVE_DIR, epoch*10000+i), normalize=True, range=(0, 500))
                    vutils.save_image(depth_gt, '%s/img/B_gt_%06d.png' % (SAVE_DIR, epoch*10000+i), normalize=True, range=(0, 500))
                    vutils.save_image(output_t2_n, '%s/img/B_out_%06d.png' % (SAVE_DIR, epoch*10000+i), normalize=True, range=(0, 500))
                    save_error_image(output_t2_n-depth_gt, '%s/img/B_zdiff_%06d.png' % (SAVE_DIR, epoch*10000+i), normalize=True, range=(-500, 500))

                if i % 150 == 0 :
                    o2 = output_t2.cpu().detach().numpy()
                    o3 = output_t2_n.cpu().detach().numpy()
                    og = depth_gt.cpu().numpy()
                    nm = DepthNorm(depth_nyu_masked).cpu().numpy()
                    ng = depth_nyu.cpu().numpy()
                    print('> ========')
                    print("> Output_t2:" + str(np.min(o2)) + "~" + str(np.max(o2)) + " (" + str(np.mean(o2)) +
                          ") // Converted to depth: " + str(np.min(o3)) + "~" + str(np.max(o3)) + " (" + str(np.mean(o3)) + ")")
                    print("> GT depth : " + str(np.min(og)) + "~" + str(np.max(og)) +
                          " // NYU GT depth from 0.0~" + str(np.max(nm)) + " to " + str(np.min(ng)) + "~" + str(np.max(ng)) + " (" + str(np.mean(ng)) + ")")


                ###########################
                # (3) Update the network parameters
                if i % 150 == 0 or i < 1:
                    vutils.save_image(mask_post, '%s/img/_mask_%06d.png' % (SAVE_DIR, epoch * 10000 + i), normalize=True)

                loss = (weight_t1loss[epoch] * loss_nyu) + (weight_t2loss[epoch] * loss_lucent) + (weight_txloss[epoch] * loss_hole) ####
                # freeze_weight(model, e_stay=False, e=True, d2_stay=False, d2=False)
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
                # loss = loss_nyu + loss_lucent + loss_hole

                # Log losses
                losses_nyu.update(loss_nyu.detach().item(), image_nyu.size(0))
                losses_lucent.update(loss_lucent.detach().item(), image_raw.size(0))
                losses_hole.update(loss_hole.detach().item(), image_raw.size(0))
                losses.update(loss.detach().item(), image_nyu.size(0) + image_raw.size(0))

                # Measure elapsed time
                batch_time.update(time.time() - end)
                end = time.time()
                eta = str(datetime.timedelta(seconds=int(batch_time.val*(N - i))))

                # Log progress
                niter = epoch*N+i
                if i % 15 == 0:
                    # Print to console
                    print('Epoch: [{0}][{1}/{2}]\t'
                    'Time {batch_time.val:.3f} ({batch_time.sum:.3f})\t'
                    'ETA {eta}\t'
                    'Loss {loss.val:.4f} ({loss.avg:.4f}) ||\t'
                    'NYU {l1.val:.4f} ({l1.avg:.4f}) [{l1d:.4f} | {l1g:.4f} | {l1s:.4f}]\t'
                    'LUC {l2.val:.4f} ({l2.avg:.4f}) [{l2d:.4f} | {l2g:.4f}]\t'
                    'TX {lx.val:.4f} ({lx.avg:.4f}) [{lxd:.4f} | {lxg:.4f}]'
                    .format(epoch, i, N, batch_time=batch_time, loss=losses, l1=losses_nyu, l1d=l_depth_t1, l1g=l_grad_t1, l1s=l_ssim_t1,
                            l2=losses_lucent, l2d=l_depth_t2, l2g=l_grad_t2, lx=losses_hole, lxd=l_depth_tx, lxg=l_grad_tx, eta=eta))
                    # Note that the numbers displayed are pre-weighted.

                    # Log to tensorboard
                    writer.add_scalar('Train/Loss', losses.val, niter)

                if i % 750 == 0:
                    LogProgress(model, writer, test_loader, test_loader_l, niter, epoch*10000+i, SAVE_DIR, HEIGHT, WIDTH)
                    path = os.path.join(SAVE_DIR, 'model_overtraining.pth')
                    torch.save(model.cpu().state_dict(), path) # saving model
                    model.cuda() # moving model to GPU for further training

                del image_nyu, depth_nyu_masked, output_t1, image_raw, depth_raw_n, output_t2
                torch.cuda.empty_cache()

            # Record epoch's intermediate results
            LogProgress(model, writer, test_loader, test_loader_l, niter, epoch*10000+i, SAVE_DIR, HEIGHT, WIDTH)
            writer.add_scalar('Train/Loss.avg', losses.avg, epoch)
            # all the saves come from https://discuss.pytorch.org/t/how-to-save-a-model-from-a-previous-epoch/20252

            torch.save(model.state_dict(), os.path.join(SAVE_DIR, 'epoch-{}.pth'.format(epoch)))

    print('Program terminated.')