def talk(self,
             max_diag_length,
             batch_input_data,
             batch_input_kb,
             agent1,
             agent2,
             worker_step,
             batch_size,
             speaker=None):
        """The main procedure to generate a single self play conversation."""
        # parse data
        bs_intent, bs_truth_action, bs_kb = self.parse_input(
            batch_input_data, batch_input_kb)
        # remember the roles of agents
        self.agents = [agent1, agent2]
        # In selfplay training the speaker will be non and we randomly chose an
        # initial speaker and initialize utterance.
        # In selfplay evaluation the speaker will be specified so we use as is
        if not speaker: speaker = int(np.random.random() < 0.5)
        # generate the conversation instance for this conversation.
        # print ('self.batch_size', self.batch_size)
        conv = Conversation(max_diag_length, self.turn1_token,
                            self.turn2_token, batch_size, speaker)

        # generate conversation by turn in batch mode until all conversations
        # terminated (finished = True) or the number of turns reached the maximum.
        turn = 0
        finished = False
        while (not finished) and turn < self.max_dialogue_turns:
            finished = self.generate_utterance(bs_intent, conv, bs_kb, speaker,
                                               turn, batch_size)
            #  Change the speaker as we move to the next turn.
            speaker = (speaker + 1) % 2
            turn += 1

        all_rewards = dialogue_utils.compute_reward_batch(
            conv.utt_arr, conv.action_arr, bs_truth_action, bs_kb,
            self.hparams)
        metrics = dialogue_utils.calculate_reward_metrics(all_rewards)
        metrics['num_turns'] = turn

        #  print out step stats only in debug mode
        if self.summary_writer and self.hparams.debug:
            for key in metrics:
                utils.add_summary(self.summary_writer, worker_step,
                                  self.dialogue_mode + '_' + key + '_ws',
                                  metrics[key])

        utt_arr, bs_pred_action = conv.get_train_data()

        if self.hparams.debug:
            print('self_play debug: ' + bs_intent[0])
            print('self_play debug: all_rewards', all_rewards[0])
            print('self_play debug: ' + ' '.join(utt_arr[0]))
            print('self_play debug: ' + ' '.join(bs_pred_action[0]))
            sys.stdout.flush()
        return (bs_intent, bs_pred_action, bs_truth_action, utt_arr,
                bs_kb), turn, metrics
    def maybe_train(self, sample, speaker, global_step, force=False):
        self.train_samples.append(sample)
        if force or len(self.train_samples) >= self.train_threadhold:
            # first generate training examples
            data_arr = []
            kb_arr = []
            for sample in self.train_samples:  # each sample is a batch of data
                intent, pred_action, truth_action, utterance, kb = sample  # batch version
                all_rewards = dialogue_utils.compute_reward_batch(
                    utterance, pred_action, truth_action, kb,
                    self.hparams)  # batch version
                train_reward, _, _, _, _, _, _, _, _ = all_rewards
                final_reward = train_reward
                reward_diag, reward_action = self.scale_reward_batch(
                    final_reward, self.gamma, utterance)  # in batches
                flat_pred_action = []
                for k in range(len(pred_action)):
                    flat_pred_action.append(' '.join(pred_action[k]))

                new_data_arr = self.format_samples_batch(
                    batch_intent=intent,
                    batch_pred_action=flat_pred_action,
                    batch_truth_action=truth_action,
                    batch_utterance=utterance,
                    batch_reward_diag=reward_diag,
                    batch_reward_action=reward_action,
                    batch_size=self.update_batch_size)
                data_arr.extend(new_data_arr)
                kb_arr.extend(kb)
            data_output, kb_output = data_arr, kb_arr
            new_global_step = None
            self.train_samples = []  # clean up
            self_play_hangle = self.mutable_handles[self.iterator_mode]
            if self.hparams.rl_training:
                new_global_step = self.do_rl_training(data_output, kb_output,
                                                      self.update_batch_size,
                                                      self.mutable_model,
                                                      self.mutable_sess,
                                                      speaker, global_step,
                                                      self_play_hangle)

            print('self.hparams.self_play_sl_multiplier=',
                  self.hparams.self_play_sl_multiplier)
            if self.hparams.self_play_sl_multiplier >= 0:  # train multiple or don't train at all
                print('do', self.hparams.self_play_sl_multiplier,
                      'supervised training')
                for _ in range(self.hparams.self_play_sl_multiplier):
                    new_global_step = self.do_SL_training(
                        self.mutable_model, self.mutable_sess, global_step,
                        self.mutable_handles[0])
            else:
                print('do one supervised traiing')
                if self.train_counter >= abs(
                        self.hparams.self_play_sl_multiplier):
                    new_global_step = self.do_SL_training(
                        self.mutable_model, self.mutable_sess, global_step,
                        self.mutable_handles[0])
                    self.train_counter = 0
                else:
                    self.train_counter += 1

            if self.summary_writer:
                utils.add_summary(
                    self.summary_writer, new_global_step,
                    self.dialogue_mode + '_' + 'sl_rl',
                    self.num_sl_updates * 1.0 / (self.num_rl_updates + 0.0001))

            return new_global_step
        return None