예제 #1
0
def main():

    resume = True
    path = 'data/NYU_DEPTH'
    batch_size = 16
    epochs = 10000
    device = torch.device('cuda:0')
    print_every = 5
    # exp_name = 'resnet18_nodropout_new'
    exp_name = 'only_depth'
    # exp_name = 'normal_internel'
    # exp_name = 'sep'
    lr = 1e-5
    weight_decay = 0.0005
    log_dir = os.path.join('logs', exp_name)
    model_dir = os.path.join('checkpoints', exp_name)
    val_every = 16
    save_every = 16


    # tensorboard
    # remove old log is not to resume
    if not resume:
        if os.path.exists(log_dir):
            shutil.rmtree(log_dir)
            os.makedirs(log_dir)
    if not os.path.exists(model_dir):
        os.makedirs(model_dir)
    tb = SummaryWriter(log_dir)
    tb.add_custom_scalars({
        'metrics': {
            'thres_1.25': ['Multiline', ['thres_1.25/train', 'thres_1.25/test']],
            'thres_1.25_2': ['Multiline', ['thres_1.25_2/train', 'thres_1.25_2/test']],
            'thres_1.25_3': ['Multiline', ['thres_1.25_3/train', 'thres_1.25_3/test']],
            'ard': ['Multiline', ['ard/train', 'ard/test']],
            'srd': ['Multiline', ['srd/train', 'srd/test']],
            'rmse_linear': ['Multiline', ['rmse_linear/train', 'rmse_linear/test']],
            'rmse_log': ['Multiline', ['rmse_log/train', 'rmse_log/test']],
            'rmse_log_invariant': ['Multiline', ['rmse_log_invariant/train', 'rmse_log_invariant/test']],
        }
    })
    
    
    # data loader
    dataset = NYUDepth(path, 'train')
    dataloader = DataLoader(dataset, batch_size, shuffle=True, num_workers=4)
    
    dataset_test = NYUDepth(path, 'test')
    dataloader_test = DataLoader(dataset_test, batch_size, shuffle=True, num_workers=4)
    
    
    # load model
    model = FCRN(True)
    model = model.to(device)
    
    
    # optimizer
    optimizer = optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay)
    
    start_epoch = 0
    if resume:
        model_path = os.path.join(model_dir, 'model.pth')
        if os.path.exists(model_path):
            print('Loading checkpoint from {}...'.format(model_path))
            # load model and optimizer
            checkpoint = torch.load(os.path.join(model_dir, 'model.pth'), map_location='cpu')
            model.load_state_dict(checkpoint['model'])
            optimizer.load_state_dict(checkpoint['optimizer'])
            start_epoch = checkpoint['epoch']
            print('Model loaded.')
        else:
            print('No checkpoint found. Train from scratch')
    
    # training
    metric_logger = MetricLogger()
    
    end = time.perf_counter()
    max_iters = epochs * len(dataloader)
    
    def normal_loss(pred, normal, conf):
        """
        :param pred: (B, 3, H, W)
        :param normal: (B, 3, H, W)
        :param conf: 1
        """
        dot_prod = (pred * normal).sum(dim=1)
        # weighted loss, (B, )
        batch_loss = ((1 - dot_prod) * conf[:, 0]).sum(1).sum(1)
        # normalize, to (B, )
        batch_loss /= conf[:, 0].sum(1).sum(1)
        return batch_loss.mean()

    def consistency_loss(pred, cloud, normal, conf):
        """
        :param pred: (B, 1, H, W)
        :param normal: (B, 3, H, W)
        :param cloud: (B, 3, H, W)
        :param conf: (B, 1, H, W)
        """
        B, _, _, _ = normal.size()
        normal = normal.detach()
        cloud = cloud.clone()
        cloud[:, 2:3, :, :] = pred
        # algorithm: use a kernel
        kernel = torch.ones((1, 1, 7, 7), device=pred.device)
        kernel = -kernel
        kernel[0, 0, 3, 3] = 48
    
        cloud_0 = cloud[:, 0:1]
        cloud_1 = cloud[:, 1:2]
        cloud_2 = cloud[:, 2:3]
        diff_0 = F.conv2d(cloud_0, kernel, padding=6, dilation=2)
        diff_1 = F.conv2d(cloud_1, kernel, padding=6, dilation=2)
        diff_2 = F.conv2d(cloud_2, kernel, padding=6, dilation=2)
        # (B, 3, H, W)
        diff = torch.cat((diff_0, diff_1, diff_2), dim=1)
        # normalize
        diff = F.normalize(diff, dim=1)
        # (B, 1, H, W)
        dot_prod = (diff * normal).sum(dim=1, keepdim=True)
        # weighted mean over image
        dot_prod = torch.abs(dot_prod.view(B, -1))
        conf = conf.view(B, -1)
        loss = (dot_prod * conf).sum(1) / conf.sum(1)
        # mean over batch
        return loss.mean()
    
    def criterion(depth_pred, normal_pred, depth, normal, cloud, conf):
        mse_loss = F.mse_loss(depth_pred, depth)
        consis_loss = consistency_loss(depth_pred, cloud, normal_pred, conf)
        norm_loss = normal_loss(normal_pred, normal, conf)
        consis_loss = torch.zeros_like(norm_loss)
        
        return mse_loss, mse_loss, mse_loss
        # return mse_loss, consis_loss, norm_loss
        # return norm_loss, norm_loss, norm_loss
    
    print('Start training')
    for epoch in range(start_epoch, epochs):
        # train
        model.train()
        for i, data in enumerate(dataloader):
            start = end
            i += 1
            data = [x.to(device) for x in data]
            image, depth, normal, conf, cloud = data
            depth_pred, normal_pred = model(image)
            mse_loss, consis_loss, norm_loss = criterion(depth_pred, normal_pred, depth, normal, cloud, conf)
            loss = mse_loss + consis_loss + norm_loss
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
            # bookkeeping
            end = time.perf_counter()
            metric_logger.update(loss=loss.item())
            metric_logger.update(mse_loss=mse_loss.item())
            metric_logger.update(norm_loss=norm_loss.item())
            metric_logger.update(consis_loss=consis_loss.item())
            metric_logger.update(batch_time=end-start)

            
            if i % print_every == 0:
                # Compute eta. global step: starting from 1
                global_step = epoch * len(dataloader) + i
                seconds = (max_iters - global_step) * metric_logger['batch_time'].global_avg
                eta = datetime.timedelta(seconds=int(seconds))
                # to display: eta, epoch, iteration, loss, batch_time
                display_dict = {
                    'eta': eta,
                    'epoch': epoch,
                    'iter': i,
                    'loss': metric_logger['loss'].median,
                    'batch_time': metric_logger['batch_time'].median
                }
                display_str = [
                    'eta: {eta}s',
                    'epoch: {epoch}',
                    'iter: {iter}',
                    'loss: {loss:.4f}',
                    'batch_time: {batch_time:.4f}s',
                ]
                print(', '.join(display_str).format(**display_dict))
                
                # tensorboard
                min_depth = depth[0].min()
                max_depth = depth[0].max() * 1.25
                depth = (depth[0] - min_depth) / (max_depth - min_depth)
                depth_pred = (depth_pred[0] - min_depth) / (max_depth - min_depth)
                depth_pred = torch.clamp(depth_pred, min=0.0, max=1.0)
                normal = (normal[0] + 1) / 2
                normal_pred = (normal_pred[0] + 1) / 2
                conf = conf[0]
                
                tb.add_scalar('train/loss', metric_logger['loss'].median, global_step)
                tb.add_scalar('train/mse_loss', metric_logger['mse_loss'].median, global_step)
                tb.add_scalar('train/consis_loss', metric_logger['consis_loss'].median, global_step)
                tb.add_scalar('train/norm_loss', metric_logger['norm_loss'].median, global_step)
                
                tb.add_image('train/depth', depth, global_step)
                tb.add_image('train/normal', normal, global_step)
                tb.add_image('train/depth_pred', depth_pred, global_step)
                tb.add_image('train/normal_pred', normal_pred, global_step)
                tb.add_image('train/conf', conf, global_step)
                tb.add_image('train/image', image[0], global_step)
                
        if (epoch) % val_every == 0 and epoch != 0:
            # validate after each epoch
            validate(dataloader, model, device, tb, epoch, 'train')
            validate(dataloader_test, model, device, tb, epoch, 'test')
        if (epoch) % save_every == 0 and epoch != 0:
            to_save = {
                'optimizer': optimizer.state_dict(),
                'model': model.state_dict(),
                'epoch': epoch,
            }
            torch.save(to_save, os.path.join(model_dir, 'model.pth'))
예제 #2
0
def main():
    batch_size = 16
    data_path = './data/nyu_depth_v2_labeled.mat'
    learning_rate = 1.0e-4
    monentum = 0.9
    weight_decay = 0.0005
    num_epochs = 100

    # 1.Load data
    train_lists, val_lists, test_lists = load_split()
    print("Loading data...")
    train_loader = torch.utils.data.DataLoader(NyuDepthLoader(
        data_path, train_lists),
                                               batch_size=batch_size,
                                               shuffle=False,
                                               drop_last=True)
    val_loader = torch.utils.data.DataLoader(NyuDepthLoader(
        data_path, val_lists),
                                             batch_size=batch_size,
                                             shuffle=True,
                                             drop_last=True)
    test_loader = torch.utils.data.DataLoader(NyuDepthLoader(
        data_path, test_lists),
                                              batch_size=batch_size,
                                              shuffle=True,
                                              drop_last=True)
    print(train_loader)
    # 2.Load model
    print("Loading model...")
    model = FCRN(batch_size)
    model.load_state_dict(load_weights(model, weights_file,
                                       dtype))  #加载官方参数,从tensorflow转过来
    #加载训练模型
    resume_from_file = False
    resume_file = './model/model_300.pth'
    if resume_from_file:
        if os.path.isfile(resume_file):
            checkpoint = torch.load(resume_file)
            start_epoch = checkpoint['epoch']
            model.load_state_dict(checkpoint['state_dict'])
            print("loaded checkpoint '{}' (epoch {})".format(
                resume_file, checkpoint['epoch']))
        else:
            print("can not find!")
    model = model.cuda()

    # 3.Loss
    # 官方MSE
    # loss_fn = torch.nn.MSELoss()
    # 自定义MSE
    # loss_fn = loss_mse()
    # 论文的loss,the reverse Huber
    loss_fn = loss_huber()
    print("loss_fn set...")

    # 4.Optim
    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
    print("optimizer set...")

    # 5.Train
    best_val_err = 1.0e-4
    start_epoch = 0

    for epoch in range(num_epochs):
        print('Starting train epoch %d / %d' %
              (start_epoch + epoch + 1, num_epochs + start_epoch))
        model.train()
        running_loss = 0
        count = 0
        epoch_loss = 0
        for input, depth in train_loader:

            input_var = Variable(input.type(dtype))
            depth_var = Variable(depth.type(dtype))

            output = model(input_var)
            loss = loss_fn(output, depth_var)
            print('loss: %f' % loss.data.cpu().item())
            count += 1
            running_loss += loss.data.cpu().numpy()

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        epoch_loss = running_loss / count
        print('epoch loss:', epoch_loss)

        # validate
        model.eval()
        num_correct, num_samples = 0, 0
        loss_local = 0
        with torch.no_grad():
            for input, depth in val_loader:
                input_var = Variable(input.type(dtype))
                depth_var = Variable(depth.type(dtype))

                output = model(input_var)
                if num_epochs == epoch + 1:
                    # 关于保存的测试图片可以参考 loader 的写法
                    # input_rgb_image = input_var[0].data.permute(1, 2, 0).cpu().numpy().astype(np.uint8)
                    input_rgb_image = input[0].data.permute(1, 2, 0)
                    input_gt_depth_image = depth_var[0][0].data.cpu().numpy(
                    ).astype(np.float32)
                    pred_depth_image = output[0].data.squeeze().cpu().numpy(
                    ).astype(np.float32)

                    input_gt_depth_image /= np.max(input_gt_depth_image)
                    pred_depth_image /= np.max(pred_depth_image)

                    plot.imsave(
                        './result/input_rgb_epoch_{}.png'.format(start_epoch +
                                                                 epoch + 1),
                        input_rgb_image)
                    plot.imsave(
                        './result/gt_depth_epoch_{}.png'.format(start_epoch +
                                                                epoch + 1),
                        input_gt_depth_image,
                        cmap="viridis")
                    plot.imsave(
                        './result/pred_depth_epoch_{}.png'.format(start_epoch +
                                                                  epoch + 1),
                        pred_depth_image,
                        cmap="viridis")

                loss_local += loss_fn(output, depth_var)

                num_samples += 1

        err = float(loss_local) / num_samples
        print('val_error: %f' % err)

        if err < best_val_err or epoch == num_epochs - 1:
            best_val_err = err
            torch.save(
                {
                    'epoch': start_epoch + epoch + 1,
                    'state_dict': model.state_dict(),
                    'optimizer': optimizer.state_dict(),
                }, './model/model_' + str(start_epoch + epoch + 1) + '.pth')

        if epoch % 10 == 0:
            learning_rate = learning_rate * 0.8
def main():
    batch_size = args.batch_size
    data_path = 'nyu_depth_v2_labeled.mat'
    learning_rate = args.lr  #1.0e-4 #1.0e-5
    monentum = 0.9
    weight_decay = 0.0005
    num_epochs = args.epochs
    step_size = args.step_size
    step_gamma = args.step_gamma
    resume_from_file = False
    isDataAug = args.data_aug
    max_depth = 1000

    # 1.Load data
    train_lists, val_lists, test_lists = load_split()
    print("Loading data......")
    train_loader = torch.utils.data.DataLoader(NyuDepthLoader(
        data_path, train_lists),
                                               batch_size=batch_size,
                                               shuffle=True,
                                               drop_last=True)
    val_loader = torch.utils.data.DataLoader(NyuDepthLoader(
        data_path, val_lists),
                                             batch_size=batch_size,
                                             shuffle=True,
                                             drop_last=True)
    test_loader = torch.utils.data.DataLoader(NyuDepthLoader(
        data_path, test_lists),
                                              batch_size=batch_size,
                                              shuffle=False,
                                              drop_last=True)
    print(train_loader)

    # 2.set the model
    print("Set the model......")
    model = FCRN(batch_size)
    resnet = torchvision.models.resnet50()

    # 加载训练到一半的模型
    # resnet.load_state_dict(torch.load('/home/xpfly/nets/ResNet/resnet50-19c8e357.pth'))
    # print("resnet50 params loaded.")

    # model.load_state_dict(load_weights(model, weights_file, dtype))

    model = model.cuda()

    # 3.Loss
    # loss_fn = torch.nn.MSELoss().cuda()
    if args.loss_type == "berhu":
        loss_fn = criteria.berHuLoss().cuda()
        print("berhu loss_fn set.")
    elif args.loss_type == "L1":
        loss_fn = criteria.MaskedL1Loss().cuda()
        print("L1 loss_fn set.")
    elif args.loss_type == "mse":
        loss_fn = criteria.MaskedMSELoss().cuda()
        print("MSE loss_fn set.")
    elif args.loss_type == "ssim":
        loss_fn = criteria.SsimLoss().cuda()
        print("Ssim loss_fn set.")
    elif args.loss_type == "three":
        loss_fn = criteria.Ssim_grad_L1().cuda()
        print("SSIM+L1+Grad loss_fn set.")

    # 5.Train
    best_val_err = 1.0e3

    # validate
    model.eval()
    num_correct, num_samples = 0, 0
    loss_local = 0
    with torch.no_grad():
        for input, depth in val_loader:
            input_var = Variable(input.type(dtype))
            depth_var = Variable(depth.type(dtype))

            output = model(input_var)

            input_rgb_image = input_var[0].data.permute(
                1, 2, 0).cpu().numpy().astype(np.uint8)
            input_gt_depth_image = depth_var[0][0].data.cpu().numpy().astype(
                np.float32)
            pred_depth_image = output[0].data.squeeze().cpu().numpy().astype(
                np.float32)

            input_gt_depth_image /= np.max(input_gt_depth_image)
            pred_depth_image /= np.max(pred_depth_image)

            plot.imsave('./result/input_rgb_epoch_0.png', input_rgb_image)
            plot.imsave('./result/gt_depth_epoch_0.png',
                        input_gt_depth_image,
                        cmap="viridis")

            plot.imsave('pred_depth_epoch_0.png',
                        pred_depth_image,
                        cmap="viridis")

            # depth_var = depth_var[:, 0, :, :]
            # loss_fn_local = torch.nn.MSELoss()

            loss_local += loss_fn(output, depth_var)

            num_samples += 1

    err = float(loss_local) / num_samples
    print('val_error before train:', err)

    start_epoch = 0

    resume_file = 'checkpoint.pth.tar'
    if resume_from_file:
        if os.path.isfile(resume_file):
            print("=> loading checkpoint '{}'".format(resume_file))
            checkpoint = torch.load(resume_file)
            start_epoch = checkpoint['epoch']
            model.load_state_dict(checkpoint['state_dict'])
            print("=> loaded checkpoint '{}' (epoch {})".format(
                resume_file, checkpoint['epoch']))
        else:
            print("=> no checkpoint found at '{}'".format(resume_file))

    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
    scheduler = StepLR(optimizer, step_size=step_size,
                       gamma=step_gamma)  # may change to other value

    for epoch in range(num_epochs):

        # 4.Optim

        # optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate, momentum=monentum)
        # optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate, momentum=monentum, weight_decay=weight_decay)
        print("optimizer set.")

        print('Starting train epoch %d / %d' %
              (start_epoch + epoch + 1, num_epochs))
        model.train()
        running_loss = 0
        count = 0
        epoch_loss = 0

        #for i, (input, depth) in enumerate(train_loader):
        for input, depth in train_loader:
            print("depth", depth)
            if isDataAug:
                depth = depth * 1000
                depth = torch.clamp(depth, 10, 1000)
                depth = max_depth / depth

            input_var = Variable(
                input.type(dtype))  # variable is for derivative
            depth_var = Variable(
                depth.type(dtype))  # variable is for derivative
            # print("depth_var",depth_var)

            output = model(input_var)

            loss = loss_fn(output, depth_var)
            print('loss:', loss.data.cpu())
            count += 1
            running_loss += loss.data.cpu().numpy()

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        epoch_loss = running_loss / count
        print('epoch loss:', epoch_loss)

        # validate
        model.eval()
        num_correct, num_samples = 0, 0
        loss_local = 0
        with torch.no_grad():
            for input, depth in val_loader:

                if isDataAug:
                    depth = depth * 1000
                    depth = torch.clamp(depth, 10, 1000)
                    depth = max_depth / depth

                input_var = Variable(input.type(dtype))
                depth_var = Variable(depth.type(dtype))

                output = model(input_var)

                input_rgb_image = input_var[0].data.permute(
                    1, 2, 0).cpu().numpy().astype(np.uint8)
                input_gt_depth_image = depth_var[0][0].data.cpu().numpy(
                ).astype(np.float32)
                pred_depth_image = output[0].data.squeeze().cpu().numpy(
                ).astype(np.float32)

                # normalization
                input_gt_depth_image /= np.max(input_gt_depth_image)
                pred_depth_image /= np.max(pred_depth_image)

                plot.imsave(
                    './result/input_rgb_epoch_{}.png'.format(start_epoch +
                                                             epoch + 1),
                    input_rgb_image)
                plot.imsave(
                    './result/gt_depth_epoch_{}.png'.format(start_epoch +
                                                            epoch + 1),
                    input_gt_depth_image,
                    cmap="viridis")
                plot.imsave(
                    './result/pred_depth_epoch_{}.png'.format(start_epoch +
                                                              epoch + 1),
                    pred_depth_image,
                    cmap="viridis")

                # depth_var = depth_var[:, 0, :, :]
                # loss_fn_local = torch.nn.MSELoss()

                loss_local += loss_fn(output, depth_var)

                num_samples += 1

                if epoch % 10 == 9:
                    PATH = args.loss_type + '.pth'
                    torch.save(model.state_dict(), PATH)

        err = float(loss_local) / num_samples
        print('val_error:', err)

        if err < best_val_err:
            best_val_err = err
            torch.save(
                {
                    'epoch': start_epoch + epoch + 1,
                    'state_dict': model.state_dict(),
                    'optimizer': optimizer.state_dict(),
                }, 'checkpoint.pth.tar')

        scheduler.step()
def main():
    batch_size = 32
    data_path = 'augmented_dataset.mat'
    learning_rate = 1.0e-5
    monentum = 0.9
    weight_decay = 0.0005
    num_epochs = 50
    resume_from_file = False

    # 1.Load data
    train_lists, val_lists, test_lists = load_split()
    print("Loading data......")
    train_loader = torch.utils.data.DataLoader(NyuDepthLoader(
        data_path, train_lists),
                                               batch_size=batch_size,
                                               shuffle=True,
                                               drop_last=True)
    val_loader = torch.utils.data.DataLoader(NyuDepthLoader(
        data_path, val_lists),
                                             batch_size=batch_size,
                                             shuffle=True,
                                             drop_last=True)
    test_loader = torch.utils.data.DataLoader(NyuDepthLoader(
        data_path, test_lists),
                                              batch_size=batch_size,
                                              shuffle=True,
                                              drop_last=True)
    print(train_loader)
    # 2.Load model
    print("Loading model......")
    model = FCRN(batch_size)
    #resnet = torchvision.models.resnet50(pretrained=True)
    resnet = torchvision.models.resnet50()
    resnet.load_state_dict(torch.load('/content/model/resnet50-19c8e357.pth'))
    #resnet.load_state_dict(torch.load('/home/xpfly/nets/ResNet/resnet50-19c8e357.pth'))
    print("resnet50 loaded.")
    resnet50_pretrained_dict = resnet.state_dict()

    #-----------------------------------------------------------
    #model.load_state_dict(load_weights(model, weights_file, dtype))
    chkp = torch.load('our_checkpoint.pth.tar')
    model.load_state_dict(chkp['state_dict'])
    #-----------------------------------------------------------
    """
    print('\nresnet50 keys:\n')
    for key, value in resnet50_pretrained_dict.items():
        print(key, value.size())
    """
    #model_dict = model.state_dict()
    """
    print('\nmodel keys:\n')
    for key, value in model_dict.items():
        print(key, value.size())

    print("resnet50.dict loaded.")
    """
    # load pretrained weights
    #resnet50_pretrained_dict = {k: v for k, v in resnet50_pretrained_dict.items() if k in model_dict}
    print("resnet50_pretrained_dict loaded.")
    """
    print('\nresnet50_pretrained keys:\n')
    for key, value in resnet50_pretrained_dict.items():
        print(key, value.size())
    """
    #model_dict.update(resnet50_pretrained_dict)
    print("model_dict updated.")
    """
    print('\nupdated model dict keys:\n')
    for key, value in model_dict.items():
        print(key, value.size())
    """
    #model.load_state_dict(model_dict)
    print("model_dict loaded.")
    model = model.cuda()

    # 3.Loss
    loss_fn = torch.nn.MSELoss().cuda()
    print("loss_fn set.")

    # 5.Train
    best_val_err = 1.0e3

    # validate
    model.eval()
    num_correct, num_samples = 0, 0
    loss_local = 0
    with torch.no_grad():
        for input, depth in val_loader:
            input_var = Variable(input.type(dtype))
            depth_var = Variable(depth.type(dtype))

            output = model(input_var)

            input_rgb_image = input_var[0].data.permute(
                1, 2, 0).cpu().numpy().astype(np.uint8)
            input_gt_depth_image = depth_var[0][0].data.cpu().numpy().astype(
                np.float32)
            pred_depth_image = output[0].data.squeeze().cpu().numpy().astype(
                np.float32)

            input_gt_depth_image /= np.max(input_gt_depth_image)
            pred_depth_image /= np.max(pred_depth_image)

            plot.imsave('input_rgb_epoch_0.png', input_rgb_image)
            plot.imsave('gt_depth_epoch_0.png',
                        input_gt_depth_image,
                        cmap="viridis")
            plot.imsave('pred_depth_epoch_0.png',
                        pred_depth_image,
                        cmap="viridis")

            # depth_var = depth_var[:, 0, :, :]
            # loss_fn_local = torch.nn.MSELoss()

            loss_local += loss_fn(output, depth_var)

            num_samples += 1

    err = float(loss_local) / num_samples
    print('val_error before train:', err)

    start_epoch = 0

    resume_file = 'checkpoint.pth.tar'
    if resume_from_file:
        if os.path.isfile(resume_file):
            print("=> loading checkpoint '{}'".format(resume_file))
            checkpoint = torch.load(resume_file)
            start_epoch = checkpoint['epoch']
            model.load_state_dict(checkpoint['state_dict'])
            print("=> loaded checkpoint '{}' (epoch {})".format(
                resume_file, checkpoint['epoch']))
        else:
            print("=> no checkpoint found at '{}'".format(resume_file))

    for epoch in range(num_epochs):

        # 4.Optim
        optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
        # optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate, momentum=monentum)
        # optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate, momentum=monentum, weight_decay=weight_decay)
        print("optimizer set.")

        print('Starting train epoch %d / %d' %
              (start_epoch + epoch + 1, num_epochs))
        model.train()
        running_loss = 0
        count = 0
        epoch_loss = 0

        #for i, (input, depth) in enumerate(train_loader):
        for input, depth in train_loader:
            # input, depth = data
            #input_var = input.cuda()
            #depth_var = depth.cuda()
            input_var = Variable(input.type(dtype))
            depth_var = Variable(depth.type(dtype))

            output = model(input_var)
            loss = loss_fn(output, depth_var)
            print('loss:', loss.data.cpu())
            count += 1
            running_loss += loss.data.cpu().numpy()

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        epoch_loss = running_loss / count
        print('epoch loss:', epoch_loss)

        # validate
        model.eval()
        num_correct, num_samples = 0, 0
        loss_local = 0
        with torch.no_grad():
            for input, depth in val_loader:
                input_var = Variable(input.type(dtype))
                depth_var = Variable(depth.type(dtype))

                output = model(input_var)

                input_rgb_image = input_var[0].data.permute(
                    1, 2, 0).cpu().numpy().astype(np.uint8)
                input_gt_depth_image = depth_var[0][0].data.cpu().numpy(
                ).astype(np.float32)
                pred_depth_image = output[0].data.squeeze().cpu().numpy(
                ).astype(np.float32)

                input_gt_depth_image /= np.max(input_gt_depth_image)
                pred_depth_image /= np.max(pred_depth_image)

                plot.imsave(
                    'input_rgb_epoch_{}.png'.format(start_epoch + epoch + 1),
                    input_rgb_image)
                plot.imsave('gt_depth_epoch_{}.png'.format(start_epoch +
                                                           epoch + 1),
                            input_gt_depth_image,
                            cmap="viridis")
                plot.imsave('pred_depth_epoch_{}.png'.format(start_epoch +
                                                             epoch + 1),
                            pred_depth_image,
                            cmap="viridis")

                # depth_var = depth_var[:, 0, :, :]
                # loss_fn_local = torch.nn.MSELoss()

                loss_local += loss_fn(output, depth_var)

                num_samples += 1

        err = float(loss_local) / num_samples
        print('val_error:', err)

        if err < best_val_err:
            best_val_err = err
            torch.save(
                {
                    'epoch': start_epoch + epoch + 1,
                    'state_dict': model.state_dict(),
                    'optimizer': optimizer.state_dict(),
                }, 'checkpoint.pth.tar')

        if epoch % 10 == 0:
            learning_rate = learning_rate * 0.6