def main():
    config = vars(parse_args())

    if config['name'] is None:
        config['name'] = 'wpf_%s' % datetime.now().strftime('%m%d%H')

    config['models'] = config['models'].split(',')

    if config['weights'] is not None:
        config['weights'] = [float(s) for s in config['weights'].split(',')]

    if not os.path.exists('models/detection/%s' % config['name']):
        os.makedirs('models/detection/%s' % config['name'])

    with open('models/detection/%s/config.yml' % config['name'], 'w') as f:
        yaml.dump(config, f)

    print('-' * 20)
    for key in config.keys():
        print('%s: %s' % (key, str(config[key])))
    print('-' * 20)

    df_list = [
        pd.read_csv('outputs/submissions/test/%s.csv' % p).fillna('')
        for p in config['models']
    ]
    new_df = pd.read_csv('inputs/sample_submission.csv')
    img_paths = np.array('inputs/test_images/' + new_df['ImageId'].values +
                         '.jpg')

    cnt = 0
    for i in tqdm(range(len(new_df))):
        dets_list = []
        for df in df_list:
            dets = np.array(df.loc[i, 'PredictionString'].split()).reshape(
                [-1, 7]).astype('float')
            dets[..., -1] = np.argsort(dets[..., -1]) + 1
            print(dets)
            dets[..., -1] /= np.sum(dets[..., -1])
            print(dets)
            dets_list.append(dets)
        dets = wpf(dets_list,
                   dist_th=config['dist_th'],
                   skip_det_th=config['skip_det_th'],
                   weights=config['weights'])
        dets = dets[dets[:, 6] > config['score_th']]
        cnt += len(dets)

        if config['show']:
            img = cv2.imread(img_paths[i])
            img_pred = visualize(img, dets)
            plt.imshow(img_pred[..., ::-1])
            plt.show()

        new_df.loc[i, 'PredictionString'] = convert_labels_to_str(dets)

    print('Number of cars: %d' % cnt)

    new_df.to_csv('outputs/submissions/test/%s.csv' % config['name'],
                  index=False)
示例#2
0
def main():
    args = parse_args()

    df = pd.read_csv('outputs/submissions/test/%s.csv' % args.name).fillna('')
    img_ids = df['ImageId'].values
    img_paths = np.array('inputs/test_images/' + df['ImageId'].values + '.jpg')
    if args.uncropped:
        cropped_img_ids = pd.read_csv(
            'inputs/testset_cropped_imageids.csv')['ImageId'].values
        for i, img_id in enumerate(img_ids):
            if img_id in cropped_img_ids:
                img_paths[
                    i] = 'inputs/test_images_uncropped/' + img_id + '.jpg'

    os.makedirs(os.path.join('tmp', args.name), exist_ok=True)
    for i in tqdm(range(len(df))):
        dets = np.array(df.loc[i, 'PredictionString'].split()).reshape(
            [-1, 7]).astype('float')

        img = cv2.imread(img_paths[i])
        img_pred = visualize(img, dets)
        if not args.write:
            plt.imshow(img_pred[..., ::-1])
            plt.show()
        else:
            cv2.imwrite(
                os.path.join('tmp', args.name, os.path.basename(img_paths[i])),
                img_pred)
def main():
    config = vars(parse_args())

    if config['name'] is None:
        config['name'] = 'ensemble_%s' % datetime.now().strftime('%m%d%H')

    if os.path.exists('models/detection/%s/config.yml' % config['name']):
        with open('models/detection/%s/config.yml' % config['name'], 'r') as f:
            config = yaml.load(f, Loader=yaml.FullLoader)
    else:
        config['models'] = config['models'].split(',')

    if not os.path.exists('models/detection/%s' % config['name']):
        os.makedirs('models/detection/%s' % config['name'])

    with open('models/detection/%s/config.yml' % config['name'], 'w') as f:
        yaml.dump(config, f)

    print('-'*20)
    for key in config.keys():
        print('%s: %s' % (key, str(config[key])))
    print('-'*20)

    with open('models/detection/%s/config.yml' % config['models'][0], 'r') as f:
        model_config = yaml.load(f, Loader=yaml.FullLoader)

    df = pd.read_csv('inputs/train.csv')
    img_paths = np.array('inputs/train_images/' + df['ImageId'].values + '.jpg')
    img_ids = df['ImageId'].values
    mask_paths = np.array('inputs/train_masks/' + df['ImageId'].values + '.jpg')
    labels = np.array([convert_str_to_labels(s, names=['yaw', 'pitch', 'roll',
                       'x', 'y', 'z', 'score']) for s in df['PredictionString']])

    dets = {}
    kf = KFold(n_splits=model_config['n_splits'], shuffle=True, random_state=41)
    for fold, (train_idx, val_idx) in enumerate(kf.split(img_paths)):
        val_img_ids = img_ids[val_idx]

        if os.path.exists('outputs/raw/val/%s.pth' %config['name']):
            merged_outputs = torch.load('outputs/raw/val/%s.pth' %config['name'])

        else:
            merged_outputs = {}
            for img_id in tqdm(val_img_ids, total=len(val_img_ids)):
                output = {
                    'hm': 0,
                    'reg': 0,
                    'depth': 0,
                    'eular': 0 if model_config['rot'] == 'eular' else None,
                    'trig': 0 if model_config['rot'] == 'trig' else None,
                    'quat': 0 if model_config['rot'] == 'quat' else None,
                    'wh': 0 if model_config['wh'] else None,
                    'mask': 0,
                }

                merged_outputs[img_id] = output

            for model_name in config['models']:
                outputs = torch.load('outputs/raw/val/%s_%d.pth' %(model_name, fold + 1))

                for img_id in tqdm(val_img_ids, total=len(val_img_ids)):
                    output = outputs[img_id]

                    merged_outputs[img_id]['hm'] += output['hm'] / len(config['models'])
                    merged_outputs[img_id]['reg'] += output['reg'] / len(config['models'])
                    merged_outputs[img_id]['depth'] += output['depth'] / len(config['models'])
                    merged_outputs[img_id]['trig'] += output['trig'] / len(config['models'])
                    merged_outputs[img_id]['wh'] += output['wh'] / len(config['models'])
                    merged_outputs[img_id]['mask'] += output['mask'] / len(config['models'])

            torch.save(merged_outputs, 'outputs/raw/val/%s_%d.pth' %(config['name'], fold + 1))

        # decode
        for img_id in tqdm(val_img_ids, total=len(val_img_ids)):
            output = merged_outputs[img_id]

            det = decode(
                model_config,
                output['hm'],
                output['reg'],
                output['depth'],
                eular=output['eular'] if model_config['rot'] == 'eular' else None,
                trig=output['trig'] if model_config['rot'] == 'trig' else None,
                quat=output['quat'] if model_config['rot'] == 'quat' else None,
                wh=output['wh'] if model_config['wh'] else None,
                mask=output['mask'],
            )
            det = det.numpy()[0]

            dets[img_id] = det.tolist()

            if config['nms']:
                det = nms(det, dist_th=config['nms_th'])

            if np.sum(det[:, 6] > config['score_th']) >= config['min_samples']:
                det = det[det[:, 6] > config['score_th']]
            else:
                det = det[:config['min_samples']]

            if config['show']:
                img = cv2.imread('inputs/train_images/%s.jpg' %img_id)
                img_pred = visualize(img, det)
                plt.imshow(img_pred[..., ::-1])
                plt.show()

            df.loc[df.ImageId == img_id, 'PredictionString'] = convert_labels_to_str(det[:, :7])

    with open('outputs/decoded/val/%s.json' %config['name'], 'w') as f:
        json.dump(dets, f)

    df.to_csv('outputs/submissions/val/%s.csv' %config['name'], index=False)
示例#4
0
def main():
    args = parse_args()
    args.uncropped = True

    with open('models/detection/%s/config.yml' % args.name, 'r') as f:
        config = yaml.load(f, Loader=yaml.FullLoader)

    # config["tvec"] = False
    print('-'*20)
    for key in config.keys():
        print('%s: %s' % (key, str(config[key])))
    print('-'*20)

    cudnn.benchmark = False

    df = pd.read_csv('inputs/sample_submission.csv')
    img_ids = df['ImageId'].values
    img_paths = np.array('inputs/test_images/' + df['ImageId'].values + '.jpg')
    mask_paths = np.array('inputs/test_masks/' + df['ImageId'].values + '.jpg')
    labels = np.array([convert_str_to_labels(s, names=['yaw', 'pitch', 'roll',
                       'x', 'y', 'z', 'score']) for s in df['PredictionString']])

    if not args.uncropped:
        cropped_img_ids = pd.read_csv('inputs/testset_cropped_imageids.csv')['ImageId'].values
        for i, img_id in enumerate(img_ids):
            if img_id in cropped_img_ids:
                img_paths[i] = 'inputs/test_images_uncropped/' + img_id + '.jpg'
                mask_paths[i] = 'inputs/test_masks_uncropped/' + img_id + '.jpg'

    test_set = Dataset(
        img_paths,
        mask_paths,
        labels,
        input_w=config['input_w'],
        input_h=config['input_h'],
        transform=None,
        test=True,
        lhalf=config['lhalf'])
    test_loader = torch.utils.data.DataLoader(
        test_set,
        batch_size=16,
        shuffle=False,
        num_workers=0,
        # num_workers=config['num_workers'],
        # pin_memory=True,
    )

    heads = OrderedDict([
        ('hm', 1),
        ('reg', 2),
        ('depth', 1),
    ])

    if config['rot'] == 'eular':
        heads['eular'] = 3
    elif config['rot'] == 'trig':
        heads['trig'] = 6
    elif config['rot'] == 'quat':
        heads['quat'] = 4
    else:
        raise NotImplementedError

    if config['wh']:
        heads['wh'] = 2
    
    if config['tvec']:
        heads['tvec'] = 3

    name = args.name
    if args.uncropped:
        name += '_uncropped'
    if args.hflip:
        name += '_hf'

    if os.path.exists('outputs/raw/test/%s.pth' %name):
        merged_outputs = torch.load('outputs/raw/test/%s.pth' %name)

    else:
        merged_outputs = {}
        for i in tqdm(range(len(df))):
            img_id = df.loc[i, 'ImageId']

            output = {
                'hm': 0,
                'reg': 0,
                'depth': 0,
                'eular': 0 if config['rot'] == 'eular' else None,
                'trig': 0 if config['rot'] == 'trig' else None,
                'quat': 0 if config['rot'] == 'quat' else None,
                'wh': 0 if config['wh'] else None,
                'tvec': 0 if config['tvec'] else None,
            }

            merged_outputs[img_id] = output

        preds = []
        for fold in range(config['n_splits']):
            print('Fold [%d/%d]' %(fold + 1, config['n_splits']))

            model = get_model(config['arch'], heads=heads,
                              head_conv=config['head_conv'],
                              num_filters=config['num_filters'],
                              dcn=config['dcn'],
                              gn=config['gn'], ws=config['ws'],
                              freeze_bn=config['freeze_bn'])
            model = model.cuda()

            model_path = 'models/detection/%s/model_%d.pth' % (config['name'], fold+1)
            if not os.path.exists(model_path):
                print('%s is not exists.' %model_path)
                continue
            model.load_state_dict(torch.load(model_path))

            model.eval()

            preds_fold = []
            outputs_fold = {}
            with torch.no_grad():
                pbar = tqdm(total=len(test_loader))
                for i, batch in enumerate(test_loader):
                    input = batch['input'].cuda()
                    mask = batch['mask'].cuda()

                    output = model(input)
                    # print(output)

                    if args.hflip:
                        output_hf = model(torch.flip(input, (-1,)))
                        output_hf['hm'] = torch.flip(output_hf['hm'], (-1,))
                        output_hf['reg'] = torch.flip(output_hf['reg'], (-1,))
                        output_hf['reg'][:, 0] = 1 - output_hf['reg'][:, 0]
                        output_hf['depth'] = torch.flip(output_hf['depth'], (-1,))
                        if config['rot'] == 'trig':
                            output_hf['trig'] = torch.flip(output_hf['trig'], (-1,))
                            yaw = torch.atan2(output_hf['trig'][:, 1], output_hf['trig'][:, 0])
                            yaw *= -1.0
                            output_hf['trig'][:, 0] = torch.cos(yaw)
                            output_hf['trig'][:, 1] = torch.sin(yaw)
                            roll = torch.atan2(output_hf['trig'][:, 5], output_hf['trig'][:, 4])
                            roll = rotate(roll, -np.pi)
                            roll *= -1.0
                            roll = rotate(roll, np.pi)
                            output_hf['trig'][:, 4] = torch.cos(roll)
                            output_hf['trig'][:, 5] = torch.sin(roll)

                        if config['wh']:
                            output_hf['wh'] = torch.flip(output_hf['wh'], (-1,))
                        
                        if config['tvec']:
                            output_hf['tvec'] = torch.flip(output_hf['tvec'], (-1,))
                            output_hf['tvec'][:, 0] *= -1.0

                        output['hm'] = (output['hm'] + output_hf['hm']) / 2
                        output['reg'] = (output['reg'] + output_hf['reg']) / 2
                        output['depth'] = (output['depth'] + output_hf['depth']) / 2
                        if config['rot'] == 'trig':
                            output['trig'] = (output['trig'] + output_hf['trig']) / 2
                        if config['wh']:
                            output['wh'] = (output['wh'] + output_hf['wh']) / 2
                        if config['tvec']:
                            output['tvec'] = (output['tvec'] + output_hf['tvec']) / 2

                    for b in range(len(batch['img_path'])):
                        img_id = os.path.splitext(os.path.basename(batch['img_path'][b]))[0]

                        outputs_fold[img_id] = {
                            'hm': output['hm'][b:b+1].cpu(),
                            'reg': output['reg'][b:b+1].cpu(),
                            'depth': output['depth'][b:b+1].cpu(),
                            'eular': output['eular'][b:b+1].cpu() if config['rot'] == 'eular' else None,
                            'trig': output['trig'][b:b+1].cpu() if config['rot'] == 'trig' else None,
                            'quat': output['quat'][b:b+1].cpu() if config['rot'] == 'quat' else None,
                            'wh': output['wh'][b:b+1].cpu() if config['wh'] else None,
                            'tvec': output['tvec'][b:b+1].cpu() if config['tvec'] else None,
                            'mask': mask[b:b+1].cpu(),
                        }

                        merged_outputs[img_id]['hm'] += outputs_fold[img_id]['hm'] / config['n_splits']
                        merged_outputs[img_id]['reg'] += outputs_fold[img_id]['reg'] / config['n_splits']
                        merged_outputs[img_id]['depth'] += outputs_fold[img_id]['depth'] / config['n_splits']
                        if config['rot'] == 'eular':
                            merged_outputs[img_id]['eular'] += outputs_fold[img_id]['eular'] / config['n_splits']
                        if config['rot'] == 'trig':
                            merged_outputs[img_id]['trig'] += outputs_fold[img_id]['trig'] / config['n_splits']
                        if config['rot'] == 'quat':
                            merged_outputs[img_id]['quat'] += outputs_fold[img_id]['quat'] / config['n_splits']
                        if config['wh']:
                            merged_outputs[img_id]['wh'] += outputs_fold[img_id]['wh'] / config['n_splits']
                        if config['tvec']:
                            merged_outputs[img_id]['tvec'] += outputs_fold[img_id]['tvec'] / config['n_splits']
                        merged_outputs[img_id]['mask'] = outputs_fold[img_id]['mask']

                    batch_det = decode(
                        config,
                        output['hm'],
                        output['reg'],
                        output['depth'],
                        eular=output['eular'] if config['rot'] == 'eular' else None,
                        trig=output['trig'] if config['rot'] == 'trig' else None,
                        quat=output['quat'] if config['rot'] == 'quat' else None,
                        wh=output['wh'] if config['wh'] else None,
                        tvec=output['tvec'] if config['tvec'] else None,
                        mask=mask,
                    )
                    batch_det = batch_det.cpu().numpy()

                    for k, det in enumerate(batch_det):
                        if args.nms:
                            det = nms(det, dist_th=args.nms_th)
                        preds_fold.append(convert_labels_to_str(det[det[:, 6] > args.score_th, :7]))

                        if args.show and not config['cv']:
                            img = cv2.imread(batch['img_path'][k])
                            img_pred = visualize(img, det[det[:, 6] > args.score_th])
                            plt.imshow(img_pred[..., ::-1])
                            plt.show()

                    pbar.update(1)
                pbar.close()

            if not config['cv']:
                df['PredictionString'] = preds_fold
                name = '%s_1_%.2f' %(args.name, args.score_th)
                if args.uncropped:
                    name += '_uncropped'
                if args.nms:
                    name += '_nms%.2f' %args.nms_th
                df.to_csv('outputs/submissions/test/%s.csv' %name, index=False)
                return

        if not args.uncropped:
            # ensemble duplicate images
            dup_df = pd.read_csv('processed/test_image_hash.csv')
            dups = dup_df.hash.value_counts()
            dups = dups.loc[dups>1]

            for i in range(len(dups)):
                img_ids = dup_df[dup_df.hash == dups.index[i]].ImageId

                output = {
                    'hm': 0,
                    'reg': 0,
                    'depth': 0,
                    'eular': 0 if config['rot'] == 'eular' else None,
                    'trig': 0 if config['rot'] == 'trig' else None,
                    'quat': 0 if config['rot'] == 'quat' else None,
                    'wh': 0 if config['wh'] else None,
                    'tvec': 0 if config['tvec'] else None,
                    'mask': 0,
                }
                for img_id in img_ids:
                    if img_id in cropped_img_ids:
                        print('fooo')
                    output['hm'] += merged_outputs[img_id]['hm'] / len(img_ids)
                    output['reg'] += merged_outputs[img_id]['reg'] / len(img_ids)
                    output['depth'] += merged_outputs[img_id]['depth'] / len(img_ids)
                    if config['rot'] == 'eular':
                        output['eular'] += merged_outputs[img_id]['eular'] / len(img_ids)
                    if config['rot'] == 'trig':
                        output['trig'] += merged_outputs[img_id]['trig'] / len(img_ids)
                    if config['rot'] == 'quat':
                        output['quat'] += merged_outputs[img_id]['quat'] / len(img_ids)
                    if config['wh']:
                        output['wh'] += merged_outputs[img_id]['wh'] / len(img_ids)
                    if config['tvec']:
                        output['tvec'] += merged_outputs[img_id]['tvec'] / len(img_ids)
                    output['mask'] += merged_outputs[img_id]['mask'] / len(img_ids)

                for img_id in img_ids:
                    merged_outputs[img_id] = output

        torch.save(merged_outputs, 'outputs/raw/test/%s.pth' %name)

    # decode
    dets = {}
    for i in tqdm(range(len(df))):
        img_id = df.loc[i, 'ImageId']

        output = merged_outputs[img_id]

        det = decode(
            config,
            output['hm'],
            output['reg'],
            output['depth'],
            eular=output['eular'] if config['rot'] == 'eular' else None,
            trig=output['trig'] if config['rot'] == 'trig' else None,
            quat=output['quat'] if config['rot'] == 'quat' else None,
            wh=output['wh'] if config['wh'] else None,
            tvec=output['tvec'] if config['tvec'] else None,
            mask=output['mask'],
        )
        det = det.numpy()[0]

        dets[img_id] = det.tolist()

        if args.nms:
            det = nms(det, dist_th=args.nms_th)

        if np.sum(det[:, 6] > args.score_th) >= args.min_samples:
            det = det[det[:, 6] > args.score_th]
        else:
            det = det[:args.min_samples]

        if args.show:
            img = cv2.imread('inputs/test_images/%s.jpg' %img_id)
            img_pred = visualize(img, det)
            plt.imshow(img_pred[..., ::-1])
            plt.show()

        df.loc[i, 'PredictionString'] = convert_labels_to_str(det[:, :7])

    with open('outputs/decoded/test/%s.json' %name, 'w') as f:
        json.dump(dets, f)

    name = '%s_%.2f' %(args.name, args.score_th)
    if args.uncropped:
        name += '_uncropped'
    if args.nms:
        name += '_nms%.2f' %args.nms_th
    if args.hflip:
        name += '_hf'
    if args.min_samples > 0:
        name += '_min%d' %args.min_samples
    df.to_csv('outputs/submissions/test/%s.csv' %name, index=False)
示例#5
0
def main():
    args = parse_args()

    with open('models/detection/%s/config.yml' % args.name, 'r') as f:
        config = yaml.load(f, Loader=yaml.FullLoader)

    print('-' * 20)
    for key in config.keys():
        print('%s: %s' % (key, str(config[key])))
    print('-' * 20)

    cudnn.benchmark = True

    df = pd.read_csv('inputs/train.csv')
    img_paths = np.array('inputs/train_images/' + df['ImageId'].values +
                         '.jpg')
    mask_paths = np.array('inputs/train_masks/' + df['ImageId'].values +
                          '.jpg')
    labels = np.array(
        [convert_str_to_labels(s) for s in df['PredictionString']])

    heads = OrderedDict([
        ('hm', 1),
        ('reg', 2),
        ('depth', 1),
    ])

    if config['rot'] == 'eular':
        heads['eular'] = 3
    elif config['rot'] == 'trig':
        heads['trig'] = 6
    elif config['rot'] == 'quat':
        heads['quat'] = 4
    else:
        raise NotImplementedError

    if config['wh']:
        heads['wh'] = 2

    pred_df = df.copy()
    pred_df['PredictionString'] = np.nan

    dets = {}
    kf = KFold(n_splits=config['n_splits'], shuffle=True, random_state=41)
    for fold, (train_idx, val_idx) in enumerate(kf.split(img_paths)):
        print('Fold [%d/%d]' % (fold + 1, config['n_splits']))

        train_img_paths, val_img_paths = img_paths[train_idx], img_paths[
            val_idx]
        train_mask_paths, val_mask_paths = mask_paths[train_idx], mask_paths[
            val_idx]
        train_labels, val_labels = labels[train_idx], labels[val_idx]

        val_set = Dataset(val_img_paths,
                          val_mask_paths,
                          val_labels,
                          input_w=config['input_w'],
                          input_h=config['input_h'],
                          transform=None,
                          lhalf=config['lhalf'])
        val_loader = torch.utils.data.DataLoader(
            val_set,
            batch_size=config['batch_size'],
            shuffle=False,
            num_workers=config['num_workers'],
            # pin_memory=True,
        )

        model = get_model(config['arch'],
                          heads=heads,
                          head_conv=config['head_conv'],
                          num_filters=config['num_filters'],
                          dcn=config['dcn'],
                          gn=config['gn'],
                          ws=config['ws'],
                          freeze_bn=config['freeze_bn'])
        model = model.cuda()

        model_path = 'models/detection/%s/model_%d.pth' % (config['name'],
                                                           fold + 1)
        if not os.path.exists(model_path):
            print('%s is not exists.' % model_path)
            continue
        model.load_state_dict(torch.load(model_path))

        model.eval()

        outputs = {}

        with torch.no_grad():
            pbar = tqdm(total=len(val_loader))
            for i, batch in enumerate(val_loader):
                input = batch['input'].cuda()
                mask = batch['mask'].cuda()
                hm = batch['hm'].cuda()
                reg_mask = batch['reg_mask'].cuda()

                output = model(input)

                if args.hflip:
                    output_hf = model(torch.flip(input, (-1, )))
                    output_hf['hm'] = torch.flip(output_hf['hm'], (-1, ))
                    output_hf['reg'] = torch.flip(output_hf['reg'], (-1, ))
                    output_hf['reg'][:, 0] = 1 - output_hf['reg'][:, 0]
                    output_hf['depth'] = torch.flip(output_hf['depth'], (-1, ))
                    if config['rot'] == 'trig':
                        output_hf['trig'] = torch.flip(output_hf['trig'],
                                                       (-1, ))
                        yaw = torch.atan2(output_hf['trig'][:, 1],
                                          output_hf['trig'][:, 0])
                        yaw *= -1.0
                        output_hf['trig'][:, 0] = torch.cos(yaw)
                        output_hf['trig'][:, 1] = torch.sin(yaw)
                        roll = torch.atan2(output_hf['trig'][:, 5],
                                           output_hf['trig'][:, 4])
                        roll = rotate(roll, -np.pi)
                        roll *= -1.0
                        roll = rotate(roll, np.pi)
                        output_hf['trig'][:, 4] = torch.cos(roll)
                        output_hf['trig'][:, 5] = torch.sin(roll)

                    if config['wh']:
                        output_hf['wh'] = torch.flip(output_hf['wh'], (-1, ))

                    output['hm'] = (output['hm'] + output_hf['hm']) / 2
                    output['reg'] = (output['reg'] + output_hf['reg']) / 2
                    output['depth'] = (output['depth'] +
                                       output_hf['depth']) / 2
                    if config['rot'] == 'trig':
                        output['trig'] = (output['trig'] +
                                          output_hf['trig']) / 2
                    if config['wh']:
                        output['wh'] = (output['wh'] + output_hf['wh']) / 2

                batch_det = decode(
                    config,
                    output['hm'],
                    output['reg'],
                    output['depth'],
                    eular=output['eular']
                    if config['rot'] == 'eular' else None,
                    trig=output['trig'] if config['rot'] == 'trig' else None,
                    quat=output['quat'] if config['rot'] == 'quat' else None,
                    wh=output['wh'] if config['wh'] else None,
                    mask=mask,
                )
                batch_det = batch_det.cpu().numpy()

                for k, det in enumerate(batch_det):
                    img_id = os.path.splitext(
                        os.path.basename(batch['img_path'][k]))[0]

                    outputs[img_id] = {
                        'hm':
                        output['hm'][k:k + 1].cpu(),
                        'reg':
                        output['reg'][k:k + 1].cpu(),
                        'depth':
                        output['depth'][k:k + 1].cpu(),
                        'eular':
                        output['eular'][k:k + 1].cpu()
                        if config['rot'] == 'eular' else None,
                        'trig':
                        output['trig'][k:k + 1].cpu()
                        if config['rot'] == 'trig' else None,
                        'quat':
                        output['quat'][k:k + 1].cpu()
                        if config['rot'] == 'quat' else None,
                        'wh':
                        output['wh'][k:k + 1].cpu() if config['wh'] else None,
                        'mask':
                        mask[k:k + 1].cpu(),
                    }

                    dets[img_id] = det.tolist()
                    if args.nms:
                        det = nms(det, dist_th=args.nms_th)
                    pred_df.loc[pred_df.ImageId == img_id,
                                'PredictionString'] = convert_labels_to_str(
                                    det[det[:, 6] > args.score_th, :7])

                    if args.show:
                        gt = batch['gt'].numpy()[k]

                        img = cv2.imread(batch['img_path'][k])
                        img_gt = visualize(img, gt[gt[:, -1] > 0])
                        img_pred = visualize(img,
                                             det[det[:, 6] > args.score_th])

                        plt.subplot(121)
                        plt.imshow(img_gt[..., ::-1])
                        plt.subplot(122)
                        plt.imshow(img_pred[..., ::-1])
                        plt.show()

                pbar.update(1)
            pbar.close()

        torch.save(outputs,
                   'outputs/raw/val/%s_%d.pth' % (args.name, fold + 1))

        torch.cuda.empty_cache()

        if not config['cv']:
            break

    with open('outputs/decoded/val/%s.json' % args.name, 'w') as f:
        json.dump(dets, f)

    name = '%s_%.2f' % (args.name, args.score_th)
    if args.nms:
        name += '_nms%.2f' % args.nms_th
    if args.hflip:
        name += '_hf'
    pred_df.to_csv('outputs/submissions/val/%s.csv' % name, index=False)
    print(pred_df.head())
def main():
    args = parse_args()

    with open('models/pose/%s/config.yml' % args.pose_name, 'r') as f:
        config = yaml.load(f, Loader=yaml.FullLoader)

    print('-' * 20)
    for key in config.keys():
        print('%s: %s' % (key, str(config[key])))
    print('-' * 20)

    cudnn.benchmark = True

    df = pd.read_csv('inputs/train.csv')
    img_ids = df['ImageId'].values
    img_paths = np.array('inputs/train_images/' + df['ImageId'].values +
                         '.jpg')
    mask_paths = np.array('inputs/train_masks/' + df['ImageId'].values +
                          '.jpg')
    labels = np.array(
        [convert_str_to_labels(s) for s in df['PredictionString']])
    with open('outputs/decoded/val/%s.json' % args.det_name, 'r') as f:
        dets = json.load(f)

    if config['rot'] == 'eular':
        num_outputs = 3
    elif config['rot'] == 'trig':
        num_outputs = 6
    elif config['rot'] == 'quat':
        num_outputs = 4
    else:
        raise NotImplementedError

    test_transform = Compose([
        transforms.Resize(config['input_w'], config['input_h']),
        transforms.Normalize(),
        ToTensor(),
    ])

    det_df = {
        'ImageId': [],
        'img_path': [],
        'det': [],
        'mask': [],
    }

    name = '%s_%.2f' % (args.det_name, args.score_th)
    if args.nms:
        name += '_nms%.2f' % args.nms_th

    output_dir = 'processed/pose_images/val/%s' % name
    os.makedirs(output_dir, exist_ok=True)

    df = []
    kf = KFold(n_splits=config['n_splits'], shuffle=True, random_state=41)
    for fold, (train_idx, val_idx) in enumerate(kf.split(img_paths)):
        print('Fold [%d/%d]' % (fold + 1, config['n_splits']))

        # create model
        model = get_pose_model(config['arch'],
                               num_outputs=num_outputs,
                               freeze_bn=config['freeze_bn'])
        model = model.cuda()

        model_path = 'models/pose/%s/model_%d.pth' % (config['name'], fold + 1)
        if not os.path.exists(model_path):
            print('%s is not exists.' % model_path)
            continue
        model.load_state_dict(torch.load(model_path))

        model.eval()

        val_img_ids = img_ids[val_idx]
        val_img_paths = img_paths[val_idx]

        fold_det_df = {
            'ImageId': [],
            'img_path': [],
            'det': [],
            'mask': [],
        }

        for img_id, img_path in tqdm(zip(val_img_ids, val_img_paths),
                                     total=len(val_img_ids)):
            img = cv2.imread(img_path)
            height, width = img.shape[:2]

            det = np.array(dets[img_id])
            det = det[det[:, 6] > args.score_th]
            if args.nms:
                det = nms(det, dist_th=args.nms_th)

            for k in range(len(det)):
                pitch, yaw, roll, x, y, z, score, w, h = det[k]

                fold_det_df['ImageId'].append(img_id)
                fold_det_df['det'].append(det[k])
                output_path = '%s_%d.jpg' % (img_id, k)
                fold_det_df['img_path'].append(output_path)

                x, y = convert_3d_to_2d(x, y, z)
                w *= 1.1
                h *= 1.1
                xmin = int(round(x - w / 2))
                xmax = int(round(x + w / 2))
                ymin = int(round(y - h / 2))
                ymax = int(round(y + h / 2))

                cropped_img = img[ymin:ymax, xmin:xmax]
                if cropped_img.shape[0] > 0 and cropped_img.shape[1] > 0:
                    cv2.imwrite(os.path.join(output_dir, output_path),
                                cropped_img)
                    fold_det_df['mask'].append(1)
                else:
                    fold_det_df['mask'].append(0)

        fold_det_df = pd.DataFrame(fold_det_df)

        test_set = PoseDataset(output_dir + '/' +
                               fold_det_df['img_path'].values,
                               fold_det_df['det'].values,
                               transform=test_transform,
                               masks=fold_det_df['mask'].values)
        test_loader = torch.utils.data.DataLoader(
            test_set,
            batch_size=config['batch_size'],
            shuffle=False,
            num_workers=config['num_workers'],
            # pin_memory=True,
        )

        fold_dets = []
        with torch.no_grad():
            for input, batch_det, mask in tqdm(test_loader,
                                               total=len(test_loader)):
                input = input.cuda()
                batch_det = batch_det.numpy()
                mask = mask.numpy()

                output = model(input)
                output = output.cpu()

                if config['rot'] == 'trig':
                    yaw = torch.atan2(output[..., 1:2], output[..., 0:1])
                    pitch = torch.atan2(output[..., 3:4], output[..., 2:3])
                    roll = torch.atan2(output[..., 5:6], output[..., 4:5])
                    roll = rotate(roll, -np.pi)

                pitch = pitch.cpu().numpy()[:, 0]
                yaw = yaw.cpu().numpy()[:, 0]
                roll = roll.cpu().numpy()[:, 0]

                batch_det[mask, 0] = pitch[mask]
                batch_det[mask, 1] = yaw[mask]
                batch_det[mask, 2] = roll[mask]

                fold_dets.append(batch_det)

        fold_dets = np.vstack(fold_dets)

        fold_det_df['det'] = fold_dets.tolist()
        fold_det_df = fold_det_df.groupby('ImageId')['det'].apply(list)
        fold_det_df = pd.DataFrame({
            'ImageId': fold_det_df.index.values,
            'PredictionString': fold_det_df.values,
        })

        df.append(fold_det_df)
        break
    df = pd.concat(df).reset_index(drop=True)

    for i in tqdm(range(len(df))):
        img_id = df.loc[i, 'ImageId']
        det = np.array(df.loc[i, 'PredictionString'])

        if args.show:
            img = cv2.imread('inputs/train_images/%s.jpg' % img_id)
            img_pred = visualize(img, det)
            plt.imshow(img_pred[..., ::-1])
            plt.show()

        df.loc[i, 'PredictionString'] = convert_labels_to_str(det[:, :7])

    name += '_%s' % args.pose_name

    df.to_csv('outputs/submissions/val/%s.csv' % name, index=False)