예제 #1
0
파일: vrbot.py 프로젝트: lddsdu/VRBot
    def forward4infer(self, pv_state, pv_hidden, pv_r_u, pv_r_u_len):
        bi_pv_hidden = self.hidden_bidirectional(pv_hidden)
        pv_r_u_enc4post, hidden4post = self.encode_sequence(self.encoder4post,
                                                            pv_r_u,
                                                            pv_r_u_len,
                                                            bi_pv_hidden)

        pv_r_u_enc4state, hidden4state = self.encode_sequence(self.encoder4state,
                                                              pv_r_u,
                                                              pv_r_u_len,
                                                              bi_pv_hidden)

        pv_r_u_enc4action, hidden4action = self.encode_sequence(self.encoder4action,
                                                                pv_r_u,
                                                                pv_r_u_len,
                                                                bi_pv_hidden)

        prior_state_prob, _ = self.pst.forward(hidden4state, pv_state, pv_r_u, pv_r_u_enc4state)
        state_index = VRBot.sampling(prior_state_prob, strategy="prob")
        state_word_index = self.lg_interpreter.loc2glo(state_index)
        state_word_prob = one_hot_scatter(state_word_index, self.vocab_size, dtype=torch.float)
        state_embed = self.embedder(state_word_index)

        prior_gumbel_intention, prior_action_prob, _ = self.ppn.forward(hidden4action,
                                                                        state_index,
                                                                        None,
                                                                        pv_r_u_enc4action,
                                                                        pv_r_u_len)

        ranking_floor = 1
        if ranking_floor > 0:
            values, indices = prior_action_prob.topk(ranking_floor, -1)
            prior_action_prob = prior_action_prob.new_zeros(prior_action_prob.shape)
            values = values / values.sum(-1).unsqueeze(-1)
            prior_action_prob.scatter_(-1, indices, values)

        action_index = VRBot.sampling(prior_action_prob, strategy="max")
        action_word_index = self.lg_interpreter.loc2glo(action_index)
        action_word_prob = one_hot_scatter(action_word_index, self.vocab_size, dtype=torch.float)
        action_embed = self.embedder(action_word_index)

        gen_log_probs = self.beam_decoder.forward(hidden4post,
                                                  None,
                                                  pv_r_u_enc4post,
                                                  state_word_prob,
                                                  state_embed,
                                                  action_word_prob,
                                                  action_embed,
                                                  pv_r_u_len,
                                                  pv_r_u,
                                                  mask_state_prob=VO.mask_state_prob)

        if len(gen_log_probs.shape) == 3:
            bst_trajectory = torch.argmax(gen_log_probs, -1)
        else:
            bst_trajectory = gen_log_probs

        return [bst_trajectory, state_index, action_index, hidden4post]
예제 #2
0
    def stage_train(self, dataset: SessionDataset, epoch, state_train=False):
        all_stage = "[{}]".format("STATE" if state_train else "ACTION")
        model_name = self.model.__class__.__name__
        dataset_bs = dataset.batch_size
        pad_loss = torch.tensor(0.0, dtype=torch.float, device=TO.device)

        # INITIAL HIDDEN & STATE
        # 1, B, H
        _init_pv_hidden = torch.zeros((1, dataset_bs, self.model.hidden_dim),
                                      dtype=torch.float,
                                      device=TO.device)
        # B, S, Vi
        _init_pv_state = torch.zeros(
            (dataset_bs, self.model.state_num, self.model.inner_vocab_size),
            dtype=torch.float,
            device=TO.device)
        _init_pv_state[:, :, 0] = 1.0

        self.cache["super"]["pv_hidden"] = _init_pv_hidden
        self.cache["super"]["pv_state"] = _init_pv_state
        self.cache["unsuper"]["pv_hidden"] = _init_pv_hidden.clone()
        self.cache["unsuper"]["pv_state"] = _init_pv_state.clone()

        # TRAINING PROCESS
        train_bar = tqdm(dataset.load_data())  # input tensors iterator

        for input_tensors, inherited, materialistic in train_bar:
            self.global_step += 1
            loss_log_tmp = "[TASK-UUID {} EPOCH {:0>2} STEP {:0>6} TRAIN-STAGE {}] - {}"
            loss_log_template = loss_log_tmp.format(TO.task_uuid, epoch,
                                                    self.global_step,
                                                    all_stage, "{loss_info}")

            if len(input_tensors) == 5:
                pv_r_u, pv_r_u_len, r, r_len, gth_i = input_tensors
                gth_i = gth_i.float()
                gth_a, gth_s = None, None
                batch_supervised = False
            elif len(input_tensors) == 7:
                pv_r_u, pv_r_u_len, r, r_len, gth_s, gth_i, gth_a = input_tensors
                gth_i = gth_i.float()
                batch_supervised = True
            else:
                raise RuntimeError("error input tensor numbers {}".format(
                    len(input_tensors)))

            key1 = "super" if batch_supervised else "unsuper"
            pv_hidden = self.cache[key1]["pv_hidden"]
            pv_state = self.cache[key1]["pv_state"]
            pv_hidden, pv_state = self.hidden_state_mask(
                pv_hidden, pv_state, inherited, materialistic)

            if state_train:
                vrbot_train_stage.state_train_tick()
                train_state_rets = self.model.forward(
                    pv_state,
                    pv_hidden,
                    pv_r_u,
                    pv_r_u_len,
                    gth_i,
                    r,
                    r_len,
                    gth_action=gth_a,
                    gth_state=gth_s,
                    train_stage=TRAIN.TRAIN_STATE,
                    supervised=batch_supervised)

                gen_log_probs1 = train_state_rets[0]
                post_state_prob = train_state_rets[2]
                prior_state_prob = train_state_rets[3]
                state_gumbel_prob = train_state_rets[4]
                action_gumbel_prob = train_state_rets[7]
                prior_intention = train_state_rets[9]
                hidden4post = train_state_rets[10]

                # loss
                state_nll_loss, _ = self.model.nll_loss(
                    gen_log_probs1,
                    r.detach()[:, 1:])
                state_kl_loss = self.model.kl_loss(prior_state_prob,
                                                   post_state_prob.detach())

                if prior_intention is None or gth_i is None:
                    prior_intention_ce_loss = pad_loss
                else:
                    prior_intention_ce_loss = self.model.cross_entropy_loss(
                        prior_intention, gth_i)

                if batch_supervised:
                    aux_state_loss = (
                        self.model.state_nll(prior_state_prob, gth_s) +
                        self.model.state_nll(post_state_prob, gth_s))
                else:
                    aux_state_loss = pad_loss

                loss = state_nll_loss + state_kl_loss + prior_intention_ce_loss + aux_state_loss
                loss_template = "{} loss: {:.4f} r_nll: {:.4f} s_kl: {:.4f} p_i_ce: {:.4f} aux_s_nll: {:.4f}"
                loss_info = loss_template.format(
                    "SUPER" if batch_supervised else "UNSUPER", loss.item(),
                    state_nll_loss.item(), state_kl_loss.item(),
                    prior_intention_ce_loss.item(), aux_state_loss.item())

            else:
                vrbot_train_stage.state_train_tick()
                vrbot_train_stage.action_train_tick()
                train_policy_rets = self.model.forward(
                    pv_state,
                    pv_hidden,
                    pv_r_u,
                    pv_r_u_len,
                    gth_i,
                    r,
                    r_len,
                    gth_action=gth_a,
                    gth_state=gth_s,
                    train_stage=TRAIN.TRAIN_POLICY,
                    supervised=batch_supervised)

                gen_log_probs1 = train_policy_rets[0]
                gen_log_probs2 = train_policy_rets[1]
                post_state_prob = train_policy_rets[2]
                prior_state_prob = train_policy_rets[3]
                state_gumbel_prob = train_policy_rets[4]
                post_action_prob = train_policy_rets[5]
                prior_action_prob = train_policy_rets[6]
                action_gumbel_prob = train_policy_rets[7]
                post_intention = train_policy_rets[8]
                prior_intention = train_policy_rets[9]
                hidden4post = train_policy_rets[10]

                state_nll_loss, _ = self.model.nll_loss(
                    gen_log_probs1,
                    r.detach()[:, 1:])
                action_nll_loss, _ = self.model.nll_loss(
                    gen_log_probs2,
                    r.detach()[:, 1:])

                if batch_supervised:
                    state_kl_loss = pad_loss
                    action_kl_loss = pad_loss
                else:
                    state_kl_loss = self.model.kl_loss(
                        prior_state_prob, post_state_prob.detach())
                    action_kl_loss = self.model.kl_loss(
                        prior_action_prob, post_action_prob.detach())

                if prior_intention is None:
                    intention_loss = pad_loss
                else:
                    prior_intention_ce_loss = self.model.cross_entropy_loss(
                        prior_intention, gth_i)
                    post_intention_ce_loss = self.model.cross_entropy_loss(
                        post_intention, gth_i)
                    if batch_supervised:
                        intention_kl_loss = pad_loss
                    else:
                        intention_kl_loss = self.model.kl_loss(
                            prior_intention, post_intention.detach())
                    intention_loss = prior_intention_ce_loss + post_intention_ce_loss + intention_kl_loss

                if batch_supervised:
                    aux_state_loss = (
                        self.model.state_nll(prior_state_prob, gth_s) +
                        self.model.state_nll(post_state_prob, gth_s))
                    aux_action_loss = (
                        self.model.state_nll(prior_action_prob, gth_a) +
                        self.model.state_nll(post_action_prob, gth_a))
                else:
                    aux_state_loss, aux_action_loss = pad_loss, pad_loss

                a_lambda, i_lambda = 0.2, 1.0
                loss = (state_nll_loss + a_lambda * action_nll_loss +
                        state_kl_loss + a_lambda * action_kl_loss +
                        i_lambda * intention_loss + aux_state_loss +
                        aux_action_loss)

                log_tmp = "{} s-nll: {:.4f} a-nll: {:.4f} s-kl: {:.4f} " \
                          "a-kl: {:.4f} i-l: {:.4f} sa-l: {:.4f} aa-l: {:.4f}"
                loss_info = log_tmp.format(
                    "SUPER" if batch_supervised else "UNSUPER",
                    state_nll_loss.item(), action_nll_loss.item(),
                    state_kl_loss.item(), action_kl_loss.item(),
                    intention_loss.item(), aux_state_loss.item(),
                    aux_action_loss.item())

            loss_log_info = loss_log_template.format(loss_info=loss_info)
            train_bar.set_description(loss_log_info)
            if self.global_step % TO.log_loss_interval == 0:
                engine_logger.info(loss_log_info)

            self.optimizer.zero_grad()
            loss = loss / float(TO.gradient_stack)
            loss.backward(retain_graph=False)

            if self.global_step % TO.gradient_stack == 0:
                self.optimizer.step()

            if self.global_step % TO.decay_interval == 0:
                engine_logger.info("learning rate decay")
                self.adjust_learning_rate(self.optimizer, TO.decay_rate,
                                          TO.mini_lr)

            if self.global_step % TO.valid_eval_interval == 0:
                self.test_with_log(self.valid_dataset,
                                   epoch,
                                   model_name,
                                   mode="valid")

            if self.global_step % TO.test_eval_interval == 0:
                self.test_with_log(self.test_dataset,
                                   epoch,
                                   model_name,
                                   mode="test")

            # copy weight decay
            self.tick(state_train=state_train, action_train=not state_train)

            pv_hidden = hidden4post
            if batch_supervised:  # use the previous gth state
                pv_state = one_hot_scatter(gth_s,
                                           state_gumbel_prob.size(-1),
                                           dtype=torch.float)
            else:
                pv_state = state_gumbel_prob

            self.cache[key1]["pv_hidden"] = pv_hidden.detach()
            self.cache[key1]["pv_state"] = pv_state.detach()
예제 #3
0
    def test(self, dataset: SessionDataset, mode="test"):
        assert mode == "test" or mode == "valid"
        print("SESSION NUM: {}".format(len(dataset.sessions)))
        dataset_bs = dataset.batch_size
        pv_hidden = torch.zeros((1, dataset_bs, self.model.hidden_dim),
                                dtype=torch.float,
                                device=TO.device)
        pv_state = torch.zeros(
            (dataset_bs, self.model.state_num, self.model.inner_vocab_size),
            dtype=torch.float,
            device=TO.device)
        pv_state[:, :, 0] = 1.0

        # cache
        all_targets = []
        all_outputs = []

        engine_logger.info("{} INFERENCE START ...".format(mode.upper()))
        session_cropper = SessionCropper(dataset.batch_size)

        self.model.eval()
        with torch.no_grad():
            for input_tensors, inherited, materialistic in tqdm(
                    dataset.load_data()):
                if len(input_tensors) == 5:
                    pv_r_u, pv_r_u_len, r, r_len, gth_intention = input_tensors
                    gth_s, gth_a = None, None
                elif len(input_tensors) == 7:
                    pv_r_u, pv_r_u_len, r, r_len, gth_s, gth_intention, gth_a = input_tensors
                else:
                    raise RuntimeError

                pv_hidden, pv_state = self.hidden_state_mask(
                    pv_hidden, pv_state, inherited, materialistic)

                gen_log_probs, state_index, action_index, hidden4post = self.model.forward(
                    pv_state, pv_hidden, pv_r_u, pv_r_u_len, None)

                posts = self.iw_interpreter.interpret_tensor2nl(pv_r_u)
                targets = self.iw_interpreter.interpret_tensor2nl(r[:, 1:])
                outputs = self.iw_interpreter.interpret_tensor2nl(
                    gen_log_probs)
                states = self.ii_interpreter.interpret_tensor2nl(state_index)
                actions = self.ii_interpreter.interpret_tensor2nl(action_index)

                if gth_s is not None:
                    gth_states = self.ii_interpreter.interpret_tensor2nl(gth_s)
                else:
                    gth_states = ["<pad>"] * len(posts)

                inherited = inherited.detach().cpu().numpy().tolist()
                materialistic = materialistic.detach().cpu().numpy().tolist()
                session_cropper.step_on(posts, targets, outputs, states,
                                        actions, inherited, materialistic,
                                        gth_states)
                all_targets += targets
                all_outputs += outputs

                # for next loop
                pv_hidden = hidden4post
                pv_state = one_hot_scatter(state_index,
                                           self.model.inner_vocab_size,
                                           dtype=torch.float)

        self.model.train()
        engine_logger.info("{} INFERENCE FINISHED".format(mode.upper()))
        metrics = eval_bleu([all_targets], all_outputs)

        return all_targets, all_outputs, metrics, session_cropper
예제 #4
0
    def forward(self,
                hidden,
                state,
                gth_intention,
                pv_r_u_enc,
                pv_r_u_len=None,
                gth_action=None,
                supervised=False):
        state_word = self.lg_interpreter.loc2glo(state)
        state_embed = self.embedder(state_word)
        tmp = []
        s_hidden = None
        for i in range(state_embed.size(1)):
            _, s_hidden = self.rnn_enc_cell.forward(state_embed[:, i:i + 1, :],
                                                    s_hidden)  # [1, B, H]
            tmp.append(s_hidden.permute(1, 0, 2))  # [B, 1, E] * S
        state_embed = torch.cat(tmp, 1)  # B, S, E

        node_embedding = None
        head_nodes = None
        node_efficient = None
        head_flag_bit = None
        graph_context = None

        rets = self.graph_db.graph_construct(state.cpu().numpy().tolist())

        adjacent_matrix, head_nodes, node_efficient, head_flag_bit, edge_type_matrix = [
            torch.tensor(r).to(TO.device) for r in rets
        ]
        node_embedding = self.r_gat.forward(adjacent_matrix, head_nodes,
                                            head_flag_bit, edge_type_matrix)
        graph_context = self.graph_attn.forward(node_embedding, node_efficient,
                                                head_flag_bit,
                                                hidden.permute(1, 0, 2))

        intention = self.intention_detector.forward(
            state_embed,
            hidden,
            pv_r_u_enc,
            pv_r_u_len,
            r_enc=None,
            graph_context=graph_context)

        gumbel_intention = self.intention_gumbel_softmax.forward(
            intention, vrbot_train_stage.a_tau)

        if (gth_intention is None) or (not self.training):
            last_dim = gumbel_intention.size(-1)
            gth_intention = one_hot_scatter(intention.argmax(-1),
                                            last_dim,
                                            dtype=torch.float)

        hidden = self.hidden_type_linear.forward(
            torch.cat([hidden.squeeze(0), gth_intention], dim=-1))

        action, gumbel_action = self.basic_policy_network.forward(
            hidden,
            state_embed,
            state,
            pv_r_u_enc,
            pv_r_u_len,
            r=None,
            r_enc=None,
            gth_action=gth_action,
            supervised=supervised,
            node_embedding=node_embedding,  # B, N, E
            head_nodes=head_nodes,  # B, N
            node_efficient=node_efficient,  # B, N
            head_flag_bit=head_flag_bit)  # B, N

        return intention, action, gumbel_action
예제 #5
0
    def forward_gru(
            self,
            hidden,
            state_emb,
            pv_r_u_enc,
            pv_r_u_len,
            r=None,
            r_enc=None,
            r_mask=None,
            gth_action=None,
            mask_gen=False,
            supervised=False,
            node_embedding=None,  # B, N, E
            head_nodes=None,  # B, N
            node_efficient=None,  # B, N
            head_flag_bit=None):
        batch_size = hidden.size(0)

        # B, 1, E
        state_context, _ = self.embed_attn.forward(hidden.unsqueeze(1),
                                                   state_emb)
        hidden = hidden + self.embed2hidden_linear(state_context).squeeze(1)
        hidden = hidden.unsqueeze(0)  # 1, B, H

        # init input
        step_input = torch.zeros(batch_size,
                                 1,
                                 self.know_vocab_size,
                                 dtype=torch.float,
                                 device=TO.device)
        step_input[:, :, 0] = 1.0  # B, 1, K
        step_input_embed = self.know_prob_embed(step_input)  # B, 1, E

        actions = []
        gumbel_actions = []

        for i in range(self.action_num):
            # B, 1, E + H
            state_context, _ = self.embed_attn.forward(hidden.permute(1, 0, 2),
                                                       state_emb)
            pv_r_u_mask = reverse_sequence_mask(pv_r_u_len, pv_r_u_enc.size(1))
            post_context, _ = self.hidden_attn.forward(hidden.permute(1, 0, 2),
                                                       pv_r_u_enc,
                                                       mask=pv_r_u_mask)
            pv_s_input = torch.cat(
                [step_input_embed, state_context, post_context], dim=-1)

            next_action_hidden, hidden = self.gru.forward(pv_s_input, hidden)

            probs = self.action_pred(batch_size,
                                     next_action_hidden,
                                     r=r,
                                     r_enc=r_enc,
                                     r_mask=r_mask,
                                     mask_gen=mask_gen,
                                     node_embedding=node_embedding,
                                     head_nodes=head_nodes,
                                     node_efficient=node_efficient,
                                     head_flag_bit=head_flag_bit)
            actions.append(probs)

            if self.training and TO.auto_regressive and (
                    gth_action
                    is not None) and supervised and (not TO.no_action_super):
                gth_step_input = gth_action[:, i:i + 1]
                gth_step_input = one_hot_scatter(gth_step_input,
                                                 self.know_vocab_size,
                                                 dtype=torch.float)
                step_input_embed = self.know_prob_embed(gth_step_input)
            else:
                if TO.auto_regressive:
                    if self.training:
                        gumbel_probs = self.action_gumbel_softmax_sampling(
                            probs)
                    else:
                        max_indices = probs.argmax(-1)
                        gumbel_probs = one_hot_scatter(max_indices,
                                                       probs.size(2),
                                                       dtype=torch.float)

                    step_input_embed = self.know_prob_embed(gumbel_probs)
                    gumbel_actions.append(gumbel_probs)
                else:
                    step_input_embed = self.know_prob_embed(probs)

        actions = torch.cat(actions, dim=1)
        if len(gumbel_actions) == 0:
            return actions, None

        gumbel_actions = torch.cat(gumbel_actions, dim=1)
        return actions, gumbel_actions
예제 #6
0
    def forward(self,
                hidden,
                pv_state,
                pv_state_emb,
                pv_r_u,
                pv_r_u_enc,
                gth_state=None,
                supervised=False):
        batch_size = pv_state_emb.size(0)
        states = []
        gumbel_states = []

        multi_hiddens = None
        step_input_embed = None
        if self.gen_strategy == "gru":
            hidden = hidden.unsqueeze(0)  # 1, B, H
            step_input = torch.zeros(batch_size,
                                     1,
                                     self.know_vocab_size,
                                     dtype=torch.float,
                                     device=TO.device)
            step_input[:, :, 0] = 1.0  # B, 1, K
            step_input_embed = self.know_prob_embed(step_input)  # B, 1, E
        elif self.gen_strategy == "mlp":
            multi_hiddens = self.step_linear(hidden)  # B, S, H
        else:
            raise NotImplementedError

        for i in range(self.state_num):
            if self.gen_strategy == "gru":
                pv_s_context, _ = self.embed_attn.forward(
                    hidden.permute(1, 0, 2), pv_state_emb)
                pv_r_u_context, _ = self.hidden_attn.forward(
                    hidden.permute(1, 0, 2), pv_r_u_enc)
                pv_s_input = torch.cat(
                    [pv_s_context, pv_r_u_context, step_input_embed],
                    dim=-1)  # B, 1, E + H + E
                next_state_hidden, hidden = self.gru.forward(
                    pv_s_input, hidden)  # B, 1, H | 1, B, H
            elif self.gen_strategy == "mlp":
                next_state_hidden = multi_hiddens[:, i:i + 1, :]  # B, 1, H
            else:
                raise NotImplementedError

            next_state = self.hidden_projection(next_state_hidden)
            logits = [next_state]
            indexs = []

            if self.with_copy:
                pv_state_weight = self.embed_copy_attn.forward(
                    next_state_hidden,
                    pv_state_emb,
                    mask=None,
                    not_softmax=True,
                    return_weight_only=True)
                logits.append(pv_state_weight)

                pv_r_u_know_index = self.lg_interpreter.glo2loc(pv_r_u)
                pv_r_u_mask = (pv_r_u_know_index == 0)
                pv_r_u_weight = self.hidden_copy_attn.forward(
                    next_state_hidden,
                    pv_r_u_enc,
                    mask=pv_r_u_mask,
                    not_softmax=True,
                    return_weight_only=True)
                logits.append(pv_r_u_weight)
                indexs.append(pv_r_u_know_index)

                logits = torch.cat(logits, -1)
                indexs = torch.cat(indexs, -1).unsqueeze(1)

            probs = self.word_softmax(logits)

            if self.with_copy:
                gen_probs = probs[:, :, :self.know_vocab_size]

                pv_state_copy_probs = probs[:, :, self.know_vocab_size:self.
                                            know_vocab_size + DO.state_num]
                pv_state_copy_probs = torch.bmm(pv_state_copy_probs, pv_state)

                copy_probs = probs[:, :, self.know_vocab_size + DO.state_num:]
                copy_probs_placeholder = torch.zeros(batch_size,
                                                     1,
                                                     self.know_vocab_size,
                                                     device=TO.device)
                copy_probs = copy_probs_placeholder.scatter_add(
                    2, indexs, copy_probs)

                probs = gen_probs + pv_state_copy_probs + copy_probs

            states.append(probs)

            if self.training and TO.auto_regressive and (
                    gth_state is not None) and supervised:
                gth_step_input = gth_state[:, i:i + 1]  # B, 1
                gth_step_input = one_hot_scatter(gth_step_input,
                                                 self.know_vocab_size,
                                                 dtype=torch.float)
                step_input_embed = self.know_prob_embed(gth_step_input)
            else:
                if TO.auto_regressive:
                    if self.training:
                        gumbel_probs = self.state_gumbel_softmax_sampling(
                            probs)  # gumbel-softmax
                    else:
                        max_indices = probs.argmax(-1)  # B, S
                        gumbel_probs = one_hot_scatter(max_indices,
                                                       probs.size(2),
                                                       dtype=torch.float)

                    step_input_embed = self.know_prob_embed(gumbel_probs)
                    gumbel_states.append(gumbel_probs)
                else:
                    step_input_embed = self.know_prob_embed(probs)  # B, 1, E

        states = torch.cat(states, dim=1)
        if len(gumbel_states) == 0:
            return states, None
        gumbel_states = torch.cat(gumbel_states, dim=1)
        return states, gumbel_states