Esempio n. 1
0
    def generate_utterance(self, speaker_type, as_string = True) :
        """
        generate utterance for image
        
        Args: 
          speaker_type: 'S0', 'S0_with_cost', or 'S1'
          as_string: boolean indicating whether to convert to readable string
        """

        if speaker_type == 'S0' :
            # print('beam search would say:', self.S0_sample(self.image_tensor, as_string, beam_sample=True))
            # print('greedy search would say:', self.S0_sample(self.image_tensor, as_string, beam_sample=False))
            return self.S0_sample(self.image_tensor, as_string, beam_sample=True)

        results = []
        feature = self.generating_encoder(self.image_tensor)
        topk_batch, topk_scores = self.decoder.beam_sample(feature, self.topk)
        topk_batch_strings = []
        lengths = []
        for caption in topk_batch.cpu().numpy() :
            out = []
            for word_id in caption :
                out.append(word_id)
                if self.vocab.idx2word[word_id] == '<end>' :
                    lengths.append(len(out))
                    topk_batch_strings.append(utils.ids_to_words(out, self.vocab)
                                              if as_string else out)
                    break
        uttCost = torch.tensor(lengths).to(device).float().unsqueeze(1) * self.cost_weight
        target_idx = torch.LongTensor([self.context.index(self.raw_image)]).to(device)
        if speaker_type == 'S0_with_cost' :
            utility = topk_scores.unsqueeze(1) - uttCost
        elif speaker_type == 'S1' :
            listener_scores = self.L0_score(topk_batch, self.context)
            informativity = (torch.index_select(listener_scores, 1, target_idx))
            utility = informativity - uttCost
        else :
            raise Exception('unknown speaker_type', speaker_type)
        reranked = sorted(zip(list(utility.data.cpu().numpy()),
                              topk_batch_strings),
                          key = operator.itemgetter(0))[::-1]
#        print(reranked)
#        print('top S0 sample', self.S0_sample(self.image_tensor))
        return reranked[0][1] #if as_string else reutils.ids_to_words(reranked[0][1], self.vocab)
Esempio n. 2
0
def main(args):
    path = '../data/model_output/speaker_lesions.csv'
    writer = EfficiencyWriter(args, path)
    speaker = AdaptiveAgent(args)
    grid = construct_context_grid(args)

    for ctx in grid:
        print("\n------gameid: {}, sample_num: {}, loss: {}, ds_type: {}, speaker_model: {}, cost_weight: {}"
            .format(ctx['gameid'], ctx['sample_num'], ctx['loss'], ctx['ds_type'],
                    ctx['speaker_model'], ctx['cost_weight']))

        speaker.loss = ctx['loss']
        speaker.reset_to_initialization(ctx['dirs'])

        speaker.dataset_type = ctx['ds_type']
        speaker.context_type = ctx['context_type']
        speaker.cost_weight = ctx['cost_weight']
        # simulate round-robin style by looping through targets in random order
        for datum in ctx['speaker_data'] :
            rep_num = datum['repNum']
            trial_num = datum['trialNum']
            target = utils.get_img_path(datum['targetImg'])
            print(target)
            speaker.trial_num = trial_num
            speaker.sample_num = ctx['sample_num']
            speaker.set_image(target)
            
            cap = speaker.generate_utterance(ctx['speaker_model'], as_string = True)
            cap = utils.ids_to_words(utils.words_to_ids(cap, speaker.vocab), speaker.vocab)

            if cap[:7] == '<start>' :
                cap = cap[8:-6]

            print('\nround {}, target {}, msg {}'.format(
                rep_num, utils.get_id_from_path(target), cap
            ))
            
            if datum['correct'] == True :
                print('training')
                speaker.update_model(trial_num, cap)

            writer.writerow(ctx, datum, trial_num, target, cap, len(cap))
    def speak(self):
        """ sends a caption for the specified image"""
        if self.json_args['roundNum'] > 0:
            if self.json_args['prevCorrect']:
                # if correct, update based on results from previous round
                self.update(self.json_args['gameid'])
            else:
                # otherwise, *remove from history (so doesn't get hit by rehearsal)
                game_histories[self.json_args['gameid']].pop()

        # pick utterance
        agent = load_agent(self.json_args)
        cap = agent.generate_utterance('S0', as_string=True)
        cap = utils.ids_to_words(utils.words_to_ids(cap, agent.vocab),
                                 agent.vocab)

        if cap[:7] == '<start>':
            cap = cap[8:-6]

        # update history for next round and return response
        self.extend_history(agent, self.json_args['target'], cap)
        self.write(cap)
Esempio n. 4
0
    def S0_sample(self, image, as_string = True, use_old_decoder=False,
                  beam_sample = True):
        """
        generate greedy utterance for image in isolation

        Args: as_string: boolean indicating whether to convert to readable string
        """
        image_tensor = utils.load_image(image) if type(image) == str else image
        feature = self.generating_encoder(image_tensor.to(device))

        # (1, max_seq_length) -> (max_seq_length)
        decoder = self.orig_decoder if use_old_decoder else self.decoder
        sampled_ids, sampled_score = decoder.beam_sample(feature, 1) if beam_sample \
                      else decoder.greedy_sample(feature)
        sampled_ids = sampled_ids[0].cpu().numpy()
#        print('sampled score of top utt', sampled_score)
        # Truncate at '<end>' token
        out = []
        for word_id in sampled_ids :
            out.append(word_id)
            if (self.vocab.idx2word[word_id] == '<end>'
                or self.vocab.idx2word[word_id] == '<pad>') :
                break
        return utils.ids_to_words(out, self.vocab) if as_string else out
def main_memory(args):
    path = '../data/model_output/prag_speaker.csv'
    writer = EfficiencyWriter(args, path)

    # init separate speaker/listener models
    speaker = AdaptiveAgent(args)
    listener = AdaptiveAgent(args)
    grid = construct_expt_grid(args)
    utt_store = {}

    for ctx in grid:
        print(
            "\ntype: {}, speaker loss: {}, listener loss: {}, speaker model: {}"
            .format(ctx['context_type'], ctx['speaker_loss'],
                    ctx['listener_loss'], ctx['speaker_model']))

        speaker.loss = ctx['speaker_loss']
        speaker.reset_to_initialization(ctx['dirs'])
        listener.loss = ctx['listener_loss']
        listener.reset_to_initialization(ctx['dirs'])

        # update round-robin style by looping through targets in random order
        for round_num in range(1, args.num_reductions):
            targets = random.sample(ctx['dirs'], len(ctx['dirs']))
            for target in targets:
                print('round {}, target {}'.format(round_num, target))

                # Set up for new round
                cap_key = "{}-{}-{}-{}-{}".format(ctx['speaker_loss'],
                                                  ctx['speaker_model'],
                                                  ctx['sample_num'], target,
                                                  round_num)
                target_idx = ctx['dirs'].index(target)
                speaker.set_image(target)
                listener.set_image(target)

                # Generate caption and update if this is first time
                if cap_key in utt_store:
                    cap = utt_store[cap_key]
                    str_cap = utils.ids_to_words(cap, speaker.vocab)
                    print('found {} in utt_store!'.format(str_cap))
                else:
                    cap = np.array(
                        speaker.generate_utterance(ctx['speaker_model'],
                                                   as_string=False))
                    print('regular: ', speaker.generate_utterance('S0'))
                    print('prag: ', speaker.generate_utterance('S1'))
                    str_cap = utils.ids_to_words(cap, speaker.vocab)
                    utt_store[cap_key] = cap
                    #                    print('adding {} to utt_store'.format(cap_key))
                    if ctx['speaker_loss'] != 'fixed':
                        speaker.update_model(round_num, str_cap)

                # evaluate caption & update listener models as relevent
                scores = listener.L0_score(np.expand_dims(cap, axis=0),
                                           ctx['dirs'])
                if not ctx['listener_loss'] in ['fixed', 'tied_to_speaker']:
                    listener.update_model(round_num, str_cap)
                elif ctx['listener_loss'] == 'tied_to_speaker':
                    listener.decoder.load_state_dict(
                        speaker.decoder.state_dict())

                # Write out
                writer.writerow(ctx, round_num, target, str_cap, scores,
                                len(cap),
                                scores[0][target_idx].data.cpu().numpy())
Esempio n. 6
0
def main():
    X = tf.placeholder(tf.float32, [None, image_size, image_size, 3])
    encoded = vgg_endpoints(X - MEAN_VALUES)['conv5_3']
    num_block = encoded.shape[1] * encoded.shape[2]
    num_filter = encoded.shape[3]

    res_op = beam_search_decode(encoded, word2id['<pad>'], len(word2id),
                                maxlen, num_block, num_filter, hidden_size,
                                embedding_size, is_training)

    # beam search 使用到的图节点,解释见beam_search_decode()函数
    initial_state = res_op[0]
    initial_memory = res_op[1]
    contexts_placeh = res_op[2]
    last_memory = res_op[3]
    last_state = res_op[4]
    last_word = res_op[5]
    contexts = res_op[6]
    current_memory = res_op[7]
    current_state = res_op[8]
    probs = res_op[9]
    alpha = res_op[10]

    # restore
    MODEL_DIR = 'model'
    sess = tf.Session()
    sess.run(tf.global_variables_initializer())
    saver = tf.train.Saver()
    saver.restore(sess, tf.train.latest_checkpoint(MODEL_DIR))

    X_data, img = pre_picture(image_path)

    # 只处理一张图片的beam search,注释见 eval.py
    contexts_, initial_memory_, initial_state_ = sess.run(
        [contexts, initial_memory, initial_state], feed_dict={X: X_data})

    result = [{
        'sentence': [],
        'memory': initial_memory_[0],
        'state': initial_state_[0],
        'score': 1.0,
        'alphas': []
    }]
    complete = []
    for t in range(maxlen + 1):
        cache = []
        step = 1 if t == 0 else beam_width
        for s in range(step):
            if t == 0:
                last_word_ = np.ones([batch_size],
                                     np.int32) * word2id['<start>']
            else:
                last_word_ = np.array([result[s]['sentence'][-1]], np.int32)

            last_memory_ = np.array([result[s]['memory']], np.float32)
            last_state_ = np.array([result[s]['state']], np.float32)

            current_memory_, current_state_, probs_, alpha_ = sess.run(
                [current_memory, current_state, probs, alpha],
                feed_dict={
                    contexts_placeh: contexts_,
                    last_memory: last_memory_,
                    last_state: last_state_,
                    last_word: last_word_
                })

            word_and_probs = [[w, p] for w, p in enumerate(probs_[0])]
            word_and_probs.sort(key=lambda x: -x[1])
            word_and_probs = word_and_probs[:beam_width + 1]

            for w, p in word_and_probs:
                item = {
                    'sentence': result[s]['sentence'] + [w],
                    'memory': current_memory_[0],
                    'state': current_state_[0],
                    'score': result[s]['score'] * p,
                    'alphas': result[s]['alphas'] + [alpha_[0]]
                }
                if id2word[w] == '<end>':
                    complete.append(item)
                else:
                    cache.append(item)

        cache.sort(key=lambda x: -x['score'])
        cache = cache[:beam_width]
        result = cache.copy()

    # 输出预测 sentence 和 attention weight
    if len(complete) == 0:
        final_sentence = result[0]['sentence']
        alphas = result[0]['alphas']
    else:
        final_sentence = complete[0]['sentence']
        alphas = complete[0]['alphas']

    sentence = ids_to_words(final_sentence, id2word)
    print('预测结果为:', sentence)
    sentence = sentence.split(' ')

    print('attention weight可视化')
    show_result(img, sentence, alphas)
Esempio n. 7
0
 def iterate_round(self, round_num, prev_caption, as_string = True) :
     caption = (prev_caption if as_string else
                utils.ids_to_words(prev_caption, self.vocab))
     self.update_model(round_num, caption)
     return self.generate_utterance(as_string)