Example #1
0
    def load_dataset(self, split, **kwargs):
        features_dir = os.path.join(self.args.features_dir,
                                    f'{split}-features-{self.args.features}')
        captions_file = os.path.join(
            self.args.captions_dir,
            f'{split}-captions.{self.args.captions_lang}')
        captions_ds = data_utils.load_indexed_dataset(captions_file,
                                                      self.captions_dict)

        image_ids_file = os.path.join(self.args.captions_dir,
                                      f'{split}-ids.txt')
        image_ids = data.read_image_ids(image_ids_file)

        if self.args.features == 'grid':
            image_ds = data.GridFeaturesDataset(features_dir,
                                                image_ids,
                                                grid_shape=(8, 8))
        elif self.args.features == 'obj':
            image_metadata_file = os.path.join(features_dir, 'metadata.csv')
            image_metadata = data.read_image_metadata(image_metadata_file)
            image_ds = data.ObjectFeaturesDataset(features_dir, image_ids,
                                                  image_metadata)
        else:
            raise ValueError(
                f'Invalid --features option: {self.args.features}')

        self.datasets[split] = data.ImageCaptionDataset(image_ds,
                                                        captions_ds,
                                                        self.captions_dict,
                                                        shuffle=True)
Example #2
0
def tokenize_captions(output_dir, split, coco):
    image_ids = data.read_image_ids(os.path.join(output_dir,
                                                 f'{split}-ids.txt'),
                                    non_redundant=True)

    gts = dict()

    for image_id in image_ids:
        caps = coco.imgToAnns[image_id]
        gts[image_id] = caps

    return PTBTokenizer().tokenize(gts)
Example #3
0
    def load_dataset(self, split, **kwargs):
        features_dir = os.path.join(self.args.features_dir,
                                    f'{split}-features-{self.args.features}')

        image_ids_file = os.path.join(self.args.captions_dir,
                                      f'{split}-ids.txt')
        image_ids = data.read_image_ids(image_ids_file,
                                        non_redundant=self.scst)

        if self.scst and split == 'valid':
            image_ids = image_ids[:self.args.scst_validation_set_size]

        if self.scst:
            captions_file = os.path.join(self.args.captions_dir,
                                         f'{split}-captions.tok.json')
            captions_ds = data.CaptionsDataset(captions_file, image_ids)
        else:
            captions_file = os.path.join(
                self.args.captions_dir,
                f'{split}-captions.{self.args.captions_lang}')
            captions_ds = data_utils.load_indexed_dataset(
                captions_file, self.captions_dict)

        if self.args.features == 'grid':
            image_ds = data.GridFeaturesDataset(features_dir,
                                                image_ids,
                                                grid_shape=(8, 8))
        elif self.args.features == 'obj':
            image_metadata_file = os.path.join(features_dir, 'metadata.csv')
            image_metadata = data.read_image_metadata(image_metadata_file)
            image_ds = data.ObjectFeaturesDataset(features_dir, image_ids,
                                                  image_metadata)
        else:
            raise ValueError(
                f'Invalid --features option: {self.args.features}')

        self.datasets[split] = data.ImageCaptionDataset(
            img_ds=image_ds,
            cap_ds=captions_ds,
            cap_dict=self.captions_dict,
            scst=self.scst,
            shuffle=True)
def predict(image_id_path: str, grid_features_path: str,
            obj_features_path: str, obj_features_meta_path: str,
            model_args) -> pd.DataFrame:
    print(model_args)
    use_cuda = torch.cuda.is_available() and not model_args.cpu

    task = tasks.setup_task(model_args)
    captions_dict = task.target_dictionary

    models, _model_args = checkpoint_utils.load_model_ensemble(
        model_args.path.split(':'), task=task)

    for model in models:
        model.make_generation_fast_(
            beamable_mm_beam_size=None
            if model_args.no_beamable_mm else model_args.beam,
            need_attn=model_args.print_alignment,
        )

        if torch.cuda.is_available() and not model_args.cpu:
            model.cuda()

    generator = task.build_generator(model_args)
    tokenizer = encoders.build_tokenizer(model_args)
    bpe = encoders.build_bpe(model_args)

    def decode(x):
        if bpe is not None:
            x = bpe.decode(x)
        if tokenizer is not None:
            x = tokenizer.decode(x)
        return x

    sample_ids = data.read_image_ids(model_args.input, non_redundant=True)
    image_ids = data.read_image_ids(image_id_path)

    assert_sample_id_validity(sample_ids, image_ids)

    if model_args.features == 'grid':
        image_ds = data.GridFeaturesDataset(grid_features_path, image_ids)
    elif model_args.features == 'obj':
        image_md = data.read_image_metadata(obj_features_meta_path)
        image_ds = data.ObjectFeaturesDataset(obj_features_path, image_ids,
                                              image_md)
    else:
        raise ValueError(f'Invalid --features option: {model_args.features}')

    prediction_ids = []
    prediction_results = []

    for sample_id in tqdm(sample_ids):
        features, locations = image_ds.read_data(sample_id)
        length = features.shape[0]

        if use_cuda:
            features = features.cuda()
            locations = locations.cuda()

        sample = {
            'net_input': {
                'src_tokens': features.unsqueeze(0),
                'src_locations': locations.unsqueeze(0),
                'src_lengths': [length]
            }
        }

        translations = task.inference_step(generator, models, sample)
        prediction = decode(captions_dict.string(translations[0][0]['tokens']))

        prediction_ids.append(sample_id)
        prediction_results.append(prediction)

    return pd.DataFrame.from_dict(data={
        'image_id': prediction_ids,
        'caption': prediction_results
    })