示例#1
0
def validate_overview(val_dataloader):
    image_dir = create_val_output_dir()
    model_path = fetch_model_path()

    state_dict = torch.load(model_path)
    model = state_dict['model']
    if args.criterion == 'l2':
        criterion = criteria.MaskedL2Loss()
    elif args.criterion == 'l1':
        criterion = criteria.MaskedL1Loss()
    elif args.criterion == 'l1sum':
        criterion = criteria.L1LossSum()

    validate(val_dataloader, model, save_image=True, image_dir=image_dir)
示例#2
0
def train_overview(train_dataloader, val_dataloader):
    global output_train, output_val

    output_dir = create_train_output_dir()
    output_train = os.path.join(output_dir, 'train.csv')
    output_val = os.path.join(output_dir, 'val.csv')

    with open(output_train, 'w') as train_csv:
        train_csv.write('{}\n'.format(','.join(fieldnames)))
    with open(output_val, 'w') as val_csv:
        val_csv.write('{}\n'.format(','.join(fieldnames)))

    print('Creating output dir {}'.format(output_dir))

    print("=> creating Model ({}-{}) ...".format(args.encoder, args.decoder))
    model = ResNet(args.encoder,
                   args.decoder,
                   args.dims,
                   args.output_size,
                   pre_trained=True)

    print("=> model created.")
    optimizer = torch.optim.SGD(model.parameters(), args.learning_rate, \
     momentum=args.momentum, weight_decay=args.weight_decay)

    # model = torch.nn.DataParallel(model).cuda() # for multi-gpu training
    model = model.cuda()

    # define loss function (criterion) and optimizer
    if args.criterion == 'l2':
        criterion = criteria.MaskedL2Loss().cuda()
    elif args.criterion == 'l1':
        criterion = criteria.MaskedL1Loss().cuda()

    for epoch in range(args.n_epochs):
        utils.modify_learning_rate(optimizer, epoch, args.learning_rate)
        train(train_dataloader, model, criterion, optimizer,
              epoch)  # train for one epoch
        result = validate(val_dataloader, model)  # evaluate on validation set

        utils.save_checkpoint(
            {
                'args': args,
                'epoch': epoch,
                'encoder': args.encoder,
                'model': model,
                'optimizer': optimizer,
            }, epoch, output_dir)
示例#3
0
def train_overview(train_dataloader, val_dataloader):
    global output_train, output_val

    output_dir = create_train_output_dir()
    output_train = os.path.join(output_dir, 'train.csv')
    output_val = os.path.join(output_dir, 'val.csv')

    with open(output_train, 'w') as train_csv:
        train_csv.write('{}\n'.format(','.join(fieldnames)))
    with open(output_val, 'w') as val_csv:
        val_csv.write('{}\n'.format(','.join(fieldnames)))

    print('Creating output dir {}'.format(output_dir))

    print("=> creating Model ...")
    if args.recurrent == 'true':
        model = DenseSLAMNet(timespan=args.stack_size)
    elif args.cnn_type == 'single':
        model = CNN_Single()
    else:
        model = CNN_Stack()

    print("=> model created.")
    optimizer = torch.optim.Adam(model.parameters(), args.learning_rate)

    model = model.cuda()

    # define loss function (criterion) and optimizer
    if args.criterion == 'l2':
        criterion = criteria.MaskedL2Loss().cuda()
    elif args.criterion == 'l1':
        criterion = criteria.MaskedL1Loss().cuda()
    elif args.criterion == 'l1sum':
        criterion = criteria.L1LossSum()

    for epoch in range(args.n_epochs):
        train(train_dataloader, model, criterion, optimizer,
              epoch)  # train for one epoch
        result = validate(val_dataloader, model)  # evaluate on validation set

        utils.save_checkpoint(
            {
                'args': args,
                'epoch': epoch,
                'model': model,
                'optimizer': optimizer,
            }, epoch, output_dir)
示例#4
0
def main():
    global args, best_result, output_directory, train_csv, test_csv

    # evaluation mode
    if args.evaluate:

        # Data loading code
        print("=> creating data loaders...")
        valdir = os.path.join('..', 'data', args.data, 'val')

        if args.data == 'nyudepthv2':
            from dataloaders.nyu import NYUDataset
            val_dataset = NYUDataset(valdir,
                                     split='val',
                                     modality=args.modality)
        else:
            raise RuntimeError('Dataset not found.')

        # set batch size to be 1 for validation
        val_loader = torch.utils.data.DataLoader(val_dataset,
                                                 batch_size=1,
                                                 shuffle=False,
                                                 num_workers=args.workers,
                                                 pin_memory=True)
        print("=> data loaders created.")

        assert os.path.isfile(args.evaluate), \
            "=> no model found at '{}'".format(args.evaluate)
        print("=> loading model '{}'".format(args.evaluate))
        checkpoint = torch.load(args.evaluate)
        if type(checkpoint) is dict:
            args.start_epoch = checkpoint['epoch']
            best_result = checkpoint['best_result']
            model = checkpoint['model']
            print("=> loaded best model (epoch {})".format(
                checkpoint['epoch']))
        else:
            model = checkpoint
            args.start_epoch = 0
        output_directory = os.path.dirname(args.evaluate)
        validate(val_loader, model, args.start_epoch, write_to_file=False)
        return

    start_epoch = 0
    if args.train:
        train_loader, val_loader = create_data_loaders(args)
        print("=> creating Model ({}-{}) ...".format(args.arch, args.decoder))

        model = models.MobileNetSkipAdd(
            output_size=train_loader.dataset.output_size)
        print("=> model created.")
        optimizer = torch.optim.SGD(model.parameters(),
                                    args.lr,
                                    momentum=args.momentum,
                                    weight_decay=args.weight_decay)

        # model = torch.nn.DataParallel(model).cuda() # for multi-gpu training
        model = model.cuda()

        # define loss function (criterion) and optimizer
    if args.criterion == 'l2':
        criterion = criteria.MaskedMSELoss().cuda()
    elif args.criterion == 'l1':
        criterion = criteria.MaskedL1Loss().cuda()

        # create results folder, if not already exists
    output_directory = utils.get_output_directory(args)
    if not os.path.exists(output_directory):
        os.makedirs(output_directory)
    train_csv = os.path.join(output_directory, 'train.csv')
    test_csv = os.path.join(output_directory, 'test.csv')
    best_txt = os.path.join(output_directory, 'best.txt')

    # create new csv files with only header
    if not args.resume:
        with open(train_csv, 'w') as csvfile:
            writer = csv.DictWriter(csvfile, fieldnames=fieldnames)
            writer.writeheader()
        with open(test_csv, 'w') as csvfile:
            writer = csv.DictWriter(csvfile, fieldnames=fieldnames)
            writer.writeheader()

        for epoch in range(start_epoch, args.epochs):
            utils.adjust_learning_rate(optimizer, epoch, args.lr)
            train(train_loader, model, criterion, optimizer,
                  epoch)  # train for one epoch
            result, img_merge = validate(val_loader, model,
                                         epoch)  # evaluate on validation set

            # remember best rmse and save checkpoint
            is_best = result.rmse < best_result.rmse
            if is_best:
                best_result = result
                with open(best_txt, 'w') as txtfile:
                    txtfile.write(
                        "epoch={}\nmse={:.3f}\nrmse={:.3f}\nabsrel={:.3f}\nlg10={:.3f}\nmae={:.3f}\ndelta1={:.3f}\nt_gpu={:.4f}\n"
                        .format(epoch, result.mse, result.rmse, result.absrel,
                                result.lg10, result.mae, result.delta1,
                                result.gpu_time))
                if img_merge is not None:
                    img_filename = output_directory + '/comparison_best.png'
                    utils.save_image(img_merge, img_filename)

            utils.save_checkpoint(
                {
                    'args': args,
                    'epoch': epoch,
                    'arch': args.arch,
                    'model': model,
                    'best_result': best_result,
                    'optimizer': optimizer,
                }, is_best, epoch, output_directory)
示例#5
0
def main():
    global args, best_result, output_directory

    # set random seed
    torch.manual_seed(args.manual_seed)

    if torch.cuda.device_count() > 1:
        print("Let's use", torch.cuda.device_count(), "GPUs!")
        args.batch_size = args.batch_size * torch.cuda.device_count()
    else:
        print("Let's use GPU ", torch.cuda.current_device())

    train_loader, val_loader = create_loader(args)

    if args.resume:
        assert os.path.isfile(args.resume), \
            "=> no checkpoint found at '{}'".format(args.resume)
        print("=> loading checkpoint '{}'".format(args.resume))
        checkpoint = torch.load(args.resume)

        start_epoch = checkpoint['epoch'] + 1
        best_result = checkpoint['best_result']
        optimizer = checkpoint['optimizer']

        # model_dict = checkpoint['model'].module.state_dict()  # to load the trained model using multi-GPUs
        # model = FCRN.ResNet(output_size=train_loader.dataset.output_size, pretrained=False)
        # model.load_state_dict(model_dict)

        # solve 'out of memory'
        model = checkpoint['model']

        print("=> loaded checkpoint (epoch {})".format(checkpoint['epoch']))

        # clear memory
        del checkpoint
        # del model_dict
        torch.cuda.empty_cache()
    else:
        print("=> creating Model")
        model = FCRN.ResNet(output_size=train_loader.dataset.output_size)
        print("=> model created.")
        start_epoch = 0

        # different modules have different learning rate
        train_params = [{
            'params': model.get_1x_lr_params(),
            'lr': args.lr
        }, {
            'params': model.get_10x_lr_params(),
            'lr': args.lr * 10
        }]

        optimizer = torch.optim.SGD(train_params,
                                    lr=args.lr,
                                    momentum=args.momentum,
                                    weight_decay=args.weight_decay)

        # You can use DataParallel() whether you use Multi-GPUs or not
        model = nn.DataParallel(model).cuda()

    # when training, use reduceLROnPlateau to reduce learning rate
    scheduler = lr_scheduler.ReduceLROnPlateau(optimizer,
                                               'min',
                                               patience=args.lr_patience)

    # loss function
    criterion = criteria.MaskedL1Loss()

    # create directory path
    output_directory = utils.get_output_directory(args)
    if not os.path.exists(output_directory):
        os.makedirs(output_directory)
    best_txt = os.path.join(output_directory, 'best.txt')
    config_txt = os.path.join(output_directory, 'config.txt')

    # write training parameters to config file
    if not os.path.exists(config_txt):
        with open(config_txt, 'w') as txtfile:
            args_ = vars(args)
            args_str = ''
            for k, v in args_.items():
                args_str = args_str + str(k) + ':' + str(v) + ',\t\n'
            txtfile.write(args_str)

    # create log
    log_path = os.path.join(
        output_directory, 'logs',
        datetime.now().strftime('%b%d_%H-%M-%S') + '_' + socket.gethostname())
    if os.path.isdir(log_path):
        shutil.rmtree(log_path)
    os.makedirs(log_path)
    logger = SummaryWriter(log_path)

    for epoch in range(start_epoch, args.epochs):

        # remember change of the learning rate
        for i, param_group in enumerate(optimizer.param_groups):
            old_lr = float(param_group['lr'])
            logger.add_scalar('Lr/lr_' + str(i), old_lr, epoch)

        train(train_loader, model, criterion, optimizer, epoch,
              logger)  # train for one epoch
        result, img_merge = validate(val_loader, model, epoch,
                                     logger)  # evaluate on validation set

        # remember best rmse and save checkpoint
        is_best = result.rmse < best_result.rmse
        if is_best:
            best_result = result
            with open(best_txt, 'w') as txtfile:
                txtfile.write(
                    "epoch={}, rmse={:.3f}, rml={:.3f}, log10={:.3f}, d1={:.3f}, d2={:.3f}, dd31={:.3f}, "
                    "t_gpu={:.4f}".format(epoch, result.rmse, result.absrel,
                                          result.lg10, result.delta1,
                                          result.delta2, result.delta3,
                                          result.gpu_time))
            if img_merge is not None:
                img_filename = output_directory + '/comparison_best.png'
                utils.save_image(img_merge, img_filename)

        # save checkpoint for each epoch
        utils.save_checkpoint(
            {
                'args': args,
                'epoch': epoch,
                'model': model,
                'best_result': best_result,
                'optimizer': optimizer,
            }, is_best, epoch, output_directory)

        # when rml doesn't fall, reduce learning rate
        scheduler.step(result.absrel)

    logger.close()
示例#6
0
文件: main.py 项目: LeonSun0101/YT
def train(train_loader, val_loader, model, criterion, optimizer, epoch, lr):
    average_meter = AverageMeter()
    model.train()  # switch to train mode
    global batch_num, best_result
    end = time.time()
    #every batch
    for i, (Y, Y_1_2, Y_1_4, Y_1_8, LR, LR_8, HR,
            name) in enumerate(train_loader):  #处理被train_loader进去的每一个数据
        batch_num = batch_num + 1
        Y = Y.cuda()
        Y_1_2 = Y_1_2.cuda()
        Y_1_4 = Y_1_4.cuda()
        Y_1_8 = Y_1_8.cuda()
        LR = LR.cuda()
        LR_8 = LR_8.cuda()
        HR = HR.cuda()
        torch.cuda.synchronize()
        data_time = time.time() - end

        end = time.time()

        if args.arch == 'VDSR_16':
            pred_HR = model(LR)
            loss = criterion(pred_HR, HR, Y)
        elif args.arch == 'VDSR_16_2':
            pred_HR = model(Y, LR)
            loss = criterion(pred_HR, HR, Y)
        elif args.arch == 'VDSR':
            pred_HR, residule = model(LR_8, Y)
            loss = criterion(pred_HR, HR, Y)
        elif args.arch == 'ResNet_bicubic':
            pred_HR, residule = model(LR_8, Y)
            loss = criterion(pred_HR, HR, Y)
        elif args.arch == 'resnet50_15_6' or 'resnet50_15_11' or 'resnet50_15_12':
            pred_HR = model(Y_1_2, LR)
            loss = criterion(pred_HR, HR, Y)

        elif args.arch == 'resnet50_15_2' or 'resnet50_15_3' or 'resnet50_15_5' or 'resnet50_15_8' or 'resnet50_15_9':
            pred_HR, residule = model(Y, LR, LR_8)
            loss = criterion(pred_HR, HR, Y)

        else:
            if config.loss_num == 2:

                if config.LOSS_1 == 'l2':
                    # 均方差
                    criterion1 = criteria.MaskedMSELoss().cuda()
                elif config.LOSS_1 == 'l1':
                    criterion1 = criteria.MaskedL1Loss().cuda()
                elif config.LOSS_1 == 'l1_canny':
                    # 均方差
                    criterion1 = criteria.MaskedL1_cannyLoss().cuda()
                elif config.LOSS_1 == 'l1_from_rgb_sobel':
                    # 均方差
                    criterion1 = criteria.MaskedL1_from_rgb_sobel_Loss().cuda()
                elif aconfig.LOSS_1 == 'l1_canny_from_GT_canny':
                    criterion1 = criteria.MaskedL1_canny_from_GT_Loss().cuda()
                elif aconfig.LOSS_1 == 'l1_from_GT_sobel':
                    criterion1 = criteria.MaskedL1_from_GT_sobel_Loss().cuda()
                elif config.LOSS_1 == 'l2_from_GT_sobel_Loss':
                    criterion1 = criteria.MaskedL2_from_GT_sobel_Loss().cuda()

                if config.use_different_size_Y == 1:
                    pred_HR, pred_thermal0 = model(Y, Y_1_2, Y_1_4, Y_1_8, LR)
                else:
                    pred_HR, pred_thermal0 = model(Y, LR)
                #final loss
                loss0 = criterion(pred_HR, HR, Y)
                #therma upsample loss
                loss1 = criterion1(pred_thermal0, HR, Y)
                loss = config.LOSS_0_weight * loss0 + config.LOSS_1_weight * loss1
            else:
                if config.use_different_size_Y == 1:
                    pred_HR, pred_thermal0 = model(Y, Y_1_2, Y_1_4, Y_1_8, LR)
                    #writer = SummaryWriter(log_dir='logs')
                    #writer.add_graph(model, input_to_model=(Y,Y_1_2,Y_1_4,Y_1_8,LR,))
                else:
                    pred_HR, pred_thermal0 = model(Y, LR)
                loss = criterion(pred_HR, HR, Y)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        torch.cuda.synchronize()
        gpu_time = time.time() - end

        # measure accuracy and record loss
        result = Result()
        result.evaluate(pred_HR, HR, loss.cpu().detach().numpy())
        average_meter.update(result, gpu_time, data_time, Y.size(0))
        end = time.time()

        if (i + 1) % args.print_freq == 0:

            print('=> output: {}'.format(output_directory))
            print('Dataset Epoch: {0} [{1}/{2}]\t'
                  'Batch Epoch: {3} \t'
                  't_Data={data_time:.3f}({average.data_time:.3f}) '
                  't_GPU={gpu_time:.3f}({average.gpu_time:.3f})\n'
                  'PSNR={result.psnr:.5f}({average.psnr:.5f}) '
                  'MSE={result.mse:.3f}({average.mse:.3f}) '
                  'RMSE={result.rmse:.3f}({average.rmse:.3f}) '
                  'MAE={result.mae:.3f}({average.mae:.3f}) '
                  'Delta1={result.delta1:.4f}({average.delta1:.4f}) '
                  'REL={result.absrel:.4f}({average.absrel:.4f}) '
                  'Lg10={result.lg10:.4f}({average.lg10:.4f}) '
                  'Loss={result.loss:}({average.loss:}) '.format(
                      epoch,
                      i + 1,
                      len(train_loader),
                      batch_num,
                      data_time=data_time,
                      gpu_time=gpu_time,
                      result=result,
                      average=average_meter.average()))
        else:
            pass
        if (batch_num + 1) % config.save_fc == 0:

            print("==============Time to evaluate=================")
            utils.adjust_learning_rate(optimizer, batch_num, lr)
            print("==============SAVE_MODEL=================")
            avg = average_meter.average()
            average_meter = AverageMeter()
            with open(train_csv, 'a') as csvfile:
                writer = csv.DictWriter(csvfile, fieldnames=fieldnames)
                writer.writerow({
                    'dataset epoch': epoch,
                    'batch epoch': batch_num + 1,
                    'psnr': 10 * math.log(1 / (avg.mse), 10),
                    'mse': result.mse,
                    'rmse': result.rmse,
                    'absrel': result.absrel,
                    'lg10': result.lg10,
                    'mae': result.mae,
                    'delta1': result.delta1,
                    'delta2': result.delta2,
                    'delta3': result.delta3,
                    'gpu_time': result.gpu_time,
                    'data_time': result.data_time,
                    'loss': result.loss
                })

            #------------------#
            #    VALIDATION    #
            #------------------#
            result_val, img_merge = validate(
                val_loader, model, epoch,
                batch_num)  # evaluate on validation set,每次训练完以后都要测试一下
            #------------------#
            # SAVE BEST MODEL  #
            #------------------#
            is_best = result_val.rmse < best_result.rmse
            if is_best:
                best_result = result_val
                with open(best_txt, 'w') as txtfile:
                    txtfile.write(
                        "dataset epoch={}\nbatch epoch={}\npsnr={:.5f}\nmse={:.3f}\nrmse={:.3f}\nabsrel={:.3f}\nlg10={:.3f}\nmae={:.3f}\ndelta1={:.3f}\nt_gpu={:.4f}\n"
                        .format(epoch, batch_num + 1,
                                10 * math.log(1 / (best_result.mse), 10),
                                best_result.mse, best_result.rmse,
                                best_result.absrel, best_result.lg10,
                                best_result.mae, best_result.delta1,
                                best_result.gpu_time))
                if img_merge is not None:
                    img_filename = output_directory + '/comparison_best.png'
                    utils.save_image(img_merge, img_filename)

            utils.save_checkpoint(
                {
                    'args': args,
                    'epoch': epoch,
                    'batch_epoch': batch_num,
                    'arch': args.arch,
                    'model': model,
                    'best_result': best_result,
                    'optimizer': optimizer,
                }, is_best, epoch, batch_num, output_directory)
示例#7
0
def main():
    global args, best_result, output_directory, train_csv, test_csv

    # Data loading code
    print("=> creating data loaders...")
    # valdir = os.path.join('..', 'data', args.data, 'val')
    # valdir ="/home/titan-nano/Documents/DLProject/data/rgbd/val/img"

    data_dir = '/p300/dataset'
    train_dir = os.path.join(data_dir, 'data', args.data, 'train')
    val_dir = os.path.join(data_dir, 'data', args.data, 'val')

    if args.data == 'nyudepthv2':
        from dataloaders.nyu import NYUDataset
        train_dataset = NYUDataset(train_dir,
                                   split='train',
                                   modality=args.modality)
        val_dataset = NYUDataset(train_dir,
                                 split='val',
                                 modality=args.modality)
    elif args.data == 'rgbd':
        from dataloaders.sist import RGBDDataset
        train_dataset = RGBDDataset(train_dir,
                                    split='train',
                                    modality=args.modality)
        val_dataset = RGBDDataset(val_dir, split='val', modality=args.modality)
    else:
        raise RuntimeError('Dataset not found.')

    # set batch size to be 1 for validation
    train_loader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size=1,
                                               shuffle=False,
                                               num_workers=args.workers,
                                               pin_memory=True)
    val_loader = torch.utils.data.DataLoader(val_dataset,
                                             batch_size=1,
                                             shuffle=False,
                                             num_workers=args.workers,
                                             pin_memory=True)
    print("=> data loaders created.")

    ############################## Resume Mode ##############################
    # loading pretrained model
    print("=> loading model '{}'".format(args.evaluate))
    args.start_epoch = 0
    checkpoint = torch.load(args.evaluate)
    if type(checkpoint) is dict:
        # loading pretrained model
        model = checkpoint['model']
        print("=> loaded best model (epoch {})".format(checkpoint['epoch']))
    else:
        model = checkpoint

    ############################## Training Setting ##############################
    optimizer = torch.optim.SGD(model.parameters(),
                                args.lr,
                                momentum=args.momentum,
                                weight_decay=args.weight_decay)

    # define loss function (criterion) and optimizer
    criterion = None
    if args.criterion == 'l2':
        criterion = criteria.MaskedMSELoss().cuda()
    elif args.criterion == 'l1':
        criterion = criteria.MaskedL1Loss().cuda()

    output_directory = os.path.dirname(args.evaluate)
    best_txt = os.path.join(output_directory, 'best.txt')

    ############################## Training ##############################
    for epoch in range(args.epochs):
        utils.adjust_learning_rate(optimizer, epoch, args.lr)
        train(train_loader, model, criterion, optimizer, epoch)
        result, img_merge = validate(val_loader, model, epoch)

        # remember best rmse and save checkpoint
        is_best = result.rmse < best_result.rmse
        if is_best:
            best_result = result
            best_model = model
            with open(best_txt, 'w') as txtfile:
                txtfile.write(
                    "epoch={}\nmse={:.3f}\nrmse={:.3f}\nabsrel={:.3f}\nlg10={:.3f}\nmae={:.3f}\ndelta1={:.3f}\nt_gpu={:.4f}\n"
                    .format(epoch, result.mse, result.rmse, result.absrel,
                            result.lg10, result.mae, result.delta1,
                            result.gpu_time))
            if img_merge is not None:
                img_filename = output_directory + '/comparison_best.png'
                utils.save_image(img_merge, img_filename)

        utils.save_checkpoint(
            {
                'args': args,
                'epoch': epoch,
                'arch': args.arch,
                'model': model,
                'best_result': best_result,
                'optimizer': optimizer,
            }, is_best, epoch, output_directory)

    # save loss file
    loss_file = np.array(history_loss)
    np.savetxt(output_directory + '/loss.txt', loss_file)

    torch.save(best_model.state_dict(), output_directory + '/best_model.pkl')
示例#8
0
else:
    args.w1, args.w2 = 0, 0

# handling GPU/CPU
cuda = torch.cuda.is_available() and not args.cpu
if cuda:
    import torch.backends.cudnn as cudnn
    cudnn.benchmark = True
    device = torch.device("cuda")
else:
    device = torch.device("cpu")
print("\n" + "=> using '{}' for computation.".format(device))

# define loss functions
depth_criterion = criteria.MaskedMSELoss() if (
    args.criterion == 'l2') else criteria.MaskedL1Loss()
photometric_criterion = criteria.PhotometricLoss()
smoothness_criterion = criteria.SmoothnessLoss()

if args.use_pose:
    # hard-coded KITTI camera intrinsics
    K = load_calib()
    fu, fv = float(K[0, 0]), float(K[1, 1])
    cu, cv = float(K[0, 2]), float(K[1, 2])
    kitti_intrinsics = Intrinsics(owidth, oheight, fu, fv, cu, cv)
    if cuda:
        kitti_intrinsics = kitti_intrinsics.cuda()


def iterate(mode, args, loader, model, optimizer, logger, epoch):
    block_average_meter = AverageMeter()
示例#9
0
def main():
    global args, best_result, output_directory, train_csv, test_csv

    # 如果有多GPU 使用多GPU训练
    if torch.cuda.device_count() > 1:
        print("Let's use", torch.cuda.device_count(), "GPUs!")
        args.batch_size = args.batch_size * torch.cuda.device_count()
    else:
        print("Let's use", torch.cuda.current_device())

    # evaluation mode
    start_epoch = 0
    if args.evaluate:
        assert os.path.isfile(args.evaluate), \
            "=> no best model found at '{}'".format(args.evaluate)
        print("=> loading best model '{}'".format(args.evaluate))
        checkpoint = torch.load(args.evaluate)
        output_directory = os.path.dirname(args.evaluate)
        args = checkpoint['args']
        start_epoch = checkpoint['epoch'] + 1

        best_result = checkpoint['best_result']
        model = checkpoint['model']
        print("=> loaded best model (epoch {})".format(checkpoint['epoch']))
        _, val_loader = create_data_loaders(args)
        args.evaluate = True
        validate(val_loader, model, checkpoint['epoch'], write_to_file=False)
        return

    # optionally resume from a checkpoint
    elif args.resume:
        assert os.path.isfile(args.resume), \
            "=> no checkpoint found at '{}'".format(args.resume)
        print("=> loading checkpoint '{}'".format(args.resume))
        checkpoint = torch.load(args.resume)
        args = checkpoint['args']
        start_epoch = checkpoint['epoch'] + 1
        best_result = checkpoint['best_result']
        model = checkpoint['model']
        optimizer = checkpoint['optimizer']
        output_directory = os.path.dirname(os.path.abspath(args.resume))
        print("=> loaded checkpoint (epoch {})".format(checkpoint['epoch']))
        train_loader, val_loader = create_data_loaders(args)
        args.resume = True

    # create new model
    else:
        train_loader, val_loader = create_data_loaders(args)
        print("=> creating Model ({})".format(args.arch))
        in_channels = len(args.modality)
        if args.arch == 'resnet50':
            model = models.resnet50(pretrained=True)
        elif args.arch == 'resnet18':
            model = models.resnet18(pretrained=True)
        print("=> model created.")
        optimizer = torch.optim.SGD(model.parameters(), args.lr, \
                                    momentum=args.momentum, weight_decay=args.weight_decay)

        # for multi-gpu training
        if torch.cuda.device_count() > 1:
            model = torch.nn.DataParallel(model).cuda()
        else:
            model = model.cuda()

    # define loss function (criterion) and optimizer
    if args.criterion == 'l2':
        criterion = criteria.MaskedMSELoss().cuda()
    elif args.criterion == 'l1':
        criterion = criteria.MaskedL1Loss().cuda()
    elif args.criterion == 'berHu':
        criterion = criteria.berHuLoss().cuda()

    # create results folder, if not already exists
    output_directory = utils.get_output_directory(args)
    if not os.path.exists(output_directory):
        os.makedirs(output_directory)

    best_txt = os.path.join(output_directory, 'best.txt')

    log_path = os.path.join(
        output_directory, 'logs',
        datetime.now().strftime('%b%d_%H-%M-%S') + '_' + socket.gethostname())
    if os.path.isdir(log_path):
        shutil.rmtree(log_path)
    os.makedirs(log_path)
    logger = SummaryWriter(log_path)

    for epoch in range(start_epoch, args.epochs):
        utils.adjust_learning_rate(optimizer, epoch, args.lr)
        train(train_loader, model, criterion, optimizer, epoch,
              logger)  # train for one epoch
        result, img_merge = validate(val_loader, model, epoch,
                                     logger)  # evaluate on validation set

        # remember best rmse and save checkpoint
        is_best = result.rmse < best_result.rmse
        if is_best:
            best_result = result
            with open(best_txt, 'w') as txtfile:
                txtfile.write(
                    "epoch={}\nrmse={:.3f}\nrml={:.3f}\nlog10={:.3f}\nDelta1={:.3f}\nDelta2={:.3f}\nDelta3={:.3f}\nt_gpu={:.4f}\n"
                    .format(epoch, result.rmse, result.absrel, result.lg10,
                            result.delta1, result.delta2, result.delta3,
                            result.gpu_time))
            if img_merge is not None:
                img_filename = output_directory + '/comparison_best.png'
                utils.save_image(img_merge, img_filename)

        utils.save_checkpoint(
            {
                'args': args,
                'epoch': epoch,
                'arch': args.arch,
                'model': model,
                'best_result': best_result,
                'optimizer': optimizer,
            }, is_best, epoch, output_directory)
示例#10
0
def main():
    global args, best_result, output_directory, train_csv, test_csv
    # Random seed setting
    torch.manual_seed(16)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

    # Data loading code
    print("=> creating data loaders...")
    data_dir = '/media/vasp/Data2/Users/vmhp806/depth-estimation'
    valdir = os.path.join(data_dir, 'data', args.data, 'val')
    traindir = os.path.join(data_dir, 'data', args.data, 'train')

    if args.data == 'nyu' or args.data == 'uow_dataset':
        from dataloaders.nyu import NYUDataset
        val_dataset = NYUDataset(valdir, split='val', modality=args.modality)
        #val_dataset = nc.SafeDataset(val_dataset)
        train_dataset = NYUDataset(traindir,
                                   split='train',
                                   modality=args.modality)
        #train_dataset = nc.SafeDataset(train_dataset)
    else:
        raise RuntimeError('Dataset not found.')

    # set batch size to be 1 for validation
    val_loader = torch.utils.data.DataLoader(val_dataset,
                                             batch_size=1,
                                             shuffle=False,
                                             num_workers=args.workers,
                                             pin_memory=True,
                                             collate_fn=my_collate)
    if not args.evaluate:
        train_loader = torch.utils.data.DataLoader(train_dataset,
                                                   batch_size=args.batch_size,
                                                   shuffle=False,
                                                   num_workers=args.workers,
                                                   pin_memory=True,
                                                   collate_fn=my_collate)
    print("=> data loaders created.")

    # evaluation mode
    if args.evaluate:
        assert os.path.isfile(args.evaluate), \
        "=> no model found at '{}'".format(args.evaluate)
        print("=> loading model '{}'".format(args.evaluate))
        checkpoint = torch.load(args.evaluate)
        if type(checkpoint) is dict:
            args.start_epoch = checkpoint['epoch']
            best_result = checkpoint['best_result']
            model = checkpoint['model']
            print("=> loaded best model (epoch {})".format(
                checkpoint['epoch']))
        else:
            model = checkpoint
            args.start_epoch = 0
        output_directory = os.path.dirname(args.evaluate)
        if args.predict:
            predict(val_loader, model, output_directory)
        else:
            validate(val_loader, model, args.start_epoch, write_to_file=False)
        return
        # optionally resume from a checkpoint
    elif args.resume:
        chkpt_path = args.resume
        assert os.path.isfile(chkpt_path), \
            "=> no checkpoint found at '{}'".format(chkpt_path)
        print("=> loading checkpoint " "'{}'".format(chkpt_path))
        checkpoint = torch.load(chkpt_path)
        #args = checkpoint['args']
        start_epoch = checkpoint['epoch'] + 1
        best_result = checkpoint['best_result']
        model = checkpoint['model']
        optimizer = torch.optim.SGD(model.parameters(), lr=0.9)
        optimizer.load_state_dict(checkpoint['optimizer'])
        #optimizer = checkpoint['optimizer']
        output_directory = os.path.dirname(os.path.abspath(chkpt_path))
        print("=> loaded checkpoint (epoch {})".format(checkpoint['epoch']))
        args.resume = True
    else:
        print("=> creating Model ({} - {}) ...".format(args.arch,
                                                       args.decoder))
        #in_channels = len(args.modality)
        if args.arch == 'mobilenet-skipconcat':
            model = models.MobileNetSkipConcat(
                decoder=args.decoder,
                output_size=train_loader.dataset.output_size)
        elif args.arch == 'mobilenet-skipadd':
            model = models.MobileNetSkipAdd(
                decoder=args.decoder,
                output_size=train_loader.dataset.output_size)
        elif args.arch == 'resnet18-skipconcat':
            model = models.ResNetSkipConcat(
                layers=18,
                decoder=args.decoder,
                output_size=train_loader.dataset.output_size)
        elif args.arch == 'resnet18-skipadd':
            model = models.ResNetSkipAdd(
                layers=18, output_size=train_loader.dataset.output_size)
        else:
            raise Exception('Invalid architecture')
        print("=> model created.")
        optimizer = torch.optim.SGD(model.parameters(),
                                    args.lr,
                                    momentum=args.momentum,
                                    weight_decay=args.weight_decay)

        # model = torch.nn.DataParallel(model).cuda() # for multi-gpu training
        model = model.cuda()
        start_epoch = 0
    # define loss function (criterion) and optimizer
    if args.criterion == 'l2':
        criterion = criteria.MaskedMSELoss().cuda()
    elif args.criterion == 'l1':
        criterion = criteria.MaskedL1Loss().cuda()

    # create results folder, if not already exists
    output_directory = utils.get_output_directory(args)
    if not os.path.exists(output_directory):
        os.makedirs(output_directory)
    train_csv = os.path.join(output_directory, 'train.csv')
    test_csv = os.path.join(output_directory, 'test.csv')
    best_txt = os.path.join(output_directory, 'best.txt')

    # create new csv files with only header
    if not args.resume:
        with open(train_csv, 'w') as csvfile:
            writer = csv.DictWriter(csvfile, fieldnames=fieldnames)
            writer.writeheader()
        with open(test_csv, 'w') as csvfile:
            writer = csv.DictWriter(csvfile, fieldnames=fieldnames)
            writer.writeheader()
    #start_epoch = 0
    for epoch in range(start_epoch, args.epochs):
        utils.adjust_learning_rate(optimizer, epoch, args.lr)
        train(train_loader, model, criterion, optimizer,
              epoch)  # train for one epoch
        result, img_merge = validate(
            val_loader, model, epoch,
            write_to_file=True)  # evaluate on validation set

        # remember best rmse and save checkpoint
        is_best = result.rmse < best_result.rmse
        if is_best:
            best_result = result
            with open(best_txt, 'w') as txtfile:
                txtfile.write(
                    "epoch={}\nmse={:.3f}\nrmse={:.3f}\nabsrel={:.3f}\nlg10={:.3f}\nmae={:.3f}\ndelta1={:.3f}\nt_gpu={:.4f}\n"
                    .format(epoch, result.mse, result.rmse, result.absrel,
                            result.lg10, result.mae, result.delta1,
                            result.gpu_time))
            if img_merge is not None:
                img_filename = output_directory + '/comparison_best.png'
                utils.save_image(img_merge, img_filename)

        utils.save_checkpoint(
            {
                'args': args,
                'epoch': epoch,
                #'arch': args.arch,
                'model': model,
                'best_result': best_result,
                'optimizer': optimizer.state_dict(),
            },
            is_best,
            epoch,
            output_directory)
def main():
    batch_size = args.batch_size
    data_path = 'nyu_depth_v2_labeled.mat'
    learning_rate = args.lr  #1.0e-4 #1.0e-5
    monentum = 0.9
    weight_decay = 0.0005
    num_epochs = args.epochs
    step_size = args.step_size
    step_gamma = args.step_gamma
    resume_from_file = False
    isDataAug = args.data_aug
    max_depth = 1000

    # 1.Load data
    train_lists, val_lists, test_lists = load_split()
    print("Loading data......")
    train_loader = torch.utils.data.DataLoader(NyuDepthLoader(
        data_path, train_lists),
                                               batch_size=batch_size,
                                               shuffle=True,
                                               drop_last=True)
    val_loader = torch.utils.data.DataLoader(NyuDepthLoader(
        data_path, val_lists),
                                             batch_size=batch_size,
                                             shuffle=True,
                                             drop_last=True)
    test_loader = torch.utils.data.DataLoader(NyuDepthLoader(
        data_path, test_lists),
                                              batch_size=batch_size,
                                              shuffle=False,
                                              drop_last=True)
    print(train_loader)

    # 2.set the model
    print("Set the model......")
    model = FCRN(batch_size)
    resnet = torchvision.models.resnet50()

    # 加载训练到一半的模型
    # resnet.load_state_dict(torch.load('/home/xpfly/nets/ResNet/resnet50-19c8e357.pth'))
    # print("resnet50 params loaded.")

    # model.load_state_dict(load_weights(model, weights_file, dtype))

    model = model.cuda()

    # 3.Loss
    # loss_fn = torch.nn.MSELoss().cuda()
    if args.loss_type == "berhu":
        loss_fn = criteria.berHuLoss().cuda()
        print("berhu loss_fn set.")
    elif args.loss_type == "L1":
        loss_fn = criteria.MaskedL1Loss().cuda()
        print("L1 loss_fn set.")
    elif args.loss_type == "mse":
        loss_fn = criteria.MaskedMSELoss().cuda()
        print("MSE loss_fn set.")
    elif args.loss_type == "ssim":
        loss_fn = criteria.SsimLoss().cuda()
        print("Ssim loss_fn set.")
    elif args.loss_type == "three":
        loss_fn = criteria.Ssim_grad_L1().cuda()
        print("SSIM+L1+Grad loss_fn set.")

    # 5.Train
    best_val_err = 1.0e3

    # validate
    model.eval()
    num_correct, num_samples = 0, 0
    loss_local = 0
    with torch.no_grad():
        for input, depth in val_loader:
            input_var = Variable(input.type(dtype))
            depth_var = Variable(depth.type(dtype))

            output = model(input_var)

            input_rgb_image = input_var[0].data.permute(
                1, 2, 0).cpu().numpy().astype(np.uint8)
            input_gt_depth_image = depth_var[0][0].data.cpu().numpy().astype(
                np.float32)
            pred_depth_image = output[0].data.squeeze().cpu().numpy().astype(
                np.float32)

            input_gt_depth_image /= np.max(input_gt_depth_image)
            pred_depth_image /= np.max(pred_depth_image)

            plot.imsave('./result/input_rgb_epoch_0.png', input_rgb_image)
            plot.imsave('./result/gt_depth_epoch_0.png',
                        input_gt_depth_image,
                        cmap="viridis")

            plot.imsave('pred_depth_epoch_0.png',
                        pred_depth_image,
                        cmap="viridis")

            # depth_var = depth_var[:, 0, :, :]
            # loss_fn_local = torch.nn.MSELoss()

            loss_local += loss_fn(output, depth_var)

            num_samples += 1

    err = float(loss_local) / num_samples
    print('val_error before train:', err)

    start_epoch = 0

    resume_file = 'checkpoint.pth.tar'
    if resume_from_file:
        if os.path.isfile(resume_file):
            print("=> loading checkpoint '{}'".format(resume_file))
            checkpoint = torch.load(resume_file)
            start_epoch = checkpoint['epoch']
            model.load_state_dict(checkpoint['state_dict'])
            print("=> loaded checkpoint '{}' (epoch {})".format(
                resume_file, checkpoint['epoch']))
        else:
            print("=> no checkpoint found at '{}'".format(resume_file))

    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
    scheduler = StepLR(optimizer, step_size=step_size,
                       gamma=step_gamma)  # may change to other value

    for epoch in range(num_epochs):

        # 4.Optim

        # optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate, momentum=monentum)
        # optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate, momentum=monentum, weight_decay=weight_decay)
        print("optimizer set.")

        print('Starting train epoch %d / %d' %
              (start_epoch + epoch + 1, num_epochs))
        model.train()
        running_loss = 0
        count = 0
        epoch_loss = 0

        #for i, (input, depth) in enumerate(train_loader):
        for input, depth in train_loader:
            print("depth", depth)
            if isDataAug:
                depth = depth * 1000
                depth = torch.clamp(depth, 10, 1000)
                depth = max_depth / depth

            input_var = Variable(
                input.type(dtype))  # variable is for derivative
            depth_var = Variable(
                depth.type(dtype))  # variable is for derivative
            # print("depth_var",depth_var)

            output = model(input_var)

            loss = loss_fn(output, depth_var)
            print('loss:', loss.data.cpu())
            count += 1
            running_loss += loss.data.cpu().numpy()

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        epoch_loss = running_loss / count
        print('epoch loss:', epoch_loss)

        # validate
        model.eval()
        num_correct, num_samples = 0, 0
        loss_local = 0
        with torch.no_grad():
            for input, depth in val_loader:

                if isDataAug:
                    depth = depth * 1000
                    depth = torch.clamp(depth, 10, 1000)
                    depth = max_depth / depth

                input_var = Variable(input.type(dtype))
                depth_var = Variable(depth.type(dtype))

                output = model(input_var)

                input_rgb_image = input_var[0].data.permute(
                    1, 2, 0).cpu().numpy().astype(np.uint8)
                input_gt_depth_image = depth_var[0][0].data.cpu().numpy(
                ).astype(np.float32)
                pred_depth_image = output[0].data.squeeze().cpu().numpy(
                ).astype(np.float32)

                # normalization
                input_gt_depth_image /= np.max(input_gt_depth_image)
                pred_depth_image /= np.max(pred_depth_image)

                plot.imsave(
                    './result/input_rgb_epoch_{}.png'.format(start_epoch +
                                                             epoch + 1),
                    input_rgb_image)
                plot.imsave(
                    './result/gt_depth_epoch_{}.png'.format(start_epoch +
                                                            epoch + 1),
                    input_gt_depth_image,
                    cmap="viridis")
                plot.imsave(
                    './result/pred_depth_epoch_{}.png'.format(start_epoch +
                                                              epoch + 1),
                    pred_depth_image,
                    cmap="viridis")

                # depth_var = depth_var[:, 0, :, :]
                # loss_fn_local = torch.nn.MSELoss()

                loss_local += loss_fn(output, depth_var)

                num_samples += 1

                if epoch % 10 == 9:
                    PATH = args.loss_type + '.pth'
                    torch.save(model.state_dict(), PATH)

        err = float(loss_local) / num_samples
        print('val_error:', err)

        if err < best_val_err:
            best_val_err = err
            torch.save(
                {
                    'epoch': start_epoch + epoch + 1,
                    'state_dict': model.state_dict(),
                    'optimizer': optimizer.state_dict(),
                }, 'checkpoint.pth.tar')

        scheduler.step()
示例#12
0
文件: main.py 项目: LeonSun0101/CD-SD
def main():
    global args, best_result, output_directory, train_csv, test_csv

    # create results folder, if not already exists
    output_directory = utils.get_output_directory(args)
    if not os.path.exists(output_directory):
        os.makedirs(output_directory)
    #建立文件
    train_csv = os.path.join(output_directory, 'train.csv')
    test_csv = os.path.join(output_directory, 'test.csv')
    best_txt = os.path.join(output_directory, 'best.txt')

    # define loss function (criterion) and optimizer,定义误差函数和优化器
    if args.criterion == 'l2':
        #均方差
        criterion = criteria.MaskedMSELoss().cuda()
    elif args.criterion == 'l1':
        criterion = criteria.MaskedL1Loss().cuda()

    # sparsifier is a class for generating random sparse depth input from the ground truth
    sparsifier = None
    max_depth = args.max_depth if args.max_depth >= 0.0 else np.inf
    if args.sparsifier == UniformSampling.name:
        sparsifier = UniformSampling(num_samples=args.num_samples,
                                     max_depth=max_depth)
    elif args.sparsifier == SimulatedStereo.name:
        sparsifier = SimulatedStereo(num_samples=args.num_samples,
                                     max_depth=max_depth)

    # Data loading code
    print("=> creating data loaders ...")
    traindir = os.path.join('data', args.data, 'train')
    valdir = os.path.join('data', args.data, 'val')

    if args.data == 'nyudepthv2':
        #需要的时候才把函数载入
        from dataloaders.nyu_dataloader import NYUDataset
        train_dataset = NYUDataset(traindir,
                                   type='train',
                                   modality=args.modality,
                                   sparsifier=sparsifier)
        val_dataset = NYUDataset(valdir,
                                 type='val',
                                 modality=args.modality,
                                 sparsifier=sparsifier)

    elif args.data == 'kitti':
        from dataloaders.kitti_dataloader import KITTIDataset
        train_dataset = KITTIDataset(traindir,
                                     type='train',
                                     modality=args.modality,
                                     sparsifier=sparsifier)
        val_dataset = KITTIDataset(valdir,
                                   type='val',
                                   modality=args.modality,
                                   sparsifier=sparsifier)

    else:
        raise RuntimeError(
            'Dataset not found.' +
            'The dataset must be either of nyudepthv2 or kitti.')

    train_loader = torch.utils.data.DataLoader(
        train_dataset,
        batch_size=args.batch_size,
        shuffle=True,
        num_workers=args.workers,
        pin_memory=True,
        sampler=None,
        worker_init_fn=lambda work_id: np.random.seed(work_id))
    # worker_init_fn ensures different sampling patterns for each data loading thread

    # set batch size to be 1 for validation
    val_loader = torch.utils.data.DataLoader(val_dataset,
                                             batch_size=1,
                                             shuffle=False,
                                             num_workers=args.workers,
                                             pin_memory=True)
    print("=> data loaders created.")

    # evaluation mode,测试模式,拿最好的效果进行测试
    if args.evaluate:
        best_model_filename = os.path.join(output_directory,
                                           'model_best.pth.tar')
        assert os.path.isfile(best_model_filename), \
        "=> no best model found at '{}'".format(best_model_filename)
        print("=> loading best model '{}'".format(best_model_filename))
        checkpoint = torch.load(best_model_filename)
        args.start_epoch = checkpoint['epoch']
        best_result = checkpoint['best_result']
        model = checkpoint['model']
        print("=> loaded best model (epoch {})".format(checkpoint['epoch']))
        validate(val_loader, model, checkpoint['epoch'], write_to_file=False)
        return

    # optionally resume from a checkpoint
    elif args.resume:
        assert os.path.isfile(args.resume), \
            "=> no checkpoint found at '{}'".format(args.resume)
        print("=> loading checkpoint '{}'".format(args.resume))
        checkpoint = torch.load(args.resume)
        args.start_epoch = checkpoint['epoch'] + 1
        best_result = checkpoint['best_result']
        model = checkpoint['model']
        optimizer = checkpoint['optimizer']
        print("=> loaded checkpoint (epoch {})".format(checkpoint['epoch']))

    # create new model,建立模型,并且训练
    else:
        # define model
        print("=> creating Model ({}-{}) ...".format(args.arch, args.decoder))
        in_channels = len(
            args.modality)  #in_channels是modality的长度,如果输入rgbd那么就是4通道。
        #这一边只提供了两个选择50或者18
        if args.arch == 'resnet50':  #调用ResNet的定义实例化model,这里的in_channels是
            model = ResNet(layers=50,
                           decoder=args.decoder,
                           output_size=train_dataset.output_size,
                           in_channels=in_channels,
                           pretrained=args.pretrained)
        elif args.arch == 'resnet18':
            model = ResNet(layers=18,
                           decoder=args.decoder,
                           output_size=train_dataset.output_size,
                           in_channels=in_channels,
                           pretrained=args.pretrained)
        print("=> model created.")

        optimizer = torch.optim.SGD(model.parameters(), args.lr, \
            momentum=args.momentum, weight_decay=args.weight_decay)

        # create new csv files with only header
        with open(train_csv, 'w') as csvfile:
            writer = csv.DictWriter(csvfile, fieldnames=fieldnames)
            writer.writeheader()
        with open(test_csv, 'w') as csvfile:
            writer = csv.DictWriter(csvfile, fieldnames=fieldnames)
            writer.writeheader()

    # model = torch.nn.DataParallel(model).cuda() # for multi-gpu training
    model = model.cuda()
    # print(model)
    print("=> model transferred to GPU.")

    for epoch in range(args.start_epoch, args.epochs):
        utils.adjust_learning_rate(optimizer, epoch, args.lr)
        train(train_loader, model, criterion, optimizer,
              epoch)  # train for one epoch
        result, img_merge = validate(
            val_loader, model,
            epoch)  # evaluate on validation set,每次训练完以后都要测试一下

        # remember best rmse and save checkpoint
        is_best = result.rmse < best_result.rmse
        if is_best:
            best_result = result
            with open(best_txt, 'w') as txtfile:
                txtfile.write(
                    "epoch={}\nmse={:.3f}\nrmse={:.3f}\nabsrel={:.3f}\nlg10={:.3f}\nmae={:.3f}\ndelta1={:.3f}\nt_gpu={:.4f}\n"
                    .format(epoch, result.mse, result.rmse, result.absrel,
                            result.lg10, result.mae, result.delta1,
                            result.gpu_time))
            if img_merge is not None:
                img_filename = output_directory + '/comparison_best.png'
                utils.save_image(img_merge, img_filename)

        utils.save_checkpoint(
            {
                'args': args,
                'epoch': epoch,
                'arch': args.arch,
                'model': model,
                'best_result': best_result,
                'optimizer': optimizer,
            }, is_best, epoch, output_directory)
示例#13
0
def train(train_loader, model, criterion, optimizer, epoch):
    average_meter = AverageMeter()
    model.train()  # switch to train mode
    end = time.time()

    for i, (input, target, label, mask) in enumerate(train_loader):
        input, target = input.cuda(), target.cuda()
        label = label.cuda()
        mask = mask.cuda()
        # print('input size  = ', input.size())
        # print('target size = ', target.size())
        torch.cuda.synchronize()
        data_time = time.time() - end

        # compute pred
        end = time.time()

        pred, pred_mask, c1, c2, c3 = model(input)
        # pred,c1,c2,c3 = model(input)
        target = torch.squeeze(target, 1)

        loss = 0.0
        lossM = 0.0
        lossC = 0.0
        loss_all = 0.0
        count = 0
        countM = 0
        countC = 0

        criterion2 = criteria.MaskedL1Loss()
        #criterionM=criteria.FocalLoss()

        for j in range(len(label)):
            if label[j] == -1:
                loss += (criterion(pred[count, 0, :, :], target[count, :, :]) +
                         criterion(pred[count, 1, :, :], target[count, :, :]) +
                         criterion(pred[count, 2, :, :],
                                   target[count, :, :])) * 1.0 / 3
                lossC += cross_entropy_loss1d(
                    c1[count, 0],
                    torch.zeros_like(label[j]).float()) + cross_entropy_loss1d(
                        c2[count, 0],
                        torch.zeros_like(
                            label[j]).float()) + cross_entropy_loss1d(
                                c3[count, 0],
                                torch.zeros_like(label[j]).float())
            elif label[j] == 1:
                loss += criterion(
                    pred[count, 1, :, :], target[count, :, :]
                )  #2 +0.5*(criterion(pred[count, 0, :, :], torch.zeros_like(pred[count, 1, :, :]))+criterion(pred[count, 2, :, :], torch.zeros_like(pred[count, 1, :, :])))
                lossM += cross_entropy_loss2d_2(pred_mask[count, :, :, :],
                                                mask[count, :, :])
                lossC += cross_entropy_loss1d(
                    c1[count, 0],
                    torch.zeros_like(label[j]).float()) + cross_entropy_loss1d(
                        c2[count, 0],
                        torch.ones_like(
                            label[j]).float()) + cross_entropy_loss1d(
                                c3[count, 0],
                                torch.zeros_like(label[j]).float())
                countM += 1
            elif label[j] == 0:
                loss += criterion(
                    pred[count, 0, :, :], target[count, :, :]
                )  #+0.5*(criterion(pred[count, 1, :, :], torch.zeros_like(pred[count, 0, :, :]))+criterion(pred[count, 2, :, :], torch.zeros_like(pred[count, 0, :, :])))
                lossM += cross_entropy_loss2d_2(pred_mask[count, :, :, :],
                                                mask[count, :, :])
                lossC += cross_entropy_loss1d(
                    c1[count, 0],
                    torch.ones_like(label[j]).float()) + cross_entropy_loss1d(
                        c2[count, 0],
                        torch.zeros_like(
                            label[j]).float()) + cross_entropy_loss1d(
                                c3[count, 0],
                                torch.zeros_like(label[j]).float())
                countM += 1
            else:
                loss += criterion(
                    pred[count, 2, :, :], target[count, :, :]
                )  #+0.5*(criterion(pred[count, 0, :, :], torch.zeros_like(pred[count, 2, :, :]))+criterion(pred[count, 1, :, :], torch.zeros_like(pred[count, 2, :, :])))
                lossM += cross_entropy_loss2d_2(pred_mask[count, :, :, :],
                                                mask[count, :, :])
                lossC += cross_entropy_loss1d(
                    c1[count, 0],
                    torch.zeros_like(label[j]).float()) + cross_entropy_loss1d(
                        c2[count, 0],
                        torch.zeros_like(
                            label[j]).float()) + cross_entropy_loss1d(
                                c3[count, 0],
                                torch.ones_like(label[j]).float())
                countM += 1
            count += 1
        # for j in range(len(label)):
        #     if label[j] == 1:
        #         # pred[count, 1, :, :] = pred[count, 1, :, :] + pred_mask[count, :, :, :]
        #         loss += cross_entropy_loss2d(pred[count, 1, :, :], mask[count,:,:])#2 +0.5*(criterion(pred[count, 0, :, :], torch.zeros_like(pred[count, 1, :, :]))+criterion(pred[count, 2, :, :], torch.zeros_like(pred[count, 1, :, :])))
        #         lossM += cross_entropy_loss2d_2(pred_mask[count, :, :, :], mask[count,:,:])
        #         lossC += cross_entropy_loss1d(c1[count,0], torch.zeros_like(label[j]).float()) + cross_entropy_loss1d(c2[count,0], torch.ones_like(label[j]).float()) + cross_entropy_loss1d(c3[count,0], torch.zeros_like(label[j]).float())
        #         countM += 1
        #     else:
        #         # pred[count, 2, :, :] = pred[count, 2, :, :] + pred_mask[count, :, :, :]
        #         loss += cross_entropy_loss2d(pred[count, 2, :, :], mask[count,:,:])#+0.5*(criterion(pred[count, 0, :, :], torch.zeros_like(pred[count, 2, :, :]))+criterion(pred[count, 1, :, :], torch.zeros_like(pred[count, 2, :, :])))
        #         lossM += cross_entropy_loss2d_2(pred_mask[count, :, :, :], mask[count,:,:])
        #         lossC += cross_entropy_loss1d(c1[count,0], torch.zeros_like(label[j]).float()) + cross_entropy_loss1d(c2[count,0], torch.zeros_like(label[j]).float()) + cross_entropy_loss1d(c3[count,0], torch.ones_like(label[j]).float())
        #         countM += 1
        #     count += 1

        lossm = 0.00001 * lossM / countM  #0.000005 0.00001
        lossC = 0.01 * lossC / count  #0.005 0.01
        # lossm = lossC
        loss = loss * 1.0 / count + lossm + lossC
        # loss =loss * 1.0/ count + lossC

        # print(pred.size(),target.size())
        # exit(0)
        #loss = criterion(pred, target)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        torch.cuda.synchronize()
        gpu_time = time.time() - end

        # measure accuracy and record loss
        result = Result()
        result.evaluate(pred, target, label)
        average_meter.update(result, gpu_time, data_time, input.size(0))
        end = time.time()

        if (i + 1) % args.print_freq == 0:
            print('=> output: {}'.format(output_directory))
            print('Train Epoch: {0} [{1}/{2}]\t'
                  't_Data={data_time:.3f}({average.data_time:.3f}) '
                  't_GPU={gpu_time:.3f}({average.gpu_time:.3f})\n\t'
                  'Loss={Loss:.5f} '
                  'LossM={LossM:.5f} '
                  'LossC={LossC:.5f} '
                  'MAE={result.mae:.2f}({average.mae:.2f}) '.format(
                      epoch,
                      i + 1,
                      len(train_loader),
                      data_time=data_time,
                      gpu_time=gpu_time,
                      Loss=loss.item(),
                      LossM=lossm.item(),
                      LossC=lossC.item(),
                      result=result,
                      average=average_meter.average()))
示例#14
0
def main():
    global args, best_result, output_directory, train_csv, test_csv  # 全局变量
    args = parser.parse_args()  # 获取参数值
    args.data = os.path.join('data', args.data)
    # os.path.join()函数:将多个路径组合后返回
    # 语法:os.path.join(path1[,path2[,......]])
    # 注:第一个绝对路径之前的参数将被忽略
    # 注意if的语句后面有冒号
    # args中modality的参数值。modality之前定义过
    if args.modality == 'rgb' and args.num_samples != 0:
        print("number of samples is forced to be 0 when input modality is rgb")
        args.num_samples = 0
# 若是RGB的sparse-to-dense,则在生成训练数据时将稀疏深度点设为0

# create results folder, if not already exists
    output_directory = os.path.join(
        'results',
        'NYUDataset.modality={}.nsample={}.arch={}.decoder={}.criterion={}.lr={}.bs={}'
        .format(args.modality, args.num_samples, args.arch, args.decoder,
                args.criterion, args.lr, args.batch_size))  # 输出文件名的格式

    # 如果路径不存在
    if not os.path.exists(output_directory):
        os.makedirs(output_directory)

    train_csv = os.path.join(output_directory, 'train.csv')
    test_csv = os.path.join(output_directory, 'test.csv')
    best_txt = os.path.join(output_directory, 'best.txt')

    # define loss function (criterion) and optimizer
    if args.criterion == 'l2':
        criterion = criteria.MaskedMSELoss().cuda(
        )  # 调用别的py文件中的内容时,若被调用的是函数,则直接写函数名即可;若被调用的是类,则要按这句话的格式写
        out_channels = 1
# elif: else if
    elif args.criterion == 'l1':
        criterion = criteria.MaskedL1Loss().cuda()
        out_channels = 1

    # Data loading code
    print("=> creating data loaders ...")
    traindir = os.path.join(args.data, 'train')
    valdir = os.path.join(args.data, 'val')

    train_dataset = NYUDataset(traindir,
                               type='train',
                               modality=args.modality,
                               num_samples=args.num_samples)
    # DataLoader是导入数据的函数
    train_loader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size=args.batch_size,
                                               shuffle=True,
                                               num_workers=args.workers,
                                               pin_memory=True,
                                               sampler=None)

    # set batch size to be 1 for validation
    val_dataset = NYUDataset(valdir,
                             type='val',
                             modality=args.modality,
                             num_samples=args.num_samples)
    val_loader = torch.utils.data.DataLoader(val_dataset,
                                             batch_size=1,
                                             shuffle=False,
                                             num_workers=args.workers,
                                             pin_memory=True)

    print("=> data loaders created.")

    # evaluation mode
    if args.evaluate:
        best_model_filename = os.path.join(output_directory,
                                           'model_best.pth.tar')
        if os.path.isfile(best_model_filename):
            print("=> loading best model '{}'".format(best_model_filename))
            checkpoint = torch.load(best_model_filename)
            args.start_epoch = checkpoint['epoch']
            best_result = checkpoint['best_result']
            model = checkpoint['model']
            print("=> loaded best model (epoch {})".format(
                checkpoint['epoch']))
        else:  # else也要加:
            print("=> no best model found at '{}'".format(best_model_filename))
        validate(val_loader, model, checkpoint['epoch'], write_to_file=False)
        return

    # optionally resume from a checkpoint
    elif args.resume:
        if os.path.isfile(args.resume):
            print("=> loading checkpoint '{}'".format(args.resume))
            checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint['epoch'] + 1
            best_result = checkpoint['best_result']
            model = checkpoint['model']
            optimizer = checkpoint['optimizer']
            print("=> loaded checkpoint (epoch {})".format(
                checkpoint['epoch']))
        else:
            print("=> no checkpoint found at '{}'".format(args.resume))

    # create new model
    else:
        # define model
        print("=> creating Model ({}-{}) ...".format(args.arch, args.decoder))
        in_channels = len(args.modality)  # len()返回对象的长度或项目个数
        if args.arch == 'resnet50':
            model = ResNet(layers=50,
                           decoder=args.decoder,
                           in_channels=in_channels,
                           out_channels=out_channels,
                           pretrained=args.pretrained)
        elif args.arch == 'resnet18':
            model = ResNet(layers=18,
                           decoder=args.decoder,
                           in_channels=in_channels,
                           out_channels=out_channels,
                           pretrained=args.pretrained)
        print("=> model created.")

        optimizer = torch.optim.SGD(model.parameters(),
                                    args.lr,
                                    momentum=args.momentum,
                                    weight_decay=args.weight_decay)

        # create new csv files with only header
        # with open() as xxx: 的用法详见https://www.cnblogs.com/ymjyqsx/p/6554817.html
        with open(train_csv, 'w') as csvfile:
            writer = csv.DictWriter(csvfile, fieldnames=fieldnames)
            writer.writeheader()
        with open(test_csv, 'w') as csvfile:
            writer = csv.DictWriter(csvfile, fieldnames=fieldnames)
            writer.writeheader()

    # model = torch.nn.DataParallel(model).cuda()
    model = model.cuda()
    print(model)
    print("=> model transferred to GPU.")

    # for循环也要有:
    # 一般情况下,循环次数未知采用while循环,循环次数已知采用for
    for epoch in range(args.start_epoch, args.epochs):
        adjust_learning_rate(optimizer, epoch)

        # train for one epoch
        train(train_loader, model, criterion, optimizer, epoch)

        # evaluate on validation set
        result, img_merge = validate(val_loader, model, epoch)
        # Python的return可以返回多个值

        # remember best rmse and save checkpoint
        is_best = result.rmse < best_result.rmse
        if is_best:
            best_result = result
            with open(best_txt, 'w') as txtfile:
                # 字符串格式化输出
                # :3f中,3表示输出宽度,f表示浮点型。若输出位数小于此宽度,则默认右对齐,左边补空格。
                #       若输出位数大于宽度,则按实际位数输出。
                # :.3f中,.3表示指定除小数点外的输出位数,f表示浮点型。
                txtfile.write(
                    "epoch={}\nmse={:.3f}\nrmse={:.3f}\nabsrel={:.3f}\nlg10={:.3f}\nmae={:.3f}\ndelta1={:.3f}\nt_gpu={:.4f}\n"
                    .format(epoch, result.mse, result.rmse, result.absrel,
                            result.lg10, result.mae, result.delta1,
                            result.gpu_time))
            # None表示该值是一个空对象,空值是Python里一个特殊的值,用None表示。None不能理解为0,因为0是有意义的,而None是一个特殊的空值。
            # 你可以将None赋值给任何变量,也可以将任何变量赋值给一个None值的对象
            # None在判断的时候是False
            # NULL是空字符,和None不一样
            if img_merge is not None:
                img_filename = output_directory + '/comparison_best.png'
                utils.save_image(img_merge, img_filename)


# Python中,万物皆对象,所有的操作都是针对对象的。一个对象包括两方面的特征:
# 属性:去描述它的特征
# 方法:它所具有的行为
# 所以,对象=属性+方法 (其实方法也是一种属性,一种区别于数据属性的可调用属性)

        save_checkpoint(
            {
                'epoch': epoch,
                'arch': args.arch,
                'model': model,
                'best_result': best_result,
                'optimizer': optimizer,
            }, is_best, epoch)
示例#15
0
def main() -> int:
    best_result = Result()
    best_result.set_to_worst()
    args: Any
    args = parser.parse_args()
    dataset = args.data
    if args.modality == 'rgb' and args.num_samples != 0:
        print("number of samples is forced to be 0 when input modality is rgb")
        args.num_samples = 0
    image_shape = (192, 256)  # if "my" in args.arch else (228, 304)

    # create results folder, if not already exists
    if args.transfer_from:
        output_directory = f"{args.transfer_from}_transfer"
    else:
        output_directory = utils.get_output_dir(args)
    args.data = os.path.join(os.environ["DATASET_DIR"], args.data)
    print("output directory :", output_directory)
    if not os.path.exists(output_directory):
        os.makedirs(output_directory)
    elif not args.evaluate:
        raise Exception("output directory allready exists")

    train_csv = os.path.join(output_directory, 'train.csv')
    test_csv = os.path.join(output_directory, 'test.csv')
    best_txt = os.path.join(output_directory, 'best.txt')

    # define loss function (criterion) and optimizer
    if args.criterion == 'l2':
        criterion = criteria.MaskedMSELoss().cuda()
    elif args.criterion == 'l1':
        criterion = criteria.MaskedL1Loss().cuda()
    out_channels = 1
    # Data loading code
    print("=> creating data loaders ...")
    traindir = os.path.join(args.data, 'train')
    valdir = traindir if dataset == "SUNRGBD" else os.path.join(
        args.data, 'val')
    DatasetType = choose_dataset_type(dataset)
    train_dataset = DatasetType(traindir,
                                phase='train',
                                modality=args.modality,
                                num_samples=args.num_samples,
                                square_width=args.square_width,
                                output_shape=image_shape,
                                depth_type=args.depth_type)

    train_loader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size=args.batch_size,
                                               shuffle=True,
                                               num_workers=args.workers,
                                               pin_memory=True,
                                               sampler=None)

    print("=> training examples:", len(train_dataset))

    val_dataset = DatasetType(valdir,
                              phase='val',
                              modality=args.modality,
                              num_samples=args.num_samples,
                              square_width=args.square_width,
                              output_shape=image_shape,
                              depth_type=args.depth_type)

    val_loader = torch.utils.data.DataLoader(val_dataset,
                                             batch_size=1,
                                             shuffle=False,
                                             num_workers=args.workers,
                                             pin_memory=True)

    print("=> validation examples:", len(val_dataset))

    print("=> data loaders created.")

    # evaluation mode
    if args.evaluate:
        best_model_filename = os.path.join(output_directory,
                                           'model_best.pth.tar')
        if os.path.isfile(best_model_filename):
            print("=> loading best model '{}'".format(best_model_filename))
            checkpoint = torch.load(best_model_filename)
            args.start_epoch = checkpoint['epoch']
            best_result = checkpoint['best_result']
            model = checkpoint['model']
            print("=> loaded best model (epoch {})".format(
                checkpoint['epoch']))
        else:
            print("=> no best model found at '{}'".format(best_model_filename))
        avg_result, avg_result_inside, avg_result_outside, _, results, evaluator = validate(
            val_loader,
            args.square_width,
            args.modality,
            output_directory,
            args.print_freq,
            test_csv,
            model,
            checkpoint['epoch'],
            write_to_file=False)
        write_results(best_txt, avg_result, avg_result_inside,
                      avg_result_outside, checkpoint['epoch'])
        for loss_name, losses in [
            ("rmses", (res.result.rmse for res in results)),
            ("delta1s", (res.result.delta1 for res in results)),
            ("delta2s", (res.result.delta2 for res in results)),
            ("delta3s", (res.result.delta3 for res in results)),
            ("maes", (res.result.mae for res in results)),
            ("absrels", (res.result.absrel for res in results)),
            ("rmses_inside", (res.result_inside.rmse for res in results)),
            ("delta1s_inside", (res.result_inside.delta1 for res in results)),
            ("delta2s_inside", (res.result_inside.delta2 for res in results)),
            ("delta3s_inside", (res.result_inside.delta3 for res in results)),
            ("maes_inside", (res.result_inside.mae for res in results)),
            ("absrels_inside", (res.result_inside.absrel for res in results)),
            ("rmses_outside", (res.result_outside.rmse for res in results)),
            ("delta1s_outside", (res.result_outside.delta1
                                 for res in results)),
            ("delta2s_outside", (res.result_outside.delta2
                                 for res in results)),
            ("delta3s_outside", (res.result_outside.delta3
                                 for res in results)),
            ("maes_outside", (res.result_outside.mae for res in results)),
            ("absrels_outside", (res.result_outside.absrel
                                 for res in results)),
        ]:
            with open(
                    os.path.join(output_directory,
                                 f"validation_{loss_name}.csv"),
                    "w") as csv_file:
                wr = csv.writer(csv_file, quoting=csv.QUOTE_ALL)
                wr.writerow(losses)

        evaluator.save_plot(os.path.join(output_directory, "best.png"))
        return 0

    # optionally resume from a checkpoint
    elif args.resume:
        if os.path.isfile(args.resume):
            print("=> loading checkpoint '{}'".format(args.resume))
            checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint['epoch'] + 1
            best_result = checkpoint['best_result']
            model = checkpoint['model']
            optimizer = checkpoint['optimizer']
            print("=> loaded checkpoint (epoch {})".format(
                checkpoint['epoch']))
        else:
            print("=> no checkpoint found at '{}'".format(args.resume))
            return 1
    # create new model
    else:
        if args.transfer_from:
            if os.path.isfile(args.transfer_from):
                print(f"=> loading checkpoint '{args.transfer_from}'")
                checkpoint = torch.load(args.transfer_from)
                args.start_epoch = 0
                model = checkpoint['model']
                print("=> loaded checkpoint")
                train_params = list(model.conv3.parameters()) + list(
                    model.decoder.layer4.parameters(
                    )) if args.train_top_only else model.parameters()
            else:
                print(f"=> no checkpoint found at '{args.transfer_from}'")
                return 1
        else:
            # define model
            print("=> creating Model ({}-{}) ...".format(
                args.arch, args.decoder))
            in_channels = len(args.modality)
            if args.arch == 'resnet50':
                n_layers = 50
            elif args.arch == 'resnet18':
                n_layers = 18
            model = ResNet(layers=n_layers,
                           decoder=args.decoder,
                           in_channels=in_channels,
                           out_channels=out_channels,
                           pretrained=args.pretrained,
                           image_shape=image_shape,
                           skip_type=args.skip_type)
            print("=> model created.")
            train_params = model.parameters()

        adjusting_learning_rate = False
        if args.optimizer == "sgd":
            optimizer = torch.optim.SGD(train_params,
                                        args.lr,
                                        momentum=args.momentum,
                                        weight_decay=args.weight_decay)
            adjusting_learning_rate = True
        elif args.optimizer == "adam":
            optimizer = torch.optim.Adam(train_params,
                                         weight_decay=args.weight_decay)
        else:
            raise Exception("We should never be here")

        if adjusting_learning_rate:
            print("=> Learning rate adjustment enabled.")
            scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
                optimizer, patience=args.adjust_lr_ep, verbose=True)
        # create new csv files with only header
        with open(train_csv, 'w') as csvfile:
            writer = csv.DictWriter(csvfile, fieldnames=fieldnames)
            writer.writeheader()
        with open(test_csv, 'w') as csvfile:
            writer = csv.DictWriter(csvfile, fieldnames=fieldnames)
            writer.writeheader()

    # model = torch.nn.DataParallel(model).cuda()
    model = model.cuda()
    print(model)
    print("=> model transferred to GPU.")
    epochs_since_best = 0
    train_results = []
    val_results = []
    for epoch in range(args.start_epoch, args.epochs):
        # train for one epoch
        res_train, res_train_inside, res_train_outside = train(
            train_loader, model, criterion, optimizer, epoch, args.print_freq,
            train_csv)
        train_results.append((res_train, res_train_inside, res_train_outside))
        # evaluate on validation set
        res_val, res_val_inside, res_val_outside, img_merge, _, _ = validate(
            val_loader, args.square_width, args.modality, output_directory,
            args.print_freq, test_csv, model, epoch, True)
        val_results.append((res_val, res_val_inside, res_val_outside))
        # remember best rmse and save checkpoint
        is_best = res_val.rmse < best_result.rmse
        if is_best:
            epochs_since_best = 0
            best_result = res_val
            write_results(best_txt, res_val, res_val_inside, res_val_outside,
                          epoch)
            if img_merge is not None:
                img_filename = output_directory + '/comparison_best.png'
                utils.save_image(img_merge, img_filename)
        else:
            epochs_since_best += 1

        save_checkpoint(
            {
                'epoch': epoch,
                'arch': args.arch,
                'model': model,
                'best_result': best_result,
                'optimizer': optimizer,
            }, is_best, epoch, output_directory)

        plot_progress(train_results, val_results, epoch, output_directory)

        if epochs_since_best > args.early_stop_epochs:
            print("early stopping")
        if adjusting_learning_rate:
            scheduler.step(res_val.rmse)
    return 0
示例#16
0
def main():
    global args, best_result, output_directory, train_csv, test_csv
    args = parser.parse_args()
    if args.modality == 'rgb' and args.num_samples != 0:
        print("number of samples is forced to be 0 when input modality is rgb")
        args.num_samples = 0
    if args.modality == 'rgb' and args.max_depth != 0.0:
        print("max depth is forced to be 0.0 when input modality is rgb/rgbd")
        args.max_depth = 0.0

    sparsifier = None
    max_depth = args.max_depth if args.max_depth >= 0.0 else np.inf
    if args.sparsifier == UniformSampling.name:
        sparsifier = UniformSampling(num_samples=args.num_samples,
                                     max_depth=max_depth)
    elif args.sparsifier == SimulatedStereo.name:
        sparsifier = SimulatedStereo(num_samples=args.num_samples,
                                     max_depth=max_depth)

    # create results folder, if not already exists
    output_directory = os.path.join(
        '/media/kuowei/8EB89C8DB89C7585/results_CS',
        '{}'.format(args.outputdir),
        '{}.sparsifier={}.modality={}.arch={}.decoder={}.criterion={}.lr={}.bs={}'
        .format(args.data, sparsifier, args.modality, args.arch, args.decoder,
                args.criterion, args.lr, args.batch_size))
    if not os.path.exists(output_directory):
        os.makedirs(output_directory)
    train_csv = os.path.join(output_directory, 'train.csv')
    test_csv = os.path.join(output_directory, 'test.csv')
    best_txt = os.path.join(output_directory, 'best.txt')

    # define loss function (criterion) and optimizer
    if args.criterion == 'l2':
        criterion = criteria.MaskedMSELoss().cuda()
    elif args.criterion == 'l1':
        criterion = criteria.MaskedL1Loss().cuda()
    out_channels = 1

    # Data loading code
    print("=> creating data loaders ...")

    traindir = os.path.join('/media/kuowei/8EB89C8DB89C7585/data', args.data,
                            'train')
    valdir = os.path.join('/media/kuowei/8EB89C8DB89C7585/data', args.data,
                          'val')

    # traindir = os.path.join('data', args.data, 'train')
    # valdir = os.path.join('data', args.data, 'val')

    # if args.data == 'kitti':
    # 	pass
    # rgb_dir = '/media/kuowei/c9cb78ce-3109-4880-adad-b628c4261d82/rgb/train/rgb/'
    # sparse_depth_dir = '/media/kuowei/c9cb78ce-3109-4880-adad-b628c4261d82/rgb/train/sd/'
    # continuous_depth_dir = '/media/kuowei/c9cb78ce-3109-4880-adad-b628c4261d82/rgb/train/cd/'
    # ground_dir = '/media/kuowei/c9cb78ce-3109-4880-adad-b628c4261d82/rgb/train/ground/'
    # train_dataset = RgbdDataset(rgb_dir, sparse_depth_dir, continuous_depth_dir, ground_dir)
    # train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.workers, pin_memory=True, sampler=None))

    # # set batch size to be 1 for validation
    # rgb_dir_val = '/media/kuowei/c9cb78ce-3109-4880-adad-b628c4261d82/rgb/validate/rgb/'
    # sparse_depth_dir_val = '/media/kuowei/c9cb78ce-3109-4880-adad-b628c4261d82/rgb/validate/sd/'
    # continuous_depth_dir_val = '/media/kuowei/c9cb78ce-3109-4880-adad-b628c4261d82/rgb/validate/cd/'
    # ground_dir_val = '/media/kuowei/c9cb78ce-3109-4880-adad-b628c4261d82/rgb/validate/ground/'
    # val_dataset = RgbdDataset(rgb_dir_val, sparse_depth_dir_val, continuous_depth_dir_val, ground_dir_val)
    # val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=1, shuffle=False, num_workers=args.workers, pin_memory=True)

    # elif args.data == 'nyudepthv2':
    train_dataset = NYUDataset(traindir,
                               type='train',
                               modality=args.modality,
                               sparsifier=sparsifier)
    train_loader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size=args.batch_size,
                                               shuffle=True,
                                               num_workers=args.workers,
                                               pin_memory=True,
                                               sampler=None)

    # set batch size to be 1 for validation
    val_dataset = NYUDataset(valdir,
                             type='val',
                             modality=args.modality,
                             sparsifier=sparsifier)
    val_loader = torch.utils.data.DataLoader(val_dataset,
                                             batch_size=1,
                                             shuffle=False,
                                             num_workers=args.workers,
                                             pin_memory=True)

    print("=> data loaders created.")

    # evaluation mode
    if args.evaluate:
        best_model_filename = os.path.join(output_directory,
                                           'model_best.pth.tar')
        if os.path.isfile(best_model_filename):
            print("=> loading best model '{}'".format(best_model_filename))
            checkpoint = torch.load(best_model_filename)
            args.start_epoch = checkpoint['epoch']
            best_result = checkpoint['best_result']
            model = checkpoint['model']
            print("=> loaded best model (epoch {})".format(
                checkpoint['epoch']))
        else:
            print("=> no best model found at '{}'".format(best_model_filename))
        validate(val_loader, model, checkpoint['epoch'], write_to_file=False)
        return

    # optionally resume from a checkpoint
    elif args.resume:
        if os.path.isfile(args.resume):
            print("=> loading checkpoint '{}'".format(args.resume))
            checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint['epoch'] + 1
            best_result = checkpoint['best_result']
            model = checkpoint['model']
            optimizer = checkpoint['optimizer']
            print("=> loaded checkpoint (epoch {})".format(
                checkpoint['epoch']))
        else:
            print("=> no checkpoint found at '{}'".format(args.resume))
            return

    # create new model
    else:
        # define model
        print("=> creating Model ({}-{}) ...".format(args.arch, args.decoder))
        in_channels = len(args.modality)
        if args.arch == 'resnet50':
            model = ResNet(layers=50,
                           decoder=args.decoder,
                           in_channels=in_channels,
                           out_channels=out_channels,
                           pretrained=args.pretrained)
        elif args.arch == 'resnet18':
            model = ResNet(layers=18,
                           decoder=args.decoder,
                           in_channels=in_channels,
                           out_channels=out_channels,
                           pretrained=args.pretrained)
        elif args.arch == 'resnet152':
            model = ResNet(layers=152,
                           decoder=args.decoder,
                           in_channels=in_channels,
                           out_channels=out_channels,
                           pretrained=args.pretrained)
        print("=> model created.")

        optimizer = torch.optim.SGD(model.parameters(),
                                    args.lr,
                                    momentum=args.momentum,
                                    weight_decay=args.weight_decay)

        # create new csv files with only header
        with open(train_csv, 'w') as csvfile:
            writer = csv.DictWriter(csvfile, fieldnames=fieldnames)
            writer.writeheader()
        with open(test_csv, 'w') as csvfile:
            writer = csv.DictWriter(csvfile, fieldnames=fieldnames)
            writer.writeheader()

    # model = torch.nn.DataParallel(model).cuda()
    model = model.cuda()
    print(model)
    print("=> model transferred to GPU.")

    for epoch in range(args.start_epoch, args.epochs):
        adjust_learning_rate(optimizer, epoch)

        # train for one epoch
        train(train_loader, model, criterion, optimizer, epoch)

        # evaluate on validation set
        result, img_merge = validate(val_loader, model, epoch)

        # remember best rmse and save checkpoint
        is_best = result.rmse < best_result.rmse
        if is_best:
            best_result = result
            with open(best_txt, 'w') as txtfile:
                txtfile.write(
                    "epoch={}\nmse={:.3f}\nrmse={:.3f}\nabsrel={:.3f}\nlg10={:.3f}\nmae={:.3f}\ndelta1={:.3f}\nt_gpu={:.4f}\n"
                    .format(epoch, result.mse, result.rmse, result.absrel,
                            result.lg10, result.mae, result.delta1,
                            result.gpu_time))
            if img_merge is not None:
                img_filename = output_directory + '/comparison_best.png'
                utils.save_image(img_merge, img_filename)

        save_checkpoint(
            {
                'epoch': epoch,
                'arch': args.arch,
                'model': model,
                'best_result': best_result,
                'optimizer': optimizer,
            }, is_best, epoch)
示例#17
0
文件: main.py 项目: JennyGao00/depth
def main():
    global opt, best_result
    trainData = nyudata.getTrainingData_NYUDV2(opt.batch_size, opt.train_phase, opt.data_root)
    valData = nyudata.getTestingData_NYUDV2(opt.batch_size,  opt.val_phase, opt.data_root)
    print("load data finished!")

    print('create the model')
    # 定义损失模型
    model = UNet()
    optimizer = utils.utils.build_optimizer(model=model,
                                learning_rate=opt.lr,
                                optimizer_name=opt.optimizer_name,
                                weight_decay=opt.weight_decay,
                                epsilon=opt.epsilon,
                                momentum=opt.momentum
                                )
    model = model.cuda()
    scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, 'min', patience=opt.lr_patience)
    # 定义损失函数
    crite = criteria.MaskedL1Loss()

    # create directory path
    output_directory = utils.utils.get_output_dir(opt.output_dir)
    if not os.path.exists(output_directory):
        os.makedirs(output_directory)
    best_txt = os.path.join(output_directory, 'best.txt')
    config_txt = os.path.join(output_directory, 'config.txt')

    # write training parameters to config file
    if not os.path.exists(config_txt):
        with open(config_txt, 'w') as txtfile:
            args_ = vars(opt)
            args_str = ''
            for k, v in args_.items():
                args_str = args_str + str(k) + ':' + str(v) + ',\t\n'
            txtfile.write(args_str)

    # create log
    log_path = os.path.join(output_directory, 'logs',
                            datetime.now().strftime('%b%d_%H-%M-%S') + '_' + socket.gethostname())
    if os.path.isdir(log_path):
        shutil.rmtree(log_path)
    os.makedirs(log_path)
    logger = SummaryWriter(log_path)

    for epoch in range(0, opt.epoch):
        train(epoch, trainData, model, crite, optimizer, logger)
        result = validate(epoch, valData, model, logger)

        # remember best rmse and save checkpoint
        is_best = result.rmse < best_result.rmse
        if is_best:
            best_result = result
            with open(best_txt, 'w') as txtfile:
                txtfile.write(
                    "epoch={}, rmse={:.3f}, rml={:.3f}, log10={:.3f}, d1={:.3f}, d2={:.3f}, dd31={:.3f}, "
                    "t_gpu={:.4f}".
                        format(epoch, result.rmse, result.absrel, result.lg10, result.delta1, result.delta2,
                               result.delta3,
                               result.gpu_time))


        # save checkpoint for each epoch
        utils.utils.save_checkpoint({
            'args': opt,
            'epoch': epoch,
            'model': model,
            'best_result': best_result,
            'optimizer': optimizer,
        }, is_best, epoch, output_directory)

        # when rml doesn't fall, reduce learning rate
        scheduler.step(result.absrel)

    logger.close()
示例#18
0
def validate(val_loader, model, epoch, write_to_file=True):
    average_meter = AverageMeter()
    model.eval()  # switch to evaluate mode
    end = time.time()
    for i, (input, target) in enumerate(val_loader):
        input, target = input.cuda(), target.cuda()
        torch.cuda.synchronize()
        data_time = time.time() - end

        # compute output
        end = time.time()
        ##############################################################
        ##             Start of PnP-Depth modification              ##
        ##############################################################
        # Original inference
        with torch.no_grad():
            ori_pred = model.pnp_forward_front(model.pnp_forward_rear(
                input))  # equivalent to `ori_pred = model(input)`

        # Inference with PnP
        sparse_target = input[:, -1:]  # NOTE: written for rgbd input
        criterion = criteria.MaskedL1Loss().cuda(
        )  # NOTE: criterion function defined here only for clarity
        pnp_iters = 5  # number of iterations
        pnp_alpha = 0.01  # update/learning rate
        pnp_z = model.pnp_forward_front(input)
        for pnp_i in range(pnp_iters):
            if pnp_i != 0:
                pnp_z = pnp_z - pnp_alpha * torch.sign(pnp_z_grad)  # iFGM
            pnp_z = Variable(pnp_z, requires_grad=True)
            pred = model.pnp_forward_rear(pnp_z)
            if pnp_i < pnp_iters - 1:
                pnp_loss = criterion(pred, sparse_target)
                pnp_z_grad = Grad([pnp_loss], [pnp_z], create_graph=True)[0]
        ##############################################################
        ##              End of PnP-Depth modification               ##
        ##############################################################
        torch.cuda.synchronize()
        gpu_time = time.time() - end

        # measure accuracy and record loss
        result = Result()
        result.evaluate(pred.data, target.data)
        average_meter.update(result, gpu_time, data_time, input.size(0))
        end = time.time()

        # save 8 images for visualization
        skip = 50
        if args.modality == 'd':
            img_merge = None
        else:
            if args.modality == 'rgb':
                rgb = input
            elif args.modality == 'rgbd':
                rgb = input[:, :3, :, :]
                depth = input[:, 3:, :, :]

            if i == 0:
                if args.modality == 'rgbd':
                    img_merge = utils.merge_into_row_with_gt(
                        rgb, depth, target, pred)
                else:
                    img_merge = utils.merge_into_row(rgb, target, pred)
            elif (i < 8 * skip) and (i % skip == 0):
                if args.modality == 'rgbd':
                    row = utils.merge_into_row_with_gt(rgb, depth, target,
                                                       pred)
                else:
                    row = utils.merge_into_row(rgb, target, pred)
                img_merge = utils.add_row(img_merge, row)
            elif i == 8 * skip:
                filename = output_directory + '/comparison_' + str(
                    epoch) + '.png'
                utils.save_image(img_merge, filename)

        if (i + 1) % args.print_freq == 0:
            print('Test: [{0}/{1}]\t'
                  't_GPU={gpu_time:.3f}({average.gpu_time:.3f})\n\t'
                  'RMSE={result.rmse:.2f}({average.rmse:.2f}) '
                  'MAE={result.mae:.2f}({average.mae:.2f}) '
                  'Delta1={result.delta1:.3f}({average.delta1:.3f}) '
                  'REL={result.absrel:.3f}({average.absrel:.3f}) '
                  'Lg10={result.lg10:.3f}({average.lg10:.3f}) '.format(
                      i + 1,
                      len(val_loader),
                      gpu_time=gpu_time,
                      result=result,
                      average=average_meter.average()))

    avg = average_meter.average()

    print('\n*\n'
          'RMSE={average.rmse:.3f}\n'
          'MAE={average.mae:.3f}\n'
          'Delta1={average.delta1:.3f}\n'
          'REL={average.absrel:.3f}\n'
          'Lg10={average.lg10:.3f}\n'
          't_GPU={time:.3f}\n'.format(average=avg, time=avg.gpu_time))

    if write_to_file:
        with open(test_csv, 'a') as csvfile:
            writer = csv.DictWriter(csvfile, fieldnames=fieldnames)
            writer.writerow({
                'mse': avg.mse,
                'rmse': avg.rmse,
                'absrel': avg.absrel,
                'lg10': avg.lg10,
                'mae': avg.mae,
                'delta1': avg.delta1,
                'delta2': avg.delta2,
                'delta3': avg.delta3,
                'data_time': avg.data_time,
                'gpu_time': avg.gpu_time
            })
    return avg, img_merge
示例#19
0
def main():
    global args, best_result, output_directory, train_csv, test_csv

    # evaluation mode
    start_epoch = 0
    if args.evaluate:
        assert os.path.isfile(args.evaluate), \
        "=> no best model found at '{}'".format(args.evaluate)
        print("=> loading best model '{}'".format(args.evaluate))
        checkpoint = torch.load(args.evaluate)
        output_directory = os.path.dirname(args.evaluate)
        args = checkpoint['args']
        start_epoch = checkpoint['epoch'] + 1
        best_result = checkpoint['best_result']
        model = checkpoint['model']
        print("=> loaded best model (epoch {})".format(checkpoint['epoch']))
        _, val_loader = create_data_loaders(args)
        args.evaluate = True
        validate(val_loader, model, checkpoint['epoch'], write_to_file=False)
        return
    elif args.crossTrain:
        print("Retraining loaded model on current input parameters")
        train_loader, val_loader = create_data_loaders(args)
        checkpoint = torch.load(args.crossTrain)
        model = checkpoint['model']
        optimizer = torch.optim.SGD(model.parameters(), args.lr, \
            momentum=args.momentum, weight_decay=args.weight_decay)
        model = model.cuda()

    # optionally resume from a checkpoint
    elif args.resume:
        chkpt_path = args.resume
        assert os.path.isfile(chkpt_path), \
            "=> no checkpoint found at '{}'".format(chkpt_path)
        print("=> loading checkpoint '{}'".format(chkpt_path))
        checkpoint = torch.load(chkpt_path)
        args = checkpoint['args']
        start_epoch = checkpoint['epoch'] + 1
        best_result = checkpoint['best_result']
        model = checkpoint['model']
        optimizer = checkpoint['optimizer']
        output_directory = os.path.dirname(os.path.abspath(chkpt_path))
        print("=> loaded checkpoint (epoch {})".format(checkpoint['epoch']))
        train_loader, val_loader = create_data_loaders(args)
        args.resume = True

    # create new model
    else:
        train_loader, val_loader = create_data_loaders(args)
        print("=> creating Model ({}-{}) ...".format(args.arch, args.decoder))
        in_channels = len(args.modality)
        if args.arch == 'resnet50':
            model = ResNet(layers=50,
                           decoder=args.decoder,
                           output_size=train_loader.dataset.output_size,
                           in_channels=in_channels,
                           pretrained=args.pretrained)
        elif args.arch == 'resnet18':
            model = ResNet(layers=18,
                           decoder=args.decoder,
                           output_size=train_loader.dataset.output_size,
                           in_channels=in_channels,
                           pretrained=args.pretrained)
        print("=> model created.")
        optimizer = torch.optim.SGD(model.parameters(), args.lr, \
            momentum=args.momentum, weight_decay=args.weight_decay)

        # model = torch.nn.DataParallel(model).cuda() # for multi-gpu training
        model = model.cuda()

    # define loss function (criterion) and optimizer
    if args.criterion == 'l2':
        criterion = criteria.MaskedMSELoss().cuda()
    elif args.criterion == 'l1':
        criterion = criteria.MaskedL1Loss().cuda()

    # create results folder, if not already exists
    output_directory = utils.get_output_directory(args)
    if not os.path.exists(output_directory):
        os.makedirs(output_directory)
    train_csv = os.path.join(output_directory, 'train.csv')
    test_csv = os.path.join(output_directory, 'test.csv')
    best_txt = os.path.join(output_directory, 'best.txt')

    # create new csv files with only header
    if not args.resume:
        with open(train_csv, 'w') as csvfile:
            writer = csv.DictWriter(csvfile, fieldnames=fieldnames)
            writer.writeheader()
        with open(test_csv, 'w') as csvfile:
            writer = csv.DictWriter(csvfile, fieldnames=fieldnames)
            writer.writeheader()

    for epoch in range(start_epoch, args.epochs):
        utils.adjust_learning_rate(optimizer, epoch, args.lr)
        train(train_loader, model, criterion, optimizer,
              epoch)  # train for one epoch
        result, img_merge = validate(val_loader, model,
                                     epoch)  # evaluate on validation set

        # remember best rmse and save checkpoint
        is_best = result.rmse < best_result.rmse
        if is_best:
            best_result = result
            with open(best_txt, 'w') as txtfile:
                txtfile.write(
                    "epoch={}\nmse={:.3f}\nrmse={:.3f}\nabsrel={:.3f}\nlg10={:.3f}\nmae={:.3f}\ndelta1={:.3f}\nt_gpu={:.4f}\n"
                    .format(epoch, result.mse, result.rmse, result.absrel,
                            result.lg10, result.mae, result.delta1,
                            result.gpu_time))
            if img_merge is not None:
                img_filename = output_directory + '/comparison_best.png'
                utils.save_image(img_merge, img_filename)

        utils.save_checkpoint(
            {
                'args': args,
                'epoch': epoch,
                'arch': args.arch,
                'model': model,
                'best_result': best_result,
                'optimizer': optimizer,
            }, is_best, epoch, output_directory)
示例#20
0
文件: main.py 项目: LeonSun0101/YT
def main():
    torch.cuda.set_device(config.cuda_id)
    global args, best_result, output_directory, train_csv, test_csv, batch_num, best_txt
    best_result = Result()
    best_result.set_to_worst()
    batch_num = 0
    output_directory = utils.get_output_directory(args)

    #-----------------#
    # pytorch version #
    #-----------------#

    try:
        torch._utils._rebuild_tensor_v2
    except AttributeError:

        def _rebuild_tensor_v2(storage, storage_offset, size, stride,
                               requires_grad, backward_hooks):
            tensor = torch._utils._rebuild_tensor(storage, storage_offset,
                                                  size, stride)
            tensor.requires_grad = requires_grad
            tensor._backward_hooks = backward_hooks
            return tensor

        torch._utils._rebuild_tensor_v2 = _rebuild_tensor_v2

    if not os.path.exists(output_directory):
        os.makedirs(output_directory)

    train_csv = os.path.join(output_directory, 'train.csv')
    test_csv = os.path.join(output_directory, 'test.csv')
    best_txt = os.path.join(output_directory, 'best.txt')

    nowTime = datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')
    file = open(namefile, 'a+')
    file.writelines(
        str("====================================================") +
        str(nowTime) + '\n')
    file.writelines(str("Cuda_id: ") + str(config.cuda_id) + '\n')
    file.writelines(str("NAME: ") + str(config.name) + '\n')
    file.writelines(str("Description: ") + str(config.description) + '\n')
    file.writelines(
        str("model: ") + str(args.arch) + '\n' + str("loss_final: ") +
        str(args.criterion) + '\n' + str("loss_1: ") + str(config.LOSS_1) +
        '\n' + str("batch_size:") + str(args.batch_size) + '\n')
    file.writelines(str("zoom_scale: ") + str(config.zoom_scale) + '\n')
    file.writelines(str("------------------------") + '\n')
    file.writelines(str("Train_dataste: ") + str(config.train_dir) + '\n')
    file.writelines(str("Validation_dataste: ") + str(config.val_dir) + '\n')
    file.writelines(str("------------------------") + '\n')
    file.writelines(str("Input_type: ") + str(config.input) + '\n')
    file.writelines(str("target_type: ") + str(config.target) + '\n')
    file.writelines(str("LOSS--------------------") + '\n')
    file.writelines(str("Loss_num: ") + str(config.loss_num) + '\n')
    file.writelines(
        str("loss_final: ") + str(args.criterion) + '\n' + str("loss_1: ") +
        str(config.LOSS_1) + '\n')
    file.writelines(
        str("loss_0_weight: ") + str(config.LOSS_0_weight) + '\n' +
        str("loss_1_weight: ") + str(config.LOSS_1_weight) + '\n')
    file.writelines(
        str("weight_GT_canny: ") + str(config.weight_GT_canny_loss) + '\n' +
        str("weight_GT_sobel: ") + str(config.weight_GT_sobel_loss) + '\n' +
        str("weight_rgb_sobel: ") + str(config.weight_rgb_sobel_loss) + '\n')
    file.writelines(str("------------------------") + '\n')
    file.writelines(str("target: ") + str(config.target) + '\n')
    file.writelines(str("data_loader_type: ") + str(config.data_loader) + '\n')
    file.writelines(str("lr: ") + str(config.Init_lr) + '\n')
    file.writelines(str("save_fc: ") + str(config.save_fc) + '\n')
    file.writelines(str("Max epoch: ") + str(config.epoch) + '\n')
    file.close()

    # define loss function (criterion) and optimizer,定义误差函数和优化器
    if args.criterion == 'l2':
        criterion = criteria.MaskedMSELoss().cuda()
    elif args.criterion == 'l1':
        criterion = criteria.MaskedL1Loss().cuda()
    elif args.criterion == 'l1_canny':
        criterion = criteria.MaskedL1_cannyLoss().cuda()
    #SOBEL
    elif args.criterion == 'l1_from_rgb_sobel':
        criterion = criteria.MaskedL1_from_rgb_sobel_Loss().cuda()
    elif args.criterion == 'l1_from_GT_rgb_sobel':
        criterion = criteria.MaskedL1_from_GT_rgb_sobel_Loss().cuda()
    elif args.criterion == 'l1_from_GT_sobel':
        criterion = criteria.MaskedL1_from_GT_sobel_Loss().cuda()
    elif args.criterion == 'l2_from_GT_sobel_Loss':
        criterion = criteria.MaskedL2_from_GT_sobel_Loss().cuda()
    #CANNY
    elif args.criterion == 'l1_canny_from_GT_canny':
        criterion = criteria.MaskedL1_canny_from_GT_Loss().cuda()

    # Data loading code
    print("=> creating data loaders ...")
    train_dir = config.train_dir
    val_dir = config.val_dir
    train_dataset = YT_dataset(train_dir, config, is_train_set=True)
    val_dataset = YT_dataset(val_dir, config, is_train_set=False)

    train_loader = torch.utils.data.DataLoader(
        train_dataset,
        batch_size=config.batch_size,
        shuffle=True,
        num_workers=args.workers,
        pin_memory=True,
        sampler=None,
        worker_init_fn=lambda work_id: np.random.seed(work_id))
    # worker_init_fn ensures different sampling patterns for each data loading thread

    # set batch size to be 1 for validation
    val_loader = torch.utils.data.DataLoader(val_dataset,
                                             batch_size=1,
                                             shuffle=False,
                                             num_workers=args.workers,
                                             pin_memory=True)
    print("=> data loaders created.")

    if args.evaluate:
        best_model_filename = os.path.join(output_directory,
                                           'model_best.pth.tar')
        assert os.path.isfile(best_model_filename), \
        "=> no best model found at '{}'".format(best_model_filename)
        print("=> loading best model '{}'".format(best_model_filename))
        checkpoint = torch.load(best_model_filename)
        args.start_epoch = checkpoint['epoch']
        best_result = checkpoint['best_result']
        model = checkpoint['model']
        print("=> loaded best model (epoch {})".format(checkpoint['epoch']))
        validate(val_loader,
                 model,
                 checkpoint['epoch'],
                 1,
                 write_to_file=False)
        return

    elif args.test:
        print("testing...")
        best_model_filename = best_model_dir
        assert os.path.isfile(best_model_filename), \
            "=> no best model found at '{}'".format(best_model_filename)
        print("=> loading best model '{}'".format(best_model_filename))
        checkpoint = torch.load(best_model_filename)
        args.start_epoch = checkpoint['epoch']
        best_result = checkpoint['best_result']
        model = checkpoint['model']
        print("=> loaded best model (epoch {})".format(checkpoint['epoch']))
        optimizer = checkpoint['optimizer']
        for state in optimizer.state.values():
            for k, v in state.items():
                print(type(v))
                if torch.is_tensor(v):
                    state[k] = v.cuda()

        #test(val_loader, model, checkpoint['epoch'], write_to_file=False)
        test(model)
        return

    elif args.resume:
        assert os.path.isfile(config.resume_model_dir), \
            "=> no checkpoint found at '{}'".format(config.resume_model_dir)
        print("=> loading checkpoint '{}'".format(config.resume_model_dir))
        best_model_filename = config.resume_model_dir
        checkpoint = torch.load(best_model_filename)
        args.start_epoch = checkpoint['epoch'] + 1
        best_result = checkpoint['best_result']
        model = checkpoint['model']
        optimizer = checkpoint['optimizer']
        for state in optimizer.state.values():
            for k, v in state.items():
                #print(type(v))
                if torch.is_tensor(v):
                    state[k] = v.cuda(config.cuda_id)

        print("=> loaded checkpoint (epoch {})".format(checkpoint['epoch']))

    else:
        print("=> creating Model ({}-{}) ...".format(args.arch, args.decoder))
        if config.input == 'RGBT':
            in_channels = 4
        elif config.input == 'YT':
            in_channels = 2
        else:
            print("Input type is wrong !")
            return 0
        if args.arch == 'resnet50':  #调用ResNet的定义实例化model,这里的in_channels是
            model = ResNet(layers=50,
                           decoder=args.decoder,
                           output_size=train_dataset.output_size,
                           in_channels=in_channels,
                           pretrained=args.pretrained)
        elif args.arch == 'resnet50_deconv1_loss0':  # 调用ResNet的定义实例化model,这里的in_channels是
            model = ResNet_with_deconv(layers=50,
                                       decoder=args.decoder,
                                       output_size=train_dataset.output_size,
                                       in_channels=in_channels,
                                       pretrained=args.pretrained)
        elif args.arch == 'resnet50_deconv1_loss1':  # 调用ResNet的定义实例化model,这里的in_channels是
            model = ResNet_with_deconv_loss(
                layers=50,
                decoder=args.decoder,
                output_size=train_dataset.output_size,
                in_channels=in_channels,
                pretrained=args.pretrained)
        elif args.arch == 'resnet50_direct_deconv1_loss1':  # 调用ResNet的定义实例化model,这里的in_channels是
            model = ResNet_with_direct_deconv(
                layers=50,
                decoder=args.decoder,
                output_size=train_dataset.output_size,
                in_channels=in_channels,
                pretrained=args.pretrained)
        elif args.arch == 'resnet50_1':  # 调用ResNet的定义实例化model,这里的in_channels是
            model = ResNet_1(layers=50,
                             decoder=args.decoder,
                             output_size=train_dataset.output_size,
                             in_channels=in_channels,
                             pretrained=args.pretrained)
        elif args.arch == 'resnet50_2':  # 调用ResNet的定义实例化model,这里的in_channels是
            model = ResNet_2(layers=50,
                             decoder=args.decoder,
                             output_size=train_dataset.output_size,
                             in_channels=in_channels,
                             pretrained=args.pretrained)
        elif args.arch == 'resnet50_3':  # 调用ResNet的定义实例化model,这里的in_channels是
            model = ResNet_3(layers=50,
                             decoder=args.decoder,
                             output_size=train_dataset.output_size,
                             in_channels=in_channels,
                             pretrained=args.pretrained)
        elif args.arch == 'resnet50_3_1':  # 调用ResNet的定义实例化model,这里的in_channels是
            model = ResNet_3_1(layers=50,
                               decoder=args.decoder,
                               output_size=train_dataset.output_size,
                               in_channels=in_channels,
                               pretrained=args.pretrained)
        elif args.arch == 'resnet50_3_2':  # 调用ResNet的定义实例化model,这里的in_channels是
            model = ResNet_3_2(layers=50,
                               decoder=args.decoder,
                               output_size=train_dataset.output_size,
                               in_channels=in_channels,
                               pretrained=args.pretrained)
        elif args.arch == 'resnet50_3_3':  # 调用ResNet的定义实例化model,这里的in_channels是
            model = ResNet_3_3(layers=50,
                               decoder=args.decoder,
                               output_size=train_dataset.output_size,
                               in_channels=in_channels,
                               pretrained=args.pretrained)
        elif args.arch == 'resnet50_4':  # 调用ResNet的定义实例化model,这里的in_channels是
            model = ResNet_4(layers=50,
                             decoder=args.decoder,
                             output_size=train_dataset.output_size,
                             in_channels=in_channels,
                             pretrained=args.pretrained)
        elif args.arch == 'resnet50_5':  # 调用ResNet的定义实例化model,这里的in_channels是
            model = ResNet_5(layers=50,
                             decoder=args.decoder,
                             output_size=train_dataset.output_size,
                             in_channels=in_channels,
                             pretrained=args.pretrained)
        elif args.arch == 'resnet50_7':  # 调用ResNet的定义实例化model,这里的in_channels是
            model = ResNet_7(layers=50,
                             decoder=args.decoder,
                             output_size=train_dataset.output_size,
                             in_channels=in_channels,
                             pretrained=args.pretrained)
        elif args.arch == 'resnet50_8':  # 调用ResNet的定义实例化model,这里的in_channels是
            model = ResNet_8(layers=50,
                             decoder=args.decoder,
                             output_size=train_dataset.output_size,
                             in_channels=in_channels,
                             pretrained=args.pretrained)
        elif args.arch == 'resnet50_9':  # 调用ResNet的定义实例化model,这里的in_channels是
            model = ResNet_9(layers=50,
                             decoder=args.decoder,
                             output_size=train_dataset.output_size,
                             in_channels=in_channels,
                             pretrained=args.pretrained)
        elif args.arch == 'resnet50_10':  # 调用ResNet的定义实例化model,这里的in_channels是
            model = ResNet_10(layers=50,
                              decoder=args.decoder,
                              output_size=train_dataset.output_size,
                              in_channels=in_channels,
                              pretrained=args.pretrained)
        elif args.arch == 'resnet50_11':  # 调用ResNet的定义实例化model,这里的in_channels是
            model = ResNet_11(layers=50,
                              decoder=args.decoder,
                              output_size=train_dataset.output_size,
                              in_channels=in_channels,
                              pretrained=args.pretrained)
        elif args.arch == 'resnet50_11_1':  # 调用ResNet的定义实例化model,这里的in_channels是
            model = ResNet_11_1(layers=50,
                                decoder=args.decoder,
                                output_size=train_dataset.output_size,
                                in_channels=in_channels,
                                pretrained=args.pretrained)
        elif args.arch == 'resnet50_11_without_pretrain':  # 调用ResNet的定义实例化model,这里的in_channels是
            model = ResNet_11_without_pretrain(
                layers=50,
                decoder=args.decoder,
                output_size=train_dataset.output_size,
                in_channels=in_channels,
                pretrained=args.pretrained)
        elif args.arch == 'resnet50_12':  # 调用ResNet的定义实例化model,这里的in_channels是
            model = ResNet_12(layers=50,
                              decoder=args.decoder,
                              output_size=train_dataset.output_size,
                              in_channels=in_channels,
                              pretrained=args.pretrained)
        elif args.arch == 'resnet50_13':  # 调用ResNet的定义实例化model,这里的in_channels是
            model = ResNet_13(layers=50,
                              decoder=args.decoder,
                              output_size=train_dataset.output_size,
                              in_channels=in_channels,
                              pretrained=args.pretrained)
        elif args.arch == 'resnet50_14':  # 调用ResNet的定义实例化model,这里的in_channels是
            model = ResNet_14(layers=50,
                              decoder=args.decoder,
                              output_size=train_dataset.output_size,
                              in_channels=in_channels,
                              pretrained=args.pretrained)
        elif args.arch == 'resnet50_15':  # 调用ResNet的定义实例化model,这里的in_channels是
            model = ResNet_15(layers=50,
                              decoder=args.decoder,
                              output_size=train_dataset.output_size,
                              in_channels=in_channels,
                              pretrained=args.pretrained)
        elif args.arch == 'resnet50_16':  # 调用ResNet的定义实例化model,这里的in_channels是
            model = ResNet_16(layers=50,
                              decoder=args.decoder,
                              output_size=train_dataset.output_size,
                              in_channels=in_channels,
                              pretrained=args.pretrained)
        elif args.arch == 'resnet50_17':  # 调用ResNet的定义实例化model,这里的in_channels是
            model = ResNet_17(layers=50,
                              decoder=args.decoder,
                              output_size=train_dataset.output_size,
                              in_channels=in_channels,
                              pretrained=args.pretrained)
        elif args.arch == 'resnet50_18':  # 调用ResNet的定义实例化model,这里的in_channels是
            model = ResNet50_18(layers=50,
                                decoder=args.decoder,
                                output_size=train_dataset.output_size,
                                in_channels=in_channels,
                                pretrained=args.pretrained)
        elif args.arch == 'resnet50_30':  # 调用ResNet的定义实例化model,这里的in_channels是
            model = ResNet_30(layers=50,
                              decoder=args.decoder,
                              output_size=train_dataset.output_size,
                              in_channels=in_channels,
                              pretrained=args.pretrained)
        elif args.arch == 'resnet50_31':  # 调用ResNet的定义实例化model,这里的in_channels是
            model = ResNet_31(layers=50,
                              decoder=args.decoder,
                              output_size=train_dataset.output_size,
                              in_channels=in_channels,
                              pretrained=args.pretrained)
        elif args.arch == 'resnet50_32':  # 调用ResNet的定义实例化model,这里的in_channels是
            model = ResNet_32(layers=50,
                              decoder=args.decoder,
                              output_size=train_dataset.output_size,
                              in_channels=in_channels,
                              pretrained=args.pretrained)
        elif args.arch == 'resnet50_33':  # 调用ResNet的定义实例化model,这里的in_channels是
            model = ResNet_33(layers=50,
                              decoder=args.decoder,
                              output_size=train_dataset.output_size,
                              in_channels=in_channels,
                              pretrained=args.pretrained)
        elif args.arch == 'resnet50_40':  # 调用ResNet的定义实例化model,这里的in_channels是
            model = ResNet_40(layers=50,
                              decoder=args.decoder,
                              output_size=train_dataset.output_size,
                              in_channels=in_channels,
                              pretrained=args.pretrained)
        elif args.arch == 'resnet50_15_1':  # 调用ResNet的定义实例化model,这里的in_channels是
            model = ResNet_15_1(layers=50,
                                decoder=args.decoder,
                                output_size=train_dataset.output_size,
                                in_channels=in_channels,
                                pretrained=args.pretrained)
        elif args.arch == 'resnet50_15_2':  # 调用ResNet的定义实例化model,这里的in_channels是
            model = ResNet_15_2(layers=50,
                                decoder=args.decoder,
                                output_size=train_dataset.output_size,
                                in_channels=in_channels,
                                pretrained=args.pretrained)
        elif args.arch == 'resnet50_15_3':  # 调用ResNet的定义实例化model,这里的in_channels是
            model = ResNet_15_3(layers=50,
                                decoder=args.decoder,
                                output_size=train_dataset.output_size,
                                in_channels=in_channels,
                                pretrained=args.pretrained)
        elif args.arch == 'resnet50_15_4':  # 调用ResNet的定义实例化model,这里的in_channels是
            model = ResNet_15_4(layers=50,
                                decoder=args.decoder,
                                output_size=train_dataset.output_size,
                                in_channels=in_channels,
                                pretrained=args.pretrained)
        elif args.arch == 'resnet50_15_5':  # 调用ResNet的定义实例化model,这里的in_channels是
            model = ResNet_15_5(layers=50,
                                decoder=args.decoder,
                                output_size=train_dataset.output_size,
                                in_channels=in_channels,
                                pretrained=args.pretrained)
        elif args.arch == 'resnet50_15_6':  # 调用ResNet的定义实例化model,这里的in_channels是
            model = ResNet_15_6(layers=50,
                                decoder=args.decoder,
                                output_size=train_dataset.output_size,
                                in_channels=in_channels,
                                pretrained=args.pretrained)
        elif args.arch == 'resnet50_15_8':  # 调用ResNet的定义实例化model,这里的in_channels是
            model = ResNet_15_8(layers=34,
                                decoder=args.decoder,
                                output_size=train_dataset.output_size,
                                in_channels=in_channels,
                                pretrained=args.pretrained)
        elif args.arch == 'resnet50_15_9':  # 调用ResNet的定义实例化model,这里的in_channels是
            model = ResNet_15_9(layers=50,
                                decoder=args.decoder,
                                output_size=train_dataset.output_size,
                                in_channels=in_channels,
                                pretrained=args.pretrained)
        elif args.arch == 'resnet50_15_10':  # 调用ResNet的定义实例化model,这里的in_channels是
            model = ResNet_15_10(layers=50,
                                 decoder=args.decoder,
                                 output_size=train_dataset.output_size,
                                 in_channels=in_channels,
                                 pretrained=args.pretrained)
        elif args.arch == 'resnet50_15_11':  # 调用ResNet的定义实例化model,这里的in_channels是
            model = ResNet_15_11(layers=50,
                                 decoder=args.decoder,
                                 output_size=train_dataset.output_size,
                                 in_channels=in_channels,
                                 pretrained=args.pretrained)
        elif args.arch == 'resnet50_15_12':  # 调用ResNet的定义实例化model,这里的in_channels是
            model = ResNet_15_12(layers=50,
                                 decoder=args.decoder,
                                 output_size=train_dataset.output_size,
                                 in_channels=in_channels,
                                 pretrained=args.pretrained)
        elif args.arch == 'resnet18':
            model = ResNet(layers=18,
                           decoder=args.decoder,
                           output_size=train_dataset.output_size,
                           in_channels=in_channels,
                           pretrained=args.pretrained)
        elif args.arch == 'resnet50_20':
            model = ResNet50_20(Bottleneck, [3, 4, 6, 3])
        elif args.arch == 'UNet':
            model = UNet()
        elif args.arch == 'UP_only':
            model = UP_only()
        elif args.arch == 'ResNet_bicubic':  # 调用ResNet的定义实例化model,这里的in_channels是
            model = ResNet_bicubic(layers=50,
                                   decoder=args.decoder,
                                   output_size=train_dataset.output_size,
                                   in_channels=in_channels,
                                   pretrained=args.pretrained)
        elif args.arch == 'VDSR':
            model = VDSR()
        elif args.arch == 'VDSR_without_res':
            model = VDSR_without_res()
        elif args.arch == 'VDSR_16':
            model = VDSR_16()
        elif args.arch == 'VDSR_16_2':
            model = VDSR_16_2()
        elif args.arch == 'Leon_resnet50':
            model = Leon_resnet50()
        elif args.arch == 'Leon_resnet101':
            model = Leon_resnet101()
        elif args.arch == 'Leon_resnet18':
            model = Leon_resnet18()
        elif args.arch == 'Double_resnet50':
            model = Double_resnet50()
        print("=> model created.")

        if args.finetune:
            print("===============loading finetune model=====================")
            assert os.path.isfile(config.fitune_model_dir), \
            "=> no checkpoint found at '{}'".format(config.fitune_model_dir)
            print("=> loading checkpoint '{}'".format(config.fitune_model_dir))
            best_model_filename = config.fitune_model_dir
            checkpoint = torch.load(best_model_filename)
            args.start_epoch = checkpoint['epoch'] + 1
            #best_result = checkpoint['best_result']
            model_fitune = checkpoint['model']
            model_fitune_dict = model_fitune.state_dict()
            model_dict = model.state_dict()
            for k in model_fitune_dict:
                if k in model_dict:
                    #print("There is model k: ",k)
                    model_dict[k] = model_fitune_dict[k]
            #model_dict={k:v for k,v in model_fitune_dict.items() if k in model_dict}
            model_dict.update(model_fitune_dict)
            model.load_state_dict(model_dict)

            #optimizer = checkpoint['optimizer']
            print("=> loaded checkpoint (epoch {})".format(
                checkpoint['epoch']))

        #optimizer = torch.optim.SGD(model.parameters(), args.lr,momentum=args.momentum, weight_decay=args.weight_decay)
        optimizer = torch.optim.Adam(model.parameters(),
                                     lr=args.lr,
                                     amsgrad=True,
                                     weight_decay=args.weight_decay)
        '''
        optimizer = torch.optim.Adam(
        [
            #{'params':model.base.parameters()}, 3
            {'params': model.re_conv_Y_1.parameters(),'lr':0.0001},
            {'params': model.re_conv_Y_2.parameters(), 'lr': 0.0001},
            {'params': model.re_conv_Y_3.parameters(), 'lr': 0.0001},
            #3
            {'params': model.re_deconv_up0.parameters(), 'lr': 0.0001},
            {'params': model.re_deconv_up1.parameters(), 'lr': 0.0001},
            {'params': model.re_deconv_up2.parameters(), 'lr': 0.0001},
            #3
            {'params': model.re_conv1.parameters(), 'lr': 0.0001},
            {'params': model.re_bn1.parameters(), 'lr': 0.0001},
            {'params': model.re_conv4.parameters(), 'lr': 0.0001},
            #5
            {'params': model.re_ResNet50_layer1.parameters(), 'lr': 0.0001},
            {'params': model.re_ResNet50_layer2.parameters(), 'lr': 0.0001},
            {'params': model.re_ResNet50_layer3.parameters(), 'lr': 0.0001},
            {'params': model.re_ResNet50_layer4.parameters(), 'lr': 0.0001},

            {'params': model.re_bn2.parameters(), 'lr': 0.0001},
            #5
            {'params': model.re_deconcv_res_up1.parameters(), 'lr': 0.0001},
            {'params': model.re_deconcv_res_up2.parameters(), 'lr': 0.0001},
            {'params': model.re_deconcv_res_up3.parameters(), 'lr': 0.0001},
            {'params': model.re_deconcv_res_up4.parameters(), 'lr': 0.0001},

            {'params': model.re_deconv_last.parameters(), 'lr': 0.0001},
            #denoise net 3
            {'params': model.conv_denoise_1.parameters(), 'lr': 0},
            {'params': model.conv_denoise_2.parameters(), 'lr': 0},
            {'params': model.conv_denoise_3.parameters(), 'lr': 0}
        ]
        , lr=args.lr, amsgrad=True, weight_decay=args.weight_decay)
        '''
        for state in optimizer.state.values():
            for k, v in state.items():
                print(type(v))
                if torch.is_tensor(v):
                    state[k] = v.cuda(config.cuda_id)
        print(optimizer)

        # create new csv files with only header
        with open(train_csv, 'w') as csvfile:
            writer = csv.DictWriter(csvfile, fieldnames=fieldnames)
            writer.writeheader()
        with open(test_csv, 'w') as csvfile:
            writer = csv.DictWriter(csvfile, fieldnames=fieldnames)
            writer.writeheader()


#    writer = SummaryWriter(log_dir='logs')

    model = model.cuda(config.cuda_id)
    #torch.save(model, './net1.pkl')
    for state in optimizer.state.values():
        for k, v in state.items():
            print(type(v))
            if torch.is_tensor(v):
                state[k] = v.cuda()

    print("=> model transferred to GPU.")

    for epoch in range(args.start_epoch, args.epochs):
        train(train_loader, val_loader, model, criterion, optimizer, epoch,
              args.lr)  # train for one epoch
def main():
    global args, best_result, output_directory, train_csv, test_csv

    sparsifier = None
    max_depth = args.max_depth if args.max_depth >= 0.0 else np.inf
    if args.sparsifier == UniformSampling.name:
        sparsifier = UniformSampling(num_samples=args.num_samples,
                                     max_depth=max_depth)
    elif args.sparsifier == SimulatedStereo.name:
        sparsifier = SimulatedStereo(num_samples=args.num_samples,
                                     max_depth=max_depth)

    # create results folder, if not already exists
    output_directory = utils.get_output_directory(args)
    if not os.path.exists(output_directory):
        os.makedirs(output_directory)
    train_csv = os.path.join(output_directory, 'train.csv')
    test_csv = os.path.join(output_directory, 'test.csv')
    best_txt = os.path.join(output_directory, 'best.txt')

    # define loss function (criterion) and optimizer
    if args.criterion == 'l2':
        criterion = criteria.MaskedMSELoss().cuda()
    elif args.criterion == 'l1':
        criterion = criteria.MaskedL1Loss().cuda()
    out_channels = 1

    # Data loading code
    print("=> creating data loaders ...")
    traindir = os.path.join('data', args.data, 'train')
    valdir = os.path.join('data', args.data, 'val')

    train_dataset = NYUDataset(traindir,
                               type='train',
                               modality=args.modality,
                               sparsifier=sparsifier)
    train_loader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size=args.batch_size,
                                               shuffle=True,
                                               num_workers=args.workers,
                                               pin_memory=True,
                                               sampler=None)

    # set batch size to be 1 for validation
    val_dataset = NYUDataset(valdir,
                             type='val',
                             modality=args.modality,
                             sparsifier=sparsifier)
    val_loader = torch.utils.data.DataLoader(val_dataset,
                                             batch_size=1,
                                             shuffle=False,
                                             num_workers=args.workers,
                                             pin_memory=True)

    print("=> data loaders created.")

    # evaluation mode
    if args.evaluate:
        best_model_filename = os.path.join(output_directory,
                                           'model_best.pth.tar')
        assert os.path.isfile(best_model_filename), \
        "=> no best model found at '{}'".format(best_model_filename)
        print("=> loading best model '{}'".format(best_model_filename))
        checkpoint = torch.load(best_model_filename)
        args.start_epoch = checkpoint['epoch']
        best_result = checkpoint['best_result']
        model = checkpoint['model']
        print("=> loaded best model (epoch {})".format(checkpoint['epoch']))
        validate(val_loader, model, checkpoint['epoch'], write_to_file=False)
        return

    # optionally resume from a checkpoint
    elif args.resume:
        assert os.path.isfile(args.resume), \
            "=> no checkpoint found at '{}'".format(args.resume)
        print("=> loading checkpoint '{}'".format(args.resume))
        checkpoint = torch.load(args.resume)
        args.start_epoch = checkpoint['epoch'] + 1
        best_result = checkpoint['best_result']
        model = checkpoint['model']
        optimizer = checkpoint['optimizer']
        print("=> loaded checkpoint (epoch {})".format(checkpoint['epoch']))

    # create new model
    else:
        # define model
        print("=> creating Model ({}-{}) ...".format(args.arch, args.decoder))
        in_channels = len(args.modality)
        if args.arch == 'resnet50':
            model = ResNet(layers=50,
                           decoder=args.decoder,
                           in_channels=in_channels,
                           out_channels=out_channels,
                           pretrained=args.pretrained)
        elif args.arch == 'resnet18':
            model = ResNet(layers=18,
                           decoder=args.decoder,
                           in_channels=in_channels,
                           out_channels=out_channels,
                           pretrained=args.pretrained)
        print("=> model created.")

        optimizer = torch.optim.SGD(model.parameters(),
                                    args.lr,
                                    momentum=args.momentum,
                                    weight_decay=args.weight_decay)

        # create new csv files with only header
        with open(train_csv, 'w') as csvfile:
            writer = csv.DictWriter(csvfile, fieldnames=fieldnames)
            writer.writeheader()
        with open(test_csv, 'w') as csvfile:
            writer = csv.DictWriter(csvfile, fieldnames=fieldnames)
            writer.writeheader()

    # model = torch.nn.DataParallel(model).cuda()
    model = model.cuda()
    print(model)
    print("=> model transferred to GPU.")

    for epoch in range(args.start_epoch, args.epochs):
        utils.adjust_learning_rate(optimizer, epoch, args.lr)

        # train for one epoch
        train(train_loader, model, criterion, optimizer, epoch)

        # evaluate on validation set
        result, img_merge = validate(val_loader, model, epoch)

        # remember best rmse and save checkpoint
        is_best = result.rmse < best_result.rmse
        if is_best:
            best_result = result
            with open(best_txt, 'w') as txtfile:
                txtfile.write(
                    "epoch={}\nmse={:.3f}\nrmse={:.3f}\nabsrel={:.3f}\nlg10={:.3f}\nmae={:.3f}\ndelta1={:.3f}\nt_gpu={:.4f}\n"
                    .format(epoch, result.mse, result.rmse, result.absrel,
                            result.lg10, result.mae, result.delta1,
                            result.gpu_time))
            if img_merge is not None:
                img_filename = output_directory + '/comparison_best.png'
                utils.save_image(img_merge, img_filename)

        utils.save_checkpoint(
            {
                'args': args,
                'epoch': epoch,
                'arch': args.arch,
                'model': model,
                'best_result': best_result,
                'optimizer': optimizer,
            }, is_best, epoch, output_directory)
示例#22
0
args = parser.parse_args()
args.use_pose = ("photo" in args.train_mode)
# args.pretrained = not args.no_pretrained
args.result = os.path.join('..', 'results')
args.use_rgb = ('rgb' in args.input) or args.use_pose
args.use_d = 'd' in args.input
args.use_g = 'g' in args.input
if args.use_pose:
    args.w1, args.w2 = 0.1, 0.1
else:
    args.w1, args.w2 = 0, 0
print(args)

# define loss functions
depth_criterion = criteria.MaskedMSELoss() if (args.criterion == 'l2') else criteria.MaskedL1Loss()
photometric_criterion = criteria.PhotometricLoss()
smoothness_criterion = criteria.SmoothnessLoss()

if args.use_pose:
    # hard-coded KITTI camera intrinsics
    K = load_calib()
    fu, fv = float(K[0,0]), float(K[1,1])
    cu, cv = float(K[0,2]), float(K[1,2])
    kitti_intrinsics = Intrinsics(owidth, oheight, fu, fv, cu, cv).cuda()


def iterate(mode, args, loader, model, optimizer, logger, epoch):
    block_average_meter = AverageMeter()
    average_meter = AverageMeter()
    meters = [block_average_meter, average_meter]