Example #1
0
    def _process_dialogue(self, data):
        new_dlgs = []
        all_sent_lens = []
        all_dlg_lens = []

        for key, raw_dlg in data.items():
            norm_dlg = [Pack(speaker=USR, utt=[BOS, BOD, EOS], bs=[0.0]*self.bs_size, db=[0.0]*self.db_size)]
            for t_id in range(len(raw_dlg['db'])):
                usr_utt = [BOS] + self.tokenize(raw_dlg['usr'][t_id]) + [EOS]
                sys_utt = [BOS] + self.tokenize(raw_dlg['sys'][t_id]) + [EOS]
                norm_dlg.append(Pack(speaker=USR, utt=usr_utt, db=raw_dlg['db'][t_id], bs=raw_dlg['bs'][t_id]))
                norm_dlg.append(Pack(speaker=SYS, utt=sys_utt, db=raw_dlg['db'][t_id], bs=raw_dlg['bs'][t_id]))
                all_sent_lens.extend([len(usr_utt), len(sys_utt)])
            # To stop dialog
            norm_dlg.append(Pack(speaker=USR, utt=[BOS, EOD, EOS], bs=[0.0]*self.bs_size, db=[0.0]*self.db_size))
            # if self.config.to_learn == 'usr':
            #     norm_dlg.append(Pack(speaker=USR, utt=[BOS, EOD, EOS], bs=[0.0]*self.bs_size, db=[0.0]*self.db_size))
            all_dlg_lens.append(len(raw_dlg['db']))
            processed_goal = self._process_goal(raw_dlg['goal'])
            new_dlgs.append(Pack(dlg=norm_dlg, goal=processed_goal, key=key))

        self.logger.info('Max utt len = %d, mean utt len = %.2f' % (
            np.max(all_sent_lens), float(np.mean(all_sent_lens))))
        self.logger.info('Max dlg len = %d, mean dlg len = %.2f' % (
            np.max(all_dlg_lens), float(np.mean(all_dlg_lens))))
        return new_dlgs
Example #2
0
    def _to_id_corpus(self, name, data):
        results = []
        for dlg in data:
            if len(dlg.dlg) < 1:
                continue
            id_dlg = []
            for turn, parsed_turn in zip(dlg.dlg, dlg.parsed_dlg):
                id_turn = Pack(
                    utt=self._sent2id(turn.utt),
                    speaker=turn.speaker,
                    parsed=self._goal2id(parsed_turn),
                )
                id_dlg.append(id_turn)
            id_goal = self._goal2id(dlg.goal)
            id_out = self._outcome2id(dlg.out)

            # data added for debugging and PR
            id_partner_goal = self._goal2id(dlg.usr_goal)

            results.append(
                Pack(
                    dlg=id_dlg,
                    goal=id_goal,
                    out=id_out,
                    partner_goal=id_partner_goal,
                    valid_partner_goals=get_valid_contexts_ints(id_goal),
                    partitions=get_latent_powerset(id_goal),
                    dlg_text=dlg.dlg,
                ))
        return results
Example #3
0
    def forward(self, data_feed, mode, clf=False, gen_type='greedy', use_py=None, return_latent=False, get_marginals=False):
        clf = False
        if not clf:
            ctx_lens = data_feed['context_lens']  # (batch_size, )
            ctx_utts = self.np2var(data_feed['contexts'], LONG)  # (batch_size, max_ctx_len, max_utt_len)
            ctx_confs = self.np2var(data_feed['context_confs'], FLOAT)  # (batch_size, max_ctx_len)
            out_utts = self.np2var(data_feed['outputs'], LONG)  # (batch_size, max_out_len)
            goals = self.np2var(data_feed['goals'], LONG)  # (batch_size, goal_len)
            batch_size = len(ctx_lens)

            # encode goal info
            goals_h = self.goal_encoder(goals)  # (batch_size, goal_nhid)

            enc_inputs, _, _ = self.utt_encoder(ctx_utts, feats=ctx_confs,
                                                goals=goals_h)  # (batch_size, max_ctx_len, num_directions*utt_cell_size)

            # enc_outs: (batch_size, max_ctx_len, ctx_cell_size)
            # enc_last: tuple, (h_n, c_n)
            # h_n: (num_layers*num_directions, batch_size, ctx_cell_size)
            # c_n: (num_layers*num_directions, batch_size, ctx_cell_size)
            enc_outs, enc_last = self.ctx_encoder(enc_inputs, input_lengths=ctx_lens, goals=None)

            # get decoder inputs
            dec_inputs = out_utts[:, :-1]
            labels = out_utts[:, 1:].contiguous()

            # pack attention context
            if self.config.dec_use_attn:
                attn_context = enc_outs
            else:
                attn_context = None

            # create decoder initial states
            dec_init_state = self.connector(enc_last)

            # decode
            dec_outputs, dec_hidden_state, ret_dict = self.decoder(batch_size=batch_size,
                                                                   dec_inputs=dec_inputs,
                                                                   # (batch_size, response_size-1)
                                                                   dec_init_state=dec_init_state,  # tuple: (h, c)
                                                                   attn_context=attn_context,
                                                                   # (batch_size, max_ctx_len, ctx_cell_size)
                                                                   mode=mode,
                                                                   gen_type=gen_type,
                                                                   beam_size=self.config.beam_size,
                                                                   goal_hid=goals_h)  # (batch_size, goal_nhid)

            if get_marginals:
                return Pack(
                    dec_outputs = dec_outputs,
                    labels = labels,
                )
            if mode == GEN:
                return ret_dict, labels
            if return_latent:
                return Pack(nll=self.nll(dec_outputs, labels),
                            latent_action=dec_init_state)
            else:
                return Pack(nll=self.nll(dec_outputs, labels))
Example #4
0
    def forward(self,
                data_feed,
                mode,
                clf=False,
                gen_type='greedy',
                return_latent=False,
                use_py=True):
        ctx_lens = data_feed['context_lens']  # (batch_size, )
        short_ctx_utts = self.np2var(
            self.extract_short_ctx(data_feed['contexts'], ctx_lens), LONG)
        out_utts = self.np2var(data_feed['outputs'],
                               LONG)  # (batch_size, max_out_len)
        bs_label = self.np2var(data_feed['bs'],
                               FLOAT)  # (batch_size, max_ctx_len, max_utt_len)
        db_label = self.np2var(data_feed['db'],
                               FLOAT)  # (batch_size, max_ctx_len, max_utt_len)
        batch_size = len(ctx_lens)

        utt_summary, _, enc_outs = self.utt_encoder(
            short_ctx_utts.unsqueeze(1))

        # get decoder inputs
        dec_inputs = out_utts[:, :-1]
        labels = out_utts[:, 1:].contiguous()

        # pack attention context
        if self.config.dec_use_attn:
            attn_context = enc_outs
        else:
            attn_context = None

        # create decoder initial states
        dec_init_state = self.policy(
            th.cat([bs_label, db_label,
                    utt_summary.squeeze(1)], dim=1)).unsqueeze(0)

        # decode
        if self.config.dec_rnn_cell == 'lstm':
            # h_dec_init_state = utt_summary.squeeze(1).unsqueeze(0)
            dec_init_state = tuple([dec_init_state, dec_init_state])

        dec_outputs, dec_hidden_state, ret_dict = self.decoder(
            batch_size=batch_size,
            dec_inputs=dec_inputs,
            # (batch_size, response_size-1)
            dec_init_state=dec_init_state,  # tuple: (h, c)
            attn_context=attn_context,
            # (batch_size, max_ctx_len, ctx_cell_size)
            mode=mode,
            gen_type=gen_type,
            beam_size=self.config.beam_size)  # (batch_size, goal_nhid)
        if mode == GEN:
            return ret_dict, labels
        if return_latent:
            return Pack(nll=self.nll(dec_outputs, labels),
                        latent_action=dec_init_state)
        else:
            return Pack(nll=self.nll(dec_outputs, labels))
Example #5
0
        def transform(token_list, usr_goal, sys_goal):
            usr, sys = [], []
            parsed_usr, parsed_sys = [], []
            num_proposals = 0
            ptr = 0
            while ptr < len(token_list):
                turn_ptr = ptr
                turn_list = []
                while True:
                    cur_token = token_list[turn_ptr]
                    turn_list.append(cur_token)
                    turn_ptr += 1
                    if cur_token == EOS:
                        ptr = turn_ptr
                        break
                all_sent_lens.append(len(turn_list))
                if turn_list[0] == USR:
                    usr.append(Pack(utt=turn_list, speaker=USR))
                    # assume usr is agent 0
                    prop = parse_c(" ".join(turn_list), usr_goal,
                                   merge=True).proposal
                    parsed = [
                        str(x) for x in ([
                            prop[0]["book"],
                            prop[0]["hat"],
                            prop[0]["ball"],
                            prop[1]["book"],
                            prop[1]["hat"],
                            prop[1]["ball"],
                        ] if prop is not None else [-1] * 6)
                    ]
                    parsed_usr.append(parsed)
                    num_proposals = num_proposals + 1 if prop is not None else num_proposals
                elif turn_list[0] == SYS:
                    sys.append(Pack(utt=turn_list, speaker=SYS))
                    # assume sys is agent 1
                    prop = parse_c(" ".join(turn_list), sys_goal,
                                   merge=True).proposal
                    parsed = [
                        str(x) for x in ([
                            prop[1]["book"],
                            prop[1]["hat"],
                            prop[1]["ball"],
                            prop[0]["book"],
                            prop[0]["hat"],
                            prop[0]["ball"],
                        ] if prop is not None else [-1] * 6)
                    ]
                    parsed_sys.append(parsed)
                    num_proposals = num_proposals + 1 if prop is not None else num_proposals
                else:
                    raise ValueError('Invalid speaker')

            all_dlg_lens.append(len(usr) + len(sys))
            return usr, sys, parsed_usr, parsed_sys, num_proposals
Example #6
0
 def _to_id_corpus(self, name, data):
     results = []
     for dlg in data:
         if len(dlg.dlg) < 1:
             continue
         id_dlg = []
         for turn in dlg.dlg:
             id_turn = Pack(utt=self._sent2id(turn.utt),
                            speaker=turn.speaker,
                            db=turn.db, bs=turn.bs)
             id_dlg.append(id_turn)
         id_goal = self._goal2id(dlg.goal)
         results.append(Pack(dlg=id_dlg, goal=id_goal, key=dlg.key))
     return results
Example #7
0
 def _to_id_corpus(self, name, data):
     results = []
     for dlg in data:
         if len(dlg.dlg) < 1:
             continue
         id_dlg = []
         for turn in dlg.dlg:
             id_turn = Pack(utt=self._sent2id(turn.utt),
                            speaker=turn.speaker)
             id_dlg.append(id_turn)
         id_goal = self._goal2id(dlg.goal)
         id_out = self._outcome2id(dlg.out)
         results.append(Pack(dlg=id_dlg, goal=id_goal, out=id_out))
     return results
Example #8
0
    def forward(self, data_feed, mode, clf=False, gen_type='greedy', use_py=None, return_latent=False):
        ctx_lens = data_feed['context_lens']  # (batch_size, )
        short_ctx_utts = self.np2var(self.extract_short_ctx(data_feed['contexts'], ctx_lens), LONG)
        out_utts = self.np2var(data_feed['outputs'], LONG)  # (batch_size, max_out_len)
        bs_label = self.np2var(data_feed['bs'], FLOAT)  # (batch_size, max_ctx_len, max_utt_len)
        db_label = self.np2var(data_feed['db'], FLOAT)  # (batch_size, max_ctx_len, max_utt_len)
        batch_size = len(ctx_lens)

        utt_summary, _, enc_outs = self.utt_encoder(short_ctx_utts.unsqueeze(1))

        # get decoder inputs
        dec_inputs = out_utts[:, :-1]
        labels = out_utts[:, 1:].contiguous()

        # create decoder initial states
        enc_last = th.cat([bs_label, db_label, utt_summary.squeeze(1)], dim=1)

        # create decoder initial states
        if self.simple_posterior:
            q_mu, q_logvar = self.c2z(enc_last)
            sample_z = self.gauss_connector(q_mu, q_logvar)
            p_mu, p_logvar = self.zero, self.zero
        else:
            p_mu, p_logvar = self.c2z(enc_last)
            # encode response and use posterior to find q(z|x, c)
            x_h, _, _ = self.utt_encoder(out_utts.unsqueeze(1))
            q_mu, q_logvar = self.xc2z(th.cat([enc_last, x_h.squeeze(1)], dim=1))

            # use prior at inference time, otherwise use posterior
            if mode == GEN or use_py:
                sample_z = self.gauss_connector(p_mu, p_logvar)
            else:
                sample_z = self.gauss_connector(q_mu, q_logvar)

        # pack attention context
        dec_init_state = self.z_embedding(sample_z.unsqueeze(0))
        attn_context = None

        # decode
        if self.config.dec_rnn_cell == 'lstm':
            dec_init_state = tuple([dec_init_state, dec_init_state])

        dec_outputs, dec_hidden_state, ret_dict = self.decoder(batch_size=batch_size,
                                                               dec_inputs=dec_inputs,
                                                               dec_init_state=dec_init_state,  # tuple: (h, c)
                                                               attn_context=attn_context,
                                                               mode=mode,
                                                               gen_type=gen_type,
                                                               beam_size=self.config.beam_size)  # (batch_size, goal_nhid)
        if mode == GEN:
            ret_dict['sample_z'] = sample_z
            return ret_dict, labels

        else:
            result = Pack(nll=self.nll(dec_outputs, labels))
            pi_kl = self.gauss_kl(q_mu, q_logvar, p_mu, p_logvar)
            result['pi_kl'] = pi_kl
            result['nll'] = self.nll(dec_outputs, labels)
            return result
    def _prepare_batch_flat(self, selected_index):
        rows = [self.data[idx] for idx in selected_index]

        ctx_utts, ctx_lens = [], []
        out_utts, out_lens = [], []
        goals, goal_lens = [], []

        for row in rows:
            in_row, out_row, goal_row = row.context, row.response, row.goal

            # source context
            batch_ctx = []
            for turn in in_row:
                batch_ctx.append(
                    self.pad_to(self.max_utt_len, turn.utt, do_pad=True))
            ctx_utts.append(batch_ctx)
            ctx_lens.append(len(batch_ctx))

            # target response
            out_utt = [t for idx, t in enumerate(out_row.utt)]
            out_utts.append(out_utt)
            out_lens.append(len(out_utt))

            # goal
            goals.append(goal_row)
            goal_lens.append(len(goal_row))

        vec_ctx_lens = np.array(ctx_lens)  # (batch_size, ), number of turns
        max_ctx_len = np.max(vec_ctx_lens)
        vec_ctx_utts = np.zeros(
            (self.batch_size, max_ctx_len, self.max_utt_len), dtype=np.int32)
        # confs is used to add some hand-crafted features
        vec_ctx_confs = np.ones((self.batch_size, max_ctx_len),
                                dtype=np.float32)
        vec_out_lens = np.array(out_lens)  # (batch_size, ), number of tokens
        max_out_len = np.max(vec_out_lens)
        vec_out_utts = np.zeros((self.batch_size, max_out_len), dtype=np.int32)

        max_goal_len, min_goal_len = max(goal_lens), min(goal_lens)
        if max_goal_len != min_goal_len or max_goal_len != 6:
            print('FATAL ERROR!')
            exit(-1)
        self.goal_len = max_goal_len
        vec_goals = np.zeros((self.batch_size, self.goal_len), dtype=np.int32)

        for b_id in range(self.batch_size):
            vec_ctx_utts[b_id, :vec_ctx_lens[b_id], :] = ctx_utts[b_id]
            vec_out_utts[b_id, :vec_out_lens[b_id]] = out_utts[b_id]
            vec_goals[b_id, :] = goals[b_id]

        return Pack(context_lens=vec_ctx_lens, \
                    contexts=vec_ctx_utts, \
                    context_confs=vec_ctx_confs, \
                    output_lens=vec_out_lens, \
                    outputs=vec_out_utts, \
                    goals=vec_goals)
Example #10
0
        def transform(token_list):
            usr, sys = [], []
            ptr = 0
            while ptr < len(token_list):
                turn_ptr = ptr
                turn_list = []
                while True:
                    cur_token = token_list[turn_ptr]
                    turn_list.append(cur_token)
                    turn_ptr += 1
                    if cur_token == EOS:
                        ptr = turn_ptr
                        break
                all_sent_lens.append(len(turn_list))
                if turn_list[0] == USR:
                    usr.append(Pack(utt=turn_list, speaker=USR))
                elif turn_list[0] == SYS:
                    sys.append(Pack(utt=turn_list, speaker=SYS))
                else:
                    raise ValueError('Invalid speaker')

            all_dlg_lens.append(len(usr) + len(sys))
            return usr, sys
Example #11
0
    def flatten_dialog_seq(self, data, backward_size):
        """
        Turn each dialog in list of dialogs into a list of context, response pairs.

        Backward_size indicates how many previous utterances to condition on.
        This should be limited to 1 or 2 at most limiting dependencies.

        The speaker is SYS, so USR utterances are not modeled.
        """
        results = []
        for dlg in data:
            goal = dlg.goal
            context_responses = []
            parsed_context_responses = []
            for i in range(1, len(dlg.dlg)):
                if dlg.dlg[i].speaker == USR:
                    continue
                e_idx = i
                s_idx = max(0, e_idx - backward_size)
                response = dlg.dlg[i].copy()
                response['utt'] = self.pad_to(self.max_utt_len,
                                              response.utt,
                                              do_pad=False)
                parsed_response = response.parsed
                context = []
                parsed_context = []
                for turn in dlg.dlg[s_idx:e_idx]:
                    turn['utt'] = self.pad_to(self.max_utt_len,
                                              turn.utt,
                                              do_pad=False)
                    context.append(turn)
                    parsed_context.append(turn.parsed)
                context_responses.append(
                    Pack(
                        context=context,
                        response=response,
                        goal=goal,
                        parsed_context=parsed_context,
                        parsed_response=parsed_response,
                        partner_goal=dlg.partner_goal,
                        valid_partner_goals=dlg.valid_partner_goals,
                        partitions=dlg.partitions,
                    ))
            results.append(context_responses)
        return results
Example #12
0
 def flatten_dialog(self, data, backward_size):
     results = []
     for dlg in data:
         goal = dlg.goal
         for i in range(1, len(dlg.dlg)):
             if dlg.dlg[i].speaker == USR:
                 continue
             e_idx = i
             s_idx = max(0, e_idx - backward_size)
             response = dlg.dlg[i].copy()
             response['utt'] = self.pad_to(self.max_utt_len,
                                           response.utt,
                                           do_pad=False)
             context = []
             for turn in dlg.dlg[s_idx:e_idx]:
                 turn['utt'] = self.pad_to(self.max_utt_len,
                                           turn.utt,
                                           do_pad=False)
                 context.append(turn)
             results.append(
                 Pack(context=context, response=response, goal=goal))
     return results
Example #13
0
 def flatten_dialog(self, data, backward_size):
     results = []
     indexes = []
     batch_indexes = []
     resp_set = set()
     for dlg in data:
         goal = dlg.goal
         key = dlg.key
         batch_index = []
         for i in range(1, len(dlg.dlg)):
             if dlg.dlg[i].speaker == USR:
                 continue
             e_idx = i
             s_idx = max(0, e_idx - backward_size)
             response = dlg.dlg[i].copy()
             response['utt'] = self.pad_to(self.max_utt_len,
                                           response.utt,
                                           do_pad=False)
             resp_set.add(json.dumps(response.utt))
             context = []
             for turn in dlg.dlg[s_idx:e_idx]:
                 turn['utt'] = self.pad_to(self.max_utt_len,
                                           turn.utt,
                                           do_pad=False)
                 context.append(turn)
             results.append(
                 Pack(context=context,
                      response=response,
                      goal=goal,
                      key=key))
             indexes.append(len(indexes))
             batch_index.append(indexes[-1])
         if len(batch_index) > 0:
             batch_indexes.append(batch_index)
     print("Unique resp {}".format(len(resp_set)))
     return results, indexes, batch_indexes
Example #14
0
config = Pack(
    random_seed=10,
    train_path='../data/negotiate/train.txt',
    val_path='../data/negotiate/val.txt',
    test_path='../data/negotiate/test.txt',
    last_n_model=4,
    max_utt_len=20,
    #backward_size = 14,
    backward_size=8,
    #batch_size = 16,
    batch_size=4,
    use_gpu=True,
    op='adam',
    init_lr=0.001,
    l2_norm=0.00001,
    momentum=0.0,
    grad_clip=10.0,
    dropout=0.3,
    max_epoch=50,
    embed_size=256,
    num_layers=1,
    #num_layers = 2,
    utt_rnn_cell='gru',
    utt_cell_size=128,
    bi_utt_cell=True,
    enc_use_attn=False,
    ctx_rnn_cell='gru',
    ctx_cell_size=256,
    bi_ctx_cell=False,
    #dec_use_attn = True,
    dec_use_attn=False,
    dec_rnn_cell=
    'gru',  # must be same as ctx_cell_size due to the passed initial state
    dec_cell_size=
    256,  # must be same as ctx_cell_size due to the passed initial state
    dec_attn_mode='cat',
    #
    beam_size=20,
    fix_train_batch=False,
    avg_type='real_word',
    print_step=100,
    ckpt_step=400,
    #ckpt_step = 2523,
    improve_threshold=0.996,
    patient_increase=2.0,
    save_model=True,
    early_stop=False,
    gen_type='greedy',
    preview_batch_num=50,
    max_dec_len=40,
    k=domain_info.input_length(),
    goal_embed_size=64,
    goal_nhid=64,
    init_range=0.1,
    pretrain_folder='2019-12-08-18-45-47-sl_word_dlg_num',
    forward_only=False,
    #forward_only = True,
    # different batching style
    seq=True,
    # use oracle context and proposal parse
    oracle_context=True,
    #oracle_context = False,
    #oracle_parse = False,
    oracle_parse=True,
    semisupervised=False,
    #prop_weight = 0.1,
    prop_weight=1,
    #prop_weight = 0,
    tie_prop_utt_enc=False,
)
Example #15
0
    def forward(self,
                data_feed,
                mode,
                clf=False,
                gen_type='greedy',
                use_py=None,
                return_latent=False,
                get_marginals=False):
        ctx_lens = data_feed['context_lens']  # (batch_size, )
        ctx_utts = self.np2var(data_feed['contexts'],
                               LONG)  # (batch_size, max_ctx_len, max_utt_len)
        ctx_confs = self.np2var(data_feed['context_confs'],
                                FLOAT)  # (batch_size, max_ctx_len)
        out_utts = self.np2var(data_feed['outputs'],
                               LONG)  # (batch_size, max_out_len)
        goals = self.np2var(data_feed['goals'], LONG)  # (batch_size, goal_len)
        batch_size = len(ctx_lens)

        # encode goal info
        goals_h = self.goal_encoder(goals)  # (batch_size, goal_nhid)

        enc_inputs, _, _ = self.utt_encoder(
            ctx_utts,
            feats=ctx_confs,
            goals=
            goals_h,  # (batch_size, max_ctx_len, num_directions*utt_cell_size)
        )

        # enc_outs: (batch_size, max_ctx_len, ctx_cell_size)
        # enc_last: tuple, (h_n, c_n)
        # h_n: (num_layers*num_directions, batch_size, ctx_cell_size)
        # c_n: (num_layers*num_directions, batch_size, ctx_cell_size)
        enc_outs, enc_last = self.ctx_encoder(enc_inputs,
                                              input_lengths=ctx_lens,
                                              goals=None)

        partitions = self.np2var(data_feed.partitions, LONG)
        num_partitions = self.np2var(data_feed.num_partitions, INT)
        # oracle input
        partner_goals = self.np2var(data_feed.true_partner_goals, LONG)
        parsed_outputs = self.np2var(data_feed.parsed_outputs, LONG)
        # true partner item values
        partner_goals_h = self.goal_encoder(partner_goals)

        # proposal prediction
        prop_enc_inputs, _, _ = self.prop_utt_encoder(
            ctx_utts,
            feats=ctx_confs,
            goals=
            goals_h,  # (batch_size, max_ctx_len, num_directions*utt_cell_size)
        )

        # enc_outs: (batch_size, max_ctx_len, ctx_cell_size)
        # enc_last: tuple, (h_n, c_n)
        # h_n: (num_layers*num_directions, batch_size, ctx_cell_size)
        # c_n: (num_layers*num_directions, batch_size, ctx_cell_size)
        prop_enc_outs, prop_enc_last = self.prop_ctx_encoder(
            enc_inputs if self.config.tie_prop_utt_enc else prop_enc_inputs,
            input_lengths=ctx_lens,
            goals=partner_goals_h if self.config.oracle_context else None,
        )

        my_state_emb_out = self.res_layer_out(
            th.cat([
                self.book_emb_out(partitions[:, :, 0]),
                self.hat_emb_out(partitions[:, :, 1]),
                self.ball_emb_out(partitions[:, :, 2]),
            ], -1))
        your_state_emb_out = self.res_layer_out(
            th.cat([
                self.book_emb_out(partitions[:, :, 3]),
                self.hat_emb_out(partitions[:, :, 4]),
                self.ball_emb_out(partitions[:, :, 5]),
            ], -1))
        state_emb_out = th.cat([my_state_emb_out, your_state_emb_out], -1)

        big_goals_h = self.res_goal_mlp(
            th.cat([
                goals_h.unsqueeze(1).expand(
                    state_emb_out.shape[0],
                    state_emb_out.shape[1],
                    goals_h.shape[-1],
                ),
                state_emb_out,
            ], -1))
        import pdb
        pdb.set_trace()

        z_size = partitions.shape[1]

        prop_mask = (partitions == parsed_outputs.unsqueeze(1)).all(-1)
        logits_prop = th.einsum("nsh,nh->ns", state_emb_out, prop_enc_last[-1])
        mask = ~(th.arange(
            z_size,
            device=num_partitions.device, dtype=num_partitions.dtype).repeat(
                partitions.shape[0], 1) < num_partitions.unsqueeze(-1))
        logp_prop = logits_prop.masked_fill(mask, float("-inf")).log_softmax(
            -1)  # get decoder inputs

        if self.config.semisupervised:
            # re-use params or make new ones? re-using can only hurt
            # TODO: use new parameters
            my_state_emb = self.res_layer(
                th.cat([
                    self.book_emb(partitions[:, :, 0]),
                    self.hat_emb(partitions[:, :, 1]),
                    self.ball_emb(partitions[:, :, 2]),
                ], -1))
            your_state_emb = self.res_layer(
                th.cat([
                    self.book_emb(partitions[:, :, 3]),
                    self.hat_emb(partitions[:, :, 4]),
                    self.ball_emb(partitions[:, :, 5]),
                ], -1))
            noise_state_emb = th.cat([my_state_emb, your_state_emb], -1)
            logp_tprop_prop = th.einsum("nth,nsh->nts", noise_state_emb,
                                        state_emb_out).log_softmax(1)
            nll_prop = -self.config.prop_weight * (
                logp_tprop_prop +
                logp_prop.unsqueeze(-2)).logsumexp(-1)[prop_mask].mean()
        else:
            nll_prop = -self.config.prop_weight * logp_prop[prop_mask].mean()

        dec_inputs = out_utts[:, :-1]
        labels = out_utts[:, 1:].contiguous()

        # pack attention context
        if self.config.dec_use_attn:
            attn_context = enc_outs
        else:
            attn_context = None

        # create decoder initial states
        dec_init_state = self.connector(
            enc_last) if self.out_backward_size is not None else None

        if mode == GEN:
            N, Z, H = big_goals_h.shape
            if gen_type == "sampled":
                sampled_proposal_indices = logp_prop.exp().multinomial(1)
            elif gen_type == "greedy":
                sampled_proposal_indices = logp_prop.argmax(-1)
            else:
                raise ValueError(f"Unknown gen_type: {gen_type}")
            sampled_goals_h = big_goals_h.gather(
                1,
                sampled_proposal_indices.view(N, 1, 1).expand(N, 1,
                                                              H)).squeeze(1)
            dec_outputs, dec_hidden_state, ret_dict = self.decoder(
                batch_size=batch_size,
                dec_inputs=dec_inputs,
                # (batch_size, response_size-1)
                dec_init_state=dec_init_state,
                attn_context=attn_context,
                # (batch_size, max_ctx_len, ctx_cell_size)
                mode=mode,
                gen_type=gen_type,
                beam_size=self.config.beam_size,
                # my goal, your goal, and the proposal!!! a lot
                goal_hid=sampled_goals_h,
                #goal_hid=big_goals_h[prop_mask],
            )  # (batch_size, goal_nhid)
            return ret_dict, labels

        # decode
        N, T = dec_inputs.shape
        dec_outputs, dec_hidden_state, ret_dict = self.decoder(
            batch_size=batch_size * z_size,
            dec_inputs=dec_inputs.repeat(1, z_size).view(-1, T),
            # (batch_size, response_size-1)
            dec_init_state=dec_init_state.repeat(1, 1, z_size).view(
                1, z_size *
                batch_size, -1) if dec_init_state is not None else None,
            attn_context=attn_context,
            # (batch_size, max_ctx_len, ctx_cell_size)
            mode=mode,
            gen_type=gen_type,
            beam_size=self.config.beam_size,
            # my goal, your goal, and the proposal!!! a lot
            goal_hid=big_goals_h.view(-1, 128),
        )  # (batch_size, goal_nhid)
        V = dec_outputs.shape[-1]
        #logp_w_prop = dec_outputs.view(N, z_size, T, V)[prop_mask]
        T_out = dec_outputs.shape[-2]
        logp_w_prop = (dec_outputs.view(N, z_size, T_out, V) +
                       logp_prop.view(N, z_size, 1, 1))
        logp_w = logp_w_prop.logsumexp(1)

        if get_marginals:
            N, Z, T, V = logp_w_prop.shape
            logp_prop_w = logp_w_prop.gather(
                -1,
                labels.view(N, 1, T, 1).expand(N, Z, T, 1),
            ).squeeze(-1).sum(-1).log_softmax(1)

            best_prop_model = logp_prop_w.argmax(-1)
            parsed_prop = prop_mask.argmax(-1)
            if self.config.semisupervised:
                logp_tprop = (logp_tprop_prop +
                              logp_prop.unsqueeze(-2)).logsumexp(-1)
                best_tprop_model = logp_tprop.argmax(-1)

            out_utts_text = [[
                self.vocab[x] for x in xs if x != self.vocab_dict["<pad>"]
            ] for xs in out_utts]
            ctx_utts_text = [[
                self.vocab[x] for xs in xss for x in xs
                if x != self.vocab_dict["<pad>"]
            ] for xss in ctx_utts]

            def get(i):
                print(partitions[i][best_prop_model[i]])
                if self.config.semisupervised:
                    print(partitions[i][best_tprop_model[i]])
                print(partitions[i][parsed_prop[i]])
                print(" ".join(out_utts_text[i]))
                print(" ".join(ctx_utts_text[i]))

            #import pdb; pdb.set_trace()
            """
            return Pack(
                dec_outputs = dec_outputs,
                labels = labels,
                logp_prop = logp_prop,
                log_marginals_prop = log_marginals_prop,
                logp_w_prop = logp_w_prop,
                logp_w = logp_w,
            )
            """
            return Pack(
                nll=self.nll(logp_w, labels),
                nll_prop=nll_prop,
            )
        if mode == GEN:
            return ret_dict, labels
        if return_latent:
            return Pack(nll=self.nll(logp_w, labels),
                        latent_action=dec_init_state)
        else:
            return Pack(nll=self.nll(logp_w, labels), nll_prop=nll_prop)
Example #16
0
config = Pack(
    train_path='../data/negotiate/train.txt',
    val_path='../data/negotiate/val.txt',
    test_path='../data/negotiate/test.txt',
    last_n_model=4,
    max_utt_len=20,
    #backward_size = 14,
    backward_size=8,
    #backward_size = 1,
    #batch_size = 32,
    batch_size=4,
    grad_clip=10.0,
    use_gpu=True,
    op='adam',
    init_lr=0.001,
    l2_norm=0.00001,
    momentum=0.0,
    dropout=0.3,
    max_epoch=100,
    embed_size=256,
    #num_layers = 1,
    num_layers=2,
    utt_rnn_cell='gru',
    utt_cell_size=128,
    bi_utt_cell=True,
    enc_use_attn=False,
    ctx_rnn_cell='gru',
    ctx_cell_size=256,
    bi_ctx_cell=False,
    z_size=128,
    #beta = 0.01,
    #simple_posterior = False,
    #use_pr = True,
    dec_use_attn=False,
    dec_rnn_cell=
    'gru',  # must be same as ctx_cell_size due to the passed initial state
    dec_cell_size=
    256,  # must be same as ctx_cell_size due to the passed initial state
    dec_attn_mode='cat',
    #
    fix_train_batch=False,
    fix_batch=False,
    beam_size=20,
    avg_type='real_word',
    print_step=100,
    ckpt_step=400,
    #ckpt_step = 2523,
    improve_threshold=0.996,
    patient_increase=2.0,
    save_model=True,
    early_stop=False,
    gen_type='greedy',
    preview_batch_num=1,
    max_dec_len=40,
    k=domain_info.input_length(),
    goal_embed_size=64,
    goal_nhid=64,
    init_range=0.1,
    pretrain_folder='2019-12-06-02-20-58-sl_hmm',
    forward_only=False,
    #forward_only = True,
    # options for sequence LVMs
    seq=True,
    noisy_proposal_labels=True,
    sup_proposal_labels=False,
    #sup_proposal_labels = True,
    label_weight=0.1,
    #label_weight = 1,
)
Example #17
0
    def forward(self, data_feed, mode, clf=False, gen_type='greedy', use_py=None, return_latent=False):
        ctx_lens = data_feed['context_lens']  # (batch_size, )
        ctx_utts = self.np2var(data_feed['contexts'], LONG)  # (batch_size, max_ctx_len, max_utt_len)
        out_utts = self.np2var(data_feed['outputs'], LONG)  # (batch_size, max_out_len)
        goals = self.np2var(data_feed['goals'], LONG)  # (batch_size, goal_len)
        batch_size = len(ctx_lens)

        # encode goal info
        goals_h = self.goal_encoder(goals)  # (batch_size, goal_nhid)

        enc_inputs, _, _ = self.utt_encoder(ctx_utts, goals=goals_h)
        # (batch_size, max_ctx_len, num_directions*utt_cell_size)

        # enc_outs: (batch_size, max_ctx_len, ctx_cell_size)
        # enc_last: tuple, (h_n, c_n)
        # h_n: (num_layers*num_directions, batch_size, ctx_cell_size)
        # c_n: (num_layers*num_directions, batch_size, ctx_cell_size)
        enc_outs, enc_last = self.ctx_encoder(enc_inputs, input_lengths=ctx_lens, goals=None)

        # get decoder inputs
        dec_inputs = out_utts[:, :-1]
        labels = out_utts[:, 1:].contiguous()

        # create decoder initial states
        if self.simple_posterior:
            logits_qy, log_qy = self.c2z(enc_last)
            sample_y = self.gumbel_connector(logits_qy)
            log_py = self.log_uniform_y
        else:
            logits_py, log_py = self.c2z(enc_last)
            # encode response and use posterior to find q(z|x, c)
            x_h, _, _ = self.utt_encoder(out_utts.unsqueeze(1), goals=goals_h)
            logits_qy, log_qy = self.xc2z(th.cat([enc_last, x_h.squeeze(1).unsqueeze(0)], dim=2))
            # use prior at inference time, otherwise use posterior
            if mode == GEN or use_py:
                sample_y = self.gumbel_connector(logits_py)
            else:
                sample_y = self.gumbel_connector(logits_qy)

        # pack attention context
        if self.config.dec_use_attn:
            z_embeddings = th.t(self.z_embedding.weight).split(self.config.k_size, dim=0)
            attn_context = []
            temp_sample_y = sample_y.view(-1, self.config.y_size, self.config.k_size)
            for z_id in range(self.config.y_size):
                attn_context.append(th.mm(temp_sample_y[:, z_id], z_embeddings[z_id]).unsqueeze(1))
            attn_context = th.cat(attn_context, dim=1)
            dec_init_state = th.sum(attn_context, dim=1).unsqueeze(0)
        else:
            attn_context = None
            dec_init_state = self.z_embedding(sample_y.view(1, -1, self.config.y_size * self.config.k_size))

        # decode
        dec_outputs, dec_hidden_state, ret_dict = self.decoder(batch_size=batch_size,
                                                               dec_inputs=dec_inputs,
                                                               # (batch_size, response_size-1)
                                                               dec_init_state=dec_init_state,  # tuple: (h, c)
                                                               attn_context=attn_context,
                                                               # (batch_size, max_ctx_len, ctx_cell_size)
                                                               mode=mode,
                                                               gen_type=gen_type,
                                                               beam_size=self.config.beam_size,
                                                               goal_hid=goals_h)  # (batch_size, goal_nhid)


        if mode == GEN:
            return ret_dict, labels
        else:
            # regularization qy to be uniform
            avg_log_qy = th.exp(log_qy.view(-1, self.config.y_size, self.config.k_size))
            avg_log_qy = th.log(th.mean(avg_log_qy, dim=0) + 1e-15)
            mi = self.entropy_loss(avg_log_qy, unit_average=True) - self.entropy_loss(log_qy, unit_average=True)
            pi_kl = self.cat_kl_loss(log_qy, log_py, batch_size, unit_average=True)
            pi_h = self.entropy_loss(log_qy, unit_average=True)
            results = Pack(nll=self.nll(dec_outputs, labels), mi=mi, pi_kl=pi_kl, pi_h=pi_h)

            if return_latent:
                results['latent_action'] = dec_init_state

            return results
Example #18
0
    def forward(self, data_feed, mode, clf=False, gen_type='greedy', use_py=None, return_latent=False):
        ctx_lens = data_feed['context_lens']  # (batch_size, )
        ctx_utts = self.np2var(data_feed['contexts'], LONG)  # (batch_size, max_ctx_len, max_utt_len)
        out_utts = self.np2var(data_feed['outputs'], LONG)  # (batch_size, max_out_len)
        goals = self.np2var(data_feed['goals'], LONG)  # (batch_size, goal_len)
        batch_size = len(ctx_lens)

        # encode goal info
        goals_h = self.goal_encoder(goals)  # (batch_size, goal_nhid)

        enc_inputs, _, _ = self.utt_encoder(ctx_utts, goals=goals_h)
        # (batch_size, max_ctx_len, num_directions*utt_cell_size)

        # enc_outs: (batch_size, max_ctx_len, ctx_cell_size)
        # enc_last: tuple, (h_n, c_n)
        # h_n: (num_layers*num_directions, batch_size, ctx_cell_size)
        # c_n: (num_layers*num_directions, batch_size, ctx_cell_size)
        enc_outs, enc_last = self.ctx_encoder(enc_inputs, input_lengths=ctx_lens, goals=None)

        # get decoder inputs
        dec_inputs = out_utts[:, :-1]
        labels = out_utts[:, 1:].contiguous()

        # create decoder initial states
        if self.simple_posterior:
            q_mu, q_logvar = self.c2z(enc_last)
            sample_z = self.gauss_connector(q_mu, q_logvar)
            p_mu, p_logvar = self.zero, self.zero
        else:
            p_mu, p_logvar = self.c2z(enc_last)
            # encode response and use posterior to find q(z|x, c)
            x_h, _, _ = self.utt_encoder(out_utts.unsqueeze(1),  goals=goals_h)
            q_mu, q_logvar = self.xc2z(th.cat([enc_last, x_h.squeeze(1).unsqueeze(0)], dim=2))

            # use prior at inference time, otherwise use posterior
            if mode == GEN or use_py:
                sample_z = self.gauss_connector(p_mu, p_logvar)
            else:
                sample_z = self.gauss_connector(q_mu, q_logvar)

        # pack attention context
        dec_init_state = self.z_embedding(sample_z)
        attn_context = None

        # decode
        if self.config.dec_rnn_cell == 'lstm':
            dec_init_state = tuple([dec_init_state, dec_init_state])

        # decode
        dec_outputs, dec_hidden_state, ret_dict = self.decoder(batch_size=batch_size,
                                                               dec_inputs=dec_inputs,
                                                               # (batch_size, response_size-1)
                                                               dec_init_state=dec_init_state,  # tuple: (h, c)
                                                               attn_context=attn_context,
                                                               # (batch_size, max_ctx_len, ctx_cell_size)
                                                               mode=mode,
                                                               gen_type=gen_type,
                                                               beam_size=self.config.beam_size,
                                                               goal_hid=goals_h)  # (batch_size, goal_nhid)

        if mode == GEN:
            ret_dict['sample_z'] = sample_z
            return ret_dict, labels

        else:
            result = Pack(nll=self.nll(dec_outputs, labels))
            pi_kl = self.gauss_kl(q_mu, q_logvar, p_mu, p_logvar)
            result['pi_kl'] = pi_kl
            result['nll'] = self.nll(dec_outputs, labels)
            return result
Example #19
0
config = Pack(
    random_seed=10,
    train_path='../data/norm-multi-woz/train_dials.json',
    valid_path='../data/norm-multi-woz/val_dials.json',
    test_path='../data/norm-multi-woz/test_dials.json',
    max_vocab_size=1000,
    max_utt_len=50,
    max_dec_len=50,
    last_n_model=5,
    backward_size=2,
    batch_size=32,
    use_gpu=True,
    op='adam',
    init_lr=0.001,
    l2_norm=1e-05,
    momentum=0.0,
    grad_clip=5.0,
    dropout=0.5,
    max_epoch=50,
    embed_size=100,
    num_layers=1,
    utt_rnn_cell='gru',
    utt_cell_size=300,
    bi_utt_cell=True,
    enc_use_attn=True,
    dec_use_attn=False,
    dec_rnn_cell='lstm',
    # must be same as ctx_cell_size due to the passed initial state
    dec_cell_size=300,
    # must be same as ctx_cell_size due to the passed initial state
    dec_attn_mode='cat',
    #
    beam_size=20,
    fix_batch=True,
    fix_train_batch=False,
    avg_type='word',
    print_step=500,
    ckpt_step=1771,
    improve_threshold=0.996,
    patient_increase=2.0,
    save_model=True,
    early_stop=False,
    gen_type='greedy',
    preview_batch_num=None,
    k=domain_info.input_length(),
    init_range=0.1,
    pretrain_folder='2018-11-13-21-27-21-sys_sl_bdu2resp',
    forward_only=False)
Example #20
0
    def forward(self, data_feed, mode, clf=False, gen_type='greedy', use_py=None, return_latent=False):
        ctx_lens = data_feed['context_lens']  # (batch_size, )
        short_ctx_utts = self.np2var(self.extract_short_ctx(data_feed['contexts'], ctx_lens), LONG)
        out_utts = self.np2var(data_feed['outputs'], LONG)  # (batch_size, max_out_len)
        bs_label = self.np2var(data_feed['bs'], FLOAT)  # (batch_size, max_ctx_len, max_utt_len)
        db_label = self.np2var(data_feed['db'], FLOAT)  # (batch_size, max_ctx_len, max_utt_len)
        batch_size = len(ctx_lens)

        utt_summary, _, enc_outs = self.utt_encoder(short_ctx_utts.unsqueeze(1))

        # get decoder inputs
        dec_inputs = out_utts[:, :-1]
        labels = out_utts[:, 1:].contiguous()

        # create decoder initial states
        enc_last = th.cat([bs_label, db_label, utt_summary.squeeze(1)], dim=1)
        # create decoder initial states
        if self.simple_posterior:
            logits_qy, log_qy = self.c2z(enc_last)
            sample_y = self.gumbel_connector(logits_qy, hard=mode==GEN)
            log_py = self.log_uniform_y
        else:
            logits_py, log_py = self.c2z(enc_last)
            # encode response and use posterior to find q(z|x, c)
            x_h, _, _ = self.utt_encoder(out_utts.unsqueeze(1))
            if self.contextual_posterior:
                logits_qy, log_qy = self.xc2z(th.cat([enc_last, x_h.squeeze(1)], dim=1))
            else:
                logits_qy, log_qy = self.xc2z(x_h.squeeze(1))

            # use prior at inference time, otherwise use posterior
            if mode == GEN or (use_py is not None and use_py is True):
                sample_y = self.gumbel_connector(logits_py, hard=False)
            else:
                sample_y = self.gumbel_connector(logits_qy, hard=True)

        # pack attention context
        if self.config.dec_use_attn:
            z_embeddings = th.t(self.z_embedding.weight).split(self.k_size, dim=0)
            attn_context = []
            temp_sample_y = sample_y.view(-1, self.config.y_size, self.config.k_size)
            for z_id in range(self.y_size):
                attn_context.append(th.mm(temp_sample_y[:, z_id], z_embeddings[z_id]).unsqueeze(1))
            attn_context = th.cat(attn_context, dim=1)
            dec_init_state = th.sum(attn_context, dim=1).unsqueeze(0)
        else:
            dec_init_state = self.z_embedding(sample_y.view(1, -1, self.config.y_size * self.config.k_size))
            attn_context = None

        # decode
        if self.config.dec_rnn_cell == 'lstm':
            dec_init_state = tuple([dec_init_state, dec_init_state])

        dec_outputs, dec_hidden_state, ret_dict = self.decoder(batch_size=batch_size,
                                                               dec_inputs=dec_inputs,
                                                               # (batch_size, response_size-1)
                                                               dec_init_state=dec_init_state,  # tuple: (h, c)
                                                               attn_context=attn_context,
                                                               # (batch_size, max_ctx_len, ctx_cell_size)
                                                               mode=mode,
                                                               gen_type=gen_type,
                                                               beam_size=self.config.beam_size)  # (batch_size, goal_nhid)
        if mode == GEN:
            ret_dict['sample_z'] = sample_y
            ret_dict['log_qy'] = log_qy
            return ret_dict, labels

        else:
            result = Pack(nll=self.nll(dec_outputs, labels))
            # regularization qy to be uniform
            avg_log_qy = th.exp(log_qy.view(-1, self.config.y_size, self.config.k_size))
            avg_log_qy = th.log(th.mean(avg_log_qy, dim=0) + 1e-15)
            b_pr = self.cat_kl_loss(avg_log_qy, self.log_uniform_y, batch_size, unit_average=True)
            mi = self.entropy_loss(avg_log_qy, unit_average=True) - self.entropy_loss(log_qy, unit_average=True)
            pi_kl = self.cat_kl_loss(log_qy, log_py, batch_size, unit_average=True)
            q_y = th.exp(log_qy).view(-1, self.config.y_size, self.config.k_size)  # b
            p = th.pow(th.bmm(q_y, th.transpose(q_y, 1, 2)) - self.eye, 2)

            result['pi_kl'] = pi_kl

            result['diversity'] = th.mean(p)
            result['nll'] = self.nll(dec_outputs, labels)
            result['b_pr'] = b_pr
            result['mi'] = mi
            return result
Example #21
0
    def _prepare_batch_seq(self, selected_index):
        dlgs = [self.data[idx] for idx in selected_index]

        dlg_idxs, dlg_lens = [], []
        ctx_utts, ctx_lens = [], []
        out_utts, out_lens = [], []
        goals, goal_lens = [], []
        partner_goals_list, num_partner_goals = [], []
        partitions, num_partitions = [], []
        true_partner_goals = []
        parsed_out_utts = []
        parsed_ctx_utts = []

        # flatten dialogs here
        # keep pointers
        for i, rows in enumerate(dlgs):
            dlg_len = 0
            for row in rows:
                in_row, out_row, goal_row = row.context, row.response, row.goal

                # source context
                batch_ctx = []
                for turn in in_row:
                    batch_ctx.append(
                        self.pad_to(self.max_utt_len, turn.utt, do_pad=True))
                ctx_utts.append(batch_ctx)
                ctx_lens.append(len(batch_ctx))

                # target response
                out_utt = [t for idx, t in enumerate(out_row.utt)]
                out_utts.append(out_utt)
                out_lens.append(len(out_utt))

                # goal
                goals.append(goal_row)
                goal_lens.append(len(goal_row))

                # valid partner goals
                partner_goals = row.valid_partner_goals
                partner_goals_list.append(partner_goals)
                num_partner_goals.append(len(partner_goals))

                # partitions
                _partitions = row.partitions
                # list of list of tuples, each tuple is a goal
                # and the inner list represents all possible partner goals
                partitions.append(_partitions)
                num_partitions.append(len(_partitions))

                # dialog index for getting features in sequence model
                dlg_idxs.append(i)

                # true partner goal
                true_partner_goals.append(row.partner_goal)

                # parsed features
                parsed_out_utts.append(out_row.parsed)
                parsed_ctx_utts.append([x.parsed for x in row.context])

                dlg_len += 1

            dlg_lens.append(dlg_len)

        effective_batch_size = len(goals)

        vec_ctx_lens = np.array(ctx_lens)  # (batch_size, ), number of turns
        max_ctx_len = np.max(vec_ctx_lens)
        vec_ctx_utts = np.zeros(
            (effective_batch_size, max_ctx_len, self.max_utt_len),
            dtype=np.int32)
        # confs is used to add some hand-crafted features
        vec_ctx_confs = np.ones((effective_batch_size, max_ctx_len),
                                dtype=np.float32)
        vec_out_lens = np.array(out_lens)  # (batch_size, ), number of tokens
        max_out_len = np.max(vec_out_lens)
        vec_out_utts = np.zeros((effective_batch_size, max_out_len),
                                dtype=np.int32)

        max_goal_len, min_goal_len = max(goal_lens), min(goal_lens)
        if max_goal_len != min_goal_len or max_goal_len != 6:
            print('FATAL ERROR!')
            exit(-1)
        self.goal_len = max_goal_len
        #vec_goals = np.zeros((effective_batch_size, self.goal_len), dtype=np.int32)
        vec_goals = np.array(goals, dtype=np.int32)

        max_partner_goals = max(num_partner_goals)
        vec_partner_goals = np.zeros(
            (effective_batch_size, max_partner_goals, self.goal_len),
            dtype=np.int32,
        )
        vec_num_partner_goals = np.array(num_partner_goals)

        # just always pad to 128, makes things easier
        max_partitions = max(num_partitions)
        #max_partitions = 128
        vec_partitions = np.zeros(
            # 3 item types
            (effective_batch_size, max_partitions, 6),
            dtype=np.int32,
        )
        vec_num_partitions = np.array(num_partitions)

        vec_dlg_idxs = np.array(dlg_idxs, dtype=np.int32)
        vec_dlg_lens = np.array(dlg_lens, dtype=np.int32)

        vec_true_partner_goals = np.array(true_partner_goals, dtype=np.int32)
        vec_parsed_out_utts = np.array(parsed_out_utts, dtype=np.int32)
        vec_parsed_ctx_utts = np.ones(
            (effective_batch_size, max_ctx_len, 6),
            dtype=np.int32,
        ) * 11  # [0,10] is taken for values. no numbers exceed 10, 11 is padding

        #
        for b_id in range(effective_batch_size):
            vec_ctx_utts[b_id, :vec_ctx_lens[b_id], :] = ctx_utts[b_id]
            vec_out_utts[b_id, :vec_out_lens[b_id]] = out_utts[b_id]
            for pg_id in range(num_partner_goals[b_id]):
                vec_partner_goals[b_id,
                                  pg_id, :] = partner_goals_list[b_id][pg_id]
            for p_id in range(num_partitions[b_id]):
                vec_partitions[b_id, p_id, :] = partitions[b_id][p_id]
            vec_parsed_ctx_utts[
                b_id, :vec_ctx_lens[b_id], :] = parsed_ctx_utts[b_id]

        return Pack(
            context_lens=vec_ctx_lens,
            contexts=vec_ctx_utts,
            context_confs=vec_ctx_confs,
            output_lens=vec_out_lens,
            outputs=vec_out_utts,
            goals=vec_goals,
            partner_goals=vec_partner_goals,
            num_partner_goals=vec_num_partner_goals,
            dlg_idxs=vec_dlg_idxs,
            dlg_lens=vec_dlg_lens,
            partitions=vec_partitions,
            num_partitions=vec_num_partitions,
            # oracle values
            true_partner_goals=vec_true_partner_goals,
            parsed_contexts=vec_parsed_ctx_utts,
            parsed_outputs=vec_parsed_out_utts,
        )
Example #22
0
    def _process_dialogue(self, data):
        def transform(token_list):
            usr, sys = [], []
            ptr = 0
            while ptr < len(token_list):
                turn_ptr = ptr
                turn_list = []
                while True:
                    cur_token = token_list[turn_ptr]
                    turn_list.append(cur_token)
                    turn_ptr += 1
                    if cur_token == EOS:
                        ptr = turn_ptr
                        break
                all_sent_lens.append(len(turn_list))
                if turn_list[0] == USR:
                    usr.append(Pack(utt=turn_list, speaker=USR))
                elif turn_list[0] == SYS:
                    sys.append(Pack(utt=turn_list, speaker=SYS))
                else:
                    raise ValueError('Invalid speaker')

            all_dlg_lens.append(len(usr) + len(sys))
            return usr, sys

        new_dlg = []
        all_sent_lens = []
        all_dlg_lens = []
        for raw_dlg in data:
            raw_words = raw_dlg.split()

            # process dialogue text
            cur_dlg = []
            words = raw_words[raw_words.index('<dialogue>') + 1: raw_words.index('</dialogue>')]
            words += [EOS]
            usr_first = True
            if words[0] == SYS:
                words = [USR, BOD, EOS] + words
                usr_first = True
            elif words[0] == USR:
                words = [SYS, BOD, EOS] + words
                usr_first = False
            else:
                print('FATAL ERROR!!! ({})'.format(words))
                exit(-1)
            usr_utts, sys_utts = transform(words)
            for usr_turn, sys_turn in zip(usr_utts, sys_utts):
                if usr_first:
                    cur_dlg.append(usr_turn)
                    cur_dlg.append(sys_turn)
                else:
                    cur_dlg.append(sys_turn)
                    cur_dlg.append(usr_turn)
            if len(usr_utts) - len(sys_utts) == 1:
                cur_dlg.append(usr_utts[-1])
            elif len(sys_utts) - len(usr_utts) == 1:
                cur_dlg.append(sys_utts[-1])

            # process goal (6 digits)
            # FIXME FATAL ERROR HERE !!!
            cur_goal = raw_words[raw_words.index('<partner_input>') + 1: raw_words.index('</partner_input>')]
            # cur_goal = raw_words[raw_words.index('<input>')+1: raw_words.index('</input>')]
            if len(cur_goal) != 6:
                print('FATAL ERROR!!! ({})'.format(cur_goal))
                exit(-1)

            # process outcome (6 tokens)
            cur_out = raw_words[raw_words.index('<output>') + 1: raw_words.index('</output>')]
            if len(cur_out) != 6:
                print('FATAL ERROR!!! ({})'.format(cur_out))
                exit(-1)

            new_dlg.append(Pack(dlg=cur_dlg, goal=cur_goal, out=cur_out))

        print('Max utt len = %d, mean utt len = %.2f' % (
            np.max(all_sent_lens), float(np.mean(all_sent_lens))))
        print('Max dlg len = %d, mean dlg len = %.2f' % (
            np.max(all_dlg_lens), float(np.mean(all_dlg_lens))))
        return new_dlg
    def forward(self,
                data_feed,
                mode,
                clf=False,
                gen_type='greedy',
                use_pz=None,
                return_latent=False):
        ctx_lens = data_feed['context_lens']  # (batch_size, )
        ctx_utts = self.np2var(data_feed['contexts'],
                               LONG)  # (batch_size, max_ctx_len, max_utt_len)
        out_utts = self.np2var(data_feed['outputs'],
                               LONG)  # (batch_size, max_out_len)
        goals = self.np2var(data_feed['goals'], LONG)  # (batch_size, goal_len)
        partitions = self.np2var(data_feed.partitions, LONG)
        num_partitions = self.np2var(data_feed.num_partitions, INT)
        # effective batch size
        batch_size = len(ctx_lens)
        true_batch_size = data_feed.dlg_idxs.max()

        # encode goal info
        goals_h = self.goal_encoder(goals)  # (batch_size, goal_nhid)

        enc_inputs, _, _ = self.utt_encoder(ctx_utts, goals=goals_h)
        # (batch_size, max_ctx_len, num_directions*utt_cell_size)

        # enc_outs: (batch_size, max_ctx_len, ctx_cell_size)
        # enc_last: tuple, (h_n, c_n)
        # h_n: (num_layers*num_directions, batch_size, ctx_cell_size)
        # c_n: (num_layers*num_directions, batch_size, ctx_cell_size)
        enc_outs, enc_last = self.ctx_encoder(enc_inputs,
                                              input_lengths=ctx_lens,
                                              goals=None)

        # get decoder inputs
        dec_inputs = out_utts[:, :-1]
        labels = out_utts[:, 1:].contiguous()

        logits_pz_t, log_pz_t = self.c2z(enc_last)

        # encode response and use posterior to find q(z|x, c)
        x_h, _, _ = self.utt_encoder(out_utts.unsqueeze(1), goals=goals_h)
        logits_qz_t, log_qz_t = self.xc2z(
            th.cat([enc_last, x_h.squeeze(1).unsqueeze(0)], dim=2))

        state_emb = self.res_layer(
            self.item_emb(partitions).view(-1, self.z_size, 3 * 32))
        _, psi_zr_zl = self.hmm_potentials(state_emb, lengths=num_partitions)

        # REMINDER: transpose last two dimensions of HMM for torch_struct
        import pdb
        pdb.set_trace()
        # reshape and run HMM?

        # decode
        dec_outputs, dec_hidden_state, ret_dict = self.decoder(
            batch_size=batch_size,
            dec_inputs=dec_inputs,
            # (batch_size, response_size-1)
            dec_init_state=dec_init_state,  # tuple: (h, c)
            attn_context=attn_context,
            # (batch_size, max_ctx_len, ctx_cell_size)
            mode=mode,
            gen_type=gen_type,
            beam_size=self.config.beam_size,
            goal_hid=goals_h,  # (batch_size, goal_nhid)
        )

        if mode == GEN:
            return ret_dict, labels
        else:
            # regularization qy to be uniform
            avg_log_qy = th.exp(
                log_qy.view(-1, self.config.y_size, self.config.k_size))
            avg_log_qy = th.log(th.mean(avg_log_qy, dim=0) + 1e-15)
            mi = self.entropy_loss(avg_log_qy,
                                   unit_average=True) - self.entropy_loss(
                                       log_qy, unit_average=True)
            pi_kl = self.cat_kl_loss(log_qy,
                                     log_py,
                                     batch_size,
                                     unit_average=True)
            pi_h = self.entropy_loss(log_qy, unit_average=True)
            results = Pack(nll=self.nll(dec_outputs, labels),
                           mi=mi,
                           pi_kl=pi_kl,
                           pi_h=pi_h)

            if return_latent:
                results['latent_action'] = dec_init_state

            return results
Example #24
0
    def forward(self,
                data_feed,
                mode,
                clf=False,
                gen_type='greedy',
                use_py=None,
                return_latent=False,
                get_marginals=False):
        clf = False
        if not clf:
            ctx_lens = data_feed['context_lens']  # (batch_size, )
            ctx_utts = self.np2var(
                data_feed['contexts'],
                LONG)  # (batch_size, max_ctx_len, max_utt_len)
            ctx_confs = self.np2var(data_feed['context_confs'],
                                    FLOAT)  # (batch_size, max_ctx_len)
            out_utts = self.np2var(data_feed['outputs'],
                                   LONG)  # (batch_size, max_out_len)
            goals = self.np2var(data_feed['goals'],
                                LONG)  # (batch_size, goal_len)
            batch_size = len(ctx_lens)

            # encode goal info
            goals_h = self.goal_encoder(goals)  # (batch_size, goal_nhid)

            enc_inputs, _, _ = self.utt_encoder(
                ctx_utts, feats=ctx_confs, goals=goals_h
            )  # (batch_size, max_ctx_len, num_directions*utt_cell_size)

            # enc_outs: (batch_size, max_ctx_len, ctx_cell_size)
            # enc_last: tuple, (h_n, c_n)
            # h_n: (num_layers*num_directions, batch_size, ctx_cell_size)
            # c_n: (num_layers*num_directions, batch_size, ctx_cell_size)
            enc_outs, enc_last = self.ctx_encoder(enc_inputs,
                                                  input_lengths=ctx_lens,
                                                  goals=None)

            partitions = self.np2var(data_feed.partitions, LONG)
            num_partitions = self.np2var(data_feed.num_partitions, INT)
            # oracle input
            partner_goals = self.np2var(data_feed.true_partner_goals, LONG)
            parsed_outputs = self.np2var(data_feed.parsed_outputs, LONG)
            # true partner item values
            partner_goals_h = self.goal_encoder(partner_goals)
            # true next utterance proposal parse
            my_state_emb = self.res_layer(
                th.cat([
                    self.book_emb(parsed_outputs[:, 0]),
                    self.hat_emb(parsed_outputs[:, 1]),
                    self.ball_emb(parsed_outputs[:, 2]),
                ], -1))
            your_state_emb = self.res_layer(
                th.cat([
                    self.book_emb(parsed_outputs[:, 3]),
                    self.hat_emb(parsed_outputs[:, 4]),
                    self.ball_emb(parsed_outputs[:, 5]),
                ], -1))

            if self.config.oracle_context and self.config.oracle_parse:
                big_goals_h = self.res_goal_mlp(
                    th.cat([
                        goals_h,
                        partner_goals_h,
                        my_state_emb,
                        your_state_emb,
                    ], -1))
            elif self.config.oracle_context:
                big_goals_h = self.res_goal_mlp(
                    th.cat([
                        goals_h,
                        partner_goals_h,
                    ], -1))
            elif self.config.oracle_parse:
                big_goals_h = self.res_goal_mlp(
                    th.cat([
                        goals_h,
                        my_state_emb,
                        your_state_emb,
                    ], -1))

            # proposal prediction
            prop_enc_inputs, _, _ = self.prop_utt_encoder(
                ctx_utts, feats=ctx_confs, goals=goals_h
            )  # (batch_size, max_ctx_len, num_directions*utt_cell_size)

            # enc_outs: (batch_size, max_ctx_len, ctx_cell_size)
            # enc_last: tuple, (h_n, c_n)
            # h_n: (num_layers*num_directions, batch_size, ctx_cell_size)
            # c_n: (num_layers*num_directions, batch_size, ctx_cell_size)
            prop_enc_outs, prop_enc_last = self.prop_ctx_encoder(
                enc_inputs, input_lengths=ctx_lens, goals=None)

            my_state_emb_out = self.res_layer_out(
                th.cat([
                    self.book_emb_out(partitions[:, :, 0]),
                    self.hat_emb_out(partitions[:, :, 1]),
                    self.ball_emb_out(partitions[:, :, 2]),
                ], -1))
            your_state_emb_out = self.res_layer_out(
                th.cat([
                    self.book_emb_out(partitions[:, :, 3]),
                    self.hat_emb_out(partitions[:, :, 4]),
                    self.ball_emb_out(partitions[:, :, 5]),
                ], -1))
            state_emb_out = th.cat([my_state_emb_out, your_state_emb_out], -1)

            label_mask = (partitions == parsed_outputs.unsqueeze(1)).all(-1)
            logits_label = th.einsum("nsh,nh->ns", state_emb_out,
                                     prop_enc_last[-1])
            mask = ~(th.arange(partitions.shape[1],
                               device=num_partitions.device,
                               dtype=num_partitions.dtype).repeat(
                                   partitions.shape[0], 1) <
                     num_partitions.unsqueeze(-1))
            logp_label = logits_label.masked_fill(mask,
                                                  float("-inf")).log_softmax(
                                                      -1)  # get decoder inputs
            nll_label = -logp_label[label_mask].mean()

            dec_inputs = out_utts[:, :-1]
            labels = out_utts[:, 1:].contiguous()

            # pack attention context
            if self.config.dec_use_attn:
                attn_context = enc_outs
            else:
                attn_context = None

            # create decoder initial states
            dec_init_state = self.connector(enc_last)

            # decode
            dec_outputs, dec_hidden_state, ret_dict = self.decoder(
                batch_size=batch_size,
                dec_inputs=dec_inputs,
                # (batch_size, response_size-1)
                dec_init_state=dec_init_state,  # tuple: (h, c)
                attn_context=attn_context,
                # (batch_size, max_ctx_len, ctx_cell_size)
                mode=mode,
                gen_type=gen_type,
                beam_size=self.config.beam_size,
                # my goal, your goal, and the proposal!!! a lot
                goal_hid=big_goals_h,
            )  # (batch_size, goal_nhid)

            if get_marginals:
                return Pack(
                    dec_outputs=dec_outputs,
                    labels=labels,
                )
            if mode == GEN:
                return ret_dict, labels
            if return_latent:
                return Pack(nll=self.nll(dec_outputs, labels),
                            latent_action=dec_init_state)
            else:
                return Pack(nll=self.nll(dec_outputs, labels),
                            nll_label=nll_label)
config = Pack(
    train_path = '../data/negotiate/train.txt',
    val_path = '../data/negotiate/val.txt',
    test_path = '../data/negotiate/test.txt',
    last_n_model = 5,
    max_utt_len = 20,
    backward_size = 14,
    batch_size = 32,
    grad_clip=3.0,
    use_gpu = True,
    op = 'adam',
    init_lr = 0.001,
    l2_norm=0.00001,
    momentum = 0.0,
    dropout = 0.5,
    max_epoch = 50,
    embed_size = 256,
    num_layers = 1,
    utt_rnn_cell = 'gru',
    utt_cell_size = 128,
    bi_utt_cell = True,
    enc_use_attn = False,
    ctx_rnn_cell = 'gru',
    ctx_cell_size = 256,
    bi_ctx_cell = False,
    y_size = 200,
    beta = 1.0,
    simple_posterior=False,
    use_pr = True,
    dec_use_attn = False,
    dec_rnn_cell = 'gru', # must be same as ctx_cell_size due to the passed initial state
    dec_cell_size = 256, # must be same as ctx_cell_size due to the passed initial state
    dec_attn_mode = 'cat',
    #
    fix_train_batch=False,
    fix_batch=False,
    beam_size = 20,
    avg_type = 'real_word',
    print_step = 100,
    ckpt_step = 400,
    improve_threshold = 0.996,
    patient_increase = 2.0,
    save_model = True,
    early_stop = False,
    gen_type = 'greedy',
    preview_batch_num = 1,
    max_dec_len = 40,
    k = domain_info.input_length(),
    goal_embed_size = 64,
    goal_nhid = 64,
    init_range = 0.1,
    pretrain_folder = '2018-11-19-21-28-29-sl_latent',
    forward_only = False
)
Example #26
0
def main():
    start_time = time.strftime('%Y-%m-%d-%H-%M-%S',
                               time.localtime(time.time()))
    print('[START]', start_time, '=' * 30)

    # RL configuration
    folder = '2019-06-20-10-24-23-sl_gauss'
    epoch_id = '28'

    env = 'gpu'
    sim_epoch_id = '23'
    simulator_folder = '2019-06-20-09-19-39-sl_word'
    exp_dir = os.path.join('config_log_model', folder, 'rl-' + start_time)
    if not os.path.exists(exp_dir):
        os.mkdir(exp_dir)

    rl_config = Pack(
        train_path='../data/negotiate/train.txt',
        val_path='../data/negotiate/val.txt',
        test_path='../data/negotiate/test.txt',
        selfplay_path='../data/negotiate/selfplay.txt',
        selfplay_eval_path='../data/negotiate/selfplay_eval.txt',
        sim_config_path=os.path.join('config_log_model', simulator_folder,
                                     'config.json'),
        sim_model_path=os.path.join('config_log_model', simulator_folder,
                                    '{}-model'.format(sim_epoch_id)),
        sv_config_path=os.path.join('config_log_model', folder, 'config.json'),
        sv_model_path=os.path.join('config_log_model', folder,
                                   '{}-model'.format(epoch_id)),
        rl_config_path=os.path.join(exp_dir, 'rl_config.json'),
        rl_model_path=os.path.join(exp_dir, 'rl_model'),
        ppl_best_model_path=os.path.join(exp_dir, 'ppl_best_model'),
        reward_best_model_path=os.path.join(exp_dir, 'reward_best_model'),
        judger_model_path=os.path.join('../FB', 'sv_model.th'),
        judger_config_path=os.path.join('../FB', 'judger_config.json'),
        record_path=exp_dir,
        record_freq=50,
        use_gpu=env == 'gpu',
        nepoch=4,
        nepisode=0,
        sv_train_freq=
        0,  # TODO pay attention to main.py, cuz it is also controlled there
        eval_freq=0,
        max_words=100,
        rl_lr=0.2,
        momentum=0.1,
        nesterov=True,
        gamma=0.95,
        rl_clip=1.0,
        ref_text='../data/negotiate/train.txt',
        domain='object_division',
        max_nego_turn=50,
        random_seed=0,
        use_latent_rl=True)

    # save configuration
    with open(rl_config.rl_config_path, 'w') as f:
        json.dump(rl_config, f, indent=4)

    # set random seed
    set_seed(rl_config.random_seed)

    # load previous supervised learning configuration and corpus
    sv_config = Pack(json.load(open(rl_config.sv_config_path)))
    sim_config = Pack(json.load(open(rl_config.sim_config_path)))

    # TODO revise the use_gpu in the config
    sv_config['use_gpu'] = rl_config.use_gpu
    sim_config['use_gpu'] = rl_config.use_gpu
    corpus = DealCorpus(sv_config)

    # load models for two agents
    # TARGET AGENT
    sys_model = models_deal.GaussHRED(corpus, sv_config)
    if sv_config.use_gpu:  # TODO gpu -> cpu transfer
        sys_model.cuda()
    sys_model.load_state_dict(
        th.load(rl_config.sv_model_path,
                map_location=lambda storage, location: storage))
    # we don't want to use Dropout during RL
    sys_model.eval()
    sys = LatentRlAgent(sys_model,
                        corpus,
                        rl_config,
                        name='System',
                        use_latent_rl=rl_config.use_latent_rl)

    # SIMULATOR we keep usr frozen, i.e. we don't update its parameters
    usr_model = models_deal.HRED(corpus, sim_config)
    if sim_config.use_gpu:  # TODO gpu -> cpu transfer
        usr_model.cuda()
    usr_model.load_state_dict(
        th.load(rl_config.sim_model_path,
                map_location=lambda storage, location: storage))
    usr_model.eval()
    usr_type = LstmAgent
    usr = usr_type(usr_model, corpus, rl_config, name='User')

    # load FB judger model
    # load FB judger model
    judger_config = Pack(json.load(open(rl_config.judger_config_path)))
    judger_config['cuda'] = rl_config.use_gpu
    judger_config['data'] = '../data/negotiate'
    judger_device_id = FB_use_cuda(judger_config.cuda)
    judger_word_corpus = FbWordCorpus(judger_config.data,
                                      freq_cutoff=judger_config.unk_threshold,
                                      verbose=True)
    judger_model = FbDialogModel(judger_word_corpus.word_dict,
                                 judger_word_corpus.item_dict,
                                 judger_word_corpus.context_dict,
                                 judger_word_corpus.output_length,
                                 judger_config, judger_device_id)
    if judger_device_id is not None:
        judger_model.cuda(judger_device_id)
    judger_model.load_state_dict(
        th.load(rl_config.judger_model_path,
                map_location=lambda storage, location: storage))
    judger_model.eval()
    judger = Judger(judger_model, judger_device_id)

    # initialize communication dialogue between two agents
    dialog = Dialog([sys, usr], judger, rl_config)
    ctx_gen = ContextGenerator(rl_config.selfplay_path)

    # simulation module
    dialog_eval = DialogEval([sys, usr], judger, rl_config)
    ctx_gen_eval = ContextGeneratorEval(rl_config.selfplay_eval_path)

    # start RL
    reinforce = Reinforce(dialog, ctx_gen, corpus, sv_config, sys_model,
                          usr_model, rl_config, dialog_eval, ctx_gen_eval)
    reinforce.run()

    # save sys model
    th.save(sys_model.state_dict(), rl_config.rl_model_path)

    end_time = time.strftime('%Y-%m-%d-%H-%M-%S', time.localtime(time.time()))
    print('[END]', end_time, '=' * 30)
Example #27
0
config = Pack(
    seed=10,
    train_path='../data/norm-multi-woz/train_dials.json',
    valid_path='../data/norm-multi-woz/val_dials.json',
    test_path='../data/norm-multi-woz/test_dials.json',
    max_vocab_size=1000,
    last_n_model=5,
    max_utt_len=50,
    max_dec_len=50,
    backward_size=2,
    batch_size=32,
    use_gpu=True,
    op='adam',
    init_lr=0.001,
    l2_norm=1e-05,
    momentum=0.0,
    grad_clip=5.0,
    dropout=0.5,
    max_epoch=100,
    embed_size=100,
    num_layers=1,
    utt_rnn_cell='gru',
    utt_cell_size=300,
    bi_utt_cell=True,
    enc_use_attn=True,
    dec_use_attn=True,
    dec_rnn_cell='lstm',
    dec_cell_size=300,
    dec_attn_mode='cat',
    y_size=10,
    k_size=20,
    beta=0.001,
    simple_posterior=True,
    contextual_posterior=True,
    use_mi=False,
    use_pr=True,
    use_diversity=False,
    #
    beam_size=20,
    fix_batch=True,
    fix_train_batch=False,
    avg_type='word',
    print_step=300,
    ckpt_step=1416,
    improve_threshold=0.996,
    patient_increase=2.0,
    save_model=True,
    early_stop=False,
    gen_type='greedy',
    preview_batch_num=None,
    k=domain_info.input_length(),
    init_range=0.1,
    pretrain_folder='2019-06-20-21-43-06-sl_cat',
    forward_only=False)
Example #28
0
    def _prepare_batch(self, selected_index):
        rows = [self.data[idx] for idx in selected_index]

        ctx_utts, ctx_lens = [], []
        out_utts, out_lens = [], []

        out_bs, out_db = [], []
        goals, goal_lens = [], [[] for _ in range(len(self.domains))]
        keys = []

        for row in rows:
            in_row, out_row, goal_row = row.context, row.response, row.goal

            # source context
            keys.append(row.key)
            batch_ctx = []
            for turn in in_row:
                batch_ctx.append(
                    self.pad_to(self.max_utt_len, turn.utt, do_pad=True))
            ctx_utts.append(batch_ctx)
            ctx_lens.append(len(batch_ctx))

            # target response
            out_utt = [t for idx, t in enumerate(out_row.utt)]
            out_utts.append(out_utt)
            out_lens.append(len(out_utt))

            out_bs.append(out_row.bs)
            out_db.append(out_row.db)

            # goal
            goals.append(goal_row)
            for i, d in enumerate(self.domains):
                goal_lens[i].append(len(goal_row[d]))

        batch_size = len(ctx_lens)
        vec_ctx_lens = np.array(ctx_lens)  # (batch_size, ), number of turns
        max_ctx_len = np.max(vec_ctx_lens)
        vec_ctx_utts = np.zeros((batch_size, max_ctx_len, self.max_utt_len),
                                dtype=np.int32)
        vec_out_bs = np.array(out_bs)  # (batch_size, 94)
        vec_out_db = np.array(out_db)  # (batch_size, 30)
        vec_out_lens = np.array(out_lens)  # (batch_size, ), number of tokens
        max_out_len = np.max(vec_out_lens)
        vec_out_utts = np.zeros((batch_size, max_out_len), dtype=np.int32)

        max_goal_lens, min_goal_lens = [max(ls) for ls in goal_lens
                                        ], [min(ls) for ls in goal_lens]
        if max_goal_lens != min_goal_lens:
            print('Fatal Error!')
            exit(-1)
        self.goal_lens = max_goal_lens
        vec_goals_list = [
            np.zeros((batch_size, l), dtype=np.float32) for l in self.goal_lens
        ]

        for b_id in range(batch_size):
            vec_ctx_utts[b_id, :vec_ctx_lens[b_id], :] = ctx_utts[b_id]
            vec_out_utts[b_id, :vec_out_lens[b_id]] = out_utts[b_id]
            for i, d in enumerate(self.domains):
                vec_goals_list[i][b_id, :] = goals[b_id][d]

        return Pack(
            context_lens=vec_ctx_lens,  # (batch_size, )
            contexts=vec_ctx_utts,  # (batch_size, max_ctx_len, max_utt_len)
            output_lens=vec_out_lens,  # (batch_size, )
            outputs=vec_out_utts,  # (batch_size, max_out_len)
            bs=vec_out_bs,  # (batch_size, 94)
            db=vec_out_db,  # (batch_size, 30)
            goals_list=
            vec_goals_list,  # 7*(batch_size, bow_len), bow_len differs w.r.t. domain
            keys=keys)
Example #29
0
def main():
    start_time = time.strftime('%Y-%m-%d-%H-%M-%S',
                               time.localtime(time.time()))
    print('[START]', start_time, '=' * 30)
    # RL configuration
    env = 'gpu'
    pretrained_folder = '2019-06-20-22-49-55-sl_cat'
    pretrained_model_id = 41

    exp_dir = os.path.join('sys_config_log_model', pretrained_folder,
                           'rl-' + start_time)
    # create exp folder
    if not os.path.exists(exp_dir):
        os.mkdir(exp_dir)

    rl_config = Pack(
        train_path='../data/norm-multi-woz/train_dials.json',
        valid_path='../data/norm-multi-woz/val_dials.json',
        test_path='../data/norm-multi-woz/test_dials.json',
        sv_config_path=os.path.join('sys_config_log_model', pretrained_folder,
                                    'config.json'),
        sv_model_path=os.path.join('sys_config_log_model', pretrained_folder,
                                   '{}-model'.format(pretrained_model_id)),
        rl_config_path=os.path.join(exp_dir, 'rl_config.json'),
        rl_model_path=os.path.join(exp_dir, 'rl_model'),
        ppl_best_model_path=os.path.join(exp_dir, 'ppl_best.model'),
        reward_best_model_path=os.path.join(exp_dir, 'reward_best.model'),
        record_path=exp_dir,
        record_freq=200,
        sv_train_freq=
        0,  # TODO pay attention to main.py, cuz it is also controlled there
        use_gpu=env == 'gpu',
        nepoch=10,
        nepisode=0,
        tune_pi_only=False,
        max_words=100,
        temperature=1.0,
        episode_repeat=1.0,
        rl_lr=0.01,
        momentum=0.0,
        nesterov=False,
        gamma=0.99,
        rl_clip=5.0,
        random_seed=100,
    )

    # save configuration
    with open(rl_config.rl_config_path, 'w') as f:
        json.dump(rl_config, f, indent=4)

    # set random seed
    set_seed(rl_config.random_seed)

    # load previous supervised learning configuration and corpus
    sv_config = Pack(json.load(open(rl_config.sv_config_path)))
    sv_config['dropout'] = 0.0
    sv_config['use_gpu'] = rl_config.use_gpu
    corpus = NormMultiWozCorpus(sv_config)

    # TARGET AGENT
    sys_model = SysPerfectBD2Cat(corpus, sv_config)
    if sv_config.use_gpu:
        sys_model.cuda()
    sys_model.load_state_dict(
        th.load(rl_config.sv_model_path,
                map_location=lambda storage, location: storage))
    sys_model.eval()
    sys = OfflineLatentRlAgent(sys_model,
                               corpus,
                               rl_config,
                               name='System',
                               tune_pi_only=rl_config.tune_pi_only)

    # start RL
    reinforce = OfflineTaskReinforce(sys, corpus, sv_config, sys_model,
                                     rl_config, task_generate)
    reinforce.run()

    end_time = time.strftime('%Y-%m-%d-%H-%M-%S', time.localtime(time.time()))
    print('[END]', end_time, '=' * 30)
Example #30
0
    def forward(
        self,
        data_feed,
        mode,
        clf=False,
        gen_type='greedy',
        use_py=None,
        return_latent=False,
        get_marginals=False,
    ):
        ctx_lens = data_feed['context_lens']  # (batch_size, )
        ctx_utts = self.np2var(data_feed['contexts'],
                               LONG)  # (batch_size, max_ctx_len, max_utt_len)
        ctx_confs = self.np2var(data_feed['context_confs'],
                                FLOAT)  # (batch_size, max_ctx_len)
        out_utts = self.np2var(data_feed['outputs'],
                               LONG)  # (batch_size, max_out_len)
        goals = self.np2var(data_feed['goals'], LONG)  # (batch_size, goal_len)
        partitions = self.np2var(data_feed.partitions, LONG)
        num_partitions = self.np2var(data_feed.num_partitions, INT)
        batch_size = len(ctx_lens)

        self.z_size = data_feed.num_partitions.max()

        # oracle
        parsed_outputs = self.np2var(data_feed.parsed_outputs, LONG)
        partner_goals = self.np2var(data_feed.true_partner_goals, LONG)

        # encode goal info
        goals_h = self.goal_encoder(goals)  # (batch_size, goal_nhid)

        enc_inputs, _, _ = self.utt_encoder(
            ctx_utts,
            feats=ctx_confs,
            goals=
            goals_h,  # (batch_size, max_ctx_len, num_directions*utt_cell_size)
        )

        # enc_outs: (batch_size, max_ctx_len, ctx_cell_size)
        # enc_last: tuple, (h_n, c_n)
        # h_n: (num_layers*num_directions, batch_size, ctx_cell_size)
        # c_n: (num_layers*num_directions, batch_size, ctx_cell_size)
        enc_outs, enc_last = self.ctx_encoder(enc_inputs,
                                              input_lengths=ctx_lens,
                                              goals=None)

        # get decoder inputs
        dec_inputs = out_utts[:, :-1]
        labels = out_utts[:, 1:].contiguous()

        # pack attention context
        if self.config.dec_use_attn:
            attn_context = enc_outs
        else:
            attn_context = None

        # create decoder initial states
        dec_init_state = self.connector(enc_last)

        # transition matrix
        ctx_input = self.prior_res_layer(enc_last[-1])
        my_state_emb = self.res_layer(
            th.cat([
                self.book_emb(partitions[:, :, 0]),
                self.hat_emb(partitions[:, :, 1]),
                self.ball_emb(partitions[:, :, 2]),
            ], -1))
        your_state_emb = self.res_layer(
            th.cat([
                self.book_emb(partitions[:, :, 3]),
                self.hat_emb(partitions[:, :, 4]),
                self.ball_emb(partitions[:, :, 5]),
            ], -1))
        state_emb = th.cat([my_state_emb, your_state_emb], -1)
        my_state_emb_out = self.res_layer_out(
            th.cat([
                self.book_emb_out(partitions[:, :, 0]),
                self.hat_emb_out(partitions[:, :, 1]),
                self.ball_emb_out(partitions[:, :, 2]),
            ], -1))
        your_state_emb_out = self.res_layer_out(
            th.cat([
                self.book_emb_out(partitions[:, :, 3]),
                self.hat_emb_out(partitions[:, :, 4]),
                self.ball_emb_out(partitions[:, :, 5]),
            ], -1))
        state_emb_out = th.cat([my_state_emb_out, your_state_emb_out], -1)

        goals_h = self.res_goal_mlp(
            th.cat([
                goals_h.unsqueeze(1).expand(
                    state_emb.shape[0], state_emb.shape[1], goals_h.shape[-1]),
                state_emb,
            ], -1)).view(-1, 128)

        # for noisy labels
        if self.noisy_proposal_labels:
            # transition from state to label
            label_mask = (partitions == parsed_outputs.unsqueeze(1)).all(-1)
            logp_label_z = th.einsum("nsh,nth->nts", state_emb,
                                     state_emb_out).log_softmax(-1)
            # outer dim t should be output label

        phi_zt, psi_zl_zr = self.hmm_potentials(state_emb,
                                                state_emb_out,
                                                ctx_input,
                                                lengths=num_partitions)
        logp_zt = phi_zt.log_softmax(-1)
        logp_zr_zl = psi_zl_zr.log_softmax(-1).transpose(-1, -2)

        # decode
        N, T = dec_inputs.shape
        dec_init_state = enc_last.repeat(1, 1, self.z_size).view(
            self.config.num_layers, N * self.z_size, -1)
        dec_outputs, dec_hidden_state, ret_dict = self.decoder(
            batch_size=batch_size * self.z_size,
            dec_inputs=dec_inputs.repeat(1, 1, self.z_size).view(
                -1, T),  # (batch_size, response_size-1)
            dec_init_state=dec_init_state,  # tuple: (h, c)
            attn_context=None,  # (batch_size, max_ctx_len, ctx_cell_size)
            mode=mode,
            gen_type=gen_type,
            beam_size=self.config.beam_size,
            goal_hid=goals_h,  # (batch_size, goal_nhid)
        )

        BLAM, T, V = dec_outputs.shape
        # all word probs, they need to be summed over
        # `log p(xt) = \sum_i \log p(w_ti)`
        logp_wt_zt = dec_outputs.view(N, self.z_size, T, V).gather(
            -1,
            labels.view(N, 1, T, 1).expand(N, self.z_size, T, 1),
        ).squeeze(-1)

        # get rid of padding, mask to 0
        logp_xt_zt = (logp_wt_zt.masked_fill(
            labels.unsqueeze(1) == self.nll.padding_idx, 0).sum(-1))

        # do linear chain stuff
        # a little weird, we're working with a chain graphical model
        # need to normalize over each zt so the lm probs remain normalized
        dlg_idxs = data_feed.dlg_idxs
        t = 0
        ll_label = 0
        prev_zt = logp_zt[t]

        logp_xt = [(logp_xt_zt[t] + prev_zt).logsumexp(-1)]
        if self.training and self.noisy_proposal_labels and label_mask[0].any(
        ):
            if not self.config.sup_proposal_labels:
                # predict noisy proposal from hidden state
                ll_label += (logp_label_z[t] + prev_zt.unsqueeze(-1)
                             ).logsumexp(0)[label_mask[t]].logsumexp(0)
            else:
                ll_label += prev_zt[label_mask[t]].logsumexp(0)
        for t in range(1, N):
            if dlg_idxs[t] != dlg_idxs[t - 1]:
                # restart hmm
                prev_zt = logp_zt[t]
                logp_xt.append((logp_xt_zt[t] + prev_zt).logsumexp(-1))
            else:
                # continue
                # unsqueeze is unnecessary, broadcasting handles it
                prev_zt = (prev_zt.unsqueeze(-2) + logp_zr_zl[t]).logsumexp(-1)
                #prev_zt = logp_zt[t]
                logp_xt.append((logp_xt_zt[t] + prev_zt).logsumexp(-1))
            if self.training and self.noisy_proposal_labels and label_mask[
                    t].any():
                if not self.config.sup_proposal_labels:
                    # predict noisy proposal from hidden state
                    ll_label += (logp_label_z[t] + prev_zt.unsqueeze(-1)
                                 ).logsumexp(0)[label_mask[t]].logsumexp(0)
                else:
                    ll_label += prev_zt[label_mask[t]].logsumexp(0)
        logp_xt = th.stack(logp_xt)
        if self.nll.avg_type == "real_word":
            nll_word = -(logp_xt / (labels.sign().sum(-1).float())).mean()
        elif self.nll.avg_type == "word":
            nll_word = -(logp_xt.sum() / labels.sign().sum())
        else:
            raise ValueError("Unknown reduction type")

        if self.training and self.noisy_proposal_labels and label_mask.any():
            #nll -= 0.1 * ll_label / label_mask.sum().float()
            nll_label = -self.config.label_weight * ll_label / label_mask.any(
                -1).sum().float()
        else:
            nll_label = th.zeros(1).to(nll_word.device)

            #import pdb; pdb.set_trace()
        if get_marginals:
            return Pack(
                dec_outputs=dec_outputs,
                logp_xt=logp_xt,
                labels=labels,
            )
        #Z = prev_zt.logsumexp(0)
        if mode == GEN:
            return ret_dict, labels
        if return_latent:
            return Pack(nll=nll, latent_action=dec_init_state)
        else:
            return Pack(nll_label=nll_label, nll_word=nll_word)