Ejemplo n.º 1
0
    def observe(self, observation):
        """Update state, given a previous reply by other agent. In case of
        questioner, it can be goal description at start of episode.
        """
        self.observation = observation

        # if episode not done, tokenize, embed and update state
        # at the end of dialog episode, perform backward pass and step
        if not observation.get('episode_done', False):
            text_tokens = self.tokenize(observation['text'])
            token_embeds = self.listen_net(text_tokens)
            if 'image' in observation:
                token_embeds = torch.cat((token_embeds, observation['image']), 1)
                token_embeds = token_embeds.squeeze(1)
            self.h_state, self.c_state = self.state_net(token_embeds,
                                                        (self.h_state, self.c_state))
        else:
            if observation.get('reward', None):
                for action in self.actions:
                    action.reinforce(observation['reward'])
                autograd_backward(self.actions, [None for _ in self.actions],
                                  retain_graph=True)
                # clamp all gradients between (-5, 5)
                for module in self.modules:
                    for parameter in module.parameters():
                        parameter.grad.data.clamp_(min=-5, max=5)
                optimizer.step()
            else:
                # start of dialog episode
                optimizer.zero_grad()
                self.reset()
Ejemplo n.º 2
0
    def observe(self, observation):
        """Update state, given a previous reply by other agent. In case of
        questioner, it can be goal description at start of episode.
        """
        self.observation = observation

        # if episode not done, tokenize, embed and update state
        # at the end of dialog episode, perform backward pass and step
        if not observation.get('episode_done', False):
            text_tokens = self.tokenize(observation['text'])
            token_embeds = self.listen_net(text_tokens)
            if 'image' in observation:
                token_embeds = torch.cat((token_embeds, observation['image']),
                                         1)
                token_embeds = token_embeds.squeeze(1)
            self.h_state, self.c_state = self.state_net(
                token_embeds, (self.h_state, self.c_state))
        else:
            if observation.get('reward', None):
                for action in self.actions:
                    action.reinforce(observation['reward'])
                autograd_backward(self.actions, [None for _ in self.actions],
                                  retain_graph=True)
                # clamp all gradients between (-5, 5)
                for module in self.modules:
                    for parameter in module.parameters():
                        parameter.grad.data.clamp_(min=-5, max=5)
                optimizer.step()
            else:
                # start of dialog episode
                optimizer.zero_grad()
                self.reset()
Ejemplo n.º 3
0
    def observe(self, observation):
        """Given an input token, interact for next round."""
        self.observation = observation
        if not observation.get('episode_done'):
            # embed and pass through LSTM
            token_embeds = self.listen_net(observation['text'])

            # concat with image representation (valid for abot)
            if 'image' in observation:
                token_embeds = torch.cat((token_embeds, observation['image']),
                                         1)
            # remove all dimensions with size one
            token_embeds = token_embeds.squeeze(1)
            # update agent state using these tokens
            self.h_state, self.c_state = self.rnn(token_embeds,
                                                  (self.h_state, self.c_state))
        else:
            if observation.get('reward') is not None:
                for action in self.actions:
                    action.reinforce(observation['reward'])
                autograd_backward(self.actions, [None for _ in self.actions],
                                  retain_graph=True)

                # clamp all gradients between (-5, 5)
                for parameter in self.parameters():
                    parameter.grad.data.clamp_(min=-5, max=5)