Example #1
0
class HRED(nn.Module):
    def __init__(self, config, tokenizer):
        super(HRED, self).__init__()

        # Attributes
        # Attributes from config
        self.word_embedding_dim = config.word_embedding_dim
        self.attr_embedding_dim = config.attr_embedding_dim
        self.sent_encoder_hidden_dim = config.sent_encoder_hidden_dim
        self.n_sent_encoder_layers = config.n_sent_encoder_layers
        self.dial_encoder_hidden_dim = config.dial_encoder_hidden_dim
        self.n_dial_encoder_layers = config.n_dial_encoder_layers
        self.decoder_hidden_dim = config.decoder_hidden_dim
        self.n_decoder_layers = config.n_decoder_layers
        self.use_attention = config.use_attention
        self.decode_max_len = config.decode_max_len
        self.tie_weights = config.tie_weights
        self.rnn_type = config.rnn_type
        self.gen_type = config.gen_type
        self.top_k = config.top_k
        self.top_p = config.top_p
        self.temp = config.temp
        self.word_embedding_path = config.word_embedding_path
        self.floor_encoder_type = config.floor_encoder
        # Optional attributes from config
        self.dropout_emb = config.dropout if hasattr(config,
                                                     "dropout") else 0.0
        self.dropout_input = config.dropout if hasattr(config,
                                                       "dropout") else 0.0
        self.dropout_hidden = config.dropout if hasattr(config,
                                                        "dropout") else 0.0
        self.dropout_output = config.dropout if hasattr(config,
                                                        "dropout") else 0.0
        self.use_pretrained_word_embedding = config.use_pretrained_word_embedding if hasattr(
            config, "use_pretrained_word_embedding") else False
        # Other attributes
        self.word2id = tokenizer.word2id
        self.id2word = tokenizer.id2word
        self.vocab_size = len(tokenizer.word2id)
        self.pad_token_id = tokenizer.pad_token_id
        self.bos_token_id = tokenizer.bos_token_id
        self.eos_token_id = tokenizer.eos_token_id

        # Input components
        self.word_embedding = nn.Embedding(
            self.vocab_size,
            self.word_embedding_dim,
            padding_idx=self.pad_token_id,
            _weight=init_word_embedding(
                load_pretrained_word_embedding=self.
                use_pretrained_word_embedding,
                pretrained_word_embedding_path=self.word_embedding_path,
                id2word=self.id2word,
                word_embedding_dim=self.word_embedding_dim,
                vocab_size=self.vocab_size,
                pad_token_id=self.pad_token_id),
        )

        # Encoding components
        self.sent_encoder = EncoderRNN(
            input_dim=self.word_embedding_dim,
            hidden_dim=self.sent_encoder_hidden_dim,
            n_layers=self.n_sent_encoder_layers,
            dropout_emb=self.dropout_emb,
            dropout_input=self.dropout_input,
            dropout_hidden=self.dropout_hidden,
            dropout_output=self.dropout_output,
            bidirectional=True,
            embedding=self.word_embedding,
            rnn_type=self.rnn_type,
        )
        self.dial_encoder = EncoderRNN(
            input_dim=self.sent_encoder_hidden_dim,
            hidden_dim=self.dial_encoder_hidden_dim,
            n_layers=self.n_dial_encoder_layers,
            dropout_emb=self.dropout_emb,
            dropout_input=self.dropout_input,
            dropout_hidden=self.dropout_hidden,
            dropout_output=self.dropout_output,
            bidirectional=False,
            rnn_type=self.rnn_type,
        )

        # Decoding components
        self.enc2dec_hidden_fc = nn.Linear(
            self.dial_encoder_hidden_dim,
            self.n_decoder_layers * self.decoder_hidden_dim if self.rnn_type
            == "gru" else self.n_decoder_layers * self.decoder_hidden_dim * 2)
        self.decoder = DecoderRNN(vocab_size=len(self.word2id),
                                  input_dim=self.word_embedding_dim,
                                  hidden_dim=self.decoder_hidden_dim,
                                  feat_dim=self.dial_encoder_hidden_dim,
                                  n_layers=self.n_decoder_layers,
                                  bos_token_id=self.bos_token_id,
                                  eos_token_id=self.eos_token_id,
                                  pad_token_id=self.pad_token_id,
                                  max_len=self.decode_max_len,
                                  dropout_emb=self.dropout_emb,
                                  dropout_input=self.dropout_input,
                                  dropout_hidden=self.dropout_hidden,
                                  dropout_output=self.dropout_output,
                                  embedding=self.word_embedding,
                                  tie_weights=self.tie_weights,
                                  rnn_type=self.rnn_type,
                                  use_attention=self.use_attention,
                                  attn_dim=self.sent_encoder_hidden_dim)

        # Extra components
        # floor encoding
        if self.floor_encoder_type == "abs":
            self.floor_encoder = AbsFloorEmbEncoder(
                input_dim=self.sent_encoder_hidden_dim,
                embedding_dim=self.attr_embedding_dim)
        elif self.floor_encoder_type == "rel":
            self.floor_encoder = RelFloorEmbEncoder(
                input_dim=self.sent_encoder_hidden_dim,
                embedding_dim=self.attr_embedding_dim)
        else:
            self.floor_encoder = None

        # Initialization
        self._init_weights()

    def _init_weights(self):
        init_module_weights(self.enc2dec_hidden_fc)

    def _init_dec_hiddens(self, context):
        batch_size = context.size(0)

        hiddens = self.enc2dec_hidden_fc(context)
        if self.rnn_type == "gru":
            hiddens = hiddens.view(
                batch_size,
                self.n_decoder_layers, self.decoder_hidden_dim).transpose(
                    0, 1).contiguous()  # (n_layers, batch_size, hidden_dim)
        elif self.rnn_type == "lstm":
            hiddens = hiddens.view(batch_size, self.n_decoder_layers,
                                   self.decoder_hidden_dim, 2)
            h = hiddens[:, :, :, 0].transpose(
                0, 1).contiguous()  # (n_layers, batch_size, hidden_dim)
            c = hiddens[:, :, :, 1].transpose(
                0, 1).contiguous()  # (n_layers, batch_size, hidden_dim)
            hiddens = (h, c)

        return hiddens

    def _encode(self, inputs, input_floors, output_floors):
        batch_size, history_len, max_x_sent_len = inputs.size()

        flat_inputs = inputs.view(batch_size * history_len, max_x_sent_len)
        input_lens = (inputs != self.pad_token_id).sum(-1)
        flat_input_lens = input_lens.view(batch_size * history_len)
        word_encodings, _, sent_encodings = self.sent_encoder(
            flat_inputs, flat_input_lens)
        word_encodings = word_encodings.view(batch_size, history_len,
                                             max_x_sent_len, -1)
        sent_encodings = sent_encodings.view(batch_size, history_len, -1)

        if self.floor_encoder is not None:
            src_floors = input_floors.view(-1)
            tgt_floors = output_floors.unsqueeze(1).repeat(
                1, history_len).view(-1)
            sent_encodings = sent_encodings.view(batch_size * history_len, -1)
            sent_encodings = self.floor_encoder(sent_encodings,
                                                src_floors=src_floors,
                                                tgt_floors=tgt_floors)
            sent_encodings = sent_encodings.view(batch_size, history_len, -1)

        dial_lens = (input_lens > 0).long().sum(
            1)  # equals number of non-padding sents
        _, _, dial_encodings = self.dial_encoder(
            sent_encodings, dial_lens)  # [batch_size, dial_encoder_dim]

        return word_encodings, sent_encodings, dial_encodings

    def _decode(self, inputs, context, attn_ctx=None, attn_mask=None):
        batch_size = context.size(0)
        hiddens = self._init_dec_hiddens(context)
        feats = None
        feats = context.unsqueeze(1).repeat(1, inputs.size(1), 1)
        ret_dict = self.decoder.forward(batch_size=batch_size,
                                        inputs=inputs,
                                        hiddens=hiddens,
                                        feats=feats,
                                        attn_ctx=attn_ctx,
                                        attn_mask=attn_mask,
                                        mode=DecoderRNN.MODE_TEACHER_FORCE)

        return ret_dict

    def _sample(self, context, attn_ctx=None, attn_mask=None, mmi_args=None):
        batch_size = context.size(0)
        hiddens = self._init_dec_hiddens(context)
        feats = None
        feats = context.unsqueeze(1).repeat(1, self.decode_max_len, 1)
        ret_dict = self.decoder.forward(
            batch_size=batch_size,
            hiddens=hiddens,
            feats=feats,
            attn_ctx=attn_ctx,
            attn_mask=attn_mask,
            mode=DecoderRNN.MODE_FREE_RUN,
            gen_type=self.gen_type,
            temp=self.temp,
            top_p=self.top_p,
            top_k=self.top_k,
            mmi_args=mmi_args,
        )

        return ret_dict

    def _get_attn_mask(self, attn_keys):
        attn_mask = (attn_keys != self.pad_token_id)
        return attn_mask

    def load_model(self, model_path):
        """Load pretrained model weights from model_path

        Arguments:
            model_path {str} -- path to pretrained model weights
        """
        pretrained_state_dict = torch.load(
            model_path, map_location=lambda storage, loc: storage)
        self.load_state_dict(pretrained_state_dict)

    def train_step(self, data):
        """One training step

        Arguments:
            data {dict of data} -- required keys and values:
                'X' {LongTensor [batch_size, history_len, max_x_sent_len]} -- token ids of context sentences
                'X_floor' {LongTensor [batch_size, history_len]} -- floors of context sentences
                'Y' {LongTensor [batch_size, max_y_sent_len]} -- token ids of response sentence
                'Y_floor' {LongTensor [batch_size]} -- floor of response sentence

        Returns:
            dict of data -- returned keys and values
                'loss' {FloatTensor []} -- loss to backword
            dict of statistics -- returned keys and values
                'ppl' {float} -- perplexity
                'loss' {float} -- batch loss
        """
        X, Y = data["X"], data["Y"]
        X_floor, Y_floor = data["X_floor"], data["Y_floor"]
        Y_in = Y[:, :-1].contiguous()
        Y_out = Y[:, 1:].contiguous()

        batch_size = X.size(0)
        max_y_len = Y_out.size(1)

        # Forward
        word_encodings, sent_encodings, dial_encodings = self._encode(
            inputs=X, input_floors=X_floor, output_floors=Y_floor)
        attn_ctx = word_encodings.view(batch_size, -1, word_encodings.size(-1))
        attn_mask = self._get_attn_mask(X.view(batch_size, -1))
        decoder_ret_dict = self._decode(inputs=Y_in,
                                        context=dial_encodings,
                                        attn_ctx=attn_ctx,
                                        attn_mask=attn_mask)

        # Calculate loss
        loss = 0
        logits = decoder_ret_dict["logits"]
        word_losses = F.cross_entropy(logits.view(-1, self.vocab_size),
                                      Y_out.view(-1),
                                      ignore_index=self.pad_token_id,
                                      reduction="none").view(
                                          batch_size, max_y_len)
        sent_loss = word_losses.sum(1).mean(0)
        loss += sent_loss
        with torch.no_grad():
            ppl = F.cross_entropy(logits.view(-1, self.vocab_size),
                                  Y_out.view(-1),
                                  ignore_index=self.pad_token_id,
                                  reduction="mean").exp()

        # return dicts
        ret_data = {"loss": loss}
        ret_stat = {"ppl": ppl.item(), "loss": loss.item()}

        return ret_data, ret_stat

    def evaluate_step(self, data):
        """One evaluation step

        Arguments:
            data {dict of data} -- required keys and values:
                'X' {LongTensor [batch_size, history_len, max_x_sent_len]} -- token ids of context sentences
                'X_floor' {LongTensor [batch_size, history_len]} -- floors of context sentences
                'Y' {LongTensor [batch_size, max_y_sent_len]} -- token ids of response sentence
                'Y_floor' {LongTensor [batch_size]} -- floor of response sentence

        Returns:
            dict of data -- returned keys and values

            dict of statistics -- returned keys and values
                'ppl' {float} -- perplexity
                'monitor' {float} -- a monitor number for learning rate scheduling
        """
        X, Y = data["X"], data["Y"]
        X_floor, Y_floor = data["X_floor"], data["Y_floor"]
        Y_in = Y[:, :-1].contiguous()
        Y_out = Y[:, 1:].contiguous()

        batch_size = X.size(0)

        with torch.no_grad():
            # Forward
            word_encodings, sent_encodings, dial_encodings = self._encode(
                inputs=X, input_floors=X_floor, output_floors=Y_floor)
            attn_ctx = word_encodings.view(batch_size, -1,
                                           word_encodings.size(-1))
            attn_mask = self._get_attn_mask(X.view(batch_size, -1))
            decoder_ret_dict = self._decode(inputs=Y_in,
                                            context=dial_encodings,
                                            attn_ctx=attn_ctx,
                                            attn_mask=attn_mask)

            # Loss
            word_loss = F.cross_entropy(decoder_ret_dict["logits"].view(
                -1, self.vocab_size),
                                        Y_out.view(-1),
                                        ignore_index=self.decoder.pad_token_id,
                                        reduction="mean")
            ppl = torch.exp(word_loss)

        # return dicts
        ret_data = {}
        ret_stat = {"ppl": ppl.item(), "monitor": ppl.item()}

        return ret_data, ret_stat

    def test_step(self, data, mmi_args=None):
        """One test step

        Arguments:
            data {dict of data} -- required keys and values:
                'X' {LongTensor [batch_size, history_len, max_x_sent_len]} -- token ids of context sentences
                'X_floor' {LongTensor [batch_size, history_len]} -- floors of context sentences
                'Y_floor' {LongTensor [batch_size]} -- floor of response sentence

        Returns:
            dict of data -- returned keys and values
                'symbols' {LongTensor [batch_size, max_decode_len]} -- token ids of response hypothesis
            dict of statistics -- returned keys and values

        """
        X, Y = data["X"], data["Y"]
        X_floor, Y_floor = data["X_floor"], data["Y_floor"]
        batch_size = X.size(0)

        with torch.no_grad():
            # Forward
            word_encodings, sent_encodings, dial_encodings = self._encode(
                inputs=X, input_floors=X_floor, output_floors=Y_floor)
            attn_ctx = word_encodings.view(batch_size, -1,
                                           word_encodings.size(-1))
            attn_mask = self._get_attn_mask(X.view(batch_size, -1))
            decoder_ret_dict = self._sample(context=dial_encodings,
                                            attn_ctx=attn_ctx,
                                            attn_mask=attn_mask,
                                            mmi_args=mmi_args)

        ret_data = {"symbols": decoder_ret_dict["symbols"]}
        ret_stat = {}

        return ret_data, ret_stat
Example #2
0
class RNNLM(nn.Module):
    def __init__(self, config, tokenizer):
        super(RNNLM, self).__init__()

        # Attributes
        # Attributes from config
        self.word_embedding_dim = config.word_embedding_dim
        self.decoder_hidden_dim = config.decoder_hidden_dim
        self.n_decoder_layers = config.n_decoder_layers
        self.decode_max_len = config.decode_max_len
        self.tie_weights = config.tie_weights
        self.rnn_type = config.rnn_type
        # Optional attributes from config
        self.gen_type = config.gen_type if hasattr(config, "gen_type") else "greedy"
        self.top_k = config.top_k if hasattr(config, "top_k") else 0
        self.top_p = config.top_p if hasattr(config, "top_p") else 0.0
        self.temp = config.temp if hasattr(config, "temp") else 1.0
        self.dropout_emb = config.dropout if hasattr(config, "dropout") else 0.0
        self.dropout_input = config.dropout if hasattr(config, "dropout") else 0.0
        self.dropout_hidden = config.dropout if hasattr(config, "dropout") else 0.0
        self.dropout_output = config.dropout if hasattr(config, "dropout") else 0.0
        self.use_pretrained_word_embedding = config.use_pretrained_word_embedding if hasattr(config, "use_pretrained_word_embedding") else False
        self.word_embedding_path = config.word_embedding_path if hasattr(config, "word_embedding_path") else None
        # Other attributes
        self.tokenizer = tokenizer
        self.word2id = tokenizer.word2id
        self.id2word = tokenizer.id2word
        self.vocab_size = len(tokenizer.word2id)
        self.pad_token_id = tokenizer.pad_token_id
        self.bos_token_id = tokenizer.bos_token_id
        self.eos_token_id = tokenizer.eos_token_id

        # Components
        self.word_embedding = nn.Embedding(
            self.vocab_size,
            self.word_embedding_dim,
            padding_idx=self.pad_token_id,
            _weight=init_word_embedding(
                load_pretrained_word_embedding=self.use_pretrained_word_embedding,
                pretrained_word_embedding_path=self.word_embedding_path,
                id2word=self.id2word,
                word_embedding_dim=self.word_embedding_dim,
                vocab_size=self.vocab_size,
                pad_token_id=self.pad_token_id
            ),
        )
        self.decoder = DecoderRNN(
            vocab_size=len(self.word2id),
            input_dim=self.word_embedding_dim,
            hidden_dim=self.decoder_hidden_dim,
            n_layers=self.n_decoder_layers,
            bos_token_id=self.bos_token_id,
            eos_token_id=self.eos_token_id,
            pad_token_id=self.pad_token_id,
            max_len=self.decode_max_len,
            dropout_emb=self.dropout_emb,
            dropout_input=self.dropout_input,
            dropout_hidden=self.dropout_hidden,
            dropout_output=self.dropout_output,
            embedding=self.word_embedding,
            tie_weights=self.tie_weights,
            rnn_type=self.rnn_type,
        )

        # Initialization
        self._init_weights()

    def _init_weights(self):
        pass

    def _random_hiddens(self, batch_size):
        if self.rnn_type == "gru":
            hiddens = torch.zeros(
                batch_size,
                self.n_decoder_layers,
                self.decoder_hidden_dim
            ).to(DEVICE).transpose(0, 1).contiguous()  # (n_layers, batch_size, hidden_dim)
            nn.init.uniform_(hiddens, -1.0, 1.0)
        elif self.rnn_type == "lstm":
            hiddens = hiddens.view(
                batch_size,
                self.n_decoder_layers,
                self.decoder_hidden_dim,
                2
            ).to(DEVICE)
            nn.init.uniform_(hiddens, -1.0, 1.0)
            h = hiddens[:, :, :, 0].transpose(0, 1).contiguous()  # (n_layers, batch_size, hidden_dim)
            c = hiddens[:, :, :, 1].transpose(0, 1).contiguous()  # (n_layers, batch_size, hidden_dim)
            hiddens = (h, c)

        return hiddens

    def _decode(self, inputs):
        batch_size = inputs.size(0)
        ret_dict = self.decoder.forward(
            batch_size=batch_size,
            inputs=inputs,
            mode=DecoderRNN.MODE_TEACHER_FORCE
        )

        return ret_dict

    def _sample(self, batch_size, hiddens=None):
        ret_dict = self.decoder.forward(
            batch_size=batch_size,
            hiddens=hiddens,
            mode=DecoderRNN.MODE_FREE_RUN,
            gen_type=self.gen_type,
            temp=self.temp,
            top_p=self.top_p,
            top_k=self.top_k,
        )

        return ret_dict

    def load_model(self, model_path):
        """Load pretrained model weights from model_path

        Arguments:
            model_path {str} -- path to pretrained model weights
        """
        pretrained_state_dict = torch.load(
            model_path,
            map_location=lambda storage, loc: storage
        )
        self.load_state_dict(pretrained_state_dict)

    def train_step(self, data):
        """One training step

        Arguments:
            data {dict of data} -- required keys and values:
                'X' {LongTensor [batch_size, max_len]} -- token ids of sentences

        Returns:
            dict of data -- returned keys and values
                'loss' {FloatTensor []} -- loss to backword
            dict of statistics -- returned keys and values
                'ppl' {float} -- perplexity
                'loss' {float} -- batch loss
        """
        X = data["X"]
        X_in = X[:, :-1].contiguous()
        X_out = X[:, 1:].contiguous()

        batch_size = X.size(0)

        # Forward
        decoder_ret_dict = self._decode(
            inputs=X_in
        )

        # Calculate loss
        loss = 0
        word_loss = F.cross_entropy(
            decoder_ret_dict["logits"].view(-1, self.vocab_size),
            X_out.view(-1),
            ignore_index=self.decoder.pad_token_id,
            reduction="mean"
        )
        ppl = torch.exp(word_loss)
        loss = word_loss

        # return dicts
        ret_data = {
            "loss": loss
        }
        ret_stat = {
            "ppl": ppl.item(),
            "loss": loss.item()
        }

        return ret_data, ret_stat

    def evaluate_step(self, data):
        """One evaluation step

        Arguments:
            data {dict of data} -- required keys and values:
                'X' {LongTensor [batch_size, max_len]} -- token ids of sentences

        Returns:
            dict of data -- returned keys and values

            dict of statistics -- returned keys and values
                'ppl' {float} -- perplexity
                'monitor' {float} -- a monitor number for learning rate scheduling
        """
        X = data["X"]
        X_in = X[:, :-1].contiguous()
        X_out = X[:, 1:].contiguous()

        batch_size = X.size(0)

        with torch.no_grad():
            decoder_ret_dict = self._decode(
                inputs=X_in
            )

            # Loss
            word_loss = F.cross_entropy(
                decoder_ret_dict["logits"].view(-1, self.vocab_size),
                X_out.view(-1),
                ignore_index=self.decoder.pad_token_id,
                reduction="mean"
            )
            ppl = torch.exp(word_loss)

        # return dicts
        ret_data = {}
        ret_stat = {
            "ppl": ppl.item(),
            "monitor": ppl.item()
        }

        return ret_data, ret_stat

    def sample_step(self, batch_size):
        """One test step

        Arguments:
            batch_size {int}

        Returns:
            dict of data -- returned keys and values
                'symbols' {LongTensor [batch_size, max_decode_len]} -- token ids of response hypothesis
            dict of statistics -- returned keys and values

        """
        with torch.no_grad():
            decoder_ret_dict = self._sample(batch_size)

        ret_data = {
            "symbols": decoder_ret_dict["symbols"]
        }
        ret_stat = {}

        return ret_data, ret_stat

    def compute_prob(self, sents):
        """Compute P(sents)

        Arguments:
            sents {List [str]} -- sentences in string form
        """
        batch_tokens = [self.tokenizer.convert_string_to_tokens(sent) for sent in sents]
        batch_token_ids = [self.tokenizer.convert_tokens_to_ids(tokens) for tokens in batch_tokens]
        batch_token_ids = self.tokenizer.convert_batch_ids_to_tensor(batch_token_ids).to(DEVICE)  # [batch_size, len]
        X_in = batch_token_ids[:, :-1]
        X_out = batch_token_ids[:, 1:]

        with torch.no_grad():
            word_ll = []
            sent_ll = []
            sent_probs = []

            batch_size = 50
            n_batches = math.ceil(X_in.size(0)/batch_size)

            for batch_idx in range(n_batches):
                begin = batch_idx * batch_size
                end = min(begin + batch_size, X_in.size(0))
                batch_X_in = X_in[begin:end]
                batch_X_out = X_out[begin:end]

                decoder_ret_dict = self._decode(
                    inputs=batch_X_in
                )
                logits = decoder_ret_dict["logits"]  # [batch_size, len-1, vocab_size]
                batch_word_ll = F.log_softmax(logits, dim=2)
                batch_gathered_word_ll = batch_word_ll.gather(2, batch_X_out.unsqueeze(2)).squeeze(2)  # [batch_size, len-1]
                batch_sent_ll = batch_gathered_word_ll.sum(1)  # [batch_size]
                batch_sent_probs = batch_sent_ll.exp()

                word_ll.append(batch_gathered_word_ll)
                sent_ll.append(batch_sent_ll)
                sent_probs.append(batch_sent_probs)

            word_ll = torch.cat(word_ll, dim=0)
            sent_ll = torch.cat(sent_ll, dim=0)
            sent_probs = torch.cat(sent_probs, dim=0)

        ret_data = {
            "word_loglikelihood": word_ll,
            "sent_loglikelihood": sent_ll,
            "sent_likelihood": sent_probs
        }

        return ret_data
Example #3
0
class VHCR(nn.Module):
    def __init__(self, config, tokenizer):
        super(VHCR, self).__init__()

        # Attributes
        # Attributes from config
        self.word_embedding_dim = config.word_embedding_dim
        self.attr_embedding_dim = config.attr_embedding_dim
        self.sent_encoder_hidden_dim = config.sent_encoder_hidden_dim
        self.n_sent_encoder_layers = config.n_sent_encoder_layers
        self.dial_encoder_hidden_dim = config.dial_encoder_hidden_dim
        self.n_dial_encoder_layers = config.n_dial_encoder_layers
        self.latent_variable_dim = config.latent_dim
        self.decoder_hidden_dim = config.decoder_hidden_dim
        self.n_decoder_layers = config.n_decoder_layers
        self.use_attention = config.use_attention
        self.decode_max_len = config.decode_max_len
        self.tie_weights = config.tie_weights
        self.rnn_type = config.rnn_type
        self.gen_type = config.gen_type
        self.top_k = config.top_k
        self.top_p = config.top_p
        self.temp = config.temp
        self.word_embedding_path = config.word_embedding_path
        self.floor_encoder_type = config.floor_encoder
        # Optional attributes from config
        self.dropout_emb = config.dropout if hasattr(config, "dropout") else 0.0
        self.dropout_input = config.dropout if hasattr(config, "dropout") else 0.0
        self.dropout_hidden = config.dropout if hasattr(config, "dropout") else 0.0
        self.dropout_output = config.dropout if hasattr(config, "dropout") else 0.0
        self.use_pretrained_word_embedding = config.use_pretrained_word_embedding if hasattr(config, "use_pretrained_word_embedding") else False
        self.n_step_annealing = config.n_step_annealing if hasattr(config, "n_step_annealing") else 0
        # Other attributes
        self.vocab_size = len(tokenizer.word2id)
        self.word2id = tokenizer.word2id
        self.id2word = tokenizer.id2word
        self.pad_token_id = tokenizer.pad_token_id
        self.bos_token_id = tokenizer.bos_token_id
        self.eos_token_id = tokenizer.eos_token_id
        self.dropout_sent = 0.25

        # Input components
        self.word_embedding = nn.Embedding(
            self.vocab_size,
            self.word_embedding_dim,
            padding_idx=self.pad_token_id,
            _weight=init_word_embedding(
                load_pretrained_word_embedding=self.use_pretrained_word_embedding,
                pretrained_word_embedding_path=self.word_embedding_path,
                id2word=self.id2word,
                word_embedding_dim=self.word_embedding_dim,
                vocab_size=self.vocab_size,
                pad_token_id=self.pad_token_id
            ),
        )

        # Encoding components
        self.sent_encoder = EncoderRNN(
            input_dim=self.word_embedding_dim,
            hidden_dim=self.sent_encoder_hidden_dim,
            n_layers=self.n_sent_encoder_layers,
            dropout_emb=self.dropout_emb,
            dropout_input=self.dropout_input,
            dropout_hidden=self.dropout_hidden,
            dropout_output=self.dropout_output,
            bidirectional=True,
            embedding=self.word_embedding,
            rnn_type=self.rnn_type,
        )
        self.dial_encoder = EncoderRNN(
            input_dim=self.sent_encoder_hidden_dim+self.latent_variable_dim,
            hidden_dim=self.dial_encoder_hidden_dim,
            n_layers=self.n_dial_encoder_layers,
            dropout_emb=self.dropout_emb,
            dropout_input=self.dropout_input,
            dropout_hidden=self.dropout_hidden,
            dropout_output=self.dropout_output,
            bidirectional=False,
            rnn_type=self.rnn_type,
        )
        self.dial_infer_encoder = EncoderRNN(
            input_dim=self.sent_encoder_hidden_dim,
            hidden_dim=self.dial_encoder_hidden_dim,
            n_layers=self.n_dial_encoder_layers,
            dropout_emb=self.dropout_emb,
            dropout_input=self.dropout_input,
            dropout_hidden=self.dropout_hidden,
            dropout_output=self.dropout_output,
            bidirectional=True,
            rnn_type=self.rnn_type,
        )

        # Variational components
        self.dial_post_net = GaussianVariation(
            input_dim=self.dial_encoder_hidden_dim,
            z_dim=self.latent_variable_dim
        )
        self.sent_prior_net = GaussianVariation(
            input_dim=self.dial_encoder_hidden_dim+self.latent_variable_dim,
            z_dim=self.latent_variable_dim
        )
        self.sent_post_net = GaussianVariation(
            input_dim=self.sent_encoder_hidden_dim+self.dial_encoder_hidden_dim+self.latent_variable_dim,
            z_dim=self.latent_variable_dim
        )
        self.unk_sent_vec = nn.Parameter(torch.randn(self.sent_encoder_hidden_dim)).to(DEVICE)

        # Decoding components
        self.ctx_fc = nn.Linear(
            2*self.latent_variable_dim+self.dial_encoder_hidden_dim,
            self.dial_encoder_hidden_dim
        )
        self.enc2dec_hidden_fc = nn.Linear(
            self.dial_encoder_hidden_dim,
            self.n_decoder_layers*self.decoder_hidden_dim if self.rnn_type == "gru" else self.n_decoder_layers*self.decoder_hidden_dim*2
        )
        self.decoder = DecoderRNN(
            vocab_size=len(self.word2id),
            input_dim=self.word_embedding_dim,
            hidden_dim=self.decoder_hidden_dim,
            feat_dim=self.dial_encoder_hidden_dim,
            n_layers=self.n_decoder_layers,
            bos_token_id=self.bos_token_id,
            eos_token_id=self.eos_token_id,
            pad_token_id=self.pad_token_id,
            max_len=self.decode_max_len,
            dropout_emb=self.dropout_emb,
            dropout_input=self.dropout_input,
            dropout_hidden=self.dropout_hidden,
            dropout_output=self.dropout_output,
            embedding=self.word_embedding,
            tie_weights=self.tie_weights,
            rnn_type=self.rnn_type,
            use_attention=self.use_attention,
            attn_dim=self.sent_encoder_hidden_dim
        )

        # Extra components
        # Floor encoding
        if self.floor_encoder_type == "abs":
            self.floor_encoder = AbsFloorEmbEncoder(
                input_dim=self.sent_encoder_hidden_dim,
                embedding_dim=self.attr_embedding_dim
            )
        elif self.floor_encoder_type == "rel":
            self.floor_encoder = RelFloorEmbEncoder(
                input_dim=self.sent_encoder_hidden_dim,
                embedding_dim=self.attr_embedding_dim
            )
        else:
            self.floor_encoder = None
        # Hidden initialization
        self.dial_z2dial_enc_hidden_fc = nn.Linear(
            self.latent_variable_dim,
            self.n_dial_encoder_layers*self.dial_encoder_hidden_dim if self.rnn_type == "gru" else self.n_dial_encoder_layers*self.dial_encoder_hidden_dim*2
        )

        # Initialization
        self._init_weights()

    def _init_weights(self):
        init_module_weights(self.enc2dec_hidden_fc)
        init_module_weights(self.ctx_fc)

    def _init_dec_hiddens(self, context):
        batch_size = context.size(0)

        hiddens = self.enc2dec_hidden_fc(context)
        if self.rnn_type == "gru":
            hiddens = hiddens.view(
                batch_size,
                self.n_decoder_layers,
                self.decoder_hidden_dim
            ).transpose(0, 1).contiguous()  # (n_layers, batch_size, hidden_dim)
        elif self.rnn_type == "lstm":
            hiddens = hiddens.view(
                batch_size,
                self.n_decoder_layers,
                self.decoder_hidden_dim,
                2
            )
            h = hiddens[:, :, :, 0].transpose(0, 1).contiguous()  # (n_layers, batch_size, hidden_dim)
            c = hiddens[:, :, :, 1].transpose(0, 1).contiguous()  # (n_layers, batch_size, hidden_dim)
            hiddens = (h, c)

        return hiddens

    def _get_ctx_sent_encodings(self, inputs, input_floors, output_floors):
        batch_size, history_len, max_x_sent_len = inputs.size()

        flat_inputs = inputs.view(batch_size*history_len, max_x_sent_len)
        input_lens = (inputs != self.pad_token_id).sum(-1)
        flat_input_lens = input_lens.view(batch_size*history_len)
        word_encodings, _, sent_encodings = self.sent_encoder(flat_inputs, flat_input_lens)
        word_encodings = word_encodings.view(batch_size, history_len, max_x_sent_len, -1)
        sent_encodings = sent_encodings.view(batch_size, history_len, -1)

        if self.floor_encoder is not None:
            src_floors = input_floors.view(-1)
            tgt_floors = output_floors.unsqueeze(1).repeat(1, history_len).view(-1)
            sent_encodings = sent_encodings.view(batch_size*history_len, -1)
            sent_encodings = self.floor_encoder(
                sent_encodings,
                src_floors=src_floors,
                tgt_floors=tgt_floors
            )
            sent_encodings = sent_encodings.view(batch_size, history_len, -1)

        if self.training and self.dropout_sent > 0.0:
            history_len = history_len
            indices = np.where(np.random.rand(history_len) < self.dropout_sent)[0]
            if len(indices) > 0:
                sent_encodings[:, indices, :] = self.unk_sent_vec

        return word_encodings, sent_encodings

    def _get_reply_sent_encodings(self, outputs):
        output_lens = (outputs != self.pad_token_id).sum(-1)
        word_encodings, _, sent_encodings = self.sent_encoder(outputs, output_lens)
        return sent_encodings

    def _get_dial_encodings(self, ctx_dial_lens, ctx_sent_encodings, z_dial):
        batch_size, history_len, _ = ctx_sent_encodings.size()

        # Init hidden states of dialog encoder from z_dial
        hiddens = self.dial_z2dial_enc_hidden_fc(z_dial)
        if self.rnn_type == "gru":
            hiddens = hiddens.view(
                batch_size,
                self.n_dial_encoder_layers,
                self.dial_encoder_hidden_dim
            ).transpose(0, 1).contiguous()  # (n_layers, batch_size, hidden_dim)
        elif self.rnn_type == "lstm":
            hiddens = hiddens.view(
                batch_size,
                self.n_dial_encoder_layers,
                self.dial_encoder_hidden_dim,
                2
            )
            h = hiddens[:, :, :, 0].transpose(0, 1).contiguous()  # (n_layers, batch_size, hidden_dim)
            c = hiddens[:, :, :, 1].transpose(0, 1).contiguous()  # (n_layers, batch_size, hidden_dim)
            hiddens = (h, c)

        # Inputs to dialog encoder
        z_dial = z_dial.unsqueeze(1).repeat(1, history_len, 1)
        dial_encoder_inputs = torch.cat([ctx_sent_encodings, z_dial], dim=2)

        dialog_lens = ctx_dial_lens
        _, _, dialog_encodings = self.dial_encoder(dial_encoder_inputs, dialog_lens, hiddens)  # [batch_size, dialog_encoder_dim]

        return dialog_encodings

    def _get_full_dial_encodings(self, ctx_dial_lens, ctx_sent_encodings, reply_sent_encodings):
        batch_size = ctx_sent_encodings.size(0)
        history_len = ctx_sent_encodings.size(1)

        full_sent_encodings = []
        for batch_idx in range(batch_size):
            encodings = []
            ctx_len = ctx_dial_lens[batch_idx].item()
            # part 1 - ctx sent encodings
            for encoding in ctx_sent_encodings[batch_idx][:ctx_len]:
                encodings.append(encoding)
            # part 2 - reply encoding
            encodings.append(reply_sent_encodings[batch_idx])
            # part 3 - padding encodings
            for encoding in ctx_sent_encodings[batch_idx][ctx_len:]:
                encodings.append(encoding)
            encodings = torch.stack(encodings, dim=0)
            full_sent_encodings.append(encodings)
        full_sent_encodings = torch.stack(full_sent_encodings, dim=0)

        full_dialog_lens = ctx_dial_lens+1  # equals number of non-padding sents
        _, _, full_dialog_encodings = self.dial_infer_encoder(full_sent_encodings, full_dialog_lens)  # [batch_size, dialog_encoder_dim]

        return full_dialog_encodings

    def _get_dial_post(self, full_dialog_encodings):
        z, mu, var = self.dial_post_net(full_dialog_encodings)

        return z, mu, var

    def _get_dial_prior(self, batch_size):
        mu = torch.FloatTensor([0.0]).to(DEVICE)
        var = torch.FloatTensor([1.0]).to(DEVICE)
        z = torch.randn([batch_size, self.latent_variable_dim]).to(DEVICE)
        return z, mu, var

    def _get_sent_post(self, reply_sent_encodings, dial_encodings, z_dial):
        sent_post_net_inputs = torch.cat([reply_sent_encodings, dial_encodings, z_dial], dim=1)
        z, mu, var = self.sent_post_net(sent_post_net_inputs)

        return z, mu, var

    def _get_sent_prior(self, dial_encodings, z_dial):
        sent_prior_net_inputs = torch.cat([dial_encodings, z_dial], dim=1)
        z, mu, var = self.sent_prior_net(sent_prior_net_inputs)

        return z, mu, var

    def _decode(self, inputs, context, attn_ctx=None, attn_mask=None):
        batch_size = context.size(0)
        hiddens = self._init_dec_hiddens(context)
        feats = None
        feats = context.unsqueeze(1).repeat(1, inputs.size(1), 1)
        ret_dict = self.decoder.forward(
            batch_size=batch_size,
            inputs=inputs,
            hiddens=hiddens,
            feats=feats,
            attn_ctx=attn_ctx,
            attn_mask=attn_mask,
            mode=DecoderRNN.MODE_TEACHER_FORCE
        )

        return ret_dict

    def _sample(self, context, attn_ctx=None, attn_mask=None):
        batch_size = context.size(0)
        hiddens = self._init_dec_hiddens(context)
        feats = None
        feats = context.unsqueeze(1).repeat(1, self.decode_max_len, 1)
        ret_dict = self.decoder.forward(
            batch_size=batch_size,
            hiddens=hiddens,
            feats=feats,
            attn_ctx=attn_ctx,
            attn_mask=attn_mask,
            mode=DecoderRNN.MODE_FREE_RUN,
            gen_type=self.gen_type,
            temp=self.temp,
            top_p=self.top_p,
            top_k=self.top_k,
        )

        return ret_dict

    def _get_attn_mask(self, attn_keys):
        attn_mask = (attn_keys != self.pad_token_id)
        return attn_mask

    def _annealing_coef_term(self, step):
        return min(1.0, 1.0*step/self.n_step_annealing)

    def load_model(self, model_path):
        """Load pretrained model weights from model_path

        Arguments:
            model_path {str} -- path to pretrained model weights
        """
        pretrained_state_dict = torch.load(
            model_path,
            map_location=lambda storage, loc: storage
        )
        self.load_state_dict(pretrained_state_dict)

    def train_step(self, data, step):
        """One training step

        Arguments:
            data {dict of data} -- required keys and values:
                'X' {LongTensor [batch_size, history_len, max_x_sent_len]} -- token ids of context sentences
                'X_floor' {LongTensor [batch_size, history_len]} -- floors of context sentences
                'Y' {LongTensor [batch_size, max_y_sent_len]} -- token ids of response sentence
                'Y_floor' {LongTensor [batch_size]} -- floor of response sentence

            step {int} -- the n-th optimization step

        Returns:
            dict of data -- returned keys and values
                'loss' {FloatTensor []} -- loss to backword
            dict of statistics -- returned keys and values
                'ppl' {float} -- perplexity
                'sent_kld' {float} -- sentence KLD
                'dial_kld' {float} -- dialog KLD
                'kld_term' {float} -- KLD annealing coefficient
                'loss' {float} -- batch loss
        """
        X, Y = data["X"], data["Y"]
        X_floor, Y_floor = data["X_floor"], data["Y_floor"]
        Y_in = Y[:, :-1].contiguous()
        Y_out = Y[:, 1:].contiguous()

        batch_size = X.size(0)
        max_y_len = Y_out.size(1)
        ctx_dial_lens = ((X != self.pad_token_id).sum(-1) > 0).sum(-1)

        # Forward
        # Encode sentences
        word_encodings, ctx_sent_encodings = self._get_ctx_sent_encodings(
            inputs=X,
            input_floors=X_floor,
            output_floors=Y_floor
        )
        reply_sent_encodings = self._get_reply_sent_encodings(
            outputs=Y,
        )
        # Encode full dialog for posterior dialog z
        full_dial_encodings = self._get_full_dial_encodings(
            ctx_dial_lens=ctx_dial_lens,
            ctx_sent_encodings=ctx_sent_encodings,
            reply_sent_encodings=reply_sent_encodings
        )
        # Get dial z
        z_dial_post, mu_dial_post, var_dial_post = self._get_dial_post(full_dial_encodings)
        # Encode dialog
        dial_encodings = self._get_dial_encodings(
            ctx_dial_lens=ctx_dial_lens,
            ctx_sent_encodings=ctx_sent_encodings,
            z_dial=z_dial_post
        )
        # Get sent z
        z_sent_post, mu_sent_post, var_sent_post = self._get_sent_post(
            reply_sent_encodings=reply_sent_encodings,
            dial_encodings=dial_encodings,
            z_dial=z_dial_post
        )
        z_sent_prior, mu_sent_prior, var_sent_prior = self._get_sent_prior(
            dial_encodings=dial_encodings,
            z_dial=z_dial_post
        )
        # Decode
        ctx_encodings = self.ctx_fc(torch.cat([dial_encodings, z_sent_post, z_dial_post], dim=1))
        attn_ctx = word_encodings.view(batch_size, -1, word_encodings.size(-1))
        attn_mask = self._get_attn_mask(X.view(batch_size, -1))
        decoder_ret_dict = self._decode(
            inputs=Y_in,
            context=ctx_encodings,
            attn_ctx=attn_ctx,
            attn_mask=attn_mask
        )

        # Loss
        loss = 0
        # Reconstruction
        logits = decoder_ret_dict["logits"]
        word_losses = F.cross_entropy(
            logits.view(-1, self.vocab_size),
            Y_out.view(-1),
            ignore_index=self.pad_token_id,
            reduction="none"
        ).view(batch_size, max_y_len)
        sent_loss = word_losses.sum(1).mean(0)
        loss += sent_loss
        with torch.no_grad():
            ppl = F.cross_entropy(
                logits.view(-1, self.vocab_size),
                Y_out.view(-1),
                ignore_index=self.pad_token_id,
                reduction="mean"
            ).exp()
        # KLD
        kld_coef = self._annealing_coef_term(step)
        dial_kld_losses = gaussian_kld(mu_dial_post, var_dial_post)
        avg_dial_kld = dial_kld_losses.mean()
        sent_kld_losses = gaussian_kld(mu_sent_post, var_sent_post, mu_sent_prior, var_sent_prior)
        avg_sent_kld = sent_kld_losses.mean()
        loss += (avg_dial_kld+avg_sent_kld)*kld_coef

        # return dicts
        ret_data = {
            "loss": loss
        }
        ret_stat = {
            "ppl": ppl.item(),
            "kld_term": kld_coef,
            "dial_kld": avg_dial_kld.item(),
            "sent_kld": avg_sent_kld.item(),
            "loss": loss.item()
        }

        return ret_data, ret_stat

    def evaluate_step(self, data):
        """One training step

        Arguments:
            data {dict of data} -- required keys and values:
                'X' {LongTensor [batch_size, history_len, max_x_sent_len]} -- token ids of context sentences
                'X_floor' {LongTensor [batch_size, history_len]} -- floors of context sentences
                'Y' {LongTensor [batch_size, max_y_len]} -- token ids of response sentence
                'Y_floor' {LongTensor [batch_size]} -- floor of response sentence

        Returns:
            dict of data -- returned keys and values

            dict of statistics -- returned keys and values
                'ppl' {float} -- perplexity
                'sent_kld' {float} -- sentence KLD
                'dial_kld' {float} -- dialog KLD
                'monitor' {float} -- a monitor number for learning rate scheduling
        """
        X, Y = data["X"], data["Y"]
        X_floor, Y_floor = data["X_floor"], data["Y_floor"]
        Y_in = Y[:, :-1].contiguous()
        Y_out = Y[:, 1:].contiguous()

        batch_size = X.size(0)
        max_y_len = Y_out.size(1)
        ctx_dial_lens = ((X != self.pad_token_id).sum(-1) > 0).sum(-1)

        with torch.no_grad():
            # Forward
            # Encode sentences
            word_encodings, ctx_sent_encodings = self._get_ctx_sent_encodings(
                inputs=X,
                input_floors=X_floor,
                output_floors=Y_floor
            )
            reply_sent_encodings = self._get_reply_sent_encodings(
                outputs=Y,
            )
            # Encode full dialog for posterior dialog z
            full_dial_encodings = self._get_full_dial_encodings(
                ctx_dial_lens=ctx_dial_lens,
                ctx_sent_encodings=ctx_sent_encodings,
                reply_sent_encodings=reply_sent_encodings
            )
            # Get dial z
            z_dial_post, mu_dial_post, var_dial_post = self._get_dial_post(full_dial_encodings)
            # Encode dialog
            dial_encodings = self._get_dial_encodings(
                ctx_dial_lens=ctx_dial_lens,
                ctx_sent_encodings=ctx_sent_encodings,
                z_dial=z_dial_post
            )
            # Get sent z
            z_sent_post, mu_sent_post, var_sent_post = self._get_sent_post(
                reply_sent_encodings=reply_sent_encodings,
                dial_encodings=dial_encodings,
                z_dial=z_dial_post
            )
            z_sent_prior, mu_sent_prior, var_sent_prior = self._get_sent_prior(
                dial_encodings=dial_encodings,
                z_dial=z_dial_post
            )
            # Decode
            ctx_encodings = self.ctx_fc(torch.cat([dial_encodings, z_sent_post, z_dial_post], dim=1))
            attn_ctx = word_encodings.view(batch_size, -1, word_encodings.size(-1))
            attn_mask = self._get_attn_mask(X.view(batch_size, -1))
            decoder_ret_dict = self._decode(
                inputs=Y_in,
                context=ctx_encodings,
                attn_ctx=attn_ctx,
                attn_mask=attn_mask
            )

            # Loss
            # Reconstruction
            logits = decoder_ret_dict["logits"]
            word_losses = F.cross_entropy(
                logits.view(-1, self.vocab_size),
                Y_out.view(-1),
                ignore_index=self.decoder.pad_token_id,
                reduction="none"
            ).view(batch_size, max_y_len)
            sent_loss = word_losses.sum(1).mean(0)
            ppl = F.cross_entropy(
                logits.view(-1, self.vocab_size),
                Y_out.view(-1),
                ignore_index=self.decoder.pad_token_id,
                reduction="mean"
            ).exp()
            # KLD
            dial_kld_losses = gaussian_kld(mu_dial_post, var_dial_post)
            avg_dial_kld = dial_kld_losses.mean()
            sent_kld_losses = gaussian_kld(mu_sent_post, var_sent_post, mu_sent_prior, var_sent_prior)
            avg_sent_kld = sent_kld_losses.mean()
            # monitor loss
            monitor_loss = sent_loss + avg_dial_kld + avg_sent_kld

        # return dicts
        ret_data = {}
        ret_stat = {
            "ppl": ppl.item(),
            "dial_kld": avg_dial_kld.item(),
            "sent_kld": avg_sent_kld.item(),
            "monitor": monitor_loss.item()
        }

        return ret_data, ret_stat

    def test_step(self, data):
        """One test step

        Arguments:
            data {dict of data} -- required keys and values:
                'X' {LongTensor [batch_size, history_len, max_x_sent_len]} -- token ids of context sentences
                'X_floor' {LongTensor [batch_size, history_len]} -- floors of context sentences
                'Y_floor' {LongTensor [batch_size]} -- floor of response sentence

        Returns:
            dict of data -- returned keys and values
                'symbols' {LongTensor [batch_size, max_decode_len]} -- token ids of response hypothesis
            dict of statistics -- returned keys and values

        """
        X = data["X"]
        X_floor, Y_floor = data["X_floor"], data["Y_floor"]
        batch_size = X.size(0)
        ctx_dial_lens = ((X != self.pad_token_id).sum(-1) > 0).sum(-1)

        with torch.no_grad():
            # Forward
            # Encode sentences
            word_encodings, ctx_sent_encodings = self._get_ctx_sent_encodings(
                inputs=X,
                input_floors=X_floor,
                output_floors=Y_floor
            )
            # Get dial z
            z_dial_prior, mu_dial_prior, var_dial_prior = self._get_dial_prior(batch_size)
            # Encode dialog
            dial_encodings = self._get_dial_encodings(
                ctx_dial_lens=ctx_dial_lens,
                ctx_sent_encodings=ctx_sent_encodings,
                z_dial=z_dial_prior
            )
            # Get sent z
            z_sent_prior, mu_sent_prior, var_sent_prior = self._get_sent_prior(
                dial_encodings=dial_encodings,
                z_dial=z_dial_prior
            )
            # Decode
            ctx_encodings = self.ctx_fc(torch.cat([dial_encodings, z_sent_prior, z_dial_prior], dim=1))
            attn_ctx = word_encodings.view(batch_size, -1, word_encodings.size(-1))
            attn_mask = self._get_attn_mask(X.view(batch_size, -1))
            decoder_ret_dict = self._sample(
                context=ctx_encodings,
                attn_ctx=attn_ctx,
                attn_mask=attn_mask
            )

        ret_data = {
            "symbols": decoder_ret_dict["symbols"]
        }
        ret_stat = {}

        return ret_data, ret_stat
Example #4
0
class VHRED(nn.Module):
    def __init__(self, config, tokenizer):
        super(VHRED, self).__init__()

        # Attributes
        # Attributes from config
        self.word_embedding_dim = config.word_embedding_dim
        self.attr_embedding_dim = config.attr_embedding_dim
        self.sent_encoder_hidden_dim = config.sent_encoder_hidden_dim
        self.n_sent_encoder_layers = config.n_sent_encoder_layers
        self.dial_encoder_hidden_dim = config.dial_encoder_hidden_dim
        self.n_dial_encoder_layers = config.n_dial_encoder_layers
        self.latent_variable_dim = config.latent_dim
        self.decoder_hidden_dim = config.decoder_hidden_dim
        self.n_decoder_layers = config.n_decoder_layers
        self.use_attention = config.use_attention
        self.decode_max_len = config.decode_max_len
        self.tie_weights = config.tie_weights
        self.rnn_type = config.rnn_type
        self.gen_type = config.gen_type
        self.top_k = config.top_k
        self.top_p = config.top_p
        self.temp = config.temp
        self.word_embedding_path = config.word_embedding_path
        self.floor_encoder_type = config.floor_encoder
        self.gaussian_mix_type = config.gaussian_mix_type
        # Optional attributes from config
        self.use_bow_loss = config.use_bow_loss if hasattr(
            config, "use_bow_loss") else True
        self.dropout = config.dropout if hasattr(config, "dropout") else 0.0
        self.dropout_emb = config.dropout if hasattr(config,
                                                     "dropout") else 0.0
        self.dropout_input = config.dropout if hasattr(config,
                                                       "dropout") else 0.0
        self.dropout_hidden = config.dropout if hasattr(config,
                                                        "dropout") else 0.0
        self.dropout_output = config.dropout if hasattr(config,
                                                        "dropout") else 0.0
        self.use_pretrained_word_embedding = config.use_pretrained_word_embedding if hasattr(
            config, "use_pretrained_word_embedding") else False
        self.n_step_annealing = config.n_step_annealing if hasattr(
            config, "n_step_annealing") else 1
        self.n_components = config.n_components if hasattr(
            config, "n_components") else 1
        # Other attributes
        self.vocab_size = len(tokenizer.word2id)
        self.word2id = tokenizer.word2id
        self.id2word = tokenizer.id2word
        self.pad_token_id = tokenizer.pad_token_id
        self.bos_token_id = tokenizer.bos_token_id
        self.eos_token_id = tokenizer.eos_token_id

        # Input components
        self.word_embedding = nn.Embedding(
            self.vocab_size,
            self.word_embedding_dim,
            padding_idx=self.pad_token_id,
            _weight=init_word_embedding(
                load_pretrained_word_embedding=self.
                use_pretrained_word_embedding,
                pretrained_word_embedding_path=self.word_embedding_path,
                id2word=self.id2word,
                word_embedding_dim=self.word_embedding_dim,
                vocab_size=self.vocab_size,
                pad_token_id=self.pad_token_id),
        )

        # Encoding components
        self.sent_encoder = EncoderRNN(
            input_dim=self.word_embedding_dim,
            hidden_dim=self.sent_encoder_hidden_dim,
            n_layers=self.n_sent_encoder_layers,
            dropout_emb=self.dropout_emb,
            dropout_input=self.dropout_input,
            dropout_hidden=self.dropout_hidden,
            dropout_output=self.dropout_output,
            bidirectional=True,
            embedding=self.word_embedding,
            rnn_type=self.rnn_type,
        )
        self.dial_encoder = EncoderRNN(
            input_dim=self.sent_encoder_hidden_dim,
            hidden_dim=self.dial_encoder_hidden_dim,
            n_layers=self.n_dial_encoder_layers,
            dropout_emb=self.dropout_emb,
            dropout_input=self.dropout_input,
            dropout_hidden=self.dropout_hidden,
            dropout_output=self.dropout_output,
            bidirectional=False,
            rnn_type=self.rnn_type,
        )

        # Variational components
        if config.n_components == 1:
            self.prior_net = GaussianVariation(
                input_dim=self.dial_encoder_hidden_dim,
                z_dim=self.latent_variable_dim,
                # large_mlp=True
            )
        elif config.n_components > 1:
            if self.gaussian_mix_type == "gmm":
                self.prior_net = GMMVariation(
                    input_dim=self.dial_encoder_hidden_dim,
                    z_dim=self.latent_variable_dim,
                    n_components=self.n_components,
                )
            elif self.gaussian_mix_type == "lgm":
                self.prior_net = LGMVariation(
                    input_dim=self.dial_encoder_hidden_dim,
                    z_dim=self.latent_variable_dim,
                    n_components=self.n_components,
                )
        self.post_net = GaussianVariation(
            input_dim=self.sent_encoder_hidden_dim +
            self.dial_encoder_hidden_dim,
            z_dim=self.latent_variable_dim,
        )
        self.latent_to_bow = nn.Sequential(
            nn.Linear(self.latent_variable_dim + self.dial_encoder_hidden_dim,
                      self.latent_variable_dim), nn.Tanh(),
            nn.Dropout(self.dropout),
            nn.Linear(self.latent_variable_dim, self.vocab_size))
        self.ctx_fc = nn.Sequential(
            nn.Linear(
                self.latent_variable_dim + self.dial_encoder_hidden_dim,
                self.dial_encoder_hidden_dim,
            ), nn.Tanh(), nn.Dropout(self.dropout))

        # Decoding components
        self.enc2dec_hidden_fc = nn.Linear(
            self.dial_encoder_hidden_dim,
            self.n_decoder_layers * self.decoder_hidden_dim if self.rnn_type
            == "gru" else self.n_decoder_layers * self.decoder_hidden_dim * 2)
        self.decoder = DecoderRNN(vocab_size=len(self.word2id),
                                  input_dim=self.word_embedding_dim,
                                  hidden_dim=self.decoder_hidden_dim,
                                  feat_dim=self.dial_encoder_hidden_dim,
                                  n_layers=self.n_decoder_layers,
                                  bos_token_id=self.bos_token_id,
                                  eos_token_id=self.eos_token_id,
                                  pad_token_id=self.pad_token_id,
                                  max_len=self.decode_max_len,
                                  dropout_emb=self.dropout_emb,
                                  dropout_input=self.dropout_input,
                                  dropout_hidden=self.dropout_hidden,
                                  dropout_output=self.dropout_output,
                                  embedding=self.word_embedding,
                                  tie_weights=self.tie_weights,
                                  rnn_type=self.rnn_type,
                                  use_attention=self.use_attention,
                                  attn_dim=self.sent_encoder_hidden_dim)

        # Extra components
        # Floor encoding
        if self.floor_encoder_type == "abs":
            self.floor_encoder = AbsFloorEmbEncoder(
                input_dim=self.sent_encoder_hidden_dim,
                embedding_dim=self.attr_embedding_dim)
        elif self.floor_encoder_type == "rel":
            self.floor_encoder = RelFloorEmbEncoder(
                input_dim=self.sent_encoder_hidden_dim,
                embedding_dim=self.attr_embedding_dim)
        else:
            self.floor_encoder = None

        # Initialization
        self._init_weights()

    def _init_weights(self):
        init_module_weights(self.enc2dec_hidden_fc)
        init_module_weights(self.latent_to_bow)
        init_module_weights(self.ctx_fc)

    def _init_dec_hiddens(self, context):
        batch_size = context.size(0)

        hiddens = self.enc2dec_hidden_fc(context)
        if self.rnn_type == "gru":
            hiddens = hiddens.view(
                batch_size,
                self.n_decoder_layers, self.decoder_hidden_dim).transpose(
                    0, 1).contiguous()  # (n_layers, batch_size, hidden_dim)
        elif self.rnn_type == "lstm":
            hiddens = hiddens.view(batch_size, self.n_decoder_layers,
                                   self.decoder_hidden_dim, 2)
            h = hiddens[:, :, :, 0].transpose(
                0, 1).contiguous()  # (n_layers, batch_size, hidden_dim)
            c = hiddens[:, :, :, 1].transpose(
                0, 1).contiguous()  # (n_layers, batch_size, hidden_dim)
            hiddens = (h, c)

        return hiddens

    def _encode_sent(self, outputs):
        output_lens = (outputs != self.pad_token_id).sum(-1)
        _, _, sent_encodings = self.sent_encoder(outputs, output_lens)
        return sent_encodings

    def _encode_dial(self, inputs, input_floors, output_floors):
        batch_size, history_len, max_x_sent_len = inputs.size()

        flat_inputs = inputs.view(batch_size * history_len, max_x_sent_len)
        input_lens = (inputs != self.pad_token_id).sum(-1)
        flat_input_lens = input_lens.view(batch_size * history_len)
        word_encodings, _, sent_encodings = self.sent_encoder(
            flat_inputs, flat_input_lens)
        word_encodings = word_encodings.view(batch_size, history_len,
                                             max_x_sent_len, -1)
        sent_encodings = sent_encodings.view(batch_size, history_len, -1)

        if self.floor_encoder is not None:
            src_floors = input_floors.view(-1)
            tgt_floors = output_floors.unsqueeze(1).repeat(
                1, history_len).view(-1)
            sent_encodings = sent_encodings.view(batch_size * history_len, -1)
            sent_encodings = self.floor_encoder(sent_encodings,
                                                src_floors=src_floors,
                                                tgt_floors=tgt_floors)
            sent_encodings = sent_encodings.view(batch_size, history_len, -1)

        dial_lens = (input_lens > 0).long().sum(
            1)  # equals number of non-padding sents
        _, _, dial_encodings = self.dial_encoder(
            sent_encodings, dial_lens)  # [batch_size, dial_encoder_dim]

        return word_encodings, sent_encodings, dial_encodings

    def _get_prior_z(self, prior_net_input, assign_k=None, return_pi=False):
        ret = self.prior_net(context=prior_net_input,
                             assign_k=assign_k,
                             return_pi=return_pi)

        return ret

    def _get_post_z(self, post_net_input):
        z, mu, var = self.post_net(post_net_input)
        return z, mu, var

    def _get_ctx_for_decoder(self, z, dial_encodings):
        ctx_encodings = self.ctx_fc(torch.cat([z, dial_encodings], dim=1))
        # ctx_encodings = self.ctx_fc(z)
        return ctx_encodings

    def _decode(self, inputs, context, attn_ctx=None, attn_mask=None):
        batch_size = context.size(0)
        hiddens = self._init_dec_hiddens(context)
        feats = None
        feats = context.unsqueeze(1).repeat(1, inputs.size(1), 1)
        ret_dict = self.decoder.forward(batch_size=batch_size,
                                        inputs=inputs,
                                        hiddens=hiddens,
                                        feats=feats,
                                        attn_ctx=attn_ctx,
                                        attn_mask=attn_mask,
                                        mode=DecoderRNN.MODE_TEACHER_FORCE)

        return ret_dict

    def _sample(self, context, attn_ctx=None, attn_mask=None):
        batch_size = context.size(0)
        hiddens = self._init_dec_hiddens(context)
        feats = None
        feats = context.unsqueeze(1).repeat(1, self.decode_max_len, 1)
        ret_dict = self.decoder.forward(
            batch_size=batch_size,
            hiddens=hiddens,
            feats=feats,
            attn_ctx=attn_ctx,
            attn_mask=attn_mask,
            mode=DecoderRNN.MODE_FREE_RUN,
            gen_type=self.gen_type,
            temp=self.temp,
            top_p=self.top_p,
            top_k=self.top_k,
        )

        return ret_dict

    def _get_attn_mask(self, attn_keys):
        attn_mask = (attn_keys != self.pad_token_id)
        return attn_mask

    def _annealing_coef_term(self, step):
        return min(1.0, 1.0 * step / self.n_step_annealing)

    def compute_active_units(self, data_source, batch_size, delta=0.01):
        with torch.no_grad():
            cnt = 0

            data_source.epoch_init()
            prior_mu_sum = 0
            post_mu_sum = 0
            while True:
                batch_data = data_source.next(batch_size)
                if batch_data is None:
                    break
                X, Y = batch_data["X"], batch_data["Y"]
                X_floor, Y_floor = batch_data["X_floor"], batch_data["Y_floor"]
                # encode
                word_encodings, sent_encodings, dial_encodings = self._encode_dial(
                    inputs=X, input_floors=X_floor, output_floors=Y_floor)
                # prior
                prior_net_input = dial_encodings
                _, prior_mu, _ = self._get_prior_z(prior_net_input)
                # post
                post_sent_encodings = self._encode_sent(Y)
                post_net_input = torch.cat(
                    [post_sent_encodings, dial_encodings], dim=1)
                _, post_mu, _ = self._get_post_z(post_net_input)
                # record
                cnt += prior_mu.size(0)
                prior_mu_sum += prior_mu.sum(0)
                post_mu_sum += post_mu.sum(0)
            prior_mu_mean = (prior_mu_sum / cnt).unsqueeze(
                0)  # [1, latent_dim]
            post_mu_mean = (post_mu_sum / cnt).unsqueeze(0)  # [1, latent_dim]

            data_source.epoch_init()
            prior_mu_var_sum = 0
            post_mu_var_sum = 0
            while True:
                batch_data = data_source.next(batch_size)
                if batch_data is None:
                    break
                X, Y = batch_data["X"], batch_data["Y"]
                X_floor, Y_floor = batch_data["X_floor"], batch_data["Y_floor"]
                # encode
                word_encodings, sent_encodings, dial_encodings = self._encode_dial(
                    inputs=X, input_floors=X_floor, output_floors=Y_floor)
                # prior
                prior_net_input = dial_encodings
                _, prior_mu, _ = self._get_prior_z(prior_net_input)
                # post
                post_sent_encodings = self._encode_sent(Y)
                post_net_input = torch.cat(
                    [post_sent_encodings, dial_encodings], dim=1)
                _, post_mu, _ = self._get_post_z(post_net_input)
                # record
                prior_mu_var_sum += ((prior_mu - prior_mu_mean)**2).sum(0)
                post_mu_var_sum += ((post_mu - post_mu_mean)**2).sum(0)
            prior_mu_var_mean = prior_mu_var_sum / (cnt - 1)  # [latent_dim]
            post_mu_var_mean = post_mu_var_sum / (cnt - 1)  # [latent_dim]

            prior_au = (prior_mu_var_mean >= delta).sum().item()
            post_au = (post_mu_var_mean >= delta).sum().item()
            prior_au_ratio = prior_au / self.latent_variable_dim
            post_au_ratio = post_au / self.latent_variable_dim
            return {
                "prior_au": prior_au,
                "post_au": post_au,
                "prior_au_ratio": prior_au_ratio,
                "post_au_ratio": post_au_ratio,
            }

    def load_model(self, model_path):
        """Load pretrained model weights from model_path

        Arguments:
            model_path {str} -- path to pretrained model weights
        """
        pretrained_state_dict = torch.load(
            model_path, map_location=lambda storage, loc: storage)
        self.load_state_dict(pretrained_state_dict)

    def train_step(self, data, step):
        """One training step

        Arguments:
            data {dict of data} -- required keys and values:
                'X' {LongTensor [batch_size, history_len, max_x_sent_len]} -- token ids of context sentences
                'X_floor' {LongTensor [batch_size, history_len]} -- floors of context sentences
                'Y' {LongTensor [batch_size, max_y_sent_len]} -- token ids of response sentence
                'Y_floor' {LongTensor [batch_size]} -- floor of response sentence

            step {int} -- the n-th optimization step

        Returns:
            dict of data -- returned keys and values
                'loss' {FloatTensor []} -- loss to backword
            dict of statistics -- returned keys and values
                'ppl' {float} -- perplexity
                'kld' {float} -- KLD
                'kld_term' {float} -- KLD annealing coefficient
                'bow_loss' {float} -- bag-of-word loss
                'loss' {float} -- batch loss
        """
        X, Y = data["X"], data["Y"]
        X_floor, Y_floor = data["X_floor"], data["Y_floor"]
        Y_in = Y[:, :-1].contiguous()
        Y_out = Y[:, 1:].contiguous()

        batch_size = X.size(0)
        max_y_len = Y_out.size(1)

        # Forward
        # Get prior z
        word_encodings, sent_encodings, dial_encodings = self._encode_dial(
            inputs=X, input_floors=X_floor, output_floors=Y_floor)
        prior_net_input = dial_encodings
        prior_z, prior_mu, prior_var = self._get_prior_z(prior_net_input)
        # Get post z
        post_sent_encodings = self._encode_sent(Y)
        post_net_input = torch.cat([post_sent_encodings, dial_encodings],
                                   dim=1)
        post_z, post_mu, post_var = self._get_post_z(post_net_input)
        # Decode
        ctx_encodings = self._get_ctx_for_decoder(post_z, dial_encodings)
        attn_ctx = word_encodings.view(batch_size, -1, word_encodings.size(-1))
        attn_mask = self._get_attn_mask(X.view(batch_size, -1))
        decoder_ret_dict = self._decode(inputs=Y_in,
                                        context=ctx_encodings,
                                        attn_ctx=attn_ctx,
                                        attn_mask=attn_mask)

        # Loss
        loss = 0
        # Reconstruction
        logits = decoder_ret_dict["logits"]
        word_losses = F.cross_entropy(logits.view(-1, self.vocab_size),
                                      Y_out.view(-1),
                                      ignore_index=self.pad_token_id,
                                      reduction="none").view(
                                          batch_size, max_y_len)
        sent_loss = word_losses.sum(1).mean(0)
        loss += sent_loss
        with torch.no_grad():
            ppl = F.cross_entropy(logits.view(-1, self.vocab_size),
                                  Y_out.view(-1),
                                  ignore_index=self.pad_token_id,
                                  reduction="mean").exp()
        # KLD
        kld_coef = self._annealing_coef_term(step)
        kld_losses = gaussian_kld(
            post_mu,
            post_var,
            prior_mu,
            prior_var,
        )
        avg_kld = kld_losses.mean()
        loss += avg_kld * kld_coef
        # BOW
        if self.use_bow_loss:
            Y_out_mask = (Y_out != self.pad_token_id).float()
            bow_input = torch.cat([post_z, dial_encodings], dim=1)
            bow_logits = self.latent_to_bow(
                bow_input)  # [batch_size, vocab_size]
            bow_loss = -F.log_softmax(bow_logits, dim=1).gather(
                1, Y_out) * Y_out_mask
            bow_loss = bow_loss.sum(1).mean()
            loss += bow_loss

        # return dicts
        ret_data = {"loss": loss}
        ret_stat = {
            "ppl": ppl.item(),
            "kld_term": kld_coef,
            "kld": avg_kld.item(),
            "prior_abs_mu_mean": prior_mu.abs().mean().item(),
            "prior_var_mean": prior_var.mean().item(),
            "post_abs_mu_mean": post_mu.abs().mean().item(),
            "post_var_mean": post_var.mean().item(),
            "loss": loss.item()
        }
        if self.use_bow_loss:
            ret_stat["bow_loss"] = bow_loss.item()

        return ret_data, ret_stat

    def evaluate_step(self, data, assign_k=None):
        """One training step

        Arguments:
            data {dict of data} -- required keys and values:
                'X' {LongTensor [batch_size, history_len, max_x_sent_len]} -- token ids of context sentences
                'X_floor' {LongTensor [batch_size, history_len]} -- floors of context sentences
                'Y' {LongTensor [batch_size, max_y_sent_len]} -- token ids of response sentence
                'Y_floor' {LongTensor [batch_size]} -- floor of response sentence

        Returns:
            dict of data -- returned keys and values

            dict of statistics -- returned keys and values
                'ppl' {float} -- perplexity
                'kld' {float} -- KLD
                'monitor' {float} -- a monitor number for learning rate scheduling
        """
        X, Y = data["X"], data["Y"]
        X_floor, Y_floor = data["X_floor"], data["Y_floor"]
        Y_in = Y[:, :-1].contiguous()
        Y_out = Y[:, 1:].contiguous()

        batch_size = X.size(0)
        max_y_len = Y_out.size(1)

        with torch.no_grad():
            # Forward
            # Get prior z
            word_encodings, sent_encodings, dial_encodings = self._encode_dial(
                inputs=X, input_floors=X_floor, output_floors=Y_floor)
            prior_net_input = dial_encodings
            prior_z, prior_mu, prior_var = self._get_prior_z(
                prior_net_input=prior_net_input, assign_k=assign_k)
            # Get post z
            post_sent_encodings = self._encode_sent(Y)
            post_net_input = torch.cat([post_sent_encodings, dial_encodings],
                                       dim=1)
            post_z, post_mu, post_var = self._get_post_z(post_net_input)
            # Decode from post z
            attn_ctx = word_encodings.view(batch_size, -1,
                                           word_encodings.size(-1))
            attn_mask = self._get_attn_mask(X.view(batch_size, -1))
            post_ctx_encodings = self._get_ctx_for_decoder(
                post_z, dial_encodings)
            post_decoder_ret_dict = self._decode(inputs=Y_in,
                                                 context=post_ctx_encodings,
                                                 attn_ctx=attn_ctx,
                                                 attn_mask=attn_mask)
            # Decode from prior z
            prior_ctx_encodings = self._get_ctx_for_decoder(
                prior_z, dial_encodings)
            prior_decoder_ret_dict = self._decode(inputs=Y_in,
                                                  context=prior_ctx_encodings,
                                                  attn_ctx=attn_ctx,
                                                  attn_mask=attn_mask)

            # Loss
            # Reconstruction
            post_word_losses = F.cross_entropy(
                post_decoder_ret_dict["logits"].view(-1, self.vocab_size),
                Y_out.view(-1),
                ignore_index=self.decoder.pad_token_id,
                reduction="none").view(batch_size, max_y_len)
            post_sent_loss = post_word_losses.sum(1).mean(0)
            post_ppl = F.cross_entropy(post_decoder_ret_dict["logits"].view(
                -1, self.vocab_size),
                                       Y_out.view(-1),
                                       ignore_index=self.decoder.pad_token_id,
                                       reduction="mean").exp()
            # Generation
            prior_word_losses = F.cross_entropy(
                prior_decoder_ret_dict["logits"].view(-1, self.vocab_size),
                Y_out.view(-1),
                ignore_index=self.decoder.pad_token_id,
                reduction="none").view(batch_size, max_y_len)
            prior_sent_loss = prior_word_losses.sum(1).mean(0)
            prior_ppl = F.cross_entropy(prior_decoder_ret_dict["logits"].view(
                -1, self.vocab_size),
                                        Y_out.view(-1),
                                        ignore_index=self.decoder.pad_token_id,
                                        reduction="mean").exp()
            # KLD
            kld_losses = gaussian_kld(post_mu, post_var, prior_mu, prior_var)
            avg_kld = kld_losses.mean()
            # monitor
            monitor_loss = post_sent_loss + avg_kld

        # return dicts
        ret_data = {}
        ret_stat = {
            "post_ppl": post_ppl.item(),
            "prior_ppl": prior_ppl.item(),
            "kld": avg_kld.item(),
            "post_abs_mu_mean": post_mu.abs().mean().item(),
            "post_var_mean": post_var.mean().item(),
            "prior_abs_mu_mean": prior_mu.abs().mean().item(),
            "prior_var_mean": prior_var.mean().item(),
            "monitor": monitor_loss.item()
        }

        return ret_data, ret_stat

    def test_step(self, data, assign_k=None):
        """One test step

        Arguments:
            data {dict of data} -- required keys and values:
                'X' {LongTensor [batch_size, history_len, max_x_sent_len]} -- token ids of context sentences
                'X_floor' {LongTensor [batch_size, history_len]} -- floors of context sentences
                'Y_floor' {LongTensor [batch_size]} -- floor of response sentence

        Returns:
            dict of data -- returned keys and values
                'symbols' {LongTensor [batch_size, max_decode_len]} -- token ids of response hypothesis
            dict of statistics -- returned keys and values

        """
        X = data["X"]
        X_floor, Y_floor = data["X_floor"], data["Y_floor"]
        batch_size = X.size(0)

        with torch.no_grad():
            # Forward
            # Get prior z
            word_encodings, sent_encodings, dial_encodings = self._encode_dial(
                inputs=X, input_floors=X_floor, output_floors=Y_floor)
            prior_net_input = dial_encodings
            prior_z, prior_mu, prior_var, prior_pi = self._get_prior_z(
                prior_net_input=prior_net_input,
                assign_k=assign_k,
                return_pi=True)
            # Decode
            ctx_encodings = self._get_ctx_for_decoder(prior_z, dial_encodings)
            attn_ctx = word_encodings.view(batch_size, -1,
                                           word_encodings.size(-1))
            attn_mask = self._get_attn_mask(X.view(batch_size, -1))
            decoder_ret_dict = self._sample(context=ctx_encodings,
                                            attn_ctx=attn_ctx,
                                            attn_mask=attn_mask)

        ret_data = {"symbols": decoder_ret_dict["symbols"], "pi": prior_pi}
        ret_stat = {}

        return ret_data, ret_stat
Example #5
0
class Mechanism_HRED(nn.Module):
    def __init__(self, config, tokenizer):
        super(Mechanism_HRED, self).__init__()

        # Attributes
        # Attributes from config
        self.word_embedding_dim = config.word_embedding_dim
        self.attr_embedding_dim = config.attr_embedding_dim
        self.sent_encoder_hidden_dim = config.sent_encoder_hidden_dim
        self.n_sent_encoder_layers = config.n_sent_encoder_layers
        self.dial_encoder_hidden_dim = config.dial_encoder_hidden_dim
        self.n_dial_encoder_layers = config.n_dial_encoder_layers
        self.decoder_hidden_dim = config.decoder_hidden_dim
        self.n_decoder_layers = config.n_decoder_layers
        self.n_mechanisms = config.n_mechanisms
        self.latent_dim = config.latent_dim
        self.use_attention = config.use_attention
        self.decode_max_len = config.decode_max_len
        self.tie_weights = config.tie_weights
        self.rnn_type = config.rnn_type
        self.gen_type = config.gen_type
        self.top_k = config.top_k
        self.top_p = config.top_p
        self.temp = config.temp
        self.word_embedding_path = config.word_embedding_path
        self.floor_encoder_type = config.floor_encoder
        # Optional attributes from config
        self.dropout_emb = config.dropout if hasattr(config,
                                                     "dropout") else 0.0
        self.dropout_input = config.dropout if hasattr(config,
                                                       "dropout") else 0.0
        self.dropout_hidden = config.dropout if hasattr(config,
                                                        "dropout") else 0.0
        self.dropout_output = config.dropout if hasattr(config,
                                                        "dropout") else 0.0
        self.use_pretrained_word_embedding = config.use_pretrained_word_embedding if hasattr(
            config, "use_pretrained_word_embedding") else False
        # Other attributes
        self.word2id = tokenizer.word2id
        self.id2word = tokenizer.id2word
        self.vocab_size = len(tokenizer.word2id)
        self.pad_token_id = tokenizer.pad_token_id
        self.bos_token_id = tokenizer.bos_token_id
        self.eos_token_id = tokenizer.eos_token_id

        # Input components
        self.word_embedding = nn.Embedding(
            self.vocab_size,
            self.word_embedding_dim,
            padding_idx=self.pad_token_id,
            _weight=init_word_embedding(
                load_pretrained_word_embedding=self.
                use_pretrained_word_embedding,
                pretrained_word_embedding_path=self.word_embedding_path,
                id2word=self.id2word,
                word_embedding_dim=self.word_embedding_dim,
                vocab_size=self.vocab_size,
                pad_token_id=self.pad_token_id),
        )

        # Encoding components
        self.sent_encoder = EncoderRNN(
            input_dim=self.word_embedding_dim,
            hidden_dim=self.sent_encoder_hidden_dim,
            n_layers=self.n_sent_encoder_layers,
            dropout_emb=self.dropout_emb,
            dropout_input=self.dropout_input,
            dropout_hidden=self.dropout_hidden,
            dropout_output=self.dropout_output,
            bidirectional=True,
            embedding=self.word_embedding,
            rnn_type=self.rnn_type,
        )
        self.dial_encoder = EncoderRNN(
            input_dim=self.sent_encoder_hidden_dim,
            hidden_dim=self.dial_encoder_hidden_dim,
            n_layers=self.n_dial_encoder_layers,
            dropout_emb=self.dropout_emb,
            dropout_input=self.dropout_input,
            dropout_hidden=self.dropout_hidden,
            dropout_output=self.dropout_output,
            bidirectional=False,
            rnn_type=self.rnn_type,
        )

        # Mechanism components
        self.mechanism_embeddings = nn.Embedding(
            self.n_mechanisms,
            self.latent_dim,
        )
        self.ctx2mech_fc = nn.Linear(self.dial_encoder_hidden_dim,
                                     self.latent_dim)
        self.score_bilinear = nn.Parameter(
            torch.FloatTensor(self.latent_dim, self.latent_dim))
        self.ctx_mech_combine_fc = nn.Linear(
            self.dial_encoder_hidden_dim + self.latent_dim,
            self.dial_encoder_hidden_dim)

        # Decoding components
        self.enc2dec_hidden_fc = nn.Linear(
            self.dial_encoder_hidden_dim,
            self.n_decoder_layers * self.decoder_hidden_dim if self.rnn_type
            == "gru" else self.n_decoder_layers * self.decoder_hidden_dim * 2)
        self.decoder = DecoderRNN(vocab_size=len(self.word2id),
                                  input_dim=self.word_embedding_dim,
                                  hidden_dim=self.decoder_hidden_dim,
                                  feat_dim=self.dial_encoder_hidden_dim,
                                  n_layers=self.n_decoder_layers,
                                  bos_token_id=self.bos_token_id,
                                  eos_token_id=self.eos_token_id,
                                  pad_token_id=self.pad_token_id,
                                  max_len=self.decode_max_len,
                                  dropout_emb=self.dropout_emb,
                                  dropout_input=self.dropout_input,
                                  dropout_hidden=self.dropout_hidden,
                                  dropout_output=self.dropout_output,
                                  embedding=self.word_embedding,
                                  tie_weights=self.tie_weights,
                                  rnn_type=self.rnn_type,
                                  use_attention=self.use_attention,
                                  attn_dim=self.sent_encoder_hidden_dim)

        # Extra components
        # floor encoding
        if self.floor_encoder_type == "abs":
            self.floor_encoder = AbsFloorEmbEncoder(
                input_dim=self.sent_encoder_hidden_dim,
                embedding_dim=self.attr_embedding_dim)
        elif self.floor_encoder_type == "rel":
            self.floor_encoder = RelFloorEmbEncoder(
                input_dim=self.sent_encoder_hidden_dim,
                embedding_dim=self.attr_embedding_dim)
        else:
            self.floor_encoder = None

        # Initialization
        self._init_weights()

    def _init_weights(self):
        init_module_weights(self.mechanism_embeddings)
        init_module_weights(self.ctx2mech_fc)
        init_module_weights(self.score_bilinear)
        init_module_weights(self.ctx_mech_combine_fc)
        init_module_weights(self.enc2dec_hidden_fc)

    def _init_dec_hiddens(self, context):
        batch_size = context.size(0)

        hiddens = self.enc2dec_hidden_fc(context)
        if self.rnn_type == "gru":
            hiddens = hiddens.view(
                batch_size,
                self.n_decoder_layers, self.decoder_hidden_dim).transpose(
                    0, 1).contiguous()  # (n_layers, batch_size, hidden_dim)
        elif self.rnn_type == "lstm":
            hiddens = hiddens.view(batch_size, self.n_decoder_layers,
                                   self.decoder_hidden_dim, 2)
            h = hiddens[:, :, :, 0].transpose(
                0, 1).contiguous()  # (n_layers, batch_size, hidden_dim)
            c = hiddens[:, :, :, 1].transpose(
                0, 1).contiguous()  # (n_layers, batch_size, hidden_dim)
            hiddens = (h, c)

        return hiddens

    def _encode(self, inputs, input_floors, output_floors):
        batch_size, history_len, max_x_sent_len = inputs.size()

        flat_inputs = inputs.view(batch_size * history_len, max_x_sent_len)
        input_lens = (inputs != self.pad_token_id).sum(-1)
        flat_input_lens = input_lens.view(batch_size * history_len)
        word_encodings, _, sent_encodings = self.sent_encoder(
            flat_inputs, flat_input_lens)
        word_encodings = word_encodings.view(batch_size, history_len,
                                             max_x_sent_len, -1)
        sent_encodings = sent_encodings.view(batch_size, history_len, -1)

        if self.floor_encoder is not None:
            src_floors = input_floors.view(-1)
            tgt_floors = output_floors.unsqueeze(1).repeat(
                1, history_len).view(-1)
            sent_encodings = sent_encodings.view(batch_size * history_len, -1)
            sent_encodings = self.floor_encoder(sent_encodings,
                                                src_floors=src_floors,
                                                tgt_floors=tgt_floors)
            sent_encodings = sent_encodings.view(batch_size, history_len, -1)

        dial_lens = (input_lens > 0).long().sum(
            1)  # equals number of non-padding sents
        _, _, dial_encodings = self.dial_encoder(
            sent_encodings, dial_lens)  # [batch_size, dial_encoder_dim]

        return word_encodings, sent_encodings, dial_encodings

    def _compute_mechanism_probs(self, ctx_encodings):
        ctx_mech = self.ctx2mech_fc(ctx_encodings)  # [batch_size, latent_dim]
        mech_scores = torch.matmul(
            torch.matmul(ctx_mech, self.score_bilinear),
            self.mechanism_embeddings.weight.T)  # [batch_size, n_mechanisms]
        mech_probs = F.softmax(mech_scores, dim=1)

        return mech_probs

    def _decode(self, inputs, context, attn_ctx=None, attn_mask=None):
        batch_size = context.size(0)
        hiddens = self._init_dec_hiddens(context)
        feats = None
        feats = context.unsqueeze(1).repeat(1, inputs.size(1), 1)
        ret_dict = self.decoder.forward(batch_size=batch_size,
                                        inputs=inputs,
                                        hiddens=hiddens,
                                        feats=feats,
                                        attn_ctx=attn_ctx,
                                        attn_mask=attn_mask,
                                        mode=DecoderRNN.MODE_TEACHER_FORCE)

        return ret_dict

    def _sample(self, context, attn_ctx=None, attn_mask=None):
        batch_size = context.size(0)
        hiddens = self._init_dec_hiddens(context)
        feats = None
        feats = context.unsqueeze(1).repeat(1, self.decode_max_len, 1)
        ret_dict = self.decoder.forward(
            batch_size=batch_size,
            hiddens=hiddens,
            feats=feats,
            attn_ctx=attn_ctx,
            attn_mask=attn_mask,
            mode=DecoderRNN.MODE_FREE_RUN,
            gen_type=self.gen_type,
            temp=self.temp,
            top_p=self.top_p,
            top_k=self.top_k,
        )

        return ret_dict

    def _get_attn_mask(self, attn_keys):
        attn_mask = (attn_keys != self.pad_token_id)
        return attn_mask

    def load_model(self, model_path):
        """Load pretrained model weights from model_path

        Arguments:
            model_path {str} -- path to pretrained model weights
        """
        pretrained_state_dict = torch.load(
            model_path, map_location=lambda storage, loc: storage)
        self.load_state_dict(pretrained_state_dict)

    def train_step(self, data):
        """One training step

        Arguments:
            data {dict of data} -- required keys and values:
                'X' {LongTensor [batch_size, history_len, max_x_sent_len]} -- token ids of context sentences
                'X_floor' {LongTensor [batch_size, history_len]} -- floors of context sentences
                'Y' {LongTensor [batch_size, max_y_sent_len]} -- token ids of response sentence
                'Y_floor' {LongTensor [batch_size]} -- floor of response sentence

        Returns:
            dict of data -- returned keys and values
                'loss' {FloatTensor []} -- loss to backword
            dict of statistics -- returned keys and values
                'ppl' {float} -- perplexity
                'loss' {float} -- batch loss
        """
        X, Y = data["X"], data["Y"]
        X_floor, Y_floor = data["X_floor"], data["Y_floor"]
        Y_in = Y[:, :-1].contiguous()
        Y_out = Y[:, 1:].contiguous()

        batch_size = X.size(0)

        # Forward
        # -- encode
        word_encodings, sent_encodings, dial_encodings = self._encode(
            inputs=X, input_floors=X_floor, output_floors=Y_floor)
        attn_ctx = word_encodings.view(batch_size, -1, word_encodings.size(-1))
        attn_mask = self._get_attn_mask(X.view(batch_size, -1))
        mech_probs = self._compute_mechanism_probs(
            dial_encodings)  # [batch_size, n_mechanisms]
        mech_embed_inputs = torch.LongTensor(list(range(
            self.n_mechanisms))).to(DEVICE)  # [n_mechanisms]
        repeated_mech_embed_inputs = mech_embed_inputs.unsqueeze(0).repeat(
            batch_size, 1).view(-1)  # [batch_size*n_mechanisms]
        repeated_mech_embeds = self.mechanism_embeddings(
            repeated_mech_embed_inputs
        )  # [batch_size*n_mechanisms, latent_dim]
        # -- repeat for each mechanism
        repeated_Y_in = Y_in.unsqueeze(1).repeat(
            1, self.n_mechanisms, 1)  # [batch_size, n_mechanisms, len]
        repeated_Y_in = repeated_Y_in.view(batch_size * self.n_mechanisms, -1)
        repeated_Y_out = Y_out.unsqueeze(1).repeat(
            1, self.n_mechanisms, 1)  # [batch_size, n_mechanisms, len]
        repeated_Y_out = repeated_Y_out.view(batch_size * self.n_mechanisms,
                                             -1)
        dial_encodings = dial_encodings.unsqueeze(1).repeat(
            1, self.n_mechanisms, 1)  # [batch_size, n_mechanisms, hidden_dim]
        dial_encodings = dial_encodings.view(batch_size * self.n_mechanisms,
                                             self.dial_encoder_hidden_dim)
        attn_ctx = attn_ctx.unsqueeze(1).repeat(1, self.n_mechanisms, 1, 1)
        attn_ctx = attn_ctx.view(batch_size * self.n_mechanisms, -1,
                                 attn_ctx.size(-1))
        attn_mask = attn_mask.unsqueeze(1).repeat(1, self.n_mechanisms, 1)
        attn_mask = attn_mask.view(batch_size * self.n_mechanisms, -1)
        # -- decode
        dec_ctx = self.ctx_mech_combine_fc(
            torch.cat([dial_encodings, repeated_mech_embeds], dim=1))
        decoder_ret_dict = self._decode(inputs=repeated_Y_in,
                                        context=dec_ctx,
                                        attn_ctx=attn_ctx,
                                        attn_mask=attn_mask)

        # Calculate loss
        loss = 0
        word_neglogll = F.cross_entropy(
            decoder_ret_dict["logits"].view(-1, self.vocab_size),
            repeated_Y_out.view(-1),
            ignore_index=self.decoder.pad_token_id,
            reduction="none").view(batch_size, self.n_mechanisms, -1)
        sent_logll = word_neglogll.sum(2) * (-1)
        mech_logll = (mech_probs + 1e-10).log()
        sent_mech_logll = sent_logll + mech_logll
        target_logll = torch.logsumexp(sent_mech_logll, dim=1)
        target_neglogll = target_logll * (-1)
        loss = target_neglogll.mean()

        with torch.no_grad():
            ppl = F.cross_entropy(decoder_ret_dict["logits"].view(
                -1, self.vocab_size),
                                  repeated_Y_out.view(-1),
                                  ignore_index=self.decoder.pad_token_id,
                                  reduction="mean").exp()

        # return dicts
        ret_data = {"loss": loss}
        ret_stat = {
            "ppl": ppl.item(),
            "loss": loss.item(),
            "mech_prob_std": mech_probs.std(1).mean().item(),
            "mech_prob_max": mech_probs.max(1)[0].mean().item()
        }

        return ret_data, ret_stat

    def evaluate_step(self, data):
        """One evaluation step

        Arguments:
            data {dict of data} -- required keys and values:
                'X' {LongTensor [batch_size, history_len, max_x_sent_len]} -- token ids of context sentences
                'X_floor' {LongTensor [batch_size, history_len]} -- floors of context sentences
                'Y' {LongTensor [batch_size, max_y_sent_len]} -- token ids of response sentence
                'Y_floor' {LongTensor [batch_size]} -- floor of response sentence

        Returns:
            dict of data -- returned keys and values

            dict of statistics -- returned keys and values
                'ppl' {float} -- perplexity
                'monitor' {float} -- a monitor number for learning rate scheduling
        """
        X, Y = data["X"], data["Y"]
        X_floor, Y_floor = data["X_floor"], data["Y_floor"]
        Y_in = Y[:, :-1].contiguous()
        Y_out = Y[:, 1:].contiguous()

        batch_size = X.size(0)

        with torch.no_grad():
            # Forward
            word_encodings, sent_encodings, dial_encodings = self._encode(
                inputs=X, input_floors=X_floor, output_floors=Y_floor)
            attn_ctx = word_encodings.view(batch_size, -1,
                                           word_encodings.size(-1))
            attn_mask = self._get_attn_mask(X.view(batch_size, -1))
            mech_probs = self._compute_mechanism_probs(
                dial_encodings)  # [batch_size, n_mechanisms]
            mech_embed_inputs = mech_probs.argmax(1)  # [batch_size]
            mech_embeds = self.mechanism_embeddings(mech_embed_inputs)
            dec_ctx = self.ctx_mech_combine_fc(
                torch.cat([dial_encodings, mech_embeds], dim=1))
            decoder_ret_dict = self._decode(inputs=Y_in,
                                            context=dec_ctx,
                                            attn_ctx=attn_ctx,
                                            attn_mask=attn_mask)

            # Loss
            word_loss = F.cross_entropy(decoder_ret_dict["logits"].view(
                -1, self.vocab_size),
                                        Y_out.view(-1),
                                        ignore_index=self.decoder.pad_token_id,
                                        reduction="mean")
            ppl = torch.exp(word_loss)

        # return dicts
        ret_data = {}
        ret_stat = {"ppl": ppl.item(), "monitor": ppl.item()}

        return ret_data, ret_stat

    def test_step(self, data, sample_from="dist"):
        """One test step

        Arguments:
            data {dict of data} -- required keys and values:
                'X' {LongTensor [batch_size, history_len, max_x_sent_len]} -- token ids of context sentences
                'X_floor' {LongTensor [batch_size, history_len]} -- floors of context sentences
                'Y_floor' {LongTensor [batch_size]} -- floor of response sentence

            sample_from {str} -- "dist": sample mechanism from computed probabilities
                                "random": sample mechanisms uniformly
        Returns:
            dict of data -- returned keys and values
                'symbols' {LongTensor [batch_size, max_decode_len]} -- token ids of response hypothesis
            dict of statistics -- returned keys and values

        """
        X, Y = data["X"], data["Y"]
        X_floor, Y_floor = data["X_floor"], data["Y_floor"]
        batch_size = X.size(0)

        with torch.no_grad():
            # Forward
            word_encodings, sent_encodings, dial_encodings = self._encode(
                inputs=X, input_floors=X_floor, output_floors=Y_floor)
            attn_ctx = word_encodings.view(batch_size, -1,
                                           word_encodings.size(-1))
            attn_mask = self._get_attn_mask(X.view(batch_size, -1))
            if sample_from == "dist":
                mech_probs = self._compute_mechanism_probs(
                    dial_encodings)  # [batch_size, n_mechanisms]
                mech_dist = torch.distributions.Categorical(mech_probs)
                mech_embed_inputs = mech_dist.sample()  # [batch_size]
            elif sample_from == "random":
                mech_embed_inputs = [
                    random.randint(0, self.n_mechanisms - 1)
                    for _ in range(batch_size)
                ]
                mech_embed_inputs = torch.LongTensor(mech_embed_inputs).to(
                    DEVICE)  # [batch_size]
            mech_embeds = self.mechanism_embeddings(mech_embed_inputs)
            dec_ctx = self.ctx_mech_combine_fc(
                torch.cat([dial_encodings, mech_embeds], dim=1))
            decoder_ret_dict = self._sample(context=dec_ctx,
                                            attn_ctx=attn_ctx,
                                            attn_mask=attn_mask)

        ret_data = {"symbols": decoder_ret_dict["symbols"]}
        ret_stat = {}

        return ret_data, ret_stat