예제 #1
0
    def validate(self, args):
        split = 'dev'
        self.critic_model.eval()
        total_stats = Statistics()
        print('=' * 20, 'VALIDATION', '=' * 20)
        for scenario in self.scenarios[split][:200]:
            controller = self._get_controller(scenario, split=split)
            example = controller.simulate(args.max_turns, verbose=args.verbose)
            session = controller.sessions[self.training_agent]
            reward = self.get_reward(example, session)

            output, output2 = self._critic_run_batch(session.dialogue.states)
            rewards = self._get_rewards(reward, output.size(0))
            masks = self._get_masks(output.size(0))
            output2 = self._get_utility(rewards, output2, masks,
                                        args.discount_factor)

            loss = F.smooth_l1_loss(output, output2)
            #print('output{}, rewards{}, loss{}'.format(output, rewards, loss))

            stats = Statistics(loss=loss.float())
            total_stats.update(stats)
        print('=' * 20, 'END VALIDATION', '=' * 20)
        self.critic_model.train()
        return total_stats
예제 #2
0
    def update(self, batch_iter, reward, model, discount=0.95):
        model.train()

        nll = []
        # batch_iter gives a dialogue
        stats = Statistics()
        dec_state = None
        for batch in batch_iter:

            # print("batch: \nencoder{}\ndecoder{}\ntitle{}\ndesc{}".format(batch.encoder_inputs.shape, batch.decoder_inputs.shape, batch.title_inputs.shape, batch.desc_inputs.shape))
            # if enc_state is not None:
            #     print("state: {}".format(batch, enc_state[0].shape))

            policy, price, pvar = self._run_batch(
                batch)  # (seq_len, batch_size, rnn_size)
            loss, batch_stats = self._compute_loss(batch,
                                                   policy=policy,
                                                   price=(price, pvar),
                                                   loss=self.train_loss)
            stats.update(batch_stats)

            loss = loss.view(-1)
            nll.append(loss)

            # TODO: Don't backprop fully.
            # if dec_state is not None:
            #     dec_state.detach()

        # print('allnll ', nll)

        nll = torch.cat(nll)  # (total_seq_len, batch_size)

        rewards = [Variable(torch.ones(1, 1) * (reward))]
        for i in range(1, nll.size(0)):
            rewards.append(rewards[-1] * discount)
        rewards = rewards[::-1]
        rewards = torch.cat(rewards)
        # print('rl shapes',nll.shape, rewards.shape)

        if self.cuda:
            loss = nll.view(-1).mul(rewards.view(-1).cuda()).mean()
        else:
            loss = nll.view(-1).mul(rewards.view(-1)).mean()

        model.zero_grad()
        loss.backward()
        nn.utils.clip_grad_norm(model.parameters(), 1.)
        self.optim.step()

        return torch.cat([
            loss.view(-1),
            nll.mean().view(-1),
            torch.ones(1, device=loss.device) * stats.mean_loss(0),
            torch.ones(1, device=loss.device) * stats.mean_loss(1)
        ], ).view(1, -1).cpu().data.numpy()
예제 #3
0
 def validate(self, args):
     split = 'dev'
     self.model.eval()
     total_stats = Statistics()
     print '='*20, 'VALIDATION', '='*20
     for scenario in self.scenarios[split][:200]:
         controller = self._get_controller(scenario, split=split)
         example = controller.simulate(args.max_turns, verbose=args.verbose)
         session = controller.sessions[self.training_agent]
         reward = self.get_reward(example, session)
         stats = Statistics(reward=reward)
         total_stats.update(stats)
     print '='*20, 'END VALIDATION', '='*20
     self.model.train()
     return total_stats
예제 #4
0
 def validate(self, args):
     split = 'dev'
     self.model.eval()
     total_stats = Statistics()
     print('=' * 20, 'VALIDATION', '=' * 20)
     for scenario in self.scenarios[split][:200]:
         controller = self._get_controller(scenario, split=split)
         # Training_agent is the ToM model
         controller.sessions[self.training_agent].set_controller(controller)
         example = controller.simulate(args.max_turns, verbose=args.verbose)
         if args.verbose:
             strs = example.to_text()
             for str in strs:
                 print(str)
             print("reward: [0]{} [1]{}".format(self.all_rewards[0][-1],
                                                self.all_rewards[1][-1]))
         session = controller.sessions[self.training_agent]
         reward = self.get_reward(example, session)
         stats = Statistics(reward=reward)
         total_stats.update(stats)
     print('=' * 20, 'END VALIDATION', '=' * 20)
     return total_stats
예제 #5
0
 def _stats(self, loss, word_num):
     return Statistics(loss.item(), 0, word_num, 0, 0)