def calc_expression(start=0, end=5, k=3, target_set='test'):
    matched_label = np.load(f'{target_set}_imgs_label_matching_with_true_target.npy', allow_pickle=True)
    rel_load = np.load(f'./{target_set}_relation_extraction.npy', allow_pickle=True)
    exps = []
    references = []
    low_confidence_track = [] # keep track if rsa ever encounter a low confident situation (don't have anything to say or no preference-0probabilty everywhere)
    for i in range(start, end):
        df = pd.read_csv(os.path.join(data_path,f'refCOCO/{target_set}/attr_tables_with_target_box/attr_{i}.tsv'), encoding='utf-8',sep='\t')

        # UNCOMMENT TO SAVE THE REFERENCES OF THE SAME RANGE AS THE PROCESSED IMAGES
        with open(os.path.join(data_path,f'refCOCO/{target_set}/labels/lab_{i}.json')) as json_file:
            label = json.load(json_file)
        refs = [[r] for r in label['ref_sents']]
        references.append(refs)
        generated_relations = rel_load[i]
        
        ### adding LSTM to adjust utterance prior ###
        #lstm = tf.keras.models.load_model('three_gram_lstm_loss_2.8469_accuracy_0.4227.h5')
#         f = open('word_to_idx_vice_versa.json')
#         tokenizer = json.load(f)
#         word_to_idx = tokenizer['word_to_idx']
#         idx_to_word = tokenizer['idx_to_word']
        
        rsa_agent = RSA(df, generated_relations=generated_relations)#,\
                        #model=lstm, word_to_idx=word_to_idx, idx_to_word=idx_to_word)
        targets = [matched_label[i][j][1] for j in range(min(k, len(matched_label[i])))]
        word_lists = []
        is_low_confidence = []
        for target in targets:
            list_utterances, low_confidence = rsa_agent.full_speaker(target)
            word_lists.append(list_utterances)
            is_low_confidence.append(low_confidence)
        expression = [' '.join(word_lists[j][::-1]) for j in range(len(word_lists))]
        exps.append(expression)
        low_confidence_track.append(is_low_confidence)
        if i % 50 == 0:
            print(f'finished file {i}')

    np.save(f'./data/{target_set}/detectron2_with_target/top{k}_exps_from_{start}_to_{end}.npy',exps)
    np.save(f'./data/{target_set}/detectron2_with_target/top{k}_exps_confidence_record_from_{start}_to_{end}.npy',low_confidence_track)
    # UNCOMMENT TO SAVE THE REFERENCES OF THE SAME RANGE AS THE PROCESSED IMAGES
    np.save(f'./data/{target_set}/detectron2_with_target/references_from_{start}_to_{end}.npy',references)
예제 #2
0
box_data = df[['box_alias', 'x1','y1','w','h']]
fig,ax = plt.subplots(1)
img = image

# ax.imshow(img)
rng = [i for i in range(len(box_data))]
for i in [4]:#rng[:]:
    name, x,y,w,h = list(box_data.iloc[i,:])
    ax = draw_box_obj(name,x,y,w,h,img,ax)

print(label['ref_sents'])
bbox = label['bbox'][0]
sentence = label['ref_sents'][0]
fig,ax_true_label = plt.subplots(1)
ax_true_label.imshow(img)
draw_box_obj(sentence,bbox[0],bbox[1],bbox[2],bbox[3],img,ax_true_label)

rsa_agent = RSA(df, generated_relations=generated_relations)
rsa_agent.objects_by_type
# output = rsa_agent.full_speaker('woman-2')

matched_boxes = np.load('train_imgs_label_matching.npy', allow_pickle=True)[21540]

# print(output)
print("######")
for matched in matched_boxes:
    _, target,_ = matched
    print(target, rsa_agent.full_speaker(target))
print("$$$$$$")
print(rsa_agent.full_speaker('woman-2'))
import pandas as pd
import json
import numpy as np
from helper import *
import argparse
from rsa import RSA
import matplotlib.pyplot as plt
import os

with open('config.json') as config_file:
    config = json.load(config_file)
data_path = config['data_path']

file_id = 21540#3278#182

df = pd.read_csv(os.path.join(data_path,f'refCOCO/train/attr_tables/attr_{file_id}.tsv'), encoding='utf-8',sep='\t')

with open(os.path.join(data_path,f'refCOCO/train/labels/lab_{file_id}.json')) as json_file:
    label = json.load(json_file)
refs = [[r] for r in label['ref_sents']]
print(refs)
img_id = df['image_id'][0]
filename = os.path.join(data_path, f'refCOCO/train/imgs_by_id/{img_id}.jpg')
print(filename)
image = plt.imread(filename)

rsa_agent = RSA(df)

speech = rsa_agent.full_speaker('woman-1')

print(speech)