예제 #1
0
파일: train.py 프로젝트: zhenzhenxiang/hd3
def main():
    global args, logger, writer
    args = get_parser().parse_args()
    logger = get_logger()
    writer = SummaryWriter(args.save_path)
    logger.info(args)
    logger.info("=> creating model ...")

    ### model ###
    corr_range = [4, 4, 4, 4, 4, 4]
    if args.task == 'flow':
        corr_range = corr_range[:5]
    model = models.HD3Model(args.task, args.encoder, args.decoder, corr_range,
                            args.context).cuda()

    logger.info(model)
    optimizer = torch.optim.Adam(model.parameters(),
                                 lr=args.base_lr,
                                 weight_decay=args.weight_decay)
    model = nn.DataParallel(model).cuda()

    cudnn.enabled = True
    cudnn.benchmark = True
    best_epe_all = 1e9

    if args.pretrain:
        ckpt_name = args.pretrain
        if os.path.isfile(ckpt_name):
            logger.info("=> loading checkpoint '{}'".format(ckpt_name))
            checkpoint = torch.load(ckpt_name)
            model.load_state_dict(checkpoint['state_dict'])
            logger.info("=> loaded checkpoint '{}'".format(ckpt_name))
        else:
            logger.info("=> no checkpoint found at '{}'".format(ckpt_name))
    elif args.pretrain_base:
        logger.info("=> loading pretrained base model '{}'".format(
            args.pretrain_base))
        base_prefix = "module.hd3net.encoder." if args.encoder!='dlaup' \
                      else "module.hd3net.encoder.base."
        load_module_state_dict(model,
                               torch.load(args.pretrain_base),
                               add=base_prefix)
        logger.info("=> loaded pretrained base model '{}'".format(
            args.pretrain_base))

    ### data loader ###
    train_transform, val_transform = datasets.get_transform(
        args.dataset_name, args.task, args.evaluate)
    train_data = datasets.HD3Data(mode=args.task,
                                  data_root=args.train_root,
                                  data_list=args.train_list,
                                  label_num=1,
                                  transform=train_transform)
    train_loader = torch.utils.data.DataLoader(train_data,
                                               batch_size=args.batch_size,
                                               shuffle=True,
                                               num_workers=args.workers,
                                               pin_memory=True)
    if args.evaluate:
        val_data = datasets.HD3Data(mode=args.task,
                                    data_root=args.val_root,
                                    data_list=args.val_list,
                                    label_num=1,
                                    transform=val_transform)
        val_loader = torch.utils.data.DataLoader(
            val_data,
            batch_size=args.batch_size_val,
            shuffle=False,
            num_workers=args.workers,
            pin_memory=True)

    ### Go! ###
    scheduler = get_lr_scheduler(optimizer, args.dataset_name)
    for epoch in range(1, args.epochs + 1):
        if scheduler is not None:
            scheduler.step()
        loss_train = train(train_loader, model, optimizer, epoch,
                           args.batch_size)
        writer.add_scalar('loss_train', loss_train, epoch)

        is_best = False
        if args.evaluate:
            torch.cuda.empty_cache()
            loss_val, epe_val = validate(val_loader, model)
            writer.add_scalar('loss_val', loss_val, epoch)
            writer.add_scalar('epe_val', epe_val, epoch)
            is_best = epe_val < best_epe_all
            best_epe_all = min(epe_val, best_epe_all)

        filename = os.path.join(args.save_path, 'model_latest.pth')
        torch.save(
            {
                'epoch': epoch,
                'state_dict': model.cpu().state_dict(),
                'optimizer': optimizer.state_dict(),
                'best_epe_all': best_epe_all
            }, filename)
        model.cuda()

        if is_best:
            shutil.copyfile(filename,
                            os.path.join(args.save_path, 'model_best.pth'))

        if epoch % args.save_step == 0:
            shutil.copyfile(
                filename,
                args.save_path + '/train_epoch_' + str(epoch) + '.pth')
예제 #2
0
def main():
    global args, logger
    args = get_parser().parse_args()
    logger = get_logger()
    logger.info(args)
    logger.info("=> creating model ...")

    # get input image size and save name list
    # each line of data_list should contain image_0, image_1, (optional gt)
    with open(args.data_list, 'r') as f:
        fnames = f.readlines()
        assert len(fnames[0].strip().split(' ')) == 2 + args.evaluate
        names = [l.strip().split(' ')[0].split('/')[-1] for l in fnames]
        sub_folders = [
            l.strip().split(' ')[0][:-len(names[i])]
            for i, l in enumerate(fnames)
        ]
        names = [l.split('.')[0] for l in names]
        input_size = cv2.imread(join(args.data_root,
                                     fnames[0].split(' ')[0])).shape

    # transform
    mean = [0.485, 0.456, 0.406]
    std = [0.229, 0.224, 0.225]
    th, tw = get_target_size(input_size[0], input_size[1])
    val_transform = transforms.Compose(
        [transforms.ToTensor(),
         transforms.Normalize(mean=mean, std=std)])
    val_data = datasets.HD3Data(mode=args.task,
                                data_root=args.data_root,
                                data_list=args.data_list,
                                label_num=args.evaluate,
                                transform=val_transform,
                                out_size=True)
    val_loader = torch.utils.data.DataLoader(val_data,
                                             batch_size=args.batch_size,
                                             shuffle=False,
                                             num_workers=args.workers,
                                             pin_memory=True)

    corr_range = [4, 4, 4, 4, 4, 4]
    if args.task == 'flow':
        corr_range = corr_range[:5]
    model = models.HD3Model(args.task, args.encoder, args.decoder, corr_range,
                            args.context).cuda()
    logger.info(model)
    model = torch.nn.DataParallel(model).cuda()
    cudnn.enabled = True
    cudnn.benchmark = True

    if os.path.isfile(args.model_path):
        logger.info("=> loading checkpoint '{}'".format(args.model_path))
        checkpoint = torch.load(args.model_path)
        model.load_state_dict(checkpoint['state_dict'], strict=True)
        logger.info("=> loaded checkpoint '{}'".format(args.model_path))
    else:
        raise RuntimeError("=> no checkpoint found at '{}'".format(
            args.model_path))

    vis_folder = os.path.join(args.save_folder, 'vis')
    vec_folder = os.path.join(args.save_folder, 'vec')
    vec_folder_4 = os.path.join(args.save_folder, 'vec4')
    vec_folder_3 = os.path.join(args.save_folder, 'vec3')
    vec_folder_2 = os.path.join(args.save_folder, 'vec2')
    vec_folder_1 = os.path.join(args.save_folder, 'vec1')
    vec_folder_list = [
        vec_folder_1, vec_folder_2, vec_folder_3, vec_folder_4, vec_folder
    ]
    check_makedirs(vis_folder)
    # check_makedirs(vec_folder)
    for folder in vec_folder_list:
        check_makedirs(folder)

    # prob map folder
    prob_folder = os.path.join(args.save_folder, 'prob')
    check_makedirs(prob_folder)

    # start testing
    logger.info('>>>>>>>>>>>>>>>> Start Test >>>>>>>>>>>>>>>>')
    data_time = AverageMeter()
    batch_time = AverageMeter()
    avg_epe = AverageMeter()

    model.eval()
    end = time.time()
    with torch.no_grad():
        for i, (img_list, label_list, img_size) in enumerate(val_loader):
            data_time.update(time.time() - end)

            img_size = img_size.cpu().numpy()
            img_list = [img.to(torch.device("cuda")) for img in img_list]
            label_list = [
                label.to(torch.device("cuda")) for label in label_list
            ]

            # resize test
            resized_img_list = [
                F.interpolate(img, (th, tw),
                              mode='bilinear',
                              align_corners=True) for img in img_list
            ]
            output = model(img_list=resized_img_list,
                           label_list=label_list,
                           get_vect=True,
                           get_prob=True,
                           get_epe=args.evaluate)
            # scale_factor = 1 / 2**(7 - len(corr_range))
            # output['vect'] = resize_dense_vector(output['vect'] * scale_factor,
            #                                      img_size[0, 1],
            #                                      img_size[0, 0])

            for level_i in range(len(corr_range)):
                scale_factor = 1 / 2**(7 - level_i - 1)
                output['vect'][level_i] = resize_dense_vector(
                    output['vect'][level_i] * scale_factor, img_size[0, 1],
                    img_size[0, 0])

            output['prob'] = output['prob'].data.cpu().numpy()
            if args.evaluate:
                avg_epe.update(output['epe'].mean().data, img_list[0].size(0))

            batch_time.update(time.time() - end)
            end = time.time()

            if (i + 1) % 10 == 0:
                logger.info(
                    'Test: [{}/{}] '
                    'Data {data_time.val:.3f} ({data_time.avg:.3f}) '
                    'Batch {batch_time.val:.3f} ({batch_time.avg:.3f}).'.
                    format(i + 1,
                           len(val_loader),
                           data_time=data_time,
                           batch_time=batch_time))

            # pred_vect = output['vect'].data.cpu().numpy()
            # pred_vect = np.transpose(pred_vect, (0, 2, 3, 1))
            # curr_bs = pred_vect.shape[0]
            pred_vect_list = []
            for pred_v in output['vect']:
                pred_vect_list.append(
                    np.transpose(pred_v.data.cpu().numpy(), (0, 2, 3, 1)))

            curr_bs = pred_vect_list[0].shape[0]

            for idx in range(curr_bs):
                curr_idx = i * args.batch_size + idx
                # curr_vect = pred_vect[idx]
                curr_vect_list = [pred_v[idx] for pred_v in pred_vect_list]

                # make folders
                vis_sub_folder = join(vis_folder, sub_folders[curr_idx])
                # vec_sub_folder = join(vec_folder, sub_folders[curr_idx])
                vec_sub_folder_list = []
                for folder in vec_folder_list:
                    vec_sub_folder_list.append(
                        join(folder, sub_folders[curr_idx]))

                prob_sub_folder = join(prob_folder, sub_folders[curr_idx])
                check_makedirs(prob_sub_folder)
                check_makedirs(vis_sub_folder)
                # check_makedirs(vec_sub_folder)
                for folder in vec_sub_folder_list:
                    check_makedirs(folder)

                # save visualzation (disparity transformed to flow here)
                # vis_fn = join(vis_sub_folder, names[curr_idx] + '.png')
                # if args.task == 'flow':
                #     vis_flo = fl.flow_to_image(curr_vect)
                # else:
                #     vis_flo = fl.flow_to_image(fl.disp2flow(curr_vect))
                # vis_flo = cv2.cvtColor(vis_flo, cv2.COLOR_RGB2BGR) #TODO changed
                # cv2.imwrite(vis_fn, vis_flo)

                # save point estimates
                fn_suffix = 'png'
                if args.task == 'flow':
                    fn_suffix = args.flow_format
                # vect_fn = join(vec_sub_folder,
                #                names[curr_idx] + '.' + fn_suffix)
                vect_fn_list = []
                for folder in vec_sub_folder_list:
                    vect_fn_list.append(
                        join(folder, names[curr_idx] + '.' + fn_suffix))
                prob_fn = join(prob_sub_folder, names[curr_idx] + '.npy')
                np.save(prob_fn, output['prob'])

                if args.task == 'flow':
                    if fn_suffix == 'png':
                        # save png format flow
                        mask_blob = np.ones(
                            (img_size[idx][1], img_size[idx][0]),
                            dtype=np.uint16)
                        # fl.write_kitti_png_file(vect_fn, curr_vect, mask_blob)
                        for curr_vect, vect_f in zip(curr_vect_list,
                                                     vect_fn_list):
                            fl.write_kitti_png_file(vect_f, curr_vect,
                                                    mask_blob)

                    else:
                        # save flo format flow
                        # fl.write_flow(curr_vect, vect_fn)
                        for curr_vect, vect_f in zip(curr_vect_list,
                                                     vect_fn_list):
                            fl.write_flow(curr_vect, vect_f)
                else:
                    # save disparity map
                    cv2.imwrite(vect_fn,
                                np.uint16(-curr_vect[:, :, 0] * 256.0))

    if args.evaluate:
        logger.info('Average End Point Error {avg_epe.avg:.2f}'.format(
            avg_epe=avg_epe))

    logger.info('<<<<<<<<<<<<<<<<< End Test <<<<<<<<<<<<<<<<<')
def main():
    global args
    args = get_parser().parse_args()
    LOGGER.info(args)

    # Get input image size and save name list.
    # Each line of data_list should contain
    # image_0, image_1, (optional) ground truth, (optional) ground truth mask.
    with open(args.data_list, 'r') as file_list:
        fnames = file_list.readlines()
        assert len(
            fnames[0].strip().split(' ')
        ) == 2 + args.evaluate + args.evaluate * args.additional_flow_masks
        input_size = cv2.imread(
            os.path.join(args.data_root, fnames[0].split(' ')[0])).shape
        if args.visualize or args.save_inputs or args.save_refined:
            names = [l.strip().split(' ')[0].split('/')[-1] for l in fnames]
            sub_folders = [
                l.strip().split(' ')[0][:-len(names[i])]
                for i, l in enumerate(fnames)
            ]
            names = [l.split('.')[0] for l in names]

    # Prepare data.
    mean = [0.485, 0.456, 0.406]
    std = [0.229, 0.224, 0.225]
    target_height, target_width = get_target_size(input_size[0], input_size[1])
    transform = transforms.Compose(
        [transforms.ToTensor(),
         transforms.Normalize(mean=mean, std=std)])
    data = hd3data.HD3Data(
        mode='flow',
        data_root=args.data_root,
        data_list=args.data_list,
        label_num=args.evaluate + args.evaluate * args.additional_flow_masks,
        transform=transform,
        out_size=True)
    data_loader = torch.utils.data.DataLoader(
        data,
        batch_size=args.batch_size,
        shuffle=False,
        num_workers=args.workers,
        pin_memory=True)

    # Setup models.
    model_hd3 = hd3model.HD3Model('flow', args.encoder, args.decoder,
                                  [4, 4, 4, 4, 4], args.context).cuda()
    model_hd3 = torch.nn.DataParallel(model_hd3).cuda()
    model_hd3.eval()

    refinement_network = PPacNet(
        args.kernel_size_preprocessing, args.kernel_size_joint,
        args.conv_specification, args.shared_filters, args.depth_layers_prob,
        args.depth_layers_guidance, args.depth_layers_joint)
    model_refine = refinement_models.EpeNet(refinement_network).cuda()
    model_refine = torch.nn.DataParallel(model_refine).cuda()
    model_refine.eval()

    # Load indicated models.
    name_hd3_model = args.model_hd3_path
    if os.path.isfile(name_hd3_model):
        checkpoint = torch.load(name_hd3_model)
        model_hd3.load_state_dict(checkpoint['state_dict'])
        LOGGER.info("Loaded HD3 checkpoint '{}'".format(name_hd3_model))
    else:
        LOGGER.info("No checkpoint found at '{}'".format(name_hd3_model))

    name_refinement_model = args.model_refine_path
    if os.path.isfile(name_refinement_model):
        checkpoint = torch.load(name_refinement_model)
        model_refine.load_state_dict(checkpoint['state_dict'])
        LOGGER.info(
            "Loaded refinement checkpoint '{}'".format(name_refinement_model))
    else:
        LOGGER.info(
            "No checkpoint found at '{}'".format(name_refinement_model))

    if args.evaluate:
        epe_hd3 = utils.AverageMeter()
        outliers_hd3 = utils.AverageMeter()
        epe_refined = utils.AverageMeter()
        outliers_refined = utils.AverageMeter()

    if args.visualize:
        visualization_folder = os.path.join(args.save_folder, 'visualizations')
        utils.check_makedirs(visualization_folder)

    if args.save_inputs:
        input_folder = os.path.join(args.save_folder, 'hd3_inputs')
        utils.check_makedirs(input_folder)

    if args.save_refined:
        refined_folder = os.path.join(args.save_folder, 'refined_flow')
        utils.check_makedirs(refined_folder)

    # Start inference.
    with torch.no_grad():
        for i, (img_list, label_list, img_size) in enumerate(data_loader):
            if i % 10 == 0:
                LOGGER.info('Done with {}/{} samples'.format(
                    i, len(data_loader)))

            img_size = img_size.cpu().numpy()
            img_list = [img.to(torch.device("cuda")) for img in img_list]
            label_list = [
                label.to(torch.device("cuda")) for label in label_list
            ]

            # Resize input images.
            resized_img_list = [
                torch.nn.functional.interpolate(
                    img, (target_height, target_width),
                    mode='bilinear',
                    align_corners=True) for img in img_list
            ]

            # Get HD3 flow.
            output = model_hd3(
                img_list=resized_img_list,
                label_list=label_list,
                get_full_vect=True,
                get_full_prob=True,
                get_epe=args.evaluate)

            # Upscale flow to full resolution.
            for level, level_flow in enumerate(output['full_vect']):
                scale_factor = 1 / 2**(6 - level)
                output['full_vect'][level] = resize_dense_vector(
                    level_flow * scale_factor, img_size[0, 1], img_size[0, 0])
            hd3_flow = output['full_vect'][-1]

            # Evaluate HD3 output if required.
            if args.evaluate:
                epe_hd3.update(
                    losses.endpoint_error(hd3_flow, label_list[0]).mean().data,
                    hd3_flow.size(0))
                outliers_hd3.update(
                    losses.outlier_rate(hd3_flow, label_list[0]).mean().data,
                    hd3_flow.size(0))

            # Upscale and interpolate flow probabilities.
            probabilities = prob_utils.get_upsampled_probabilities_hd3(
                output['full_vect'], output['full_prob'])

            if args.save_inputs:
                save_hd3_inputs(
                    hd3_flow, probabilities, input_folder,
                    sub_folders[i * args.batch_size:(i + 1) * args.batch_size],
                    names[i * args.batch_size:(i + 1) * args.batch_size])
                continue

            # Refine flow with PPAC network.
            log_probabilities = prob_utils.safe_log(probabilities)
            output_refine = model_refine(
                hd3_flow,
                log_probabilities,
                img_list[0],
                label_list=label_list,
                get_loss=args.evaluate,
                get_epe=args.evaluate,
                get_outliers=args.evaluate)

            # Evaluate refined output if required
            if args.evaluate:
                epe_refined.update(output_refine['epe'].mean().data,
                                   hd3_flow.size(0))
                outliers_refined.update(output_refine['outliers'].mean().data,
                                        hd3_flow.size(0))

            # Save visualizations of optical flow if required.
            if args.visualize:
                refined_flow = output_refine['flow']
                ground_truth = None
                if args.evaluate:
                    ground_truth = label_list[0][:, :2]
                save_visualizations(
                    hd3_flow, refined_flow, ground_truth, visualization_folder,
                    sub_folders[i * args.batch_size:(i + 1) * args.batch_size],
                    names[i * args.batch_size:(i + 1) * args.batch_size])

            # Save refined optical flow if required.
            if args.save_refined:
                refined_flow = output_refine['flow']
                save_refined_flow(
                    refined_flow, refined_folder,
                    sub_folders[i * args.batch_size:(i + 1) * args.batch_size],
                    names[i * args.batch_size:(i + 1) * args.batch_size])

    if args.evaluate:
        LOGGER.info(
            'Accuracy of HD3 optical flow:      '
            'AEE={epe_hd3.avg:.4f}, Outliers={outliers_hd3.avg:.4f}'.format(
                epe_hd3=epe_hd3, outliers_hd3=outliers_hd3))
        if not args.save_inputs:
            LOGGER.info(
                'Accuracy of refined optical flow:  '
                'AEE={epe_refined.avg:.4f}, Outliers={outliers_refined.avg:.4f}'
                .format(
                    epe_refined=epe_refined,
                    outliers_refined=outliers_refined))