def validate(data_loader, encoder, decoder, word_map, print_freq):
    """
    Perform validation of one training epoch.

    """
    decoder.eval()
    if encoder:
        encoder.eval()

    target_captions = []
    generated_captions = []
    coco_ids = []

    # Loop over batches
    for i, (images, all_captions_for_image, _,
            coco_id) in enumerate(data_loader):
        images = images.to(device)

        # Forward propagation
        if encoder:
            images = encoder(images)
        scores, decode_lengths, alphas = decoder(images)

        if i % print_freq == 0:
            logging.info("Validation: [Batch {0}/{1}]\t".format(
                i, len(data_loader)))

        # Target captions
        for j in range(all_captions_for_image.shape[0]):
            img_captions = [
                get_caption_without_special_tokens(caption, word_map)
                for caption in all_captions_for_image[j].tolist()
            ]
            target_captions.append(img_captions)

        # Generated captions
        _, captions = torch.max(scores, dim=2)
        captions = [
            get_caption_without_special_tokens(caption, word_map)
            for caption in captions.tolist()
        ]
        generated_captions.extend(captions)

        coco_ids.append(coco_id[0])

        assert len(target_captions) == len(generated_captions)

    bleu4 = corpus_bleu(target_captions, generated_captions)

    logging.info("\n * BLEU-4 - {bleu}\n".format(bleu=bleu4))

    return bleu4
def calc_recall(
    generated_captions,
    test_indices,
    word_map,
    nouns,
    other,
    occurrences_data,
    contains_pair_function,
    nlp_pipeline,
):
    true_positives = dict.fromkeys(["N=1", "N=2", "N=3", "N=4", "N=5"], 0)
    numbers = dict.fromkeys(["N=1", "N=2", "N=3", "N=4", "N=5"], 0)
    adjective_frequencies = Counter()
    verb_frequencies = Counter()
    for coco_id in test_indices:
        top_k_captions = generated_captions[coco_id]
        count = occurrences_data[OCCURRENCE_DATA][coco_id][PAIR_OCCURENCES]

        hit = False
        for caption in top_k_captions:
            caption = " ".join(
                decode_caption(
                    get_caption_without_special_tokens(caption, word_map), word_map
                )
            )
            pos_tagged_caption = nlp_pipeline(caption).sentences[0]
            _, _, contains_pair = contains_pair_function(
                pos_tagged_caption, nouns, other
            )
            if contains_pair:
                hit = True

            noun_is_present = False
            for word in pos_tagged_caption.words:
                if word.lemma in nouns:
                    noun_is_present = True
            if noun_is_present:
                adjectives = get_adjectives_for_noun(pos_tagged_caption, nouns)
                if len(adjectives) == 0:
                    adjective_frequencies["No adjective"] += 1
                adjective_frequencies.update(adjectives)

                verbs = get_verbs_for_noun(pos_tagged_caption, nouns)
                if len(verbs) == 0:
                    verb_frequencies["No verb"] += 1
                verb_frequencies.update(verbs)

        if hit:
            true_positives["N={}".format(count)] += 1
        numbers["N={}".format(count)] += 1

    recall_score = {
        "true_positives": true_positives,
        "numbers": numbers,
        "adjective_frequencies": adjective_frequencies,
        "verb_frequencies": verb_frequencies,
    }
    return recall_score
def generate_captions(checkpoint, data_folder, image_path, beam_size, print_beam):
    # Load model
    checkpoint = torch.load(checkpoint, map_location=device)
    model_name = checkpoint["model_name"]
    decoder = checkpoint["decoder"]
    decoder = decoder.to(device)
    decoder.eval()

    if not model_name == MODEL_SHOW_ATTEND_TELL:
        raise NotImplementedError()

    encoder = checkpoint["encoder"]
    encoder = encoder.to(device)
    encoder.eval()

    # Load word map
    word_map_path = os.path.join(data_folder, WORD_MAP_FILENAME)
    with open(word_map_path, "r") as json_file:
        word_map = json.load(json_file)

    # Read image and process
    image_data = read_image(image_path)

    image_data = image_data / 255.0

    image_data = torch.FloatTensor(image_data)
    image_features = image_data.unsqueeze(0)
    image_features = image_features.to(device)
    image_features = encoder(image_features)
    generated_sequences, alphas, beam = decoder.beam_search(
        image_features, beam_size, store_alphas=True, print_beam=print_beam
    )

    for seq in generated_sequences:
        print(
            " ".join(
                decode_caption(
                    get_caption_without_special_tokens(seq, word_map), word_map
                )
            )
        )
Esempio n. 4
0
def calculate_metric(
    metric_name,
    target_captions,
    generated_captions,
    generated_beams,
    word_map,
    heldout_pairs,
    beam_size,
    output_file_name,
):
    if metric_name == METRIC_BLEU:
        generated_captions = [
            get_caption_without_special_tokens(top_k_captions[0], word_map)
            for top_k_captions in generated_captions.values()
        ]
        target_captions = target_captions.values()
        bleu_1 = corpus_bleu(target_captions,
                             generated_captions,
                             weights=(1, 0, 0, 0))
        bleu_2 = corpus_bleu(target_captions,
                             generated_captions,
                             weights=(0.5, 0.5, 0, 0))
        bleu_3 = corpus_bleu(target_captions,
                             generated_captions,
                             weights=(0.33, 0.33, 0.33, 0))
        bleu_4 = corpus_bleu(target_captions,
                             generated_captions,
                             weights=(0.25, 0.25, 0.25, 0.25))
        bleu_scores = [bleu_1, bleu_2, bleu_3, bleu_4]
        bleu_scores = [float("%.2f" % elem) for elem in bleu_scores]
        logging.info("\nBLEU score @ beam size {} is {}".format(
            beam_size, bleu_scores))
    elif metric_name == METRIC_RECALL:
        recall_pairs(generated_captions, word_map, heldout_pairs,
                     output_file_name)
    elif metric_name == METRIC_BEAM_OCCURRENCES:
        beam_occurrences_score = beam_occurrences(generated_beams, beam_size,
                                                  word_map, heldout_pairs)
        logging.info("\nBeam occurrences score @ beam size {} is {}".format(
            beam_size, beam_occurrences_score))
Esempio n. 5
0
def re_rank_beam(
    decoder,
    top_k_generated_captions,
    encoded_features,
    word_map,
    coco_id,
    print_captions,
):
    if print_captions:
        logging.info("COCO ID: {}".format(coco_id))
        logging.info("Before re-ranking:")
        for caption in top_k_generated_captions[:5]:
            logging.info(" ".join(
                decode_caption(
                    get_caption_without_special_tokens(caption, word_map),
                    word_map)))

    lengths = [len(caption) - 1 for caption in top_k_generated_captions]
    top_k_generated_captions = torch.tensor(
        [
            top_k_generated_caption + [word_map[TOKEN_PADDING]] *
            (max(lengths) + 1 - len(top_k_generated_caption))
            for top_k_generated_caption in top_k_generated_captions
        ],
        device=device,
    )
    image_embedded, image_captions_embedded = decoder.forward_ranking(
        encoded_features, top_k_generated_captions,
        torch.tensor(lengths, device=device))
    image_embedded = image_embedded.detach().cpu().numpy()[0]
    image_captions_embedded = image_captions_embedded.detach().cpu().numpy()

    indices = get_top_ranked_captions_indices(image_embedded,
                                              image_captions_embedded)
    top_k_generated_captions = [top_k_generated_captions[i] for i in indices]

    return [caption.cpu().numpy() for caption in top_k_generated_captions]
def count_adjective_noun_pairs(preprocessed_data_folder):
    nlp_pipeline = stanfordnlp.Pipeline()

    with open(os.path.join(preprocessed_data_folder, IMAGES_META_FILENAME),
              "r") as json_file:
        images_meta = json.load(json_file)

    word_map_path = os.path.join(preprocessed_data_folder, WORD_MAP_FILENAME)
    with open(word_map_path, "r") as json_file:
        word_map = json.load(json_file)

    data = {}

    for coco_id, image_meta in tqdm(images_meta.items()):
        encoded_captions = image_meta[DATA_CAPTIONS]

        decoded_captions = [
            " ".join(
                decode_caption(
                    get_caption_without_special_tokens(caption, word_map),
                    word_map)) for caption in encoded_captions
        ]

        data[coco_id] = {}
        data[coco_id][DATA_COCO_SPLIT] = image_meta[DATA_COCO_SPLIT]
        data[coco_id]["pos_tagged_captions"] = []

        for caption in decoded_captions:
            doc = nlp_pipeline(caption)
            sentence = doc.sentences[0]
            data[coco_id]["pos_tagged_captions"].append(sentence)

    data_path = os.path.join(preprocessed_data_folder,
                             POS_TAGGED_CAPTIONS_FILENAME)
    print("\nSaving results to {}".format(data_path))
    with open(data_path, "wb") as pickle_file:
        pickle.dump(data, pickle_file)
Esempio n. 7
0
def show_images(data_folder, pair):
    image_features = h5py.File(os.path.join(data_folder, IMAGES_FILENAME), "r")

    with open(os.path.join(data_folder, IMAGES_META_FILENAME),
              "r") as json_file:
        images_meta = json.load(json_file)

    word_map_file = os.path.join(data_folder, WORD_MAP_FILENAME)
    with open(word_map_file, "r") as json_file:
        word_map = json.load(json_file)

    _, _, test_images_split = get_splits_from_occurrences_data([pair])

    for coco_id in test_images_split:
        image_data = image_features[coco_id][()]

        print("COCO ID: ", coco_id)
        for caption in images_meta[coco_id][DATA_CAPTIONS]:
            print(" ".join(
                decode_caption(
                    get_caption_without_special_tokens(caption, word_map),
                    word_map)))
        show_img(image_data)
        print("")
Esempio n. 8
0
def evaluate(
    data_folder,
    dataset_splits,
    checkpoint_path,
    metrics,
    beam_size,
    eval_beam_size,
    re_ranking,
    nucleus_sampling,
    visualize,
    print_beam,
    print_captions,
):
    # Load model
    checkpoint = torch.load(checkpoint_path, map_location=device)

    model_name = checkpoint["model_name"]
    logging.info("Model: {}".format(model_name))

    encoder = checkpoint["encoder"]
    if encoder:
        encoder = encoder.to(device)
        encoder.eval()

    decoder = checkpoint["decoder"]
    decoder = decoder.to(device)
    word_map = decoder.word_map
    decoder.eval()

    logging.info("Decoder params: {}".format(decoder.params))

    # Get the dataset splits
    dataset_splits_dict = json.load(open(dataset_splits, "r"))
    test_images_split = dataset_splits_dict["test_images_split"]

    if model_name == MODEL_SHOW_ATTEND_TELL:
        # Normalization
        normalize = transforms.Normalize(mean=IMAGENET_IMAGES_MEAN,
                                         std=IMAGENET_IMAGES_STD)

        # DataLoader
        data_loader = torch.utils.data.DataLoader(
            CaptionTestDataset(
                data_folder,
                IMAGES_FILENAME,
                test_images_split,
                transforms.Compose([normalize]),
                features_scale_factor=1 / 255.0,
            ),
            batch_size=1,
            shuffle=True,
            num_workers=1,
            pin_memory=True,
        )
    elif (model_name == MODEL_BOTTOM_UP_TOP_DOWN
          or model_name == MODEL_BOTTOM_UP_TOP_DOWN_RANKING):
        data_loader = torch.utils.data.DataLoader(
            CaptionTestDataset(data_folder, BOTTOM_UP_FEATURES_FILENAME,
                               test_images_split),
            batch_size=1,
            shuffle=True,
            num_workers=1,
            pin_memory=True,
        )
    else:
        raise RuntimeError("Unknown model name: {}".format(model_name))

    # Lists for target captions and generated captions for each image
    target_captions = {}
    generated_captions = {}
    generated_beams = {}

    for image_features, all_captions_for_image, caption_lengths, coco_id in tqdm(
            data_loader, desc="Evaluate with beam size " + str(beam_size)):
        coco_id = coco_id[0]

        # Target captions
        target_captions[coco_id] = [
            get_caption_without_special_tokens(caption, word_map)
            for caption in all_captions_for_image[0].tolist()
        ]

        # Generate captions
        encoded_features = image_features.to(device)
        if encoder:
            encoded_features = encoder(encoded_features)

        store_beam = True if METRIC_BEAM_OCCURRENCES in metrics else False

        if nucleus_sampling:
            top_k_generated_captions, alphas, beam = decoder.nucleus_sampling(
                encoded_features,
                beam_size,
                top_p=nucleus_sampling,
                print_beam=print_beam,
            )
        else:
            top_k_generated_captions, alphas, beam = decoder.beam_search(
                encoded_features,
                beam_size,
                store_alphas=visualize,
                store_beam=store_beam,
                print_beam=print_beam,
            )

        if visualize:
            logging.info("Image COCO ID: {}".format(coco_id))
            for caption, alpha in zip(top_k_generated_captions, alphas):
                visualize_attention(image_features.squeeze(0),
                                    caption,
                                    alpha,
                                    word_map,
                                    smoothen=True)

        if re_ranking:
            top_k_generated_captions = re_rank_beam(
                decoder,
                top_k_generated_captions,
                encoded_features,
                word_map,
                coco_id,
                print_captions,
            )

        generated_captions[coco_id] = top_k_generated_captions[:eval_beam_size]
        if print_captions:
            logging.info("COCO ID: {}".format(coco_id))
            for caption in generated_captions[coco_id]:
                logging.info(" ".join(
                    decode_caption(
                        get_caption_without_special_tokens(caption, word_map),
                        word_map,
                    )))
        if store_beam:
            generated_beams[coco_id] = beam

        assert len(target_captions) == len(generated_captions)

    # Save results
    name = str(os.path.basename(checkpoint_path).split(".")[0])
    if re_ranking:
        name += "_re_ranking"
    if nucleus_sampling:
        name += "_nucleus_sampling_p_" + str(nucleus_sampling)
    results_output_file_name = "results_" + name + ".json"

    results = []
    for coco_id, top_k_captions in generated_captions.items():
        caption = " ".join(
            decode_caption(
                get_caption_without_special_tokens(top_k_captions[0],
                                                   word_map),
                word_map,
            ))
        results.append({"image_id": int(coco_id), "caption": caption})
    json.dump(results, open(results_output_file_name, "w"))

    # Calculate metric scores
    eval_output_file_name = "eval_" + name + ".json"
    for metric in metrics:
        calculate_metric(
            metric,
            target_captions,
            generated_captions,
            generated_beams,
            word_map,
            dataset_splits_dict["heldout_pairs"],
            beam_size,
            eval_output_file_name,
        )