def tell(self, episode):
        if not hasattr(episode, 'to_tell'):
            self.prepare(episode)

        if episode.to_tell:
            msg = episode.to_tell.pop(0)
            episode.append(codraw_data.TellGroup(msg))
Beispiel #2
0
    def tell(self, episode):
        clipart = episode.get_last(codraw_data.SelectClipart).clipart
        candidate_cliparts = heapq.nlargest(
            self.num_candidates,
            self.datagen.clipart_to_msg,
            key=lambda cand_clipart: clipart_similarity(cand_clipart, clipart))
        # global dump
        # dump = candidate_cliparts, episode
        # assert False

        candidate_msgs = [
            self.datagen.clipart_to_msg[cand_clipart]
            for cand_clipart in candidate_cliparts
        ]

        expected_context = [
            event.clipart for event in episode
            if isinstance(event, codraw_data.SelectClipart)
        ][:-1]

        candidate_responses = [
            self.drawer_model.just_draw(msg, expected_context)
            for msg in candidate_msgs
        ]

        best_idx = np.argmax([
            scene_similarity(response_scene, [clipart])
            for response_scene in candidate_responses
        ])

        best_msg = candidate_msgs[best_idx]

        episode.append(codraw_data.TellGroup(best_msg))
    def on_receive_message(self, message=None, noOfMsg=None):
        if self.disconnected:
            print("[ERROR] Disconnected bot received a message!")
            return
        print(f"Got human message {noOfMsg}: {message}")
        assert message is not None

        if self.agent_type == codraw_data.Agent.TELLER:
            self.episode.append(codraw_data.ReplyGroup(message))
        else:
            self.episode.append(codraw_data.TellGroup(message))
        self.run_model_actions()
    def just_draw(self, msg, scene=[], *args, **kwargs):
        assert hasattr(self, 'draw'), "Model is not a drawer"
        episode = Episode([codraw_data.TellGroup(msg), codraw_data.ObserveCanvas(scene)])
        if isinstance(self, nn.Module):
            self.eval()
        self.draw(episode, *args, **kwargs)
        event_multi = episode.get_last(codraw_data.DrawGroup)
        if event_multi is not None:
            return codraw_data.AbstractScene(event_multi.cliparts)

        event_single = episode.get_last(codraw_data.DrawClipart)
        return event_single.clipart
    def prepare(self, episode):
        true_scene = episode.get_last(codraw_data.ObserveTruth).scene

        example_batch = self.datagen.tensors_from_episode(episode)
        b_scene_mask, ks_prelstm, vs_prelstm, ks, vs = self.forward(
            example_batch, return_loss=False)

        to_tell = []

        lstm_state = None  # carried across conversation rounds!

        for round in range(self.max_rounds):
            tokens = [self.datagen.vocabulary_dict['<S>']]
            events_this_round = []
            # Longest utterance in all of CoDraw is 39 words
            # Humans have a 140-char limit, but this is not easy to enforce with
            # word-level tokenization
            for wordnum in range(50):
                token_emb = self.word_embs(
                    torch.tensor(
                        tokens[-1],
                        dtype=torch.long).to(cuda_if_available))[None, None, :]
                attended_values_prelstm = self.attn_prelstm(
                    token_emb,
                    ks=ks_prelstm,
                    vs=vs_prelstm,
                    k_mask=b_scene_mask)
                lstm_in = torch.cat([token_emb, attended_values_prelstm], -1)
                lstm_out, lstm_state = self.lstm(lstm_in, lstm_state)
                attended_values = self.attn(lstm_out,
                                            ks=ks,
                                            vs=vs,
                                            k_mask=b_scene_mask)
                pre_project = torch.cat([lstm_out, attended_values], -1)

                if tokens[-1] == self.datagen.vocabulary_dict[
                        '<S>'] and self.prediction_loss_scale > 0:
                    assert not events_this_round
                    if self.predict_for_full_library:
                        clipart_state_predictor_in = torch.cat([
                            lstm_out,
                            b_scene_mask.to(torch.float)[None, :, :],
                        ], -1)
                    else:
                        clipart_state_predictor_in = lstm_out
                    clipart_state_logits = self.clipart_state_predictor(
                        clipart_state_predictor_in).view(
                            self.datagen.NUM_INDEX,
                            self.datagen.NUM_CLIPART_STATES)
                    clipart_state_selected = clipart_state_logits.argmax(
                        dim=-1)
                    undrawn = AbstractScene([
                        c for c in true_scene if clipart_state_selected[c.idx]
                        == self.datagen.CLIPART_STATE_UNDRAWN
                    ])
                    intention = codraw_data.TellerIntention(drawn=None,
                                                            undrawn=undrawn,
                                                            draw_next=None)
                    events_this_round.append(intention)

                word_logits = self.word_project(pre_project[0, 0, :])
                word_logits[
                    self.datagen.vocabulary_dict['<S>']] = -float('inf')
                if round == 0 and wordnum == 0:
                    word_logits[self.datagen.
                                vocabulary_dict['</TELL>']] = -float('inf')

                if self.inference_method == 'greedy':
                    next_token = int(word_logits.argmax())
                elif self.inference_method == 'sample':
                    next_token = int(
                        torch.multinomial(
                            F.softmax(word_logits / self.sampling_temperature,
                                      dim=-1)[None, :], 1).item())
                else:
                    raise ValueError(
                        f"Invalid inference_method: {self.inference_method}")

                assert next_token != self.datagen.vocabulary_dict['<S>']
                tokens.append(next_token)
                if next_token == self.datagen.vocabulary_dict['</S>']:
                    break
                elif next_token == self.datagen.vocabulary_dict['</TELL>']:
                    break

            if tokens[-1] == self.datagen.vocabulary_dict['</TELL>']:
                break

            msg = " ".join([self.datagen.vocabulary[i] for i in tokens[1:-1]])
            events_this_round.append(codraw_data.TellGroup(msg))
            to_tell.append(events_this_round)

        episode.to_tell = to_tell