예제 #1
0
def main():
    global best_error, n_iter, device
    args = parser.parse_args()
    torch.manual_seed(args.seed)
    if args.evaluate:
        args.epochs = 0

    train_set = BioData(args.root, seed=args.seed, train=True)

    train_loader = torch.utils.data.DataLoader(
        train_set, batch_size=args.batch_size, shuffle=True,
        num_workers=args.workers, pin_memory=True)
    # val_loader = torch.utils.data.DataLoader(
    #     val_set, batch_size=args.batch_size, shuffle=False,
    #     num_workers=args.workers, pin_memory=True)

    if args.epoch_size == 0:
        args.epoch_size = len(train_loader)

    # create model
    print("=> creating model")

    bio_net = models.BioModelCnn.to(device)
    bio_net.init_weights()

    cudnn.benchmark = True
    bio_net = torch.nn.DataParallel(bio_net)

    print('=> setting adam solver')

    optim_params = [
        {'params': bio_net.parameters(), 'lr': args.lr},
    ]
    optimizer = torch.optim.Adam(optim_params,
                                 betas=(args.momentum, args.beta),
                                 weight_decay=args.weight_decay)

    logger = TermLogger(n_epochs=args.epochs, train_size=min(len(train_loader), args.epoch_size), valid_size=len(val_loader))
    logger.epoch_bar.start()

    for epoch in range(args.epochs):
        logger.epoch_bar.update(epoch)

        # train for one epoch
        logger.reset_train_bar()
        train_loss = train(args, train_loader, bio_net, optimizer, args.epoch_size, logger)
        logger.train_writer.write(' * Avg Loss : {:.3f}'.format(train_loss))
        logger.reset_valid_bar()

    logger.epoch_bar.finish()
예제 #2
0
def main():
    global global_vars_dict
    args = global_vars_dict['args']
    best_error = -1  #best model choosing

    #mkdir
    timestamp = datetime.datetime.now().strftime("%m-%d-%H:%M")

    args.save_path = Path('checkpoints') / Path(args.data_dir).stem / timestamp

    args.save_path.makedirs_p()
    torch.manual_seed(args.seed)
    if args.alternating:
        args.alternating_flags = np.array([False, False, True])
    #mk writers
    tb_writer = SummaryWriter(args.save_path)

    # Data loading code and transpose

    if args.data_normalization == 'global':
        normalize = custom_transforms.Normalize(mean=[0.5, 0.5, 0.5],
                                                std=[0.5, 0.5, 0.5])
    elif args.data_normalization == 'local':
        normalize = custom_transforms.NormalizeLocally()

    valid_transform = custom_transforms.Compose(
        [custom_transforms.ArrayToTensor(), normalize])

    print("=> fetching scenes in '{}'".format(args.data_dir))

    train_transform = custom_transforms.Compose([
        #custom_transforms.RandomRotate(),
        custom_transforms.RandomHorizontalFlip(),
        custom_transforms.RandomScaleCrop(),
        custom_transforms.ArrayToTensor(),
        normalize
    ])

    #train set, loader only建立一个
    from datasets.sequence_mc import SequenceFolder
    train_set = SequenceFolder(  # mc data folder
        args.data_dir,
        transform=train_transform,
        seed=args.seed,
        train=True,
        sequence_length=args.sequence_length,  # 5
        target_transform=None,
        depth_format='png')

    train_loader = torch.utils.data.DataLoader(train_set,
                                               batch_size=args.batch_size,
                                               shuffle=True,
                                               num_workers=args.workers,
                                               pin_memory=True,
                                               drop_last=True)

    if args.epoch_size == 0:
        args.epoch_size = len(train_loader)

#val set,loader 挨个建立
#if args.val_with_depth_gt:
    from datasets.validation_folders2 import ValidationSet

    val_set_with_depth_gt = ValidationSet(args.data_dir,
                                          transform=valid_transform,
                                          depth_format='png')

    val_loader_depth = torch.utils.data.DataLoader(val_set_with_depth_gt,
                                                   batch_size=args.batch_size,
                                                   shuffle=False,
                                                   num_workers=args.workers,
                                                   pin_memory=True,
                                                   drop_last=True)

    print('{} samples found in {} train scenes'.format(len(train_set),
                                                       len(train_set.scenes)))

    #1 create model
    print("=> creating model")
    #1.1 disp_net
    disp_net = getattr(models, args.dispnet)().cuda()
    output_exp = True  #args.mask_loss_weight > 0

    if args.pretrained_disp:
        print("=> using pre-trained weights from {}".format(
            args.pretrained_disp))
        weights = torch.load(args.pretrained_disp)
        disp_net.load_state_dict(weights['state_dict'])
    else:
        disp_net.init_weights()

    if args.resume:
        print("=> resuming from checkpoint")
        dispnet_weights = torch.load(args.save_path /
                                     'dispnet_checkpoint.pth.tar')
        disp_net.load_state_dict(dispnet_weights['state_dict'])

    cudnn.benchmark = True
    disp_net = torch.nn.DataParallel(disp_net)

    print('=> setting adam solver')

    parameters = chain(disp_net.parameters())

    optimizer = torch.optim.Adam(parameters,
                                 args.lr,
                                 betas=(args.momentum, args.beta),
                                 weight_decay=args.weight_decay)

    if args.resume and (args.save_path /
                        'optimizer_checkpoint.pth.tar').exists():
        print("=> loading optimizer from checkpoint")
        optimizer_weights = torch.load(args.save_path /
                                       'optimizer_checkpoint.pth.tar')
        optimizer.load_state_dict(optimizer_weights['state_dict'])

    #
    if args.log_terminal:
        logger = TermLogger(n_epochs=args.epochs,
                            train_size=min(len(train_loader), args.epoch_size),
                            valid_size=len(val_loader_depth))
        logger.reset_epoch_bar()
    else:
        logger = None


#预先评估下
    criterion_train = MaskedL1Loss().to(device)  # l1LOSS 容易优化
    criterion_val = ComputeErrors().to(device)

    #depth_error_names,depth_errors = validate_depth_with_gt(val_loader_depth, disp_net,criterion=criterion_val, epoch=0, logger=logger,tb_writer=tb_writer,global_vars_dict=global_vars_dict)

    #logger.reset_epoch_bar()
    #    logger.epoch_logger_update(epoch=0,time=0,names=depth_error_names,values=depth_errors)
    epoch_time = AverageMeter()
    end = time.time()
    #3. main cycle
    for epoch in range(1, args.epochs):  #epoch 0 在第没入循环之前已经测试了.

        logger.reset_train_bar()
        logger.reset_valid_bar()

        errors = [0]
        error_names = ['no error names depth']

        #3.2 train for one epoch---------
        loss_names, losses = train_depth_gt(train_loader=train_loader,
                                            disp_net=disp_net,
                                            optimizer=optimizer,
                                            criterion=criterion_train,
                                            logger=logger,
                                            train_writer=tb_writer,
                                            global_vars_dict=global_vars_dict)

        #3.3 evaluate on validation set-----
        depth_error_names, depth_errors = validate_depth_with_gt(
            val_loader=val_loader_depth,
            disp_net=disp_net,
            criterion=criterion_val,
            epoch=epoch,
            logger=logger,
            tb_writer=tb_writer,
            global_vars_dict=global_vars_dict)

        epoch_time.update(time.time() - end)
        end = time.time()

        #3.5 log_terminal
        #if args.log_terminal:
        if args.log_terminal:
            logger.epoch_logger_update(epoch=epoch,
                                       time=epoch_time,
                                       names=depth_error_names,
                                       values=depth_errors)

    # tensorboard scaler
    #train loss
        for loss_name, loss in zip(loss_names, losses.avg):
            tb_writer.add_scalar('train/' + loss_name, loss, epoch)

        #val_with_gt loss
        for name, error in zip(depth_error_names, depth_errors.avg):
            tb_writer.add_scalar('val/' + name, error, epoch)

        #3.6 save model and remember lowest error and save checkpoint
        total_loss = losses.avg[0]
        if best_error < 0:
            best_error = total_loss

        is_best = total_loss <= best_error
        best_error = min(best_error, total_loss)
        save_checkpoint(args.save_path, {
            'epoch': epoch + 1,
            'state_dict': disp_net.module.state_dict()
        }, {
            'epoch': epoch + 1,
            'state_dict': None
        }, {
            'epoch': epoch + 1,
            'state_dict': None
        }, {
            'epoch': epoch + 1,
            'state_dict': None
        }, is_best)

    if args.log_terminal:
        logger.epoch_bar.finish()
예제 #3
0
def main():
    print('=> number of GPU: ', args.gpu_num)
    os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu_num
    save_path = save_path_formatter(args, parser)
    args.save_path = 'checkpoints' / save_path
    print("=> information will be saved in {}".format(args.save_path))
    args.save_path.makedirs_p()
    torch.manual_seed(args.seed)

    img_H = args.height
    img_W = args.width
    if args.evaluate:
        args.epochs = 0
    training_writer = SummaryWriter(args.save_path)

    ########################################################################
    ######################   Data loading part    ##########################

    ## normalize -1 to 1 func
    normalize = Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
    if args.dataset == "NYU":
        valid_transform = Compose([
            CenterCrop(size=(img_H, img_W)),
            ArrayToTensor(height=img_H, width=img_W), normalize
        ])  ### NYU valid transform ###
    else:
        valid_transform = Compose(
            [ArrayToTensor(height=img_H, width=img_W),
             normalize])  ### KITTI valid transform ###
    print("=> fetching scenes in '{}'".format(args.data))
    print("=> Dataset: ", args.dataset)

    if args.dataset == 'KITTI':
        train_transform = Compose([
            RandomHorizontalFlip(),
            RandomScaleCrop(),
            ArrayToTensor(height=img_H, width=img_W), normalize
        ])
        train_set = SequenceFolder(args.data,
                                   args=args,
                                   transform=train_transform,
                                   seed=args.seed,
                                   train=True,
                                   mode=args.mode)
        if args.real_test is False:
            print("=> test on validation set")
            '''
            val_set = SequenceFolder(
                args.data, args = args, transform=valid_transform,
                seed=args.seed, train=False, mode = args.mode)
            '''
            val_set = TestFolder(args.data,
                                 args=args,
                                 transform=valid_transform,
                                 seed=args.seed,
                                 train=False,
                                 mode=args.mode)
        else:
            print("=> test on Eigen test split")
            val_set = TestFolder(args.data,
                                 args=args,
                                 transform=valid_transform,
                                 seed=args.seed,
                                 train=False,
                                 mode=args.mode)
    elif args.dataset == 'Make3D':
        train_transform = Compose([
            RandomHorizontalFlip(),
            RandomScaleCrop(),
            ArrayToTensor(height=img_H, width=img_W), normalize
        ])
        train_set = Make3DFolder(args.data,
                                 args=args,
                                 transform=train_transform,
                                 seed=args.seed,
                                 train=True,
                                 mode=args.mode)
        val_set = Make3DFolder(args.data,
                               args=args,
                               transform=valid_transform,
                               seed=args.seed,
                               train=False,
                               mode=args.mode)
    elif args.dataset == 'NYU':
        if args.mode == 'RtoD':
            print('RtoD transform created')
            train_transform = EnhancedCompose([
                Merge(),
                RandomCropNumpy(size=(251, 340)),
                RandomRotate(angle_range=(-5, 5), mode='constant'),
                Split([0, 3], [3, 4])
            ])
            train_transform_2 = EnhancedCompose([
                CenterCrop(size=(img_H, img_W)),
                RandomHorizontalFlip(),
                [RandomColor(multiplier_range=(0.8, 1.2)), None],
                ArrayToTensor(height=img_H, width=img_W), normalize
            ])
        elif args.mode == 'DtoD':
            print('DtoD transform created')
            train_transform = EnhancedCompose([
                Merge(),
                RandomCropNumpy(size=(251, 340)),
                RandomRotate(angle_range=(-4, 4), mode='constant'),
                Split([0, 1])
            ])
            train_transform_2 = EnhancedCompose([
                CenterCrop(size=(img_H, img_W)),
                RandomHorizontalFlip(),
                ArrayToTensor(height=img_H, width=img_W), normalize
            ])
        train_set = NYUdataset(args.data,
                               args=args,
                               transform=train_transform,
                               transform_2=train_transform_2,
                               seed=args.seed,
                               train=True,
                               mode=args.mode)
        val_set = NYUdataset(args.data,
                             args=args,
                             transform=valid_transform,
                             seed=args.seed,
                             train=False,
                             mode=args.mode)
    #print('samples_num: {}  train scenes: {}'.format(len(train_set), len(train_set.scenes)))
    print('=> samples_num: {}  '.format(len(train_set)))
    print('=> samples_num: {}- test'.format(len(val_set)))
    train_loader = torch.utils.data.DataLoader(train_set,
                                               batch_size=args.batch_size,
                                               shuffle=True,
                                               num_workers=args.workers,
                                               pin_memory=True)
    val_loader = torch.utils.data.DataLoader(val_set,
                                             batch_size=args.batch_size,
                                             shuffle=False,
                                             num_workers=args.workers,
                                             pin_memory=True)

    if args.epoch_size == 0:
        args.epoch_size = len(train_loader)
    cudnn.benchmark = True
    ###########################################################################
    ###########################################################################

    ################################################################################
    ###################### Setting Network, Loss, Optimizer part ###################

    print("=> creating model")
    if args.mode == 'DtoD':
        print('- DtoD train')
        AE_DtoD = AutoEncoder_DtoD(norm=args.norm,
                                   input_dim=1,
                                   height=img_H,
                                   width=img_W)
        AE_DtoD = nn.DataParallel(AE_DtoD)
        AE_DtoD = AE_DtoD.cuda()
        #AE_DtoD.load_state_dict(torch.load(args.model_dir))
        print('- DtoD model is created')
        optimizer_AE = optim.Adam(AE_DtoD.parameters(),
                                  args.lr, [args.momentum, args.beta],
                                  eps=1e-08,
                                  weight_decay=5e-4)
        criterion_L2 = nn.MSELoss()
        criterion_L1 = nn.L1Loss()
    elif args.mode == 'RtoD':
        print('- RtoD train')
        AE_DtoD = AutoEncoder_DtoD(norm=args.norm,
                                   input_dim=1,
                                   height=img_H,
                                   width=img_W)
        AE_DtoD = nn.DataParallel(AE_DtoD)
        AE_DtoD = AE_DtoD.cuda()
        AE_DtoD.load_state_dict(torch.load(args.model_dir))
        AE_DtoD.eval()
        print('- pretrained DtoD model is created')
        AE_RtoD = AutoEncoder_2(norm=args.norm,
                                input_dim=3,
                                height=img_H,
                                width=img_W)
        AE_RtoD = nn.DataParallel(AE_RtoD)
        AE_RtoD = AE_RtoD.cuda()
        #AE_RtoD.load_state_dict(torch.load(args.RtoD_model_dir))
        print('- RtoD model is created')
        optimizer_AE = optim.Adam(AE_RtoD.parameters(),
                                  args.lr, [args.momentum, args.beta],
                                  eps=1e-08,
                                  weight_decay=5e-4)
        criterion_L2 = nn.MSELoss()
        criterion_L1 = nn.L1Loss()
    elif args.mode == 'RtoD_single':
        print('- RtoD single train')
        AE_DtoD = None
        AE_RtoD = AutoEncoder_2(norm=args.norm,
                                input_dim=3,
                                height=img_H,
                                width=img_W)
        AE_RtoD = nn.DataParallel(AE_RtoD)
        AE_RtoD = AE_RtoD.cuda()
        #AE_RtoD.load_state_dict(torch.load(args.RtoD_model_dir))
        print('- RtoD model is created')
        optimizer_AE = optim.Adam(AE_RtoD.parameters(),
                                  args.lr, [args.momentum, args.beta],
                                  eps=1e-08,
                                  weight_decay=5e-4)
        criterion_L2 = nn.MSELoss()
        criterion_L1 = nn.L1Loss()
    elif args.mode == 'DtoD_test':
        print('- DtoD test')
        AE_DtoD = AutoEncoder_DtoD(norm=args.norm,
                                   input_dim=1,
                                   height=img_H,
                                   width=img_W)
        AE_DtoD = nn.DataParallel(AE_DtoD)
        AE_DtoD = AE_DtoD.cuda()
        AE_DtoD.load_state_dict(torch.load(args.model_dir))
        print('- pretrained DtoD model is created')
    elif args.mode == 'RtoD_test':
        print('- RtoD test')
        AE_RtoD = AutoEncoder(norm=args.norm, height=img_H, width=img_W)
        #AE_RtoD = AutoEncoder_2(norm=args.norm,input_dim=3,height=img_H,width=img_W)
        AE_RtoD = nn.DataParallel(AE_RtoD)
        AE_RtoD = AE_RtoD.cuda()
        AE_RtoD.load_state_dict(torch.load(args.RtoD_model_dir))
        print('- pretrained RtoD model is created')

    #############################################################################
    #############################################################################

    ############################ data log #######################################
    if args.evaluate == True:
        logger = TermLogger(n_epochs=args.epochs,
                            train_size=min(len(train_loader), args.epoch_size),
                            valid_size=len(val_loader))
        logger.epoch_bar.start()
    elif args.evaluate == False:
        logger = None

    #logger = TermLogger(n_epochs=args.epochs, train_size=min(len(train_loader), args.epoch_size), valid_size=len(val_loader))
    #logger.epoch_bar.start()
    if logger is not None:
        with open(args.save_path / args.log_summary, 'w') as csvfile:
            writer = csv.writer(csvfile, delimiter='\t')
            writer.writerow(['train_loss', 'validation_loss'])

        with open(args.save_path / args.log_full, 'w') as csvfile:
            writer = csv.writer(csvfile, delimiter='\t')
            writer.writerow(['train_loss_sum', 'output_loss', 'latent_loss'])

    #############################################################################

    ############################ Training part ##################################
    if args.mode == 'DtoD':
        loss = train_AE_DtoD(args, AE_DtoD, criterion_L2, criterion_L1,
                             optimizer_AE, train_loader, val_loader,
                             args.batch_size, args.epochs, args.lr, logger,
                             training_writer)
        print('Final loss:', loss.item())
    elif args.mode == 'RtoD' or args.mode == 'RtoD_single':
        loss, output_loss, latent_loss = train_AE_RtoD(
            args, AE_RtoD, AE_DtoD, criterion_L2, criterion_L1, optimizer_AE,
            train_loader, val_loader, args.batch_size, args.epochs, args.lr,
            logger, training_writer)

    ########################### Evaluating part #################################
    if args.mode == 'DtoD_test':
        test_model = AE_DtoD
        print("DtoD_test - switch model to eval mode")
    elif args.mode == 'RtoD_test':
        test_model = AE_RtoD
        print("RtoD_test - switch model to eval mode")
    test_model.eval()
    if (logger is not None) and (args.evaluate == True):
        if args.dataset == 'KITTI':
            logger.reset_valid_bar()
            errors, min_errors, error_names = validate(args, val_loader,
                                                       test_model, 0, logger,
                                                       args.mode)
            error_length = 8
        elif args.dataset == 'Make3D':
            logger.reset_valid_bar()
            errors, min_errors, error_names = validate_Make3D(
                args, val_loader, test_model, 0, logger, args.mode)
            error_length = 4
        elif args.dataset == 'NYU':
            logger.reset_valid_bar()
            errors, min_errors, error_names = validate_NYU(
                args, val_loader, test_model, 0, logger, args.mode)
            error_length = 8
        for error, name in zip(errors, error_names):
            training_writer.add_scalar(name, error, 0)
        error_string = ', '.join(
            '{} : {:.3f}'.format(name, error) for name, error in zip(
                error_names[0:error_length], errors[0:error_length]))
        logger.valid_writer.write(' * Avg {}'.format(error_string))
        print("")
        error_string = ', '.join(
            '{} : {:.3f}'.format(name, error) for name, error in zip(
                error_names[0:error_length], min_errors[0:error_length]))
        logger.valid_writer.write(' * Avg {}'.format(error_string))
        logger.valid_bar.finish()
        print(args.dataset, "valdiation finish")

    ##  Test

    if args.img_save is False:
        print("--only Test mode finish--")
        return

    k = 0

    for gt_data, rgb_data, _ in val_loader:
        if args.mode == 'RtoD' or args.mode == 'RtoD_test':
            gt_data = Variable(gt_data.cuda())
            final_AE_in = rgb_data.cuda()
        elif args.mode == 'DtoD' or args.mode == 'DtoD_test':
            rgb_data = Variable(rgb_data.cuda())
            final_AE_in = gt_data.cuda()
        final_AE_in = Variable(final_AE_in)
        with torch.no_grad():
            final_AE_depth = test_model(final_AE_in, istrain=False)
        img_arr = [final_AE_depth, gt_data, rgb_data]
        folder_name_list = ['/output_depth', '/ground_truth', '/input_rgb']
        img_name_list = ['/final_AE_depth_', '/final_AE_gt_', '/final_AE_rgb_']
        folder_iter = cycle(folder_name_list)
        img_name_iter = cycle(img_name_list)
        for img in img_arr:
            img_org = img.cpu().detach().numpy()
            folder_name = next(folder_iter)
            img_name = next(img_name_iter)
            result_dir = args.result_dir + folder_name
            if not os.path.exists(result_dir):
                os.makedirs(result_dir)
            for t in range(img_org.shape[0]):
                img = img_org[t]
                if img.shape[0] == 3:
                    img_ = np.empty([img_H, img_W, 3])
                    img_[:, :, 0] = img[0, :, :]
                    img_[:, :, 1] = img[1, :, :]
                    img_[:, :, 2] = img[2, :, :]
                elif img.shape[0] == 1:
                    img_ = np.empty([img_H, img_W])
                    img_[:, :] = img[0, :, :]
                scipy.misc.imsave(result_dir + img_name + '%05d.jpg' % (k + t),
                                  img_)
        k += img_org.shape[0]
예제 #4
0
def main():
    global args, best_photo_loss, n_iter
    args = parser.parse_args()
    if args.dataset_format == 'stacked':
        from datasets.stacked_sequence_folders import SequenceFolder
    elif args.dataset_format == 'sequential':
        from datasets.sequence_folders import SequenceFolder
    save_path = Path('{}epochs{},seq{},b{},lr{},p{},m{},s{}'.format(
        args.epochs,
        ',epochSize' + str(args.epoch_size) if args.epoch_size > 0 else '',
        args.sequence_length, args.batch_size, args.lr, args.photo_loss_weight,
        args.mask_loss_weight, args.smooth_loss_weight))
    timestamp = datetime.datetime.now().strftime("%m-%d-%H:%M")
    args.save_path = 'checkpoints' / save_path / timestamp
    print('=> will save everything to {}'.format(args.save_path))
    args.save_path.makedirs_p()
    torch.manual_seed(args.seed)

    train_writer = SummaryWriter(args.save_path / 'train')
    valid_writer = SummaryWriter(args.save_path / 'valid')
    output_writers = []
    if args.log_output:
        for i in range(3):
            output_writers.append(
                SummaryWriter(args.save_path / 'valid' / str(i)))

    # Data loading code
    normalize = custom_transforms.Normalize(mean=[0.5, 0.5, 0.5],
                                            std=[0.5, 0.5, 0.5])
    input_transform = custom_transforms.Compose([
        custom_transforms.RandomHorizontalFlip(),
        custom_transforms.RandomScaleCrop(),
        custom_transforms.ArrayToTensor(), normalize
    ])

    print("=> fetching scenes in '{}'".format(args.data))
    train_set = SequenceFolder(args.data,
                               transform=input_transform,
                               seed=args.seed,
                               train=True,
                               sequence_length=args.sequence_length)
    val_set = SequenceFolder(args.data,
                             transform=custom_transforms.Compose([
                                 custom_transforms.ArrayToTensor(), normalize
                             ]),
                             seed=args.seed,
                             train=False,
                             sequence_length=args.sequence_length)
    print('{} samples found in {} train scenes'.format(len(train_set),
                                                       len(train_set.scenes)))
    print('{} samples found in {} valid scenes'.format(len(val_set),
                                                       len(val_set.scenes)))
    train_loader = torch.utils.data.DataLoader(train_set,
                                               batch_size=args.batch_size,
                                               shuffle=True,
                                               num_workers=args.workers,
                                               pin_memory=True)
    val_loader = torch.utils.data.DataLoader(val_set,
                                             batch_size=args.batch_size,
                                             shuffle=False,
                                             num_workers=args.workers,
                                             pin_memory=True)

    if args.epoch_size == 0:
        args.epoch_size = len(train_loader)

    # create model
    print("=> creating model")

    disp_net = models.DispNetS().cuda()
    output_exp = args.mask_loss_weight > 0
    if not output_exp:
        print("=> no mask loss, PoseExpnet will only output pose")
    pose_exp_net = models.PoseExpNet(
        nb_ref_imgs=args.sequence_length - 1,
        output_exp=args.mask_loss_weight > 0).cuda()

    if args.pretrained_exp_pose:
        print("=> using pre-trained weights for explainabilty and pose net")
        weights = torch.load(args.pretrained_exp_pose)
        pose_exp_net.load_state_dict(weights['state_dict'], strict=False)
    else:
        pose_exp_net.init_weights()

    if args.pretrained_disp:
        print("=> using pre-trained weights for Dispnet")
        weights = torch.load(args.pretrained_disp)
        disp_net.load_state_dict(weights['state_dict'])
    else:
        disp_net.init_weights()

    cudnn.benchmark = True
    disp_net = torch.nn.DataParallel(disp_net)
    pose_exp_net = torch.nn.DataParallel(pose_exp_net)

    print('=> setting adam solver')

    parameters = chain(disp_net.parameters(), pose_exp_net.parameters())
    optimizer = torch.optim.Adam(parameters,
                                 args.lr,
                                 betas=(args.momentum, args.beta),
                                 weight_decay=args.weight_decay)

    with open(args.save_path / args.log_summary, 'w') as csvfile:
        writer = csv.writer(csvfile, delimiter='\t')
        writer.writerow(['train_loss', 'validation_loss'])

    with open(args.save_path / args.log_full, 'w') as csvfile:
        writer = csv.writer(csvfile, delimiter='\t')
        writer.writerow(
            ['train_loss', 'photo_loss', 'explainability_loss', 'smooth_loss'])

    logger = TermLogger(n_epochs=args.epochs,
                        train_size=min(len(train_loader), args.epoch_size),
                        valid_size=len(val_loader))
    logger.epoch_bar.start()

    for epoch in range(args.epochs):
        logger.epoch_bar.update(epoch)

        # train for one epoch
        logger.reset_train_bar()
        train_loss = train(train_loader, disp_net, pose_exp_net, optimizer,
                           args.epoch_size, logger, train_writer)
        logger.train_writer.write(' * Avg Loss : {:.3f}'.format(train_loss))

        # evaluate on validation set
        logger.reset_valid_bar()
        valid_photo_loss, valid_exp_loss, valid_total_loss = validate(
            val_loader, disp_net, pose_exp_net, epoch, logger, output_writers)
        logger.valid_writer.write(
            ' * Avg Photo Loss : {:.3f}, Valid Loss : {:.3f}, Total Loss : {:.3f}'
            .format(valid_photo_loss, valid_exp_loss, valid_total_loss))
        valid_writer.add_scalar(
            'photometric_error', valid_photo_loss * 4, n_iter
        )  # Loss is multiplied by 4 because it's only one scale, instead of 4 during training
        valid_writer.add_scalar('explanability_loss', valid_exp_loss * 4,
                                n_iter)
        valid_writer.add_scalar('total_loss', valid_total_loss * 4, n_iter)

        if best_photo_loss < 0:
            best_photo_loss = valid_photo_loss

        # remember lowest error and save checkpoint
        is_best = valid_photo_loss < best_photo_loss
        best_photo_loss = min(valid_photo_loss, best_photo_loss)
        save_checkpoint(args.save_path, {
            'epoch': epoch + 1,
            'state_dict': disp_net.module.state_dict()
        }, {
            'epoch': epoch + 1,
            'state_dict': pose_exp_net.module.state_dict()
        }, is_best)

        with open(args.save_path / args.log_summary, 'a') as csvfile:
            writer = csv.writer(csvfile, delimiter='\t')
            writer.writerow([train_loss, valid_total_loss])
    logger.epoch_bar.finish()
예제 #5
0
def main():
    best_error = -1
    n_iter = 0
    torch_device = torch.device(
        "cuda") if torch.cuda.is_available() else torch.device("cpu")

    # parse training arguments
    args = parse_training_args()
    args.training_output_freq = 100  # resetting the training output frequency here.

    # create a folder to save the output of training
    save_path = make_save_path(args)
    args.save_path = save_path
    # save the current configuration to a pickel file
    dump_config(save_path, args)

    print('=> Saving checkpoints to {}'.format(save_path))
    # set manual seed. WHY??
    torch.manual_seed(args.seed)
    # tensorboard summary
    tb_writer = SummaryWriter(save_path)

    # Data preprocessing
    train_transform = valid_transform = custom_transforms.Compose([
        custom_transforms.ArrayToTensor(),
        custom_transforms.Normalize(mean=0.5, std=0.5)
    ])

    # Load datasets
    print("=> Fetching scenes in '{}'".format(args.data))

    train_set = val_set = None
    if args.lfformat is 'focalstack':
        train_set, val_set = get_focal_stack_loaders(args, train_transform,
                                                     valid_transform)
    elif args.lfformat is 'stack':
        train_set, val_set = get_stacked_lf_loaders(args, train_transform,
                                                    valid_transform)

    print('=> {} samples found in {} train scenes'.format(
        len(train_set), len(train_set.scenes)))
    print('=> {} samples found in {} valid scenes'.format(
        len(val_set), len(val_set.scenes)))

    # Create batch loader
    train_loader = torch.utils.data.DataLoader(train_set,
                                               batch_size=args.batch_size,
                                               shuffle=True,
                                               num_workers=args.workers,
                                               pin_memory=True)
    val_loader = torch.utils.data.DataLoader(val_set,
                                             batch_size=args.batch_size,
                                             shuffle=False,
                                             num_workers=args.workers,
                                             pin_memory=True)

    # Pull first example from dataset to check number of channels
    input_channels = train_set[0][1].shape[0]
    args.epoch_size = len(train_loader)
    print("=> Using {} input channels, {} total batches".format(
        input_channels, args.epoch_size))

    # create model
    print("=> Creating models")
    disp_net = models.LFDispNet(in_channels=input_channels).to(torch_device)
    output_exp = args.mask_loss_weight > 0
    pose_exp_net = models.LFPoseNet(in_channels=input_channels,
                                    nb_ref_imgs=args.sequence_length -
                                    1).to(torch_device)

    # Load or initialize weights
    if args.pretrained_exp_pose:
        print("=> Using pre-trained weights for explainabilty and pose net")
        weights = torch.load(args.pretrained_exp_pose)
        pose_exp_net.load_state_dict(weights['state_dict'], strict=False)
    else:
        pose_exp_net.init_weights()

    if args.pretrained_disp:
        print("=> Using pre-trained weights for Dispnet")
        weights = torch.load(args.pretrained_disp)
        disp_net.load_state_dict(weights['state_dict'])
    else:
        disp_net.init_weights()

    # Set some torch flags
    cudnn.benchmark = True
    disp_net = torch.nn.DataParallel(disp_net)
    pose_exp_net = torch.nn.DataParallel(pose_exp_net)

    # Define optimizer
    print('=> Setting adam solver')
    optim_params = [{
        'params': disp_net.parameters(),
        'lr': args.lr
    }, {
        'params': pose_exp_net.parameters(),
        'lr': args.lr
    }]
    optimizer = torch.optim.Adam(optim_params,
                                 betas=(args.momentum, args.beta),
                                 weight_decay=args.weight_decay)

    # Logging
    with open(os.path.join(save_path, args.log_summary), 'w') as csvfile:
        writer = csv.writer(csvfile, delimiter='\t')
        writer.writerow(['train_loss', 'validation_loss'])

    with open(os.path.join(save_path, args.log_full), 'w') as csvfile:
        writer = csv.writer(csvfile, delimiter='\t')
        writer.writerow(
            ['train_loss', 'photo_loss', 'explainability_loss', 'smooth_loss'])

    logger = TermLogger(n_epochs=args.epochs,
                        train_size=min(len(train_loader), args.epoch_size),
                        valid_size=len(val_loader))
    logger.epoch_bar.start()

    # train the network
    for epoch in range(args.epochs):
        logger.epoch_bar.update(epoch)

        # train for one epoch
        logger.reset_train_bar()
        train_loss = train(args, train_loader, disp_net, pose_exp_net,
                           optimizer, args.epoch_size, logger, tb_writer,
                           n_iter, torch_device)
        logger.train_writer.write(' * Avg Loss : {:.3f}'.format(train_loss))

        # evaluate on validation set
        logger.reset_valid_bar()
        errors, error_names = validate_without_gt(args, val_loader, disp_net,
                                                  pose_exp_net, epoch, logger,
                                                  tb_writer, torch_device)
        error_string = ', '.join('{} : {:.3f}'.format(name, error)
                                 for name, error in zip(error_names, errors))
        logger.valid_writer.write(' * Avg {}'.format(error_string))

        # tensorboard logging
        for error, name in zip(errors, error_names):
            tb_writer.add_scalar(name, error, epoch)

        # Up to you to chose the most relevant error to measure your model's performance,
        # careful some measures are to maximize (such as a1,a2,a3)
        decisive_error = errors[1]
        if best_error < 0:
            best_error = decisive_error

        # remember lowest error and save checkpoint
        is_best = decisive_error < best_error
        best_error = min(best_error, decisive_error)
        save_checkpoint(save_path, {
            'epoch': epoch + 1,
            'state_dict': disp_net.module.state_dict()
        }, {
            'epoch': epoch + 1,
            'state_dict': pose_exp_net.module.state_dict()
        }, is_best)

        with open(os.path.join(save_path, args.log_summary), 'a') as csvfile:
            writer = csv.writer(csvfile, delimiter='\t')
            writer.writerow([train_loss, decisive_error])
    logger.epoch_bar.finish()
예제 #6
0
def main():
    global best_error, n_iter, device
    args = parser.parse_args()

    timestamp = datetime.datetime.now().strftime("%m-%d-%H:%M")
    save_path = Path(args.name)
    args.save_path = 'checkpoints' / save_path / timestamp
    print('=> will save everything to {}'.format(args.save_path))
    args.save_path.makedirs_p()

    torch.manual_seed(args.seed)
    np.random.seed(args.seed)
    cudnn.deterministic = True
    cudnn.benchmark = True

    training_writer = SummaryWriter(args.save_path)
    output_writers = []
    if args.log_output:
        for i in range(3):
            output_writers.append(
                SummaryWriter(args.save_path / 'valid' / str(i)))

    # Data loading code
    normalize = custom_transforms.Normalize(mean=[0.45, 0.45, 0.45],
                                            std=[0.225, 0.225, 0.225])

    train_transform = custom_transforms.Compose([
        custom_transforms.RandomHorizontalFlip(),
        custom_transforms.RandomScaleCrop(),
        custom_transforms.ArrayToTensor(), normalize
    ])

    valid_transform = custom_transforms.Compose(
        [custom_transforms.ArrayToTensor(), normalize])

    print("=> fetching scenes in '{}'".format(args.data))
    train_set = SequenceFolder(args.data,
                               transform=train_transform,
                               seed=args.seed,
                               train=True,
                               sequence_length=args.sequence_length)

    # if no Groundtruth is avalaible, Validation set is the same type as training set to measure photometric loss from warping
    if args.with_gt:
        from datasets.validation_folders import ValidationSet
        val_set = ValidationSet(args.data, transform=valid_transform)
    else:
        val_set = SequenceFolder(args.data,
                                 transform=valid_transform,
                                 seed=args.seed,
                                 train=False,
                                 sequence_length=args.sequence_length)
    print('{} samples found in {} train scenes'.format(len(train_set),
                                                       len(train_set.scenes)))
    print('{} samples found in {} valid scenes'.format(len(val_set),
                                                       len(val_set.scenes)))
    train_loader = torch.utils.data.DataLoader(train_set,
                                               batch_size=args.batch_size,
                                               shuffle=True,
                                               num_workers=args.workers,
                                               pin_memory=True)
    val_loader = torch.utils.data.DataLoader(val_set,
                                             batch_size=args.batch_size,
                                             shuffle=False,
                                             num_workers=args.workers,
                                             pin_memory=True)

    if args.epoch_size == 0:
        args.epoch_size = len(train_loader)

    # create model
    print("=> creating model")
    disp_net = models.DispResNet(args.resnet_layers,
                                 args.with_pretrain).to(device)
    pose_net = models.PoseResNet(18, args.with_pretrain).to(device)

    # load parameters
    if args.pretrained_disp:
        print("=> using pre-trained weights for DispResNet")
        weights = torch.load(args.pretrained_disp)
        disp_net.load_state_dict(weights['state_dict'], strict=False)

    if args.pretrained_pose:
        print("=> using pre-trained weights for PoseResNet")
        weights = torch.load(args.pretrained_pose)
        pose_net.load_state_dict(weights['state_dict'], strict=False)

    disp_net = torch.nn.DataParallel(disp_net)
    pose_net = torch.nn.DataParallel(pose_net)

    print('=> setting adam solver')
    optim_params = [{
        'params': disp_net.parameters(),
        'lr': args.lr
    }, {
        'params': pose_net.parameters(),
        'lr': args.lr
    }]
    optimizer = torch.optim.Adam(optim_params,
                                 betas=(args.momentum, args.beta),
                                 weight_decay=args.weight_decay)

    with open(args.save_path / args.log_summary, 'w') as csvfile:
        writer = csv.writer(csvfile, delimiter='\t')
        writer.writerow(['train_loss', 'validation_loss'])

    with open(args.save_path / args.log_full, 'w') as csvfile:
        writer = csv.writer(csvfile, delimiter='\t')
        writer.writerow([
            'train_loss', 'photo_loss', 'smooth_loss',
            'geometry_consistency_loss'
        ])

    logger = TermLogger(n_epochs=args.epochs,
                        train_size=min(len(train_loader), args.epoch_size),
                        valid_size=len(val_loader))
    logger.epoch_bar.start()

    for epoch in range(args.epochs):
        logger.epoch_bar.update(epoch)

        # train for one epoch
        logger.reset_train_bar()
        train_loss = train(args, train_loader, disp_net, pose_net, optimizer,
                           args.epoch_size, logger, training_writer)
        logger.train_writer.write(' * Avg Loss : {:.3f}'.format(train_loss))

        # evaluate on validation set
        logger.reset_valid_bar()
        if args.with_gt:
            errors, error_names = validate_with_gt(args, val_loader, disp_net,
                                                   epoch, logger,
                                                   output_writers)
        else:
            errors, error_names = validate_without_gt(args, val_loader,
                                                      disp_net, pose_net,
                                                      epoch, logger,
                                                      output_writers)
        error_string = ', '.join('{} : {:.3f}'.format(name, error)
                                 for name, error in zip(error_names, errors))
        logger.valid_writer.write(' * Avg {}'.format(error_string))

        for error, name in zip(errors, error_names):
            training_writer.add_scalar(name, error, epoch)

        # Up to you to chose the most relevant error to measure your model's performance, careful some measures are to maximize (such as a1,a2,a3)
        decisive_error = errors[1]
        if best_error < 0:
            best_error = decisive_error

        # remember lowest error and save checkpoint
        is_best = decisive_error < best_error
        best_error = min(best_error, decisive_error)
        save_checkpoint(args.save_path, {
            'epoch': epoch + 1,
            'state_dict': disp_net.module.state_dict()
        }, {
            'epoch': epoch + 1,
            'state_dict': pose_net.module.state_dict()
        }, is_best)

        with open(args.save_path / args.log_summary, 'a') as csvfile:
            writer = csv.writer(csvfile, delimiter='\t')
            writer.writerow([train_loss, decisive_error])
    logger.epoch_bar.finish()
def main():
    global args, best_error, n_iter, device
    args = parser.parse_args()
    save_path = save_path_formatter(args, parser)
    args.save_path = 'checkpoints_shifted' / save_path
    print('=> will save everything to {}'.format(args.save_path))
    args.save_path.makedirs_p()
    torch.manual_seed(args.seed)

    training_writer = SummaryWriter(args.save_path)
    output_writers = []
    if args.log_output:
        for i in range(3):
            output_writers.append(
                SummaryWriter(args.save_path / 'valid' / str(i)))

    # Data loading code
    normalize = custom_transforms.Normalize(mean=[0.5, 0.5, 0.5],
                                            std=[0.5, 0.5, 0.5])
    train_transform = custom_transforms.Compose([
        custom_transforms.RandomHorizontalFlip(),
        custom_transforms.RandomScaleCrop(),
        custom_transforms.ArrayToTensor(), normalize
    ])

    valid_transform = custom_transforms.Compose(
        [custom_transforms.ArrayToTensor(), normalize])

    print("=> fetching scenes in '{}'".format(args.data))
    train_set = ShiftedSequenceFolder(
        args.data,
        transform=train_transform,
        seed=args.seed,
        train=True,
        sequence_length=args.sequence_length,
        target_displacement=args.target_displacement)

    # if no Groundtruth is avalaible, Validation set is the same type as training set to measure photometric loss from warping
    if args.with_gt:
        from datasets.validation_folders import ValidationSet
        val_set = ValidationSet(args.data, transform=valid_transform)
    else:
        val_set = SequenceFolder(
            args.data,
            transform=valid_transform,
            seed=args.seed,
            train=False,
            sequence_length=args.sequence_length,
        )
    print('{} samples found in {} train scenes'.format(len(train_set),
                                                       len(train_set.scenes)))
    print('{} samples found in {} valid scenes'.format(len(val_set),
                                                       len(val_set.scenes)))
    train_loader = torch.utils.data.DataLoader(train_set,
                                               batch_size=args.batch_size,
                                               shuffle=True,
                                               num_workers=args.workers,
                                               pin_memory=True)
    adjust_loader = torch.utils.data.DataLoader(
        train_set,
        batch_size=args.batch_size,
        shuffle=False,
        num_workers=0,
        pin_memory=True
    )  # workers is set to 0 to avoid multiple instances to be modified at the same time
    val_loader = torch.utils.data.DataLoader(val_set,
                                             batch_size=args.batch_size,
                                             shuffle=False,
                                             num_workers=args.workers,
                                             pin_memory=True)

    if args.epoch_size == 0:
        args.epoch_size = len(train_loader)

    train.args = args
    # create model
    print("=> creating model")

    disp_net = models.DispNetS().cuda()
    output_exp = args.mask_loss_weight > 0
    if not output_exp:
        print("=> no mask loss, PoseExpnet will only output pose")
    pose_exp_net = models.PoseExpNet(
        nb_ref_imgs=args.sequence_length - 1,
        output_exp=args.mask_loss_weight > 0).to(device)

    if args.pretrained_exp_pose:
        print("=> using pre-trained weights for explainabilty and pose net")
        weights = torch.load(args.pretrained_exp_pose)
        pose_exp_net.load_state_dict(weights['state_dict'], strict=False)
    else:
        pose_exp_net.init_weights()

    if args.pretrained_disp:
        print("=> using pre-trained weights for Dispnet")
        weights = torch.load(args.pretrained_disp)
        disp_net.load_state_dict(weights['state_dict'])
    else:
        disp_net.init_weights()

    cudnn.benchmark = True
    disp_net = torch.nn.DataParallel(disp_net)
    pose_exp_net = torch.nn.DataParallel(pose_exp_net)

    print('=> setting adam solver')

    parameters = chain(disp_net.parameters(), pose_exp_net.parameters())
    optimizer = torch.optim.Adam(parameters,
                                 args.lr,
                                 betas=(args.momentum, args.beta),
                                 weight_decay=args.weight_decay)

    with open(args.save_path / args.log_summary, 'w') as csvfile:
        writer = csv.writer(csvfile, delimiter='\t')
        writer.writerow(['train_loss', 'validation_loss'])

    with open(args.save_path / args.log_full, 'w') as csvfile:
        writer = csv.writer(csvfile, delimiter='\t')
        writer.writerow(
            ['train_loss', 'photo_loss', 'explainability_loss', 'smooth_loss'])

    logger = TermLogger(n_epochs=args.epochs,
                        train_size=min(len(train_loader), args.epoch_size),
                        valid_size=len(val_loader))
    logger.epoch_bar.start()

    for epoch in range(args.epochs):
        logger.epoch_bar.update(epoch)

        # train for one epoch
        logger.reset_train_bar()
        train_loss = train(args, train_loader, disp_net, pose_exp_net,
                           optimizer, args.epoch_size, logger, training_writer)
        logger.train_writer.write(' * Avg Loss : {:.3f}'.format(train_loss))

        if (epoch + 1) % 5 == 0:
            train_set.adjust = True
            logger.reset_train_bar(len(adjust_loader))
            average_shifts = adjust_shifts(args, train_set, adjust_loader,
                                           pose_exp_net, epoch, logger,
                                           training_writer)
            shifts_string = ' '.join(
                ['{:.3f}'.format(s) for s in average_shifts])
            logger.train_writer.write(
                ' * adjusted shifts, average shifts are now : {}'.format(
                    shifts_string))
            for i, shift in enumerate(average_shifts):
                training_writer.add_scalar('shifts{}'.format(i), shift, epoch)
            train_set.adjust = False

        # evaluate on validation set
        logger.reset_valid_bar()
        if args.with_gt:
            errors, error_names = validate_with_gt(args, val_loader, disp_net,
                                                   epoch, logger,
                                                   output_writers)
        else:
            errors, error_names = validate_without_gt(args, val_loader,
                                                      disp_net, pose_exp_net,
                                                      epoch, logger,
                                                      output_writers)
        error_string = ', '.join('{} : {:.3f}'.format(name, error)
                                 for name, error in zip(error_names, errors))
        logger.valid_writer.write(' * Avg {}'.format(error_string))

        for error, name in zip(errors, error_names):
            training_writer.add_scalar(name, error, epoch)

        # Up to you to chose the most relevant error to measure your model's performance, careful some measures are to maximize (such as a1,a2,a3)
        decisive_error = errors[0]
        if best_error < 0:
            best_error = decisive_error

        # remember lowest error and save checkpoint
        is_best = decisive_error < best_error
        best_error = min(best_error, decisive_error)
        save_checkpoint(args.save_path, {
            'epoch': epoch + 1,
            'state_dict': disp_net.module.state_dict()
        }, {
            'epoch': epoch + 1,
            'state_dict': pose_exp_net.module.state_dict()
        }, is_best)

        with open(args.save_path / args.log_summary, 'a') as csvfile:
            writer = csv.writer(csvfile, delimiter='\t')
            writer.writerow([train_loss, decisive_error])
    logger.epoch_bar.finish()
예제 #8
0
def prepare_environment():
    env = {}
    args = parser.parse_args()
    if args.dataset_format == 'KITTI':
        from datasets.shifted_sequence_folders import ShiftedSequenceFolder
    elif args.dataset_format == 'StillBox':
        from datasets.shifted_sequence_folders import StillBox as ShiftedSequenceFolder
    elif args.dataset_format == 'TUM':
        from datasets.shifted_sequence_folders import TUM as ShiftedSequenceFolder
    save_path = save_path_formatter(args, parser)
    args.save_path = 'checkpoints' / save_path
    print('=> will save everything to {}'.format(args.save_path))
    args.save_path.makedirs_p()
    torch.manual_seed(args.seed)

    args.test_batch_size = 4 * args.batch_size
    if args.evaluate:
        args.epochs = 0

    env['training_writer'] = SummaryWriter(args.save_path)
    output_writers = []
    if args.log_output:
        for i in range(3):
            output_writers.append(
                SummaryWriter(args.save_path / 'valid' / str(i)))
    env['output_writers'] = output_writers

    # Data loading code
    normalize = custom_transforms.Normalize(mean=[0.5, 0.5, 0.5],
                                            std=[0.5, 0.5, 0.5])
    train_transform = custom_transforms.Compose([
        # custom_transforms.RandomHorizontalFlip(),
        custom_transforms.ArrayToTensor(),
        normalize
    ])

    valid_transform = custom_transforms.Compose(
        [custom_transforms.ArrayToTensor(), normalize])

    print("=> fetching scenes in '{}'".format(args.data))
    train_set = ShiftedSequenceFolder(args.data,
                                      transform=train_transform,
                                      seed=args.seed,
                                      train=True,
                                      with_depth_gt=False,
                                      with_pose_gt=args.supervise_pose,
                                      sequence_length=args.sequence_length)
    val_set = ShiftedSequenceFolder(args.data,
                                    transform=valid_transform,
                                    seed=args.seed,
                                    train=False,
                                    sequence_length=args.sequence_length,
                                    with_depth_gt=args.with_gt,
                                    with_pose_gt=args.with_gt)
    print('{} samples found in {} train scenes'.format(len(train_set),
                                                       len(train_set.scenes)))
    print('{} samples found in {} valid scenes'.format(len(val_set),
                                                       len(val_set.scenes)))
    train_loader = torch.utils.data.DataLoader(train_set,
                                               batch_size=args.batch_size,
                                               shuffle=True,
                                               num_workers=args.workers,
                                               pin_memory=True)
    val_loader = torch.utils.data.DataLoader(val_set,
                                             batch_size=4 * args.batch_size,
                                             shuffle=False,
                                             num_workers=args.workers,
                                             pin_memory=True)

    env['train_set'] = train_set
    env['val_set'] = val_set
    env['train_loader'] = train_loader
    env['val_loader'] = val_loader

    if args.epoch_size == 0:
        args.epoch_size = len(train_loader)

    # create model
    print("=> creating model")
    pose_net = models.PoseNet(seq_length=args.sequence_length,
                              batch_norm=args.bn in ['pose',
                                                     'both']).to(device)

    if args.pretrained_pose:
        print("=> using pre-trained weights for pose net")
        weights = torch.load(args.pretrained_pose)
        pose_net.load_state_dict(weights['state_dict'], strict=False)

    depth_net = models.DepthNet(depth_activation="elu",
                                batch_norm=args.bn in ['depth',
                                                       'both']).to(device)

    if args.pretrained_depth:
        print("=> using pre-trained DepthNet model")
        data = torch.load(args.pretrained_depth)
        depth_net.load_state_dict(data['state_dict'])

    cudnn.benchmark = True
    depth_net = torch.nn.DataParallel(depth_net)
    pose_net = torch.nn.DataParallel(pose_net)

    env['depth_net'] = depth_net
    env['pose_net'] = pose_net

    print('=> setting adam solver')

    optim_params = [{
        'params': depth_net.parameters(),
        'lr': args.lr
    }, {
        'params': pose_net.parameters(),
        'lr': args.lr
    }]
    # parameters = chain(depth_net.parameters(), pose_exp_net.parameters())
    optimizer = torch.optim.Adam(optim_params,
                                 betas=(args.momentum, args.beta),
                                 weight_decay=args.weight_decay)
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer,
                                                args.lr_decay_frequency,
                                                gamma=0.5)
    env['optimizer'] = optimizer
    env['scheduler'] = scheduler

    with open(args.save_path / args.log_summary, 'w') as csvfile:
        writer = csv.writer(csvfile, delimiter='\t')
        writer.writerow(['train_loss', 'validation_loss'])

    with open(args.save_path / args.log_full, 'w') as csvfile:
        writer = csv.writer(csvfile, delimiter='\t')
        writer.writerow(
            ['train_loss', 'photo_loss', 'explainability_loss', 'smooth_loss'])

    logger = TermLogger(n_epochs=args.epochs,
                        train_size=min(len(train_loader), args.epoch_size),
                        valid_size=len(val_loader))
    logger.epoch_bar.start()
    env['logger'] = logger

    env['args'] = args

    return env
예제 #9
0
def main():
    args = parser.parse_args()
    print("=> No Distributed Training")
    print('=> Index of using GPU: ', args.gpu_num)
    os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu_num
    assert torch.backends.cudnn.enabled, "Amp requires cudnn backend to be enabled."

    torch.manual_seed(args.seed)

    if args.evaluate is True:
        save_path = save_path_formatter(args, parser)
        args.save_path = 'checkpoints' / save_path
        print("=> information will be saved in {}".format(args.save_path))
        args.save_path.makedirs_p()
        training_writer = SummaryWriter(args.save_path)

    ######################   Data loading part    ##########################
    if args.dataset == 'KITTI':
        args.max_depth = 80.0
    elif args.dataset == 'NYU':
        args.max_depth = 10.0

    if args.result_dir == '':
        args.result_dir = './' + args.dataset + '_Eval_results'
    args.log_metric = args.dataset + '_' + args.encoder + args.log_metric

    test_set = MyDataset(args, train=False)
    print("=> Dataset: ", args.dataset)
    print("=> Data height: {}, width: {} ".format(args.height, args.width))
    print('=> test  samples_num: {}  '.format(len(test_set)))

    test_sampler = None

    val_loader = torch.utils.data.DataLoader(test_set,
                                             batch_size=args.batch_size,
                                             shuffle=False,
                                             num_workers=args.workers,
                                             pin_memory=True,
                                             sampler=test_sampler)

    cudnn.benchmark = True
    ###########################################################################

    ###################### setting model list #################################
    if args.multi_test is True:
        print("=> all of model tested")
        models_list_dir = Path(args.models_list_dir)
        models_list = sorted(models_list_dir.files('*.pkl'))
    else:
        print("=> just one model tested")
        models_list = [args.model_dir]

    ###################### setting Network part ###################
    print("=> creating model")
    Model = LDRN(args)

    num_params_encoder = 0
    num_params_decoder = 0
    for p in Model.encoder.parameters():
        num_params_encoder += p.numel()
    for p in Model.decoder.parameters():
        num_params_decoder += p.numel()
    print("===============================================")
    print("model encoder parameters: ", num_params_encoder)
    print("model decoder parameters: ", num_params_decoder)
    print("Total parameters: {}".format(num_params_encoder +
                                        num_params_decoder))
    print("===============================================")
    Model = Model.cuda()
    Model = torch.nn.DataParallel(Model)

    if args.evaluate is True:
        ############################ data log #######################################
        logger = TermLogger(n_epochs=args.epochs,
                            train_size=min(len(val_loader), args.epoch_size),
                            valid_size=len(val_loader))
        with open(args.save_path / args.log_metric, 'w') as csvfile:
            writer = csv.writer(csvfile, delimiter='\t')
            if args.dataset == 'KITTI':
                writer.writerow([
                    'Filename', 'Abs_diff', 'Abs_rel', 'Sq_rel', 'a1', 'a2',
                    'a3', 'RMSE', 'RMSE_log'
                ])
            elif args.dataset == 'Make3D':
                writer.writerow(
                    ['Filename', 'Abs_diff', 'Abs_rel', 'log10', 'rmse'])
            elif args.dataset == 'NYU':
                writer.writerow([
                    'Filename', 'Abs_diff', 'Abs_rel', 'log10', 'a1', 'a2',
                    'a3', 'RMSE', 'RMSE_log'
                ])
        ########################### Evaluating part #################################
        test_model = Model

        print("Model Initialized")

        test_len = len(models_list)
        print("=> Length of model list: ", test_len)

        for i in range(test_len):
            filename = models_list[i].split('/')[-1]
            logger.reset_valid_bar()
            test_model.load_state_dict(
                torch.load(models_list[i], map_location='cuda:0'))
            #test_model.load_state_dict(torch.load(models_list[i]))
            test_model.eval()
            if args.dataset == 'KITTI':
                errors, error_names = validate(args, val_loader, test_model,
                                               logger, 'KITTI')
            elif args.dataset == 'NYU':
                errors, error_names = validate(args, val_loader, test_model,
                                               logger, 'NYU')
            for error, name in zip(errors, error_names):
                training_writer.add_scalar(name, error, 0)
            logger.valid_writer.write(' * model: {}'.format(models_list[i]))
            print("")
            error_string = ', '.join(
                '{} : {:.3f}'.format(name, error) for name, error in zip(
                    error_names[0:len(error_names)], errors[0:len(errors)]))
            logger.valid_writer.write(' * Avg {}'.format(error_string))
            print("")
            logger.valid_bar.finish()
            with open(args.save_path / args.log_metric, 'a') as csvfile:
                writer = csv.writer(csvfile, delimiter='\t')
                writer.writerow(
                    ['%s' % filename] +
                    ['%.4f' % (errors[k]) for k in range(len(errors))])

        print(args.dataset, " valdiation finish")
        ##  Test

        if args.img_save is False:
            print("--only Test mode finish--")
            return
    else:
        test_model = Model
        test_model.load_state_dict(
            torch.load(models_list[0], map_location='cuda:0'))
        #test_model.load_state_dict(torch.load(models_list[0]))
        test_model.eval()
        print("=> No validation")

    test_set = MyDataset(args, train=False, return_filename=True)
    test_sampler = None
    val_loader = torch.utils.data.DataLoader(test_set,
                                             batch_size=1,
                                             shuffle=False,
                                             num_workers=args.workers,
                                             pin_memory=True,
                                             sampler=test_sampler)

    if args.img_save is True:
        cmap = plt.cm.jet
        print("=> img save start")
        for idx, (rgb_data, gt_data, gt_dense,
                  filename) in enumerate(val_loader):
            if gt_data.ndim != 4 and gt_data[0] == False:
                continue
            img_H = gt_data.shape[2]
            img_W = gt_data.shape[3]
            gt_data = Variable(gt_data.cuda())
            input_img = Variable(rgb_data.cuda())
            gt_data = gt_data.clamp(0, args.max_depth)
            if args.use_dense_depth is True:
                gt_dense = Variable(gt_dense.cuda())
                gt_dense = gt_dense.clamp(0, args.max_depth)

            input_img_flip = torch.flip(input_img, [3])
            with torch.no_grad():
                _, final_depth = test_model(input_img)
                _, final_depth_flip = test_model(input_img_flip)
            final_depth_flip = torch.flip(final_depth_flip, [3])
            final_depth = 0.5 * (final_depth + final_depth_flip)

            final_depth = final_depth.clamp(0, args.max_depth)
            d_min = min(final_depth.min(), gt_data.min())
            d_max = max(final_depth.max(), gt_data.max())

            d_min = d_min.cpu().detach().numpy().astype(np.float32)
            d_max = d_max.cpu().detach().numpy().astype(np.float32)

            filename = filename[0]
            img_arr = [
                final_depth, final_depth, final_depth, gt_data, rgb_data,
                gt_dense, gt_dense, gt_dense
            ]
            folder_name_list = [
                '/output_depth', '/output_depth_cmap_gray',
                '/output_depth_cmap_jet', '/ground_truth', '/input_rgb',
                '/dense_gt', '/dense_gt_cmap_gray', '/dense_gt_cmap_jet'
            ]
            img_name_list = [
                '/' + filename, '/cmap_gray_' + filename,
                '/cmap_jet_' + filename, '/gt_' + filename, '/rgb_' + filename,
                '/gt_dense_' + filename, '/gt_dense_cmap_gray_' + filename,
                '/gt_dense_cmap_jet_' + filename
            ]
            if args.use_dense_depth is False:
                img_arr = img_arr[:5]
                folder_name_list = folder_name_list[:5]
                img_name_list = img_name_list[:5]

            folder_iter = cycle(folder_name_list)
            img_name_iter = cycle(img_name_list)
            for img in img_arr:
                folder_name = next(folder_iter)
                img_name = next(img_name_iter)
                if folder_name == '/output_depth_cmap_gray' or folder_name == '/dense_gt_cmap_gray':
                    if args.dataset == 'NYU':
                        img = img * 1000.0
                        img = img.cpu().detach().numpy().astype(np.uint16)
                        img_org = img.copy()
                    else:
                        img = img * 256.0
                        img = img.cpu().detach().numpy().astype(np.uint16)
                        img_org = img.copy()
                elif folder_name == '/output_depth_cmap_jet' or folder_name == '/dense_gt_cmap_jet':
                    img_org = img
                else:
                    img = (img / img.max()) * 255.0
                    img_org = img.cpu().detach().numpy().astype(np.float32)
                result_dir = args.result_dir + folder_name
                for t in range(img_org.shape[0]):
                    img = img_org[t]
                    if folder_name == '/output_depth_cmap_jet' or folder_name == '/dense_gt_cmap_jet':
                        img_ = np.squeeze(img.cpu().numpy().astype(np.float32))
                        img_ = ((img_ - d_min) / (d_max - d_min))
                        img_ = cmap(img_)[:, :, :3] * 255
                    else:
                        if img.shape[0] == 3:
                            img_ = np.empty([img_H, img_W,
                                             3]).astype(img.dtype)
                            '''
                            img_[:,:,2] = img[0,:,:]
                            img_[:,:,1] = img[1,:,:]
                            img_[:,:,0] = img[2,:,:]        # for BGR
                            '''
                            img_ = img.transpose(1, 2, 0)  # for RGB
                        elif img.shape[0] == 1:
                            img_ = np.ones([img_H, img_W]).astype(img.dtype)
                            img_[:, :] = img[0, :, :]
                    if not os.path.exists(result_dir):
                        os.makedirs(result_dir)
                    if folder_name == '/output_depth_cmap_gray' or folder_name == '/dense_gt_cmap_gray':
                        plt.imsave(result_dir + img_name,
                                   np.log10(img_),
                                   cmap='Greys')
                    elif folder_name == '/output_depth_cmap_jet' or folder_name == '/dense_gt_cmap_jet':
                        img_ = Image.fromarray(img_.astype('uint8'))
                        img_.save(result_dir + img_name)
                    else:
                        imageio.imwrite(result_dir + img_name, img_)
            if (idx + 1) % 10 == 0:
                print(idx + 1, "th image is processed..")
        print("--Test image save finish--")
    return
예제 #10
0
def main():
    global args, best_error, n_iter, device
    args = parser.parse_args()
    from dataset_loader import SequenceFolder

    save_path = save_path_formatter(args, parser)
    args.save_path = 'checkpoints' / save_path
    print('=> will save everything to {}'.format(args.save_path))
    args.save_path.makedirs_p()
    torch.manual_seed(args.seed)

    training_writer = SummaryWriter(args.save_path)
    output_writers = []
    if args.log_output:
        for i in range(3):
            output_writers.append(
                SummaryWriter(args.save_path / 'valid' / str(i)))

    train_transform = custom_transforms.Compose([
        custom_transforms.ArrayToTensor(),
    ])

    valid_transform = custom_transforms.Compose([
        custom_transforms.ArrayToTensor(),
    ])

    print("=> fetching scenes in '{}'".format(args.data))
    train_set = SequenceFolder(args.data,
                               transform=train_transform,
                               seed=args.seed,
                               train=True,
                               sequence_length=args.sequence_length)

    val_set = SequenceFolder(
        args.data,
        transform=valid_transform,
        seed=args.seed,
        train=False,
        sequence_length=args.sequence_length,
    )
    print('{} samples found in {} train scenes'.format(len(train_set),
                                                       len(train_set.scenes)))
    print('{} samples found in {} valid scenes'.format(len(val_set),
                                                       len(val_set.scenes)))
    train_loader = torch.utils.data.DataLoader(train_set,
                                               batch_size=args.batch_size,
                                               shuffle=True,
                                               num_workers=args.workers,
                                               pin_memory=True)
    val_loader = torch.utils.data.DataLoader(val_set,
                                             batch_size=args.batch_size,
                                             shuffle=False,
                                             num_workers=args.workers,
                                             pin_memory=True)

    if args.epoch_size == 0:
        args.epoch_size = len(train_loader)

    encoder = enCoder()
    decoder = deCoder()
    gplayer = GPlayer()

    if args.pretrained_dict:
        print("=> using pre-trained weights")
        weights = torch.load(args.pretrained_dict)
        pretrained_dict = weights['state_dict']

        encoder_dict = encoder.state_dict()
        pretrained_dict_encoder = {
            k: v
            for k, v in pretrained_dict.items() if k in encoder_dict
        }
        encoder_dict.update(pretrained_dict_encoder)
        encoder.load_state_dict(pretrained_dict_encoder)

        decoder_dict = decoder.state_dict()
        pretrained_dict_decoder = {
            k: v
            for k, v in pretrained_dict.items() if k in decoder_dict
        }
        decoder_dict.update(pretrained_dict_decoder)
        decoder.load_state_dict(pretrained_dict_decoder)

    encoder = encoder.to(device)
    decoder = decoder.to(device)

    cudnn.benchmark = True
    encoder = torch.nn.DataParallel(encoder)
    decoder = torch.nn.DataParallel(decoder)

    parameters = chain(encoder.parameters(), gplayer.parameters(),
                       decoder.parameters())
    optimizer = torch.optim.Adam(parameters,
                                 args.lr,
                                 betas=(args.momentum, args.beta),
                                 weight_decay=args.weight_decay)

    logger = TermLogger(n_epochs=args.epochs,
                        train_size=min(len(train_loader), args.epoch_size),
                        valid_size=len(val_loader))

    logger.epoch_bar.start()

    for epoch in range(args.epochs):
        logger.epoch_bar.update(epoch)

        # train for one epoch
        logger.reset_train_bar()

        train_loss = train(train_loader, encoder, gplayer, decoder, optimizer,
                           args.epoch_size, logger, training_writer)
        logger.train_writer.write(' * Avg Loss : {:.3f}'.format(train_loss))

        # evaluate on validation set
        logger.reset_valid_bar()

        errors, error_names = validate(val_loader, encoder, gplayer, decoder,
                                       epoch, logger, output_writers)
        error_string = ', '.join('{} : {:.3f}'.format(name, error)
                                 for name, error in zip(error_names, errors))
        logger.valid_writer.write(' * Avg {}'.format(error_string))

        for error, name in zip(errors, error_names):
            training_writer.add_scalar(name, error, epoch)

        decisive_error = errors[-1]
        if best_error < 0:
            best_error = decisive_error

        # save best checkpoint
        is_best = decisive_error < best_error
        best_error = min(best_error, decisive_error)
        save_checkpoint(args.save_path, {
            'epoch': epoch + 1,
            'state_dict': encoder.state_dict()
        }, {
            'epoch': epoch + 1,
            'state_dict': gplayer.state_dict()
        }, {
            'epoch': epoch + 1,
            'state_dict': decoder.state_dict()
        }, is_best)

    logger.epoch_bar.finish()
예제 #11
0
파일: main2.py 프로젝트: cudnn/cc-1
def main():
    global global_vars_dict
    args = global_vars_dict['args']
    best_error = -1  #best model choosing

    #mkdir
    timestamp = datetime.datetime.now().strftime("%m-%d-%H:%M")

    args.save_path = Path('checkpoints') / Path(args.data_dir).stem / timestamp
    print('=> will save everything to {}'.format(args.save_path))
    args.save_path.makedirs_p()
    torch.manual_seed(args.seed)
    if args.alternating:
        args.alternating_flags = np.array([False, False, True])
    #mk writers
    tb_writer = SummaryWriter(args.save_path)

    # Data loading code
    flow_loader_h, flow_loader_w = 256, 832

    if args.data_normalization == 'global':
        normalize = custom_transforms.Normalize(mean=[0.5, 0.5, 0.5],
                                                std=[0.5, 0.5, 0.5])
    elif args.data_normalization == 'local':
        normalize = custom_transforms.NormalizeLocally()

    if args.fix_flownet:
        train_transform = custom_transforms.Compose([
            custom_transforms.RandomHorizontalFlip(),
            custom_transforms.RandomScaleCrop(),
            custom_transforms.ArrayToTensor(), normalize
        ])
    else:
        train_transform = custom_transforms.Compose([
            custom_transforms.RandomRotate(),
            custom_transforms.RandomHorizontalFlip(),
            custom_transforms.RandomScaleCrop(),
            custom_transforms.ArrayToTensor(), normalize
        ])

    valid_transform = custom_transforms.Compose(
        [custom_transforms.ArrayToTensor(), normalize])

    valid_flow_transform = custom_transforms.Compose([
        custom_transforms.Scale(h=flow_loader_h, w=flow_loader_w),
        custom_transforms.ArrayToTensor(), normalize
    ])

    print("=> fetching scenes in '{}'".format(args.data_dir))

    #train set, loader only建立一个
    if args.dataset_format == 'stacked':
        from datasets.stacked_sequence_folders import SequenceFolder
    elif args.dataset_format == 'sequential':
        from datasets.sequence_folders import SequenceFolder
        train_set = SequenceFolder(  #mc data folder
            args.data_dir,
            transform=train_transform,
            seed=args.seed,
            train=True,
            sequence_length=args.sequence_length,  #5
            target_transform=None)
    elif args.dataset_format == 'sequential_with_gt':  # with all possible gt
        from datasets.sequence_mc import SequenceFolder
        train_set = SequenceFolder(  # mc data folder
            args.data_dir,
            transform=train_transform,
            seed=args.seed,
            train=True,
            sequence_length=args.sequence_length,  # 5
            target_transform=None)
    else:
        return

    if args.DEBUG:
        train_set.__len__ = 32
        train_set.samples = train_set.samples[:32]
    train_loader = torch.utils.data.DataLoader(train_set,
                                               batch_size=args.batch_size,
                                               shuffle=True,
                                               num_workers=args.workers,
                                               pin_memory=True,
                                               drop_last=True)

    if args.epoch_size == 0:
        args.epoch_size = len(train_loader)

#val set,loader 挨个建立

# if no Groundtruth is avalaible, Validation set is the same type as training set to measure photometric loss from warping
    if args.val_without_gt:
        from datasets.sequence_folders2 import SequenceFolder  #就多了一级文件夹
        val_set_without_gt = SequenceFolder(  #只有图
            args.data_dir,
            transform=valid_transform,
            seed=None,
            train=False,
            sequence_length=args.sequence_length,
            target_transform=None)
        val_loader = torch.utils.data.DataLoader(val_set_without_gt,
                                                 batch_size=args.batch_size,
                                                 shuffle=False,
                                                 num_workers=args.workers,
                                                 pin_memory=True,
                                                 drop_last=True)

    if args.val_with_depth_gt:
        from datasets.validation_folders2 import ValidationSet

        val_set_with_depth_gt = ValidationSet(args.data_dir,
                                              transform=valid_transform)

        val_loader_depth = torch.utils.data.DataLoader(
            val_set_with_depth_gt,
            batch_size=args.batch_size,
            shuffle=False,
            num_workers=args.workers,
            pin_memory=True,
            drop_last=True)

    if args.val_with_flow_gt:  #暂时没有
        from datasets.validation_flow import ValidationFlow
        val_flow_set = ValidationFlow(root=args.kitti_dir,
                                      sequence_length=args.sequence_length,
                                      transform=valid_flow_transform)
        val_flow_loader = torch.utils.data.DataLoader(
            val_flow_set,
            batch_size=1,
            # batch size is 1 since images in kitti have different sizes
            shuffle=False,
            num_workers=args.workers,
            pin_memory=True,
            drop_last=True)

    print('{} samples found in {} train scenes'.format(len(train_set),
                                                       len(train_set.scenes)))
    if args.val_without_gt:
        print('{} samples found in {} valid scenes'.format(
            len(val_set_without_gt), len(val_set_without_gt.scenes)))

#1 create model
    print("=> creating model")
    #1.1 disp_net
    disp_net = getattr(models, args.dispnet)().cuda()
    output_exp = True  #args.mask_loss_weight > 0
    if not output_exp:
        print("=> no mask loss, PoseExpnet will only output pose")
    #1.2 pose_net
    pose_net = getattr(models, args.posenet)(nb_ref_imgs=args.sequence_length -
                                             1).cuda()

    #1.3.flow_net
    if args.flownet == 'SpyNet':
        flow_net = getattr(models,
                           args.flownet)(nlevels=args.nlevels,
                                         pre_normalization=normalize).cuda()
    elif args.flownet == 'FlowNetC6':  #flonwtc6
        flow_net = getattr(models, args.flownet)(nlevels=args.nlevels).cuda()
    elif args.flownet == 'FlowNetS':
        flow_net = getattr(models, args.flownet)(nlevels=args.nlevels).cuda()
    elif args.flownet == 'Back2Future':
        flow_net = getattr(models, args.flownet)(nlevels=args.nlevels).cuda()

    # 1.4 mask_net
    mask_net = getattr(models,
                       args.masknet)(nb_ref_imgs=args.sequence_length - 1,
                                     output_exp=True).cuda()

    #2 载入参数
    #2.1 pose
    if args.pretrained_pose:
        print("=> using pre-trained weights for explainabilty and pose net")
        weights = torch.load(args.pretrained_pose)
        pose_net.load_state_dict(weights['state_dict'])
    else:
        pose_net.init_weights()

    if args.pretrained_mask:
        print("=> using pre-trained weights for explainabilty and pose net")
        weights = torch.load(args.pretrained_mask)
        mask_net.load_state_dict(weights['state_dict'])
    else:
        mask_net.init_weights()

    # import ipdb; ipdb.set_trace()
    if args.pretrained_disp:
        print("=> using pre-trained weights from {}".format(
            args.pretrained_disp))
        weights = torch.load(args.pretrained_disp)
        disp_net.load_state_dict(weights['state_dict'])
    else:
        disp_net.init_weights()

    if args.pretrained_flow:
        print("=> using pre-trained weights for FlowNet")
        weights = torch.load(args.pretrained_flow)
        flow_net.load_state_dict(weights['state_dict'])
    else:
        flow_net.init_weights()

    if args.resume:
        print("=> resuming from checkpoint")
        dispnet_weights = torch.load(args.save_path /
                                     'dispnet_checkpoint.pth.tar')
        posenet_weights = torch.load(args.save_path /
                                     'posenet_checkpoint.pth.tar')
        masknet_weights = torch.load(args.save_path /
                                     'masknet_checkpoint.pth.tar')
        flownet_weights = torch.load(args.save_path /
                                     'flownet_checkpoint.pth.tar')
        disp_net.load_state_dict(dispnet_weights['state_dict'])
        pose_net.load_state_dict(posenet_weights['state_dict'])
        flow_net.load_state_dict(flownet_weights['state_dict'])
        mask_net.load_state_dict(masknet_weights['state_dict'])

    # import ipdb; ipdb.set_trace()
    cudnn.benchmark = True
    disp_net = torch.nn.DataParallel(disp_net)
    pose_net = torch.nn.DataParallel(pose_net)
    mask_net = torch.nn.DataParallel(mask_net)
    flow_net = torch.nn.DataParallel(flow_net)

    print('=> setting adam solver')

    parameters = chain(disp_net.parameters(), pose_net.parameters(),
                       mask_net.parameters(), flow_net.parameters())
    optimizer = torch.optim.Adam(parameters,
                                 args.lr,
                                 betas=(args.momentum, args.beta),
                                 weight_decay=args.weight_decay)

    if args.resume and (args.save_path /
                        'optimizer_checkpoint.pth.tar').exists():
        print("=> loading optimizer from checkpoint")
        optimizer_weights = torch.load(args.save_path /
                                       'optimizer_checkpoint.pth.tar')
        optimizer.load_state_dict(optimizer_weights['state_dict'])

    with open(args.save_path / args.log_summary, 'w') as csvfile:
        writer = csv.writer(csvfile, delimiter='\t')
        writer.writerow(['train_loss', 'validation_loss'])

    with open(args.save_path / args.log_full, 'w') as csvfile:
        writer = csv.writer(csvfile, delimiter='\t')
        writer.writerow([
            'train_loss', 'photo_cam_loss', 'photo_flow_loss',
            'explainability_loss', 'smooth_loss'
        ])

    #
    if args.log_terminal:
        logger = TermLogger(n_epochs=args.epochs,
                            train_size=min(len(train_loader), args.epoch_size),
                            valid_size=len(val_loader_depth))
        logger.epoch_bar.start()
    else:
        logger = None

#预先评估下

    if args.pretrained_disp or args.evaluate:
        logger.reset_valid_bar()
        if args.val_without_gt:
            pass
            #val_loss = validate_without_gt(val_loader,disp_net,pose_net,mask_net,flow_net,epoch=0, logger=logger, tb_writer=tb_writer,nb_writers=3,global_vars_dict = global_vars_dict)
            #val_loss =0

        if args.val_with_depth_gt:
            pass
            depth_errors, depth_error_names = validate_depth_with_gt(
                val_loader_depth,
                disp_net,
                epoch=0,
                logger=logger,
                tb_writer=tb_writer,
                global_vars_dict=global_vars_dict)


#3. main cycle
    for epoch in range(1, args.epochs):  #epoch 0 在第没入循环之前已经测试了.
        #3.1 四个子网络,训练哪几个
        if args.fix_flownet:
            for fparams in flow_net.parameters():
                fparams.requires_grad = False

        if args.fix_masknet:
            for fparams in mask_net.parameters():
                fparams.requires_grad = False

        if args.fix_posenet:
            for fparams in pose_net.parameters():
                fparams.requires_grad = False

        if args.fix_dispnet:
            for fparams in disp_net.parameters():
                fparams.requires_grad = False

        if args.log_terminal:
            logger.epoch_bar.update(epoch)
            logger.reset_train_bar()
        #validation data
        flow_error_names = ['no']
        flow_errors = [0]
        errors = [0]
        error_names = ['no error names depth']
        print('\nepoch [{}/{}]\n'.format(epoch + 1, args.epochs))
        #3.2 train for one epoch---------
        #train_loss=0
        train_loss = train_gt(train_loader, disp_net, pose_net, mask_net,
                              flow_net, optimizer, logger, tb_writer,
                              global_vars_dict)

        #3.3 evaluate on validation set-----

        if args.val_without_gt:
            val_loss = validate_without_gt(val_loader,
                                           disp_net,
                                           pose_net,
                                           mask_net,
                                           flow_net,
                                           epoch=0,
                                           logger=logger,
                                           tb_writer=tb_writer,
                                           nb_writers=3,
                                           global_vars_dict=global_vars_dict)

        if args.val_with_depth_gt:
            depth_errors, depth_error_names = validate_depth_with_gt(
                val_loader_depth,
                disp_net,
                epoch=epoch,
                logger=logger,
                tb_writer=tb_writer,
                global_vars_dict=global_vars_dict)

        if args.val_with_flow_gt:
            pass
            #flow_errors, flow_error_names = validate_flow_with_gt(val_flow_loader, disp_net, pose_net, mask_net, flow_net, epoch, logger, tb_writer)

            #for error, name in zip(flow_errors, flow_error_names):
            #    training_writer.add_scalar(name, error, epoch)

        #----------------------

        #3.4 Up to you to chose the most relevant error to measure your model's performance, careful some measures are to maximize (such as a1,a2,a3)

        if not args.fix_posenet:
            decisive_error = 0  # flow_errors[-2]    # epe_rigid_with_gt_mask
        elif not args.fix_dispnet:
            decisive_error = 0  # errors[0]      #depth abs_diff
        elif not args.fix_flownet:
            decisive_error = 0  # flow_errors[-1]    #epe_non_rigid_with_gt_mask
        elif not args.fix_masknet:
            decisive_error = 0  #flow_errors[3]     # percent outliers

        #3.5 log
        if args.log_terminal:
            logger.train_writer.write(
                ' * Avg Loss : {:.3f}'.format(train_loss))
            logger.reset_valid_bar()
        #eopch data log on tensorboard
        #train loss
        tb_writer.add_scalar('epoch/train_loss', train_loss, epoch)
        #val_without_gt loss
        if args.val_without_gt:
            tb_writer.add_scalar('epoch/val_loss', val_loss, epoch)

        if args.val_with_depth_gt:
            #val with depth gt
            for error, name in zip(depth_errors, depth_error_names):
                tb_writer.add_scalar('epoch/' + name, error, epoch)

        #3.6 save model and remember lowest error and save checkpoint

        if best_error < 0:
            best_error = train_loss

        is_best = train_loss <= best_error
        best_error = min(best_error, train_loss)
        save_checkpoint(args.save_path, {
            'epoch': epoch + 1,
            'state_dict': disp_net.module.state_dict()
        }, {
            'epoch': epoch + 1,
            'state_dict': pose_net.module.state_dict()
        }, {
            'epoch': epoch + 1,
            'state_dict': mask_net.module.state_dict()
        }, {
            'epoch': epoch + 1,
            'state_dict': flow_net.module.state_dict()
        }, is_best)

        with open(args.save_path / args.log_summary, 'a') as csvfile:
            writer = csv.writer(csvfile, delimiter='\t')
            writer.writerow([train_loss, decisive_error])
    if args.log_terminal:
        logger.epoch_bar.finish()
예제 #12
0
def main():
    global args, best_error, n_iter
    args = parser.parse_args()
    if args.dataset_format == 'stacked':
        from datasets.stacked_sequence_folders import SequenceFolder
    elif args.dataset_format == 'sequential':
        from datasets.sequence_folders import SequenceFolder
    save_path = Path(args.name)
    args.save_path = 'checkpoints'/save_path #/timestamp
    print('=> will save everything to {}'.format(args.save_path))
    args.save_path.makedirs_p()
    torch.manual_seed(args.seed)
    if args.alternating:
        args.alternating_flags = np.array([False,False,True])

    training_writer = SummaryWriter(args.save_path)
    output_writers = []
    if args.log_output:
        for i in range(3):
            output_writers.append(SummaryWriter(args.save_path/'valid'/str(i)))

    # Data loading code
    flow_loader_h, flow_loader_w = 256, 832

    if args.data_normalization =='global':
        normalize = custom_transforms.Normalize(mean=[0.5, 0.5, 0.5],
                                                std=[0.5, 0.5, 0.5])
    elif args.data_normalization =='local':
        normalize = custom_transforms.NormalizeLocally()


    train_transform = custom_transforms.Compose([
        custom_transforms.RandomHorizontalFlip(),
        custom_transforms.RandomScaleCrop(),
        custom_transforms.ArrayToTensor(),
        normalize
    ])
 

    valid_transform = custom_transforms.Compose([custom_transforms.ArrayToTensor(), normalize])

    valid_flow_transform = custom_transforms.Compose([custom_transforms.Scale(h=flow_loader_h, w=flow_loader_w),
                            custom_transforms.ArrayToTensor(), normalize])

    print("=> fetching scenes in '{}'".format(args.data))
    train_set = SequenceFolder(
        args.data,
        transform=train_transform,
        seed=args.seed,
        train=True,
        sequence_length=args.sequence_length
    )

    # if no Groundtruth is avalaible, Validation set is the same type as training set to measure photometric loss from warping
    
    val_set = SequenceFolder(
        args.data,
        transform=valid_transform,
        seed=args.seed,
        train=False,
        sequence_length=args.sequence_length,
    )

    if args.with_flow_gt:
        from datasets.validation_flow import ValidationFlow
        val_flow_set = ValidationFlow(root=args.kitti_dir,
                                        sequence_length=args.sequence_length, transform=valid_flow_transform)

    if args.DEBUG:
        train_set.__len__ = 32
        train_set.samples = train_set.samples[:32]

    print('{} samples found in {} train scenes'.format(len(train_set), len(train_set.scenes)))
    print('{} samples found in {} valid scenes'.format(len(val_set), len(val_set.scenes)))
    train_loader = torch.utils.data.DataLoader(
        train_set, batch_size=args.batch_size, shuffle=True,
        num_workers=args.workers, pin_memory=True, drop_last=True)
    val_loader = torch.utils.data.DataLoader(
        val_set, batch_size=args.batch_size, shuffle=False,
        num_workers=args.workers, pin_memory=True, drop_last=True)

    if args.with_flow_gt:
        val_flow_loader = torch.utils.data.DataLoader(val_flow_set, batch_size=1,               # batch size is 1 since images in kitti have different sizes
                        shuffle=False, num_workers=args.workers, pin_memory=True, drop_last=True)

    if args.epoch_size == 0:
        args.epoch_size = len(train_loader)

    # create model
    print("=> creating model")
    
    if args.flownet=='SpyNet':
        flow_net = getattr(models, args.flownet)(nlevels=args.nlevels, pre_normalization=normalize).cuda()
    else:
        flow_net = getattr(models, args.flownet)(nlevels=args.nlevels).cuda()

    # load pre-trained weights

    if args.pretrained_flow:
        print("=> using pre-trained weights for FlowNet")
        weights = torch.load(args.pretrained_flow)
        flow_net.load_state_dict(weights['state_dict'])
    # else:
        #flow_net.init_weights()


    if args.resume:
        print("=> resuming from checkpoint")  
        flownet_weights = torch.load(args.save_path/'flownet_checkpoint.pth.tar')
        flow_net.load_state_dict(flownet_weights['state_dict'])


    # import ipdb; ipdb.set_trace()
    cudnn.benchmark = True
    flow_net = torch.nn.DataParallel(flow_net)

    print('=> setting adam solver')
    parameters = chain(flow_net.parameters())
    optimizer = torch.optim.Adam(parameters, args.lr,
                                 betas=(args.momentum, args.beta),
                                 weight_decay=args.weight_decay)

    milestones = [300]
    scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones, gamma=0.1, last_epoch=-1)

    if args.min:
        print("using min method")

    if args.resume and (args.save_path/'optimizer_checkpoint.pth.tar').exists():
        print("=> loading optimizer from checkpoint")
        optimizer_weights = torch.load(args.save_path/'optimizer_checkpoint.pth.tar')
        optimizer.load_state_dict(optimizer_weights['state_dict'])

    with open(args.save_path/args.log_summary, 'w') as csvfile:
        writer = csv.writer(csvfile, delimiter='\t')
        writer.writerow(['train_loss', 'validation_loss'])

    with open(args.save_path/args.log_full, 'w') as csvfile:
        writer = csv.writer(csvfile, delimiter='\t')
        writer.writerow(['train_loss', 'photo_cam_loss', 'photo_flow_loss', 'explainability_loss', 'smooth_loss'])

    if args.log_terminal:
        logger = TermLogger(n_epochs=args.epochs, train_size=min(len(train_loader), args.epoch_size), valid_size=len(val_loader))
        logger.epoch_bar.start()
    else:
        logger=None

    for epoch in range(args.epochs):
        scheduler.step()

        if args.fix_flownet:
            for fparams in flow_net.parameters():
                fparams.requires_grad = False

        if args.log_terminal:
            logger.epoch_bar.update(epoch)
            logger.reset_train_bar()

        # train for one epoch
        train_loss = train(train_loader, flow_net, optimizer, args.epoch_size, logger, training_writer)

        if args.log_terminal:
            logger.train_writer.write(' * Avg Loss : {:.3f}'.format(train_loss))
            logger.reset_valid_bar()


        if args.with_flow_gt:
            flow_errors, flow_error_names = validate_flow_with_gt(val_flow_loader, flow_net, epoch, logger, output_writers)

            error_string = ', '.join('{} : {:.3f}'.format(name, error) for name, error in zip(flow_error_names, flow_errors))

            if args.log_terminal:
                logger.valid_writer.write(' * Avg {}'.format(error_string))
            else:
                print('Epoch {} completed'.format(epoch))

            for error, name in zip(flow_errors, flow_error_names):
                training_writer.add_scalar(name, error, epoch)

        
        decisive_error = flow_errors[0]
        if best_error < 0:
            best_error = decisive_error

        # remember lowest error and save checkpoint
        is_best = decisive_error <= best_error
        best_error = min(best_error, decisive_error)
        save_checkpoint(
            args.save_path, {
                'epoch': epoch + 1,
                'state_dict': flow_net.module.state_dict()
            }, {
                'epoch': epoch + 1,
                'state_dict': optimizer.state_dict()
            },
            is_best)

        with open(args.save_path/args.log_summary, 'a') as csvfile:
            writer = csv.writer(csvfile, delimiter='\t')
            writer.writerow([train_loss, decisive_error])
    if args.log_terminal:
        logger.epoch_bar.finish()
예제 #13
0
def main():
    args = parser.parse_args()
    print('=> number of GPU: ', args.gpu_num)
    os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu_num
    save_path = save_path_formatter(args, parser)
    args.save_path = 'checkpoints' / save_path
    print("=> information will be saved in {}".format(args.save_path))
    args.save_path.makedirs_p()
    torch.manual_seed(args.seed)

    img_H = args.height
    img_W = args.width

    training_writer = SummaryWriter(args.save_path)

    ########################################################################
    ######################   Data loading part    ##########################

    ## normalize -1 to 1 func
    normalize = Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
    if args.dataset == 'NYU':
        valid_transform = Compose([
            CenterCrop(size=(img_H, img_W)),
            ArrayToTensor(height=img_H, width=img_W), normalize
        ])  ### NYU valid transform ###
    elif args.dataset == 'KITTI':
        valid_transform = Compose(
            [ArrayToTensor(height=img_H, width=img_W),
             normalize])  ### KITTI valid transform ###
    print("=> fetching scenes in '{}'".format(args.data))
    print("=> Dataset: ", args.dataset)

    if args.dataset == 'KITTI':
        print("=> test on Eigen test split")
        val_set = TestFolder(args.data,
                             args=args,
                             transform=valid_transform,
                             seed=args.seed,
                             train=False,
                             mode=args.mode)
    elif args.dataset == 'Make3D':
        val_set = Make3DFolder(args.data,
                               args=args,
                               transform=valid_transform,
                               seed=args.seed,
                               train=False,
                               mode=args.mode)
    elif args.dataset == 'NYU':
        val_set = NYUdataset(args.data,
                             args=args,
                             transform=valid_transform,
                             seed=args.seed,
                             train=False,
                             mode=args.mode)
    print('=> samples_num: {}- test'.format(len(val_set)))
    val_loader = torch.utils.data.DataLoader(val_set,
                                             batch_size=args.batch_size,
                                             shuffle=False,
                                             num_workers=args.workers,
                                             pin_memory=True)
    cudnn.benchmark = True
    ###########################################################################

    ###################### setting model list #################################
    if args.multi_test is True:
        print("=> all of model tested")
        models_list_dir = Path(args.models_list_dir)
        models_list = sorted(models_list_dir.files('*.pkl'))
    else:
        print("=> just one model tested")
        models_list = [args.model_dir]

    ###################### setting Network part ###################

    print("=> creating base model")
    if args.mode == 'DtoD_test':
        print('- DtoD test')
        AE_DtoD = AutoEncoder_DtoD(norm=args.norm,
                                   input_dim=1,
                                   height=img_H,
                                   width=img_W)
        AE_DtoD = nn.DataParallel(AE_DtoD)
        AE_DtoD = AE_DtoD.cuda()
    elif args.mode == 'RtoD_test':
        print('- RtoD test')
        #AE_RtoD = AutoEncoder_Unet(norm=args.norm,height=img_H,width=img_W) #previous gradloss_mask model
        #AE_RtoD = AutoEncoder_2(norm=args.norm,input_dim=3,height=img_H,width=img_W) #current autoencoder_2 model
        AE_RtoD = AutoEncoder(norm=args.norm, height=img_H, width=img_W)
        AE_RtoD = nn.DataParallel(AE_RtoD)
        AE_RtoD = AE_RtoD.cuda()
    #############################################################################

    if args.evaluate is True:
        ############################ data log #######################################
        logger = TermLogger(n_epochs=args.epochs,
                            train_size=min(len(val_loader), args.epoch_size),
                            valid_size=len(val_loader))
        #logger.epoch_bar.start()
        with open(args.save_path / args.log_metric, 'w') as csvfile:
            writer = csv.writer(csvfile, delimiter='\t')
            if args.dataset == 'KITTI':
                writer.writerow([
                    'Epoch', 'Abs_diff', 'Abs_rel', 'Sq_rel', 'a1', 'a2', 'a3',
                    'RMSE', 'RMSE_log'
                ])
            elif args.dataset == 'Make3D':
                writer.writerow(
                    ['Epoch', 'Abs_diff', 'Abs_rel', 'log10', 'rmse'])
            elif args.dataset == 'NYU':
                writer.writerow([
                    'Epoch', 'Abs_diff', 'Abs_rel', 'log10', 'a1', 'a2', 'a3',
                    'RMSE', 'RMSE_log'
                ])
        ########################### Evaluating part #################################
        if args.mode == 'DtoD_test':
            test_model = AE_DtoD
            print("DtoD_test - eval 모드로 설정")
        elif args.mode == 'RtoD_test':
            test_model = AE_RtoD
            print("RtoD_test - eval 모드로 설정")

        test_len = len(models_list)
        print("=> Length of model list: ", test_len)

        for i in range(test_len):
            logger.reset_valid_bar()
            test_model.load_state_dict(torch.load(models_list[i]))
            test_model.eval()
            if args.dataset == 'KITTI':
                errors, min_errors, error_names = validate(
                    args, val_loader, test_model, 0, logger, args.mode)
            elif args.dataset == 'Make3D':
                errors, min_errors, error_names = validate_Make3D(
                    args, val_loader, test_model, 0, logger, args.mode)
            elif args.dataset == 'NYU':
                errors, min_errors, error_names = validate_NYU(
                    args, val_loader, test_model, 0, logger, args.mode)
            for error, name in zip(errors, error_names):
                training_writer.add_scalar(name, error, 0)
            logger.valid_writer.write(' * RtoD_model: {}'.format(
                models_list[i]))
            #error_string = ', '.join('{} : {:.3f}'.format(name, error) for name, error in zip(error_names[0:len(error_names)], errors[0:len(errors)]))
            error_string = ', '.join(
                '{} : {:.3f}'.format(name, error)
                for name, error in zip(error_names[0:len(error_names)],
                                       min_errors[0:len(errors)]))
            logger.valid_writer.write(' * Avg {}'.format(error_string))
            print("")
            #error_string = ', '.join('{} : {:.3f}'.format(name, error) for name, error in zip(error_names[0:8], min_errors[0:8]))
            #logger.valid_writer.write(' * Avg {}'.format(error_string))
            logger.valid_bar.finish()
            with open(args.save_path / args.log_metric, 'a') as csvfile:
                writer = csv.writer(csvfile, delimiter='\t')
                writer.writerow(
                    ['%02d' % i] +
                    ['%.4f' % (min_errors[k]) for k in range(len(min_errors))])

        print(args.dataset, " valdiation finish")
        ##  Test

        if args.img_save is False:
            print("--only Test mode finish--")
            return
    else:
        if args.mode == 'DtoD_test':
            test_model = AE_DtoD
            print("DtoD_test - eval 모드로 설정")
        elif args.mode == 'RtoD_test':
            test_model = AE_RtoD
            print("RtoD_test - eval 모드로 설정")
        test_model.load_state_dict(torch.load(models_list[0]))
        test_model.eval()
        print("=> No validation")

    k = 0

    print("=> img save start")
    resize_ = Resize()
    for gt_data, rgb_data, filename in val_loader:
        if args.mode == 'RtoD' or args.mode == 'RtoD_test':
            gt_data = Variable(gt_data.cuda())
            final_AE_in = rgb_data.cuda()
        elif args.mode == 'DtoD' or args.mode == 'DtoD_test':
            rgb_data = Variable(rgb_data.cuda())
            final_AE_in = gt_data.cuda()
        final_AE_in = Variable(final_AE_in)
        with torch.no_grad():
            final_AE_depth = test_model(final_AE_in, istrain=False)
        img_arr = [final_AE_depth, gt_data, rgb_data]
        folder_name_list = ['/output_depth', '/ground_truth', '/input_rgb']
        img_name_list = ['/final_AE_depth_', '/final_AE_gt_', '/final_AE_rgb_']
        folder_iter = cycle(folder_name_list)
        img_name_iter = cycle(img_name_list)
        for img in img_arr:
            img_org = img.cpu().detach().numpy()
            folder_name = next(folder_iter)
            img_name = next(img_name_iter)
            result_dir = args.result_dir + folder_name
            for t in range(img_org.shape[0]):
                filename_ = filename[t]
                img = img_org[t]
                if img.shape[0] == 3:
                    img_ = np.empty([img_H, img_W, 3])
                    img_[:, :, 0] = img[0, :, :]
                    img_[:, :, 1] = img[1, :, :]
                    img_[:, :, 2] = img[2, :, :]
                    if args.resize is True:
                        img_ = resize_(img_, (384, 1248), 'rgb')
                elif img.shape[0] == 1:
                    img_ = np.empty([img_H, img_W])
                    img_[:, :] = img[0, :, :]
                    if args.resize is True:
                        img_ = resize_(img_, (384, 1248), 'depth')
                        img_ = img_[:, :, 0]
                if not os.path.exists(result_dir):
                    os.makedirs(result_dir)
                scipy.misc.imsave(result_dir + img_name + '%05d.jpg' % (k + t),
                                  img_)
                #print(img_.shape)
                #print(filename_)
                #print(result_dir)
                #print(result_dir+filename_)
                #scipy.misc.imsave(result_dir + filename_ ,img_)
        k += img_org.shape[0]
    print("--Test image save finish--")
    return
예제 #14
0
def main():
    global best_error, n_iter, device
    args = parser.parse_args()
    if args.dataset_format == 'stacked':
        from datasets.stacked_sequence_folders import SequenceFolder
    elif args.dataset_format == 'sequential':
        from datasets.sequence_folders import SequenceFolder
    save_path = save_path_formatter(args, parser)
    args.save_path = 'checkpoints' / save_path
    print('=> will save everything to {}'.format(args.save_path))
    args.save_path.makedirs_p()  #如果没有,则建立,有则啥都不干 in Path.py小工具
    torch.manual_seed(args.seed)
    if args.evaluate:
        args.epochs = 0
#tensorboard SummaryWriter
    training_writer = SummaryWriter(args.save_path)  #for tensorboard

    output_writers = []  #list
    if args.log_output:
        for i in range(3):
            output_writers.append(
                SummaryWriter(args.save_path / 'valid' / str(i)))
# Data loading code
    normalize = custom_transforms.Normalize(mean=[0.5, 0.5, 0.5],
                                            std=[0.5, 0.5, 0.5])
    train_transform = custom_transforms.Compose([
        custom_transforms.RandomHorizontalFlip(),
        custom_transforms.RandomScaleCrop(),
        custom_transforms.ArrayToTensor(), normalize
    ])
    '''transform'''
    valid_transform = custom_transforms.Compose(
        [custom_transforms.ArrayToTensor(), normalize])

    print("=> fetching scenes in '{}'".format(args.data))
    train_set = SequenceFolder(
        args.data,  #processed_data_train_sets
        transform=train_transform,  #把几种变换函数输入进去
        seed=args.seed,
        train=True,
        sequence_length=args.sequence_length)
    # if no Groundtruth is avalaible, Validation set is
    # the same type as training set to measure photometric loss from warping
    if args.with_gt:
        from datasets.validation_folders import ValidationSet
        val_set = ValidationSet(args.data, transform=valid_transform)
    else:
        val_set = SequenceFolder(
            args.data,
            transform=valid_transform,
            seed=args.seed,
            train=False,
            sequence_length=args.sequence_length,
        )
    print('{} samples found in {} train scenes'.format(
        len(train_set), len(train_set.scenes)))  #训练集都是序列,不用左右
    print('{} samples found in {} valid scenes'.format(
        len(val_set), len(val_set.scenes)))  #测试集也是序列,不需要左右
    train_loader = torch.utils.data.DataLoader(  #data(list): [tensor(B,3,H,W),list(B),(B,H,W),(b,h,w)]
        dataset=train_set,  #sequenceFolder
        batch_size=args.batch_size,
        shuffle=True,  #打乱
        num_workers=args.workers,  #多线程读取数据
        pin_memory=True)
    val_loader = torch.utils.data.DataLoader(
        dataset=val_set,
        batch_size=args.batch_size,
        shuffle=False,  #不打乱
        num_workers=args.workers,
        pin_memory=True)

    if args.epoch_size == 0:
        args.epoch_size = len(train_loader)

# create model
    print("=> creating model")
    #disp
    disp_net = models.DispNetS().to(device)
    output_exp = args.mask_loss_weight > 0
    if not output_exp:
        print("=> no mask loss, PoseExpnet will only output pose")
    #pose
    pose_exp_net = models.PoseExpNet(
        nb_ref_imgs=args.sequence_length - 1,
        output_exp=args.mask_loss_weight > 0).to(device)

    #init posenet
    if args.pretrained_exp_pose:
        print("=> using pre-trained weights for explainabilty and pose net")
        weights = torch.load(args.pretrained_exp_pose)
        pose_exp_net.load_state_dict(weights['state_dict'], strict=False)
    else:
        pose_exp_net.init_weights()

    #init dispNet

    if args.pretrained_disp:
        print("=> using pre-trained weights for Dispnet")
        weights = torch.load(args.pretrained_disp)
        disp_net.load_state_dict(weights['state_dict'])
    else:
        disp_net.init_weights()

    cudnn.benchmark = True
    disp_net = torch.nn.DataParallel(disp_net)
    pose_exp_net = torch.nn.DataParallel(pose_exp_net)

    print('=> setting adam solver')
    #可以看到两个一起训练
    optim_params = [{
        'params': disp_net.parameters(),
        'lr': args.lr
    }, {
        'params': pose_exp_net.parameters(),
        'lr': args.lr
    }]
    optimizer = torch.optim.Adam(optim_params,
                                 betas=(args.momentum, args.beta),
                                 weight_decay=args.weight_decay)
    #训练结果写入csv
    with open(args.save_path / args.log_summary, 'w') as csvfile:
        writer = csv.writer(csvfile, delimiter='\t')
        writer.writerow(['train_loss', 'validation_loss'])

    with open(args.save_path / args.log_full, 'w') as csvfile:
        writer = csv.writer(csvfile, delimiter='\t')
        writer.writerow(
            ['train_loss', 'photo_loss', 'explainability_loss', 'smooth_loss'])

    n_epochs = args.epochs
    train_size = min(len(train_loader), args.epoch_size)
    valid_size = len(val_loader)
    logger = TermLogger(n_epochs=args.epochs,
                        train_size=min(len(train_loader), args.epoch_size),
                        valid_size=len(val_loader))
    logger.epoch_bar.start()

    if args.pretrained_disp or args.evaluate:
        logger.reset_valid_bar()
        if args.with_gt:
            errors, error_names = validate_with_gt(args, val_loader, disp_net,
                                                   0, logger, output_writers)
        else:
            errors, error_names = validate_without_gt(args, val_loader,
                                                      disp_net, pose_exp_net,
                                                      0, logger,
                                                      output_writers)

        for error, name in zip(
                errors, error_names
        ):  #validation时,对['Total loss', 'Photo loss', 'Exp loss']三个 epoch-record 指标添加记录值
            training_writer.add_scalar(name, error, 0)
        error_string = ', '.join(
            '{} : {:.3f}'.format(name, error)
            for name, error in zip(error_names[2:9], errors[2:9]))
        logger.valid_writer.write(' * Avg {}'.format(error_string))


#main cycle
    for epoch in range(args.epochs):
        logger.epoch_bar.update(epoch)

        logger.reset_train_bar()
        #1. train for one epoch
        train_loss = train(args, train_loader, disp_net, pose_exp_net,
                           optimizer, args.epoch_size, logger, training_writer)
        #其他参数都好解释, logger: SelfDefined class,

        logger.train_writer.write(' * Avg Loss : {:.3f}'.format(train_loss))

        logger.reset_valid_bar()

        # 2. validate on validation set
        if args.with_gt:  #<class 'list'>: ['Total loss', 'Photo loss', 'Exp loss']
            errors, error_names = validate_with_gt(args, val_loader, disp_net,
                                                   epoch, logger,
                                                   output_writers)
        else:
            errors, error_names = validate_without_gt(args, val_loader,
                                                      disp_net, pose_exp_net,
                                                      epoch, logger,
                                                      output_writers)

        error_string = ', '.join('{} : {:.3f}'.format(name, error)
                                 for name, error in zip(error_names, errors))
        logger.valid_writer.write(' * Avg {}'.format(error_string))

        for error, name in zip(errors, error_names):
            training_writer.add_scalar(name, error,
                                       epoch)  #损失函数中记录epoch-record指标

        # Up to you to chose the most relevant error to measure
        # your model's performance, careful some measures are to maximize (such as a1,a2,a3)

        # 3. remember lowest error and save checkpoint
        decisive_error = errors[1]
        if best_error < 0:
            best_error = decisive_error
        is_best = decisive_error < best_error
        best_error = min(best_error, decisive_error)

        #模型保存
        save_checkpoint(args.save_path, {
            'epoch': epoch + 1,
            'state_dict': disp_net.module.state_dict()
        }, {
            'epoch': epoch + 1,
            'state_dict': pose_exp_net.module.state_dict()
        }, is_best)

        with open(args.save_path / args.log_summary,
                  'a') as csvfile:  #每个epoch留下结果
            writer = csv.writer(csvfile, delimiter='\t')
            writer.writerow([train_loss,
                             decisive_error])  #第二个就是validataion 中的epoch-record
            # loss<class 'list'>: ['Total loss', 'Photo loss', 'Exp loss']
    logger.epoch_bar.finish()
예제 #15
0
def main():
    global args, best_error, n_iter
    args = parser.parse_args()

    save_path = Path(args.name)
    args.data = Path(args.data)

    args.save_path = 'checkpoints' / save_path  #/timestamp
    print('=> will save everything to {}'.format(args.save_path))
    args.save_path.makedirs_p()
    torch.manual_seed(args.seed)

    training_writer = SummaryWriter(args.save_path)
    output_writer = SummaryWriter(args.save_path / 'valid')

    print("=> fetching dataset")
    mnist_transform = torchvision.transforms.Compose([
        torchvision.transforms.ToTensor(),
        torchvision.transforms.Normalize((0.1307, ), (0.3081, ))
    ])
    trainset_mnist = torchvision.datasets.MNIST(args.data / 'mnist',
                                                train=True,
                                                transform=mnist_transform,
                                                target_transform=None,
                                                download=True)
    valset_mnist = torchvision.datasets.MNIST(args.data / 'mnist',
                                              train=False,
                                              transform=mnist_transform,
                                              target_transform=None,
                                              download=True)

    svhn_transform = torchvision.transforms.Compose([
        torchvision.transforms.Resize(size=(28, 28)),
        torchvision.transforms.Grayscale(),
        torchvision.transforms.ToTensor()
    ])
    trainset_svhn = torchvision.datasets.SVHN(args.data / 'svhn',
                                              split='train',
                                              transform=svhn_transform,
                                              target_transform=None,
                                              download=True)
    valset_svhn = torchvision.datasets.SVHN(args.data / 'svhn',
                                            split='test',
                                            transform=svhn_transform,
                                            target_transform=None,
                                            download=True)

    if args.dataset == 'mnist':
        print("Training only on MNIST")
        train_set, val_set = trainset_mnist, valset_mnist
    elif args.dataset == 'svhn':
        print("Training only on SVHN")
        train_set, val_set = trainset_svhn, valset_svhn
    else:
        print("Training on both MNIST and SVHN")
        train_set = torch.utils.data.ConcatDataset(
            [trainset_mnist, trainset_svhn])
        val_set = torch.utils.data.ConcatDataset([valset_mnist, valset_svhn])

    print('{} Train samples and {} test samples found in MNIST'.format(
        len(trainset_mnist), len(valset_mnist)))
    print('{} Train samples and {} test samples found in SVHN'.format(
        len(trainset_svhn), len(valset_svhn)))

    train_loader = torch.utils.data.DataLoader(train_set,
                                               batch_size=args.batch_size,
                                               shuffle=True,
                                               num_workers=args.workers,
                                               pin_memory=True,
                                               drop_last=True)
    val_loader = torch.utils.data.DataLoader(val_set,
                                             batch_size=args.batch_size,
                                             shuffle=False,
                                             num_workers=args.workers,
                                             pin_memory=True,
                                             drop_last=False)

    if args.epoch_size == 0:
        args.epoch_size = len(train_loader)

    # create model
    print("=> creating model")

    alice_net = LeNet()
    bob_net = LeNet()
    mod_net = LeNet(nout=1)

    if args.pretrained_alice:
        print("=> using pre-trained weights from {}".format(
            args.pretrained_alice))
        weights = torch.load(args.pretrained_alice)
        alice_net.load_state_dict(weights['state_dict'], strict=False)

    if args.pretrained_bob:
        print("=> using pre-trained weights from {}".format(
            args.pretrained_bob))
        weights = torch.load(args.pretrained_bob)
        bob_net.load_state_dict(weights['state_dict'], strict=False)

    if args.pretrained_mod:
        print("=> using pre-trained weights from {}".format(
            args.pretrained_mod))
        weights = torch.load(args.pretrained_mod)
        mod_net.load_state_dict(weights['state_dict'], strict=False)

    if args.resume:
        print("=> resuming from checkpoint")
        alice_weights = torch.load(args.save_path /
                                   'alicenet_checkpoint.pth.tar')
        bob_weights = torch.load(args.save_path / 'bobnet_checkpoint.pth.tar')
        mod_weights = torch.load(args.save_path / 'modnet_checkpoint.pth.tar')

        alice_net.load_state_dict(alice_weights['state_dict'])
        bob_net.load_state_dict(bob_weights['state_dict'])
        mod_net.load_state_dict(mod_weights['state_dict'])

    cudnn.benchmark = True
    alice_net = alice_net.cuda()
    bob_net = bob_net.cuda()
    mod_net = mod_net.cuda()

    print('=> setting adam solver')

    parameters = chain(alice_net.parameters(), bob_net.parameters(),
                       mod_net.parameters())
    optimizer_compete = torch.optim.Adam(parameters,
                                         args.lr,
                                         betas=(args.momentum, args.beta),
                                         weight_decay=args.weight_decay)

    optimizer_collaborate = torch.optim.Adam(mod_net.parameters(),
                                             args.lr,
                                             betas=(args.momentum, args.beta),
                                             weight_decay=args.weight_decay)

    with open(args.save_path / args.log_summary, 'w') as csvfile:
        writer = csv.writer(csvfile, delimiter='\t')
        writer.writerow(['val_loss_full', 'val_loss_alice', 'val_loss_bob'])

    with open(args.save_path / args.log_full, 'w') as csvfile:
        writer = csv.writer(csvfile, delimiter='\t')
        writer.writerow(
            ['train_loss_full', 'train_loss_alice', 'train_loss_bob'])

    if args.log_terminal:
        logger = TermLogger(n_epochs=args.epochs,
                            train_size=min(len(train_loader), args.epoch_size),
                            valid_size=len(val_loader))
        logger.epoch_bar.start()
    else:
        logger = None

    for epoch in range(args.epochs):
        mode = 'compete' if (epoch % 2) == 0 else 'collaborate'

        if args.fix_alice:
            for fparams in alice_net.parameters():
                fparams.requires_grad = False

        if args.fix_bob:
            for fparams in bob_net.parameters():
                fparams.requires_grad = False

        if args.fix_mod:
            mode = 'compete'
            for fparams in mod_net.parameters():
                fparams.requires_grad = False

        if args.log_terminal:
            logger.epoch_bar.update(epoch)
            logger.reset_train_bar()

        # train for one epoch
        if mode == 'compete':
            train_loss = train(train_loader,
                               alice_net,
                               bob_net,
                               mod_net,
                               optimizer_compete,
                               args.epoch_size,
                               logger,
                               training_writer,
                               mode=mode)
        elif mode == 'collaborate':
            train_loss = train(train_loader,
                               alice_net,
                               bob_net,
                               mod_net,
                               optimizer_collaborate,
                               args.epoch_size,
                               logger,
                               training_writer,
                               mode=mode)

        if args.log_terminal:
            logger.train_writer.write(
                ' * Avg Loss : {:.3f}'.format(train_loss))
            logger.reset_valid_bar()

        if epoch % 1 == 0:

            # evaluate on validation set
            errors, error_names = validate(val_loader, alice_net, bob_net,
                                           mod_net, epoch, logger,
                                           output_writer)

            error_string = ', '.join(
                '{} : {:.3f}'.format(name, error)
                for name, error in zip(error_names, errors))

            if args.log_terminal:
                logger.valid_writer.write(' * Avg {}'.format(error_string))
            else:
                print('Epoch {} completed'.format(epoch))

            for error, name in zip(errors, error_names):
                training_writer.add_scalar(name, error, epoch)

            # Up to you to chose the most relevant error to measure your model's performance, careful some measures are to maximize (such as a1,a2,a3)

            if args.fix_alice:
                decisive_error = errors[2]
            elif args.fix_bob:
                decisive_error = errors[1]
            else:
                decisive_error = errors[0]  # epe_total
            if best_error < 0:
                best_error = decisive_error

            # remember lowest error and save checkpoint
            is_best = decisive_error <= best_error
            best_error = min(best_error, decisive_error)
            save_alice_bob_mod(args.save_path, {
                'epoch': epoch + 1,
                'state_dict': alice_net.state_dict()
            }, {
                'epoch': epoch + 1,
                'state_dict': bob_net.state_dict()
            }, {
                'epoch': epoch + 1,
                'state_dict': mod_net.state_dict()
            }, is_best)

            with open(args.save_path / args.log_summary, 'a') as csvfile:
                writer = csv.writer(csvfile, delimiter='\t')
                writer.writerow([train_loss, decisive_error])

    if args.log_terminal:
        logger.epoch_bar.finish()
예제 #16
0
def main():
    global best_error, n_iter, device
    args = parser.parse_args()

    save_path = make_save_path(args)
    args.save_path = save_path
    dump_config(save_path, args)
    print('=> Saving checkpoints to {}'.format(save_path))
    torch.manual_seed(args.seed)
    tb_writer = SummaryWriter(save_path)

    # Data preprocessing
    train_transform = valid_transform = custom_transforms.Compose([
        custom_transforms.ArrayToTensor(),
        custom_transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
    ])

    # Create dataloader
    print("=> Fetching scenes in '{}'".format(args.data))
    train_set = SequenceFolder(
        args.data,
        gray=args.gray,
        cameras=args.cameras,
        transform=train_transform,
        seed=args.seed,
        train=True,
        sequence_length=args.sequence_length
    )
    
    val_set = SequenceFolder(
        args.data,
        gray=args.gray,
        cameras=args.cameras,
        transform=valid_transform,
        seed=args.seed,
        train=False,
        sequence_length=args.sequence_length,
        shuffle=False
    )

    print('=> {} samples found in {} train scenes'.format(len(train_set), len(train_set.scenes)))
    print('=> {} samples found in {} valid scenes'.format(len(val_set), len(val_set.scenes)))

    # Create batch loader
    train_loader = torch.utils.data.DataLoader(train_set, batch_size=args.batch_size, shuffle=True, num_workers=args.workers, pin_memory=True)
    val_loader = torch.utils.data.DataLoader(val_set, batch_size=args.batch_size, shuffle=False, num_workers=args.workers, pin_memory=True)

    # Pull first example from dataset to check number of channels
    input_channels = train_set[0][1].shape[0]   
    args.epoch_size = len(train_loader)
    print("=> Using {} input channels, {} total batches".format(input_channels, args.epoch_size))
    
    # create model
    print("=> Creating models")
    pose_exp_net = models.LFPoseNet(in_channels=input_channels, nb_ref_imgs=args.sequence_length - 1, output_exp=args.mask_loss_weight > 0).to(device)

    if args.pretrained_exp_pose:
        print("=> Using pre-trained weights for explainabilty and pose net")
        weights = torch.load(args.pretrained_exp_pose)
        pose_exp_net.load_state_dict(weights['state_dict'], strict=False)
    else:
        pose_exp_net.init_weights()

    cudnn.benchmark = True
    pose_exp_net = torch.nn.DataParallel(pose_exp_net)

    print('=> Setting adam solver')

    optim_params = [
        {'params': pose_exp_net.parameters(), 'lr': args.lr}
    ]

    optimizer = torch.optim.Adam(optim_params, betas=(args.momentum, args.beta), weight_decay=args.weight_decay)


    with open(save_path + "/" + args.log_summary, 'w') as csvfile:
        writer = csv.writer(csvfile, delimiter='\t')
        writer.writerow(['train_loss', 'validation_loss'])

    logger = TermLogger(n_epochs=args.epochs, train_size=min(len(train_loader), args.epoch_size), valid_size=len(val_loader))
    logger.epoch_bar.start()

    for epoch in range(args.epochs):
        logger.epoch_bar.update(epoch)

        # train for one epoch
        logger.reset_train_bar()
        train_loss = train(args, train_loader, pose_exp_net, optimizer, args.epoch_size, logger, tb_writer)
        logger.train_writer.write(' * Avg Loss : {:.3f}'.format(train_loss))
        
        # evaluate on validation set
        logger.reset_valid_bar()
        valid_loss = validate(args, val_loader, pose_exp_net, logger, tb_writer)

        if valid_loss < best_error or best_error < 0:
            best_error = valid_loss
            checkpoint = {
                "epoch": epoch + 1,
                "state_dict": pose_exp_net.module.state_dict()
            }
            torch.save(checkpoint, save_path + "/" + 'posenet_best.pth.tar')
        torch.save(checkpoint, save_path + "/" + 'posenet_checkpoint.pth.tar')

    logger.epoch_bar.finish()
def main():
    global best_error, n_iter, device
    args = parse_multiwarp_training_args()
    # Some non-optional parameters for training
    args.training_output_freq = 100
    args.tilesize = 8

    save_path = make_save_path(args)
    args.save_path = save_path

    print("Using device: {}".format(device))

    dump_config(save_path, args)
    print('\n\n=> Saving checkpoints to {}'.format(save_path))

    torch.manual_seed(args.seed)                # setting a manual seed for reproducability
    tb_writer = SummaryWriter(save_path)        # tensorboard summary writer

    # Data pre-processing - Just convert arrays to tensor and normalize the data to be largely between 0 and 1
    train_transform = valid_transform = custom_transforms.Compose([
        custom_transforms.ArrayToTensor(),
        custom_transforms.Normalize(mean=0.5, std=0.5)
    ])

    # Create data loader based on the format of the light field
    print("=> Fetching scenes in '{}'".format(args.data))
    train_set, val_set = None, None
    if args.lfformat == 'focalstack':
        train_set, val_set = get_focal_stack_loaders(args, train_transform, valid_transform)
    elif args.lfformat == 'stack':
        is_monocular = False
        if len(args.cameras) == 1 and args.cameras[0] == 8 and args.cameras_stacked == "input":
                is_monocular = True
        train_set, val_set = get_stacked_lf_loaders(args, train_transform, valid_transform, is_monocular=is_monocular)
    elif args.lfformat == 'epi':
        train_set, val_set = get_epi_loaders(args, train_transform, valid_transform)

    print('=> {} samples found in {} train scenes'.format(len(train_set), len(train_set.scenes)))
    print('=> {} samples found in {} validation scenes'.format(len(val_set), len(val_set.scenes)))

    print('=> Multi-warp training, warping {} sub-apertures'.format(len(args.cameras)))

    # Create batch loader
    train_loader = torch.utils.data.DataLoader(train_set,
                                               batch_size=args.batch_size,
                                               shuffle=True,
                                               num_workers=args.workers,
                                               pin_memory=True)
    val_loader = torch.utils.data.DataLoader(val_set,
                                             batch_size=args.batch_size,
                                             shuffle=False,
                                             num_workers=args.workers,
                                             pin_memory=True)

    output_channels = len(args.cameras) # for multi-warp photometric loss, we request as many depth values as the cameras used
    args.epoch_size = len(train_loader)
    
    # Create models
    print("=> Creating models")

    if args.lfformat == "epi":
        print("=> Using EPI encoders")
        if args.cameras_epi == "vertical":
            disp_encoder = models.EpiEncoder('vertical', args.tilesize).to(device)
            pose_encoder = models.RelativeEpiEncoder('vertical', args.tilesize).to(device)
            dispnet_input_channels = 16 + len(args.cameras)     # 16 is the number of output channels of the encoder
            posenet_input_channels = 16 + len(args.cameras)     # 16 is the number of output channels of the encoder
        elif args.cameras_epi == "horizontal":
            disp_encoder = models.EpiEncoder('horizontal', args.tilesize).to(device)
            pose_encoder = models.RelativeEpiEncoder('horizontal', args.tilesize).to(device)
            dispnet_input_channels = 16 + len(args.cameras)  # 16 is the number of output channels of the encoder
            posenet_input_channels = 16 + len(args.cameras)  # 16 is the number of output channels of the encoder
        elif args.cameras_epi == "full":
            disp_encoder = models.EpiEncoder('full', args.tilesize).to(device)
            pose_encoder = models.RelativeEpiEncoder('full', args.tilesize).to(device)
            if args.without_disp_stack:
                dispnet_input_channels = 32  # 16 is the number of output channels of each encoder
            else:
                dispnet_input_channels = 32 + 5  # 16 is the number of output channels of each encoder, 5 from stack
            posenet_input_channels = 32 + 5  # 16 is the number of output channels of each encoder
        else:
            raise ValueError("Incorrect cameras epi format")
    else:
        disp_encoder = None
        pose_encoder = None
        # for stack lfformat channels = num_cameras * num_colour_channels
        # for focalstack lfformat channels = num_focal_planes * num_colour_channels
        dispnet_input_channels = posenet_input_channels = train_set[0]['tgt_lf_formatted'].shape[0]
    
    disp_net = models.LFDispNet(in_channels=dispnet_input_channels,
                                out_channels=output_channels, encoder=disp_encoder).to(device)
    pose_net = models.LFPoseNet(in_channels=posenet_input_channels,
                                nb_ref_imgs=args.sequence_length - 1, encoder=pose_encoder).to(device)

    print("=> [DispNet] Using {} input channels, {} output channels".format(dispnet_input_channels, output_channels))
    print("=> [PoseNet] Using {} input channels".format(posenet_input_channels))

    if args.pretrained_exp_pose:
        print("=> [PoseNet] Using pre-trained weights for pose net")
        weights = torch.load(args.pretrained_exp_pose)
        pose_net.load_state_dict(weights['state_dict'], strict=False)
    else:
        print("=> [PoseNet] training from scratch")
        pose_net.init_weights()

    if args.pretrained_disp:
        print("=> [DispNet] Using pre-trained weights for DispNet")
        weights = torch.load(args.pretrained_disp)
        disp_net.load_state_dict(weights['state_dict'])
    else:
        print("=> [DispNet] training from scratch")
        disp_net.init_weights()

    # this flag tells CUDNN to find the optimal set of algorithms for this specific input data size, which improves
    # runtime efficiency, but takes a while to load in the beginning.
    cudnn.benchmark = True
    # disp_net = torch.nn.DataParallel(disp_net)
    # pose_net = torch.nn.DataParallel(pose_net)

    print('=> Setting adam solver')

    optim_params = [
        {'params': disp_net.parameters(), 'lr': args.lr}, 
        {'params': pose_net.parameters(), 'lr': args.lr}
    ]

    optimizer = torch.optim.Adam(optim_params, betas=(args.momentum, args.beta), weight_decay=args.weight_decay)
    lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=50, gamma=0.5)

    with open(save_path + "/" + args.log_summary, 'w') as csvfile:
        writer = csv.writer(csvfile, delimiter='\t')
        writer.writerow(['train_loss', 'validation_loss'])

    with open(save_path + "/" + args.log_full, 'w') as csvfile:
        writer = csv.writer(csvfile, delimiter='\t')
        writer.writerow(['train_loss', 'photo_loss', 'smooth_loss', 'pose_loss'])

    logger = TermLogger(n_epochs=args.epochs,
                        train_size=min(len(train_loader), args.epoch_size),
                        valid_size=len(val_loader))
    logger.epoch_bar.start()

    w1 = torch.tensor(args.photo_loss_weight, dtype=torch.float32, device=device, requires_grad=True)
    # w2 = torch.tensor(args.mask_loss_weight, dtype=torch.float32, device=device, requires_grad=True)
    w3 = torch.tensor(args.smooth_loss_weight, dtype=torch.float32, device=device, requires_grad=True)
    # w4 = torch.tensor(args.gt_pose_loss_weight, dtype=torch.float32, device=device, requires_grad=True)

    # add some constant parameters to the log for easy visualization
    tb_writer.add_scalar(tag="batch_size", scalar_value=args.batch_size)

    # tb_writer.add_scalar(tag="mask_loss_weight", scalar_value=args.mask_loss_weight)    # this is not used

    # tb_writer.add_scalar(tag="gt_pose_loss_weight", scalar_value=args.gt_pose_loss_weight)

    for epoch in range(args.epochs):
        logger.epoch_bar.update(epoch)

        # train for one epoch
        logger.reset_train_bar()
        train_loss = train(args, train_loader, disp_net, pose_net, optimizer, args.epoch_size, logger, tb_writer, w1, w3)
        logger.train_writer.write(' * Avg Loss : {:.3f}'.format(train_loss))

        # evaluate on validation set
        logger.reset_valid_bar()
        errors, error_names = validate_without_gt(args, val_loader, disp_net, pose_net, epoch, logger, tb_writer, w1, w3)
        error_string = ', '.join('{} : {:.3f}'.format(name, error) for name, error in zip(error_names, errors))
        logger.valid_writer.write(' * Avg {}'.format(error_string))

        # update the learning rate (annealing)
        lr_scheduler.step()

        # add the learning rate to the tensorboard logging
        tb_writer.add_scalar(tag="learning_rate", scalar_value=lr_scheduler.get_last_lr()[0], global_step=epoch)

        tb_writer.add_scalar(tag="photometric_loss_weight", scalar_value=w1, global_step=epoch)
        tb_writer.add_scalar(tag="smooth_loss_weight", scalar_value=w3, global_step=epoch)

        # add validation errors to the tensorboard logging
        for error, name in zip(errors, error_names):
            tb_writer.add_scalar(tag=name, scalar_value=error, global_step=epoch)

        decisive_error = errors[2]
        if best_error < 0:
            best_error = decisive_error

        # remember lowest error and save checkpoint
        is_best = decisive_error < best_error
        best_error = min(best_error, decisive_error)
        save_checkpoint(save_path, {'epoch': epoch + 1, 'state_dict': disp_net.state_dict()},
                        {'epoch': epoch + 1, 'state_dict': pose_net.state_dict()}, is_best)

        # save a checkpoint every 20 epochs anyway
        if epoch % 20 == 0:
            save_checkpoint_current(save_path, {'epoch': epoch + 1, 'state_dict': disp_net.state_dict()},
                                    {'epoch': epoch + 1, 'state_dict': pose_net.state_dict()}, epoch)

        with open(save_path + "/" + args.log_summary, 'a') as csvfile:
            writer = csv.writer(csvfile, delimiter='\t')
            writer.writerow([train_loss, decisive_error])
    logger.epoch_bar.finish()
예제 #18
0
def main():
    global best_error, n_iter, device
    args = parser.parse_args()
    if args.dataset_format == 'stacked':
        from datasets.stacked_sequence_folders import SequenceFolder
    elif args.dataset_format == 'sequential':
        from datasets.sequence_folders import SequenceFolder
    save_path = save_path_formatter(args, parser)
    args.save_path = 'checkpoints' / save_path
    print('=> will save everything to {}'.format(args.save_path))
    args.save_path.makedirs_p()
    torch.manual_seed(args.seed)
    if args.evaluate:
        args.epochs = 0

    tb_writer = SummaryWriter(args.save_path)
    # Data loading code
    normalize = custom_transforms.Normalize(mean=[0.5, 0.5, 0.5],
                                            std=[0.5, 0.5, 0.5])
    train_transform = custom_transforms.Compose([
        custom_transforms.RandomHorizontalFlip(),
        custom_transforms.RandomScaleCrop(),
        custom_transforms.ArrayToTensor(), normalize
    ])

    valid_transform = custom_transforms.Compose(
        [custom_transforms.ArrayToTensor(), normalize])

    print("=> fetching scenes in '{}'".format(args.data))
    train_set = SequenceFolder(args.data,
                               transform=train_transform,
                               seed=args.seed,
                               train=True,
                               sequence_length=args.sequence_length)

    # if no Groundtruth is avalaible, Validation set is the same type as training set to measure photometric loss from warping
    if args.with_gt:
        from datasets.validation_folders import ValidationSet
        val_set = ValidationSet(args.data, transform=valid_transform)
    else:
        val_set = SequenceFolder(
            args.data,
            transform=valid_transform,
            seed=args.seed,
            train=False,
            sequence_length=args.sequence_length,
        )
    print('{} samples found in {} train scenes'.format(len(train_set),
                                                       len(train_set.scenes)))
    print('{} samples found in {} valid scenes'.format(len(val_set),
                                                       len(val_set.scenes)))
    train_loader = torch.utils.data.DataLoader(train_set,
                                               batch_size=args.batch_size,
                                               shuffle=True,
                                               num_workers=args.workers,
                                               pin_memory=True,
                                               drop_last=True)
    val_loader = torch.utils.data.DataLoader(val_set,
                                             batch_size=args.batch_size,
                                             shuffle=False,
                                             num_workers=args.workers,
                                             pin_memory=True)

    if args.epoch_size == 0:
        args.epoch_size = len(train_loader)

    # create model
    print("=> creating model")

    disp_net = models.DispNetS().to(device)
    seg_net = DeepLab(num_classes=args.nclass,
                      backbone=args.backbone,
                      output_stride=args.out_stride,
                      sync_bn=args.sync_bn,
                      freeze_bn=args.freeze_bn).to(device)
    if args.pretrained_seg:
        print("=> using pre-trained weights for seg net")
        weights = torch.load(args.pretrained_seg)
        seg_net.load_state_dict(weights, strict=False)
    output_exp = args.mask_loss_weight > 0
    if not output_exp:
        print("=> no mask loss, PoseExpnet will only output pose")
    pose_exp_net = models.PoseExpNet(
        nb_ref_imgs=args.sequence_length - 1,
        output_exp=args.mask_loss_weight > 0).to(device)

    if args.pretrained_exp_pose:
        print("=> using pre-trained weights for explainabilty and pose net")
        weights = torch.load(args.pretrained_exp_pose)
        pose_exp_net.load_state_dict(weights['state_dict'], strict=False)
    else:
        pose_exp_net.init_weights()

    if args.pretrained_disp:
        print("=> using pre-trained weights for Dispnet")
        weights = torch.load(args.pretrained_disp)
        disp_net.load_state_dict(weights['state_dict'])
    else:
        disp_net.init_weights()

    cudnn.benchmark = True
    disp_net = torch.nn.DataParallel(disp_net)
    pose_exp_net = torch.nn.DataParallel(pose_exp_net)
    seg_net = torch.nn.DataParallel(seg_net)

    print('=> setting adam solver')

    optim_params = [{
        'params': disp_net.parameters(),
        'lr': args.lr
    }, {
        'params': pose_exp_net.parameters(),
        'lr': args.lr
    }]
    optimizer = torch.optim.Adam(optim_params,
                                 betas=(args.momentum, args.beta),
                                 weight_decay=args.weight_decay)

    with open(args.save_path / args.log_summary, 'w') as csvfile:
        writer = csv.writer(csvfile, delimiter='\t')
        writer.writerow(['train_loss', 'validation_loss'])

    with open(args.save_path / args.log_full, 'w') as csvfile:
        writer = csv.writer(csvfile, delimiter='\t')
        writer.writerow(
            ['train_loss', 'photo_loss', 'explainability_loss', 'smooth_loss'])

    logger = TermLogger(n_epochs=args.epochs,
                        train_size=min(len(train_loader), args.epoch_size),
                        valid_size=len(val_loader))
    logger.epoch_bar.start()

    if args.pretrained_disp or args.evaluate:
        logger.reset_valid_bar()
        if args.with_gt:
            errors, error_names = validate_with_gt(args, val_loader, disp_net,
                                                   0, logger, tb_writer)
        else:
            errors, error_names = validate_without_gt(args, val_loader,
                                                      disp_net, pose_exp_net,
                                                      0, logger, tb_writer)
        for error, name in zip(errors, error_names):
            tb_writer.add_scalar(name, error, 0)
        error_string = ', '.join(
            '{} : {:.3f}'.format(name, error)
            for name, error in zip(error_names[2:9], errors[2:9]))
        logger.valid_writer.write(' * Avg {}'.format(error_string))

    for epoch in range(args.epochs):
        logger.epoch_bar.update(epoch)

        # train for one epoch
        logger.reset_train_bar()
        train_loss = train(args, train_loader, disp_net, pose_exp_net, seg_net,
                           optimizer, args.epoch_size, logger, tb_writer)
        logger.train_writer.write(' * Avg Loss : {:.3f}'.format(train_loss))

        # evaluate on validation set
        logger.reset_valid_bar()
        if args.with_gt:
            errors, error_names = validate_with_gt(args, val_loader, disp_net,
                                                   seg_net, epoch, logger,
                                                   tb_writer)
        else:
            errors, error_names = validate_without_gt(args, val_loader,
                                                      disp_net, pose_exp_net,
                                                      epoch, logger, tb_writer)
        error_string = ', '.join('{} : {:.3f}'.format(name, error)
                                 for name, error in zip(error_names, errors))
        logger.valid_writer.write(' * Avg {}'.format(error_string))

        for error, name in zip(errors, error_names):
            tb_writer.add_scalar(name, error, epoch)

        # Up to you to chose the most relevant error to measure your model's performance, careful some measures are to maximize (such as a1,a2,a3)
        decisive_error = errors[1]
        if best_error < 0:
            best_error = decisive_error

        # remember lowest error and save checkpoint
        is_best = decisive_error < best_error
        best_error = min(best_error, decisive_error)
        save_checkpoint(args.save_path, {
            'epoch': epoch + 1,
            'state_dict': disp_net.module.state_dict()
        }, {
            'epoch': epoch + 1,
            'state_dict': pose_exp_net.module.state_dict()
        }, is_best)

        with open(args.save_path / args.log_summary, 'a') as csvfile:
            writer = csv.writer(csvfile, delimiter='\t')
            writer.writerow([train_loss, decisive_error])
    logger.epoch_bar.finish()
예제 #19
0
파일: main.py 프로젝트: zhuoliny/flowattack
def main():
    global args, best_error, n_iter
    args = parser.parse_args()
    save_path = Path(args.name)
    args.save_path = 'checkpoints' / save_path  #/timestamp
    print('=> will save everything to {}'.format(args.save_path))
    args.save_path.makedirs_p()
    torch.manual_seed(args.seed)

    training_writer = SummaryWriter(args.save_path)
    output_writer = SummaryWriter(args.save_path / 'valid')

    # Data loading code
    flow_loader_h, flow_loader_w = 384, 1280

    train_transform = custom_transforms.Compose([
        custom_transforms.RandomHorizontalFlip(),
        custom_transforms.RandomScaleCrop(h=256, w=256),
        custom_transforms.ArrayToTensor(),
    ])

    valid_transform = custom_transforms.Compose([
        custom_transforms.Scale(h=flow_loader_h, w=flow_loader_w),
        custom_transforms.ArrayToTensor()
    ])

    print("=> fetching scenes in '{}'".format(args.data))
    train_set = SequenceFolder(args.data,
                               transform=train_transform,
                               seed=args.seed,
                               train=True,
                               sequence_length=3)

    if args.valset == "kitti2015":
        from datasets.validation_flow import ValidationFlowKitti2015
        val_set = ValidationFlowKitti2015(root=args.kitti_data,
                                          transform=valid_transform)
    elif args.valset == "kitti2012":
        from datasets.validation_flow import ValidationFlowKitti2012
        val_set = ValidationFlowKitti2012(root=args.kitti_data,
                                          transform=valid_transform)

    if args.DEBUG:
        train_set.__len__ = 32
        train_set.samples = train_set.samples[:32]

    print('{} samples found in {} train scenes'.format(len(train_set),
                                                       len(train_set.scenes)))
    print('{} samples found in valid scenes'.format(len(val_set)))
    train_loader = torch.utils.data.DataLoader(train_set,
                                               batch_size=1,
                                               shuffle=True,
                                               num_workers=args.workers,
                                               pin_memory=True,
                                               drop_last=True)

    val_loader = torch.utils.data.DataLoader(
        val_set,
        batch_size=
        1,  # batch size is 1 since images in kitti have different sizes
        shuffle=False,
        num_workers=args.workers,
        pin_memory=True,
        drop_last=True)

    if args.epoch_size == 0:
        args.epoch_size = len(train_loader)

    # create model
    print("=> creating model")

    if args.flownet == 'SpyNet':
        flow_net = getattr(models, args.flownet)(nlevels=6, pretrained=True)
    elif args.flownet == 'Back2Future':
        flow_net = getattr(
            models, args.flownet)(pretrained='pretrained/b2f_rm_hard.pth.tar')
    elif args.flownet == 'PWCNet':
        flow_net = models.pwc_dc_net(
            'pretrained/pwc_net_chairs.pth.tar')  # pwc_net.pth.tar')
    else:
        flow_net = getattr(models, args.flownet)()

    if args.flownet in ['SpyNet', 'Back2Future', 'PWCNet']:
        print("=> using pre-trained weights for " + args.flownet)
    elif args.flownet in ['FlowNetC']:
        print("=> using pre-trained weights for FlowNetC")
        weights = torch.load('pretrained/FlowNet2-C_checkpoint.pth.tar')
        flow_net.load_state_dict(weights['state_dict'])
    elif args.flownet in ['FlowNetS']:
        print("=> using pre-trained weights for FlowNetS")
        weights = torch.load('pretrained/flownets.pth.tar')
        flow_net.load_state_dict(weights['state_dict'])
    elif args.flownet in ['FlowNet2']:
        print("=> using pre-trained weights for FlowNet2")
        weights = torch.load('pretrained/FlowNet2_checkpoint.pth.tar')
        flow_net.load_state_dict(weights['state_dict'])
    else:
        flow_net.init_weights()

    pytorch_total_params = sum(p.numel() for p in flow_net.parameters())
    print("Number of model paramters: " + str(pytorch_total_params))

    flow_net = flow_net.cuda()

    cudnn.benchmark = True
    if args.patch_type == 'circle':
        patch, mask, patch_shape = init_patch_circle(args.image_size,
                                                     args.patch_size)
        patch_init = patch.copy()
    elif args.patch_type == 'square':
        patch, patch_shape = init_patch_square(args.image_size,
                                               args.patch_size)
        patch_init = patch.copy()
        mask = np.ones(patch_shape)
    else:
        sys.exit("Please choose a square or circle patch")

    if args.patch_path:
        patch, mask, patch_shape = init_patch_from_image(
            args.patch_path, args.mask_path, args.image_size, args.patch_size)
        patch_init = patch.copy()

    if args.log_terminal:
        logger = TermLogger(n_epochs=args.epochs,
                            train_size=min(len(train_loader), args.epoch_size),
                            valid_size=len(val_loader),
                            attack_size=args.max_count)
        logger.epoch_bar.start()
    else:
        logger = None

    for epoch in range(args.epochs):

        if args.log_terminal:
            logger.epoch_bar.update(epoch)
            logger.reset_train_bar()

        # train for one epoch
        patch, mask, patch_init, patch_shape = train(patch, mask, patch_init,
                                                     patch_shape, train_loader,
                                                     flow_net, epoch, logger,
                                                     training_writer)

        # Validate
        errors, error_names = validate_flow_with_gt(patch, mask, patch_shape,
                                                    val_loader, flow_net,
                                                    epoch, logger,
                                                    output_writer)

        error_string = ', '.join('{} : {:.3f}'.format(name, error)
                                 for name, error in zip(error_names, errors))
        #
        if args.log_terminal:
            logger.valid_writer.write(' * Avg {}'.format(error_string))
        else:
            print('Epoch {} completed'.format(epoch))

        for error, name in zip(errors, error_names):
            training_writer.add_scalar(name, error, epoch)

        torch.save(patch, args.save_path / 'patch_epoch_{}'.format(str(epoch)))

    if args.log_terminal:
        logger.epoch_bar.finish()
예제 #20
0
def main():
    global best_error, n_iter, device
    args = parser.parse_args()
    save_path = save_path_formatter(args, parser)
    args.save_path = 'checkpoints'/save_path
    print('=> will save everything to {}'.format(args.save_path))
    args.save_path.makedirs_p()
    torch.manual_seed(args.seed)
    if args.evaluate:
        args.epochs = 0

    training_writer = SummaryWriter(args.save_path)
    output_writers = []
    if args.log_output:
        for i in range(3):
            output_writers.append(SummaryWriter(args.save_path/'valid'/str(i)))

    # Data loading code
    train_transform = custom_transforms.Compose([
        custom_transforms.RandomHorizontalFlip(),
        custom_transforms.RandomScaleCrop(),
        custom_transforms.ArrayToTensor()
    ])

    valid_transform = custom_transforms.Compose([custom_transforms.ArrayToTensor(), normalize])

    print("=> fetching scenes in '{}'".format(args.data))
    train_set = KITTIDataset(
        root_dir,
        sequences,
        max_distance=args.max_distance,
        transform=None
    )

    val_set = KITTIDataset(
        root_dir,
        sequences,
        max_distance=args.max_distance,
        transform=None
    )
    print('{} samples found in {} train scenes'.format(len(train_set), len(train_set)))
    print('{} samples found in {} valid scenes'.format(len(val_set), len(val_set)))

    train_loader = torch.utils.data.DataLoader(
        train_set, batch_size=args.batch_size, shuffle=True,
        num_workers=args.workers, pin_memory=True)
    val_loader = torch.utils.data.DataLoader(
        val_set, batch_size=args.batch_size, shuffle=False,
        num_workers=args.workers, pin_memory=True)

    if args.epoch_size == 0:
        args.epoch_size = len(train_loader)

    # create model
    print("=> creating model")

    config_file = "./configs/e2e_mask_rcnn_R_50_FPN_1x.yaml"
    cfg.merge_from_file(config_file)
    cfg.freeze()
    pretrained_model_path = "./e2e_mask_rcnn_R_50_FPN_1x.pth"
    disvo = DISVO(cfg, pretrained_model_path).cuda()

    if args.pretrained_disvo:
        print("=> using pre-trained weights for Dispnet")
        weights = torch.load(args.pretrained_disvo)
        disvo.load_state_dict(weights['state_dict'])
    else:
        disvo.init_weights()

    cudnn.benchmark = True

    print('=> setting adam solver')

    optim_params = [
        {'params': disvo.parameters(), 'lr': args.lr}
    ]
    optimizer = torch.optim.Adam(optim_params,
                                 betas=(args.momentum, args.beta),
                                 weight_decay=args.weight_decay)

    with open(args.save_path/args.log_summary, 'w') as csvfile:
        writer = csv.writer(csvfile, delimiter='\t')
        writer.writerow(['train_loss', 'validation_loss'])

    with open(args.save_path/args.log_full, 'w') as csvfile:
        writer = csv.writer(csvfile, delimiter='\t')
        writer.writerow(['train_loss'])

    logger = TermLogger(n_epochs=args.epochs, train_size=min(len(train_loader), args.epoch_size), valid_size=len(val_loader))
    logger.epoch_bar.start()

    if args.pretrained_disvo or args.evaluate:
        logger.reset_valid_bar()
        errors, error_names = validate(args, val_loader, disvo, 0, logger, output_writers)
        for error, name in zip(errors, error_names):
            training_writer.add_scalar(name, error, 0)
        error_string = ', '.join('{} : {:.3f}'.format(name, error) for name, error in zip(error_names[2:9], errors[2:9]))
        logger.valid_writer.write(' * Avg {}'.format(error_string))

    for epoch in range(args.epochs):
        logger.epoch_bar.update(epoch)

        # train for one epoch
        logger.reset_train_bar()
        train_loss = train(args, train_loader, disvo, optimizer, args.epoch_size, logger, training_writer)
        logger.train_writer.write(' * Avg Loss : {:.3f}'.format(train_loss))

        # evaluate on validation set
        logger.reset_valid_bar()
        errors, error_names = validate(args, val_loader, disvo, 0, logger, output_writers)
        error_string = ', '.join('{} : {:.3f}'.format(name, error) for name, error in zip(error_names, errors))
        logger.valid_writer.write(' * Avg {}'.format(error_string))

        for error, name in zip(errors, error_names):
            training_writer.add_scalar(name, error, epoch)

        # Up to you to chose the most relevant error to measure your model's performance, careful some measures are to maximize (such as a1,a2,a3)
        decisive_error = errors[1]
        if best_error < 0:
            best_error = decisive_error

        # remember lowest error and save checkpoint
        is_best = decisive_error < best_error
        best_error = min(best_error, decisive_error)
        save_checkpoint(
            args.save_path, {
                'epoch': epoch + 1,
                'state_dict': disvo.module.state_dict()
            },
            is_best)

        with open(args.save_path/args.log_summary, 'a') as csvfile:
            writer = csv.writer(csvfile, delimiter='\t')
            writer.writerow([train_loss, decisive_error])
    logger.epoch_bar.finish()