class HRED(nn.Module): def __init__(self, config, tokenizer): super(HRED, self).__init__() # Attributes # Attributes from config self.word_embedding_dim = config.word_embedding_dim self.attr_embedding_dim = config.attr_embedding_dim self.sent_encoder_hidden_dim = config.sent_encoder_hidden_dim self.n_sent_encoder_layers = config.n_sent_encoder_layers self.dial_encoder_hidden_dim = config.dial_encoder_hidden_dim self.n_dial_encoder_layers = config.n_dial_encoder_layers self.decoder_hidden_dim = config.decoder_hidden_dim self.n_decoder_layers = config.n_decoder_layers self.use_attention = config.use_attention self.decode_max_len = config.decode_max_len self.tie_weights = config.tie_weights self.rnn_type = config.rnn_type self.gen_type = config.gen_type self.top_k = config.top_k self.top_p = config.top_p self.temp = config.temp self.word_embedding_path = config.word_embedding_path self.floor_encoder_type = config.floor_encoder # Optional attributes from config self.dropout_emb = config.dropout if hasattr(config, "dropout") else 0.0 self.dropout_input = config.dropout if hasattr(config, "dropout") else 0.0 self.dropout_hidden = config.dropout if hasattr(config, "dropout") else 0.0 self.dropout_output = config.dropout if hasattr(config, "dropout") else 0.0 self.use_pretrained_word_embedding = config.use_pretrained_word_embedding if hasattr( config, "use_pretrained_word_embedding") else False # Other attributes self.word2id = tokenizer.word2id self.id2word = tokenizer.id2word self.vocab_size = len(tokenizer.word2id) self.pad_token_id = tokenizer.pad_token_id self.bos_token_id = tokenizer.bos_token_id self.eos_token_id = tokenizer.eos_token_id # Input components self.word_embedding = nn.Embedding( self.vocab_size, self.word_embedding_dim, padding_idx=self.pad_token_id, _weight=init_word_embedding( load_pretrained_word_embedding=self. use_pretrained_word_embedding, pretrained_word_embedding_path=self.word_embedding_path, id2word=self.id2word, word_embedding_dim=self.word_embedding_dim, vocab_size=self.vocab_size, pad_token_id=self.pad_token_id), ) # Encoding components self.sent_encoder = EncoderRNN( input_dim=self.word_embedding_dim, hidden_dim=self.sent_encoder_hidden_dim, n_layers=self.n_sent_encoder_layers, dropout_emb=self.dropout_emb, dropout_input=self.dropout_input, dropout_hidden=self.dropout_hidden, dropout_output=self.dropout_output, bidirectional=True, embedding=self.word_embedding, rnn_type=self.rnn_type, ) self.dial_encoder = EncoderRNN( input_dim=self.sent_encoder_hidden_dim, hidden_dim=self.dial_encoder_hidden_dim, n_layers=self.n_dial_encoder_layers, dropout_emb=self.dropout_emb, dropout_input=self.dropout_input, dropout_hidden=self.dropout_hidden, dropout_output=self.dropout_output, bidirectional=False, rnn_type=self.rnn_type, ) # Decoding components self.enc2dec_hidden_fc = nn.Linear( self.dial_encoder_hidden_dim, self.n_decoder_layers * self.decoder_hidden_dim if self.rnn_type == "gru" else self.n_decoder_layers * self.decoder_hidden_dim * 2) self.decoder = DecoderRNN(vocab_size=len(self.word2id), input_dim=self.word_embedding_dim, hidden_dim=self.decoder_hidden_dim, feat_dim=self.dial_encoder_hidden_dim, n_layers=self.n_decoder_layers, bos_token_id=self.bos_token_id, eos_token_id=self.eos_token_id, pad_token_id=self.pad_token_id, max_len=self.decode_max_len, dropout_emb=self.dropout_emb, dropout_input=self.dropout_input, dropout_hidden=self.dropout_hidden, dropout_output=self.dropout_output, embedding=self.word_embedding, tie_weights=self.tie_weights, rnn_type=self.rnn_type, use_attention=self.use_attention, attn_dim=self.sent_encoder_hidden_dim) # Extra components # floor encoding if self.floor_encoder_type == "abs": self.floor_encoder = AbsFloorEmbEncoder( input_dim=self.sent_encoder_hidden_dim, embedding_dim=self.attr_embedding_dim) elif self.floor_encoder_type == "rel": self.floor_encoder = RelFloorEmbEncoder( input_dim=self.sent_encoder_hidden_dim, embedding_dim=self.attr_embedding_dim) else: self.floor_encoder = None # Initialization self._init_weights() def _init_weights(self): init_module_weights(self.enc2dec_hidden_fc) def _init_dec_hiddens(self, context): batch_size = context.size(0) hiddens = self.enc2dec_hidden_fc(context) if self.rnn_type == "gru": hiddens = hiddens.view( batch_size, self.n_decoder_layers, self.decoder_hidden_dim).transpose( 0, 1).contiguous() # (n_layers, batch_size, hidden_dim) elif self.rnn_type == "lstm": hiddens = hiddens.view(batch_size, self.n_decoder_layers, self.decoder_hidden_dim, 2) h = hiddens[:, :, :, 0].transpose( 0, 1).contiguous() # (n_layers, batch_size, hidden_dim) c = hiddens[:, :, :, 1].transpose( 0, 1).contiguous() # (n_layers, batch_size, hidden_dim) hiddens = (h, c) return hiddens def _encode(self, inputs, input_floors, output_floors): batch_size, history_len, max_x_sent_len = inputs.size() flat_inputs = inputs.view(batch_size * history_len, max_x_sent_len) input_lens = (inputs != self.pad_token_id).sum(-1) flat_input_lens = input_lens.view(batch_size * history_len) word_encodings, _, sent_encodings = self.sent_encoder( flat_inputs, flat_input_lens) word_encodings = word_encodings.view(batch_size, history_len, max_x_sent_len, -1) sent_encodings = sent_encodings.view(batch_size, history_len, -1) if self.floor_encoder is not None: src_floors = input_floors.view(-1) tgt_floors = output_floors.unsqueeze(1).repeat( 1, history_len).view(-1) sent_encodings = sent_encodings.view(batch_size * history_len, -1) sent_encodings = self.floor_encoder(sent_encodings, src_floors=src_floors, tgt_floors=tgt_floors) sent_encodings = sent_encodings.view(batch_size, history_len, -1) dial_lens = (input_lens > 0).long().sum( 1) # equals number of non-padding sents _, _, dial_encodings = self.dial_encoder( sent_encodings, dial_lens) # [batch_size, dial_encoder_dim] return word_encodings, sent_encodings, dial_encodings def _decode(self, inputs, context, attn_ctx=None, attn_mask=None): batch_size = context.size(0) hiddens = self._init_dec_hiddens(context) feats = None feats = context.unsqueeze(1).repeat(1, inputs.size(1), 1) ret_dict = self.decoder.forward(batch_size=batch_size, inputs=inputs, hiddens=hiddens, feats=feats, attn_ctx=attn_ctx, attn_mask=attn_mask, mode=DecoderRNN.MODE_TEACHER_FORCE) return ret_dict def _sample(self, context, attn_ctx=None, attn_mask=None, mmi_args=None): batch_size = context.size(0) hiddens = self._init_dec_hiddens(context) feats = None feats = context.unsqueeze(1).repeat(1, self.decode_max_len, 1) ret_dict = self.decoder.forward( batch_size=batch_size, hiddens=hiddens, feats=feats, attn_ctx=attn_ctx, attn_mask=attn_mask, mode=DecoderRNN.MODE_FREE_RUN, gen_type=self.gen_type, temp=self.temp, top_p=self.top_p, top_k=self.top_k, mmi_args=mmi_args, ) return ret_dict def _get_attn_mask(self, attn_keys): attn_mask = (attn_keys != self.pad_token_id) return attn_mask def load_model(self, model_path): """Load pretrained model weights from model_path Arguments: model_path {str} -- path to pretrained model weights """ pretrained_state_dict = torch.load( model_path, map_location=lambda storage, loc: storage) self.load_state_dict(pretrained_state_dict) def train_step(self, data): """One training step Arguments: data {dict of data} -- required keys and values: 'X' {LongTensor [batch_size, history_len, max_x_sent_len]} -- token ids of context sentences 'X_floor' {LongTensor [batch_size, history_len]} -- floors of context sentences 'Y' {LongTensor [batch_size, max_y_sent_len]} -- token ids of response sentence 'Y_floor' {LongTensor [batch_size]} -- floor of response sentence Returns: dict of data -- returned keys and values 'loss' {FloatTensor []} -- loss to backword dict of statistics -- returned keys and values 'ppl' {float} -- perplexity 'loss' {float} -- batch loss """ X, Y = data["X"], data["Y"] X_floor, Y_floor = data["X_floor"], data["Y_floor"] Y_in = Y[:, :-1].contiguous() Y_out = Y[:, 1:].contiguous() batch_size = X.size(0) max_y_len = Y_out.size(1) # Forward word_encodings, sent_encodings, dial_encodings = self._encode( inputs=X, input_floors=X_floor, output_floors=Y_floor) attn_ctx = word_encodings.view(batch_size, -1, word_encodings.size(-1)) attn_mask = self._get_attn_mask(X.view(batch_size, -1)) decoder_ret_dict = self._decode(inputs=Y_in, context=dial_encodings, attn_ctx=attn_ctx, attn_mask=attn_mask) # Calculate loss loss = 0 logits = decoder_ret_dict["logits"] word_losses = F.cross_entropy(logits.view(-1, self.vocab_size), Y_out.view(-1), ignore_index=self.pad_token_id, reduction="none").view( batch_size, max_y_len) sent_loss = word_losses.sum(1).mean(0) loss += sent_loss with torch.no_grad(): ppl = F.cross_entropy(logits.view(-1, self.vocab_size), Y_out.view(-1), ignore_index=self.pad_token_id, reduction="mean").exp() # return dicts ret_data = {"loss": loss} ret_stat = {"ppl": ppl.item(), "loss": loss.item()} return ret_data, ret_stat def evaluate_step(self, data): """One evaluation step Arguments: data {dict of data} -- required keys and values: 'X' {LongTensor [batch_size, history_len, max_x_sent_len]} -- token ids of context sentences 'X_floor' {LongTensor [batch_size, history_len]} -- floors of context sentences 'Y' {LongTensor [batch_size, max_y_sent_len]} -- token ids of response sentence 'Y_floor' {LongTensor [batch_size]} -- floor of response sentence Returns: dict of data -- returned keys and values dict of statistics -- returned keys and values 'ppl' {float} -- perplexity 'monitor' {float} -- a monitor number for learning rate scheduling """ X, Y = data["X"], data["Y"] X_floor, Y_floor = data["X_floor"], data["Y_floor"] Y_in = Y[:, :-1].contiguous() Y_out = Y[:, 1:].contiguous() batch_size = X.size(0) with torch.no_grad(): # Forward word_encodings, sent_encodings, dial_encodings = self._encode( inputs=X, input_floors=X_floor, output_floors=Y_floor) attn_ctx = word_encodings.view(batch_size, -1, word_encodings.size(-1)) attn_mask = self._get_attn_mask(X.view(batch_size, -1)) decoder_ret_dict = self._decode(inputs=Y_in, context=dial_encodings, attn_ctx=attn_ctx, attn_mask=attn_mask) # Loss word_loss = F.cross_entropy(decoder_ret_dict["logits"].view( -1, self.vocab_size), Y_out.view(-1), ignore_index=self.decoder.pad_token_id, reduction="mean") ppl = torch.exp(word_loss) # return dicts ret_data = {} ret_stat = {"ppl": ppl.item(), "monitor": ppl.item()} return ret_data, ret_stat def test_step(self, data, mmi_args=None): """One test step Arguments: data {dict of data} -- required keys and values: 'X' {LongTensor [batch_size, history_len, max_x_sent_len]} -- token ids of context sentences 'X_floor' {LongTensor [batch_size, history_len]} -- floors of context sentences 'Y_floor' {LongTensor [batch_size]} -- floor of response sentence Returns: dict of data -- returned keys and values 'symbols' {LongTensor [batch_size, max_decode_len]} -- token ids of response hypothesis dict of statistics -- returned keys and values """ X, Y = data["X"], data["Y"] X_floor, Y_floor = data["X_floor"], data["Y_floor"] batch_size = X.size(0) with torch.no_grad(): # Forward word_encodings, sent_encodings, dial_encodings = self._encode( inputs=X, input_floors=X_floor, output_floors=Y_floor) attn_ctx = word_encodings.view(batch_size, -1, word_encodings.size(-1)) attn_mask = self._get_attn_mask(X.view(batch_size, -1)) decoder_ret_dict = self._sample(context=dial_encodings, attn_ctx=attn_ctx, attn_mask=attn_mask, mmi_args=mmi_args) ret_data = {"symbols": decoder_ret_dict["symbols"]} ret_stat = {} return ret_data, ret_stat
class RNNLM(nn.Module): def __init__(self, config, tokenizer): super(RNNLM, self).__init__() # Attributes # Attributes from config self.word_embedding_dim = config.word_embedding_dim self.decoder_hidden_dim = config.decoder_hidden_dim self.n_decoder_layers = config.n_decoder_layers self.decode_max_len = config.decode_max_len self.tie_weights = config.tie_weights self.rnn_type = config.rnn_type # Optional attributes from config self.gen_type = config.gen_type if hasattr(config, "gen_type") else "greedy" self.top_k = config.top_k if hasattr(config, "top_k") else 0 self.top_p = config.top_p if hasattr(config, "top_p") else 0.0 self.temp = config.temp if hasattr(config, "temp") else 1.0 self.dropout_emb = config.dropout if hasattr(config, "dropout") else 0.0 self.dropout_input = config.dropout if hasattr(config, "dropout") else 0.0 self.dropout_hidden = config.dropout if hasattr(config, "dropout") else 0.0 self.dropout_output = config.dropout if hasattr(config, "dropout") else 0.0 self.use_pretrained_word_embedding = config.use_pretrained_word_embedding if hasattr(config, "use_pretrained_word_embedding") else False self.word_embedding_path = config.word_embedding_path if hasattr(config, "word_embedding_path") else None # Other attributes self.tokenizer = tokenizer self.word2id = tokenizer.word2id self.id2word = tokenizer.id2word self.vocab_size = len(tokenizer.word2id) self.pad_token_id = tokenizer.pad_token_id self.bos_token_id = tokenizer.bos_token_id self.eos_token_id = tokenizer.eos_token_id # Components self.word_embedding = nn.Embedding( self.vocab_size, self.word_embedding_dim, padding_idx=self.pad_token_id, _weight=init_word_embedding( load_pretrained_word_embedding=self.use_pretrained_word_embedding, pretrained_word_embedding_path=self.word_embedding_path, id2word=self.id2word, word_embedding_dim=self.word_embedding_dim, vocab_size=self.vocab_size, pad_token_id=self.pad_token_id ), ) self.decoder = DecoderRNN( vocab_size=len(self.word2id), input_dim=self.word_embedding_dim, hidden_dim=self.decoder_hidden_dim, n_layers=self.n_decoder_layers, bos_token_id=self.bos_token_id, eos_token_id=self.eos_token_id, pad_token_id=self.pad_token_id, max_len=self.decode_max_len, dropout_emb=self.dropout_emb, dropout_input=self.dropout_input, dropout_hidden=self.dropout_hidden, dropout_output=self.dropout_output, embedding=self.word_embedding, tie_weights=self.tie_weights, rnn_type=self.rnn_type, ) # Initialization self._init_weights() def _init_weights(self): pass def _random_hiddens(self, batch_size): if self.rnn_type == "gru": hiddens = torch.zeros( batch_size, self.n_decoder_layers, self.decoder_hidden_dim ).to(DEVICE).transpose(0, 1).contiguous() # (n_layers, batch_size, hidden_dim) nn.init.uniform_(hiddens, -1.0, 1.0) elif self.rnn_type == "lstm": hiddens = hiddens.view( batch_size, self.n_decoder_layers, self.decoder_hidden_dim, 2 ).to(DEVICE) nn.init.uniform_(hiddens, -1.0, 1.0) h = hiddens[:, :, :, 0].transpose(0, 1).contiguous() # (n_layers, batch_size, hidden_dim) c = hiddens[:, :, :, 1].transpose(0, 1).contiguous() # (n_layers, batch_size, hidden_dim) hiddens = (h, c) return hiddens def _decode(self, inputs): batch_size = inputs.size(0) ret_dict = self.decoder.forward( batch_size=batch_size, inputs=inputs, mode=DecoderRNN.MODE_TEACHER_FORCE ) return ret_dict def _sample(self, batch_size, hiddens=None): ret_dict = self.decoder.forward( batch_size=batch_size, hiddens=hiddens, mode=DecoderRNN.MODE_FREE_RUN, gen_type=self.gen_type, temp=self.temp, top_p=self.top_p, top_k=self.top_k, ) return ret_dict def load_model(self, model_path): """Load pretrained model weights from model_path Arguments: model_path {str} -- path to pretrained model weights """ pretrained_state_dict = torch.load( model_path, map_location=lambda storage, loc: storage ) self.load_state_dict(pretrained_state_dict) def train_step(self, data): """One training step Arguments: data {dict of data} -- required keys and values: 'X' {LongTensor [batch_size, max_len]} -- token ids of sentences Returns: dict of data -- returned keys and values 'loss' {FloatTensor []} -- loss to backword dict of statistics -- returned keys and values 'ppl' {float} -- perplexity 'loss' {float} -- batch loss """ X = data["X"] X_in = X[:, :-1].contiguous() X_out = X[:, 1:].contiguous() batch_size = X.size(0) # Forward decoder_ret_dict = self._decode( inputs=X_in ) # Calculate loss loss = 0 word_loss = F.cross_entropy( decoder_ret_dict["logits"].view(-1, self.vocab_size), X_out.view(-1), ignore_index=self.decoder.pad_token_id, reduction="mean" ) ppl = torch.exp(word_loss) loss = word_loss # return dicts ret_data = { "loss": loss } ret_stat = { "ppl": ppl.item(), "loss": loss.item() } return ret_data, ret_stat def evaluate_step(self, data): """One evaluation step Arguments: data {dict of data} -- required keys and values: 'X' {LongTensor [batch_size, max_len]} -- token ids of sentences Returns: dict of data -- returned keys and values dict of statistics -- returned keys and values 'ppl' {float} -- perplexity 'monitor' {float} -- a monitor number for learning rate scheduling """ X = data["X"] X_in = X[:, :-1].contiguous() X_out = X[:, 1:].contiguous() batch_size = X.size(0) with torch.no_grad(): decoder_ret_dict = self._decode( inputs=X_in ) # Loss word_loss = F.cross_entropy( decoder_ret_dict["logits"].view(-1, self.vocab_size), X_out.view(-1), ignore_index=self.decoder.pad_token_id, reduction="mean" ) ppl = torch.exp(word_loss) # return dicts ret_data = {} ret_stat = { "ppl": ppl.item(), "monitor": ppl.item() } return ret_data, ret_stat def sample_step(self, batch_size): """One test step Arguments: batch_size {int} Returns: dict of data -- returned keys and values 'symbols' {LongTensor [batch_size, max_decode_len]} -- token ids of response hypothesis dict of statistics -- returned keys and values """ with torch.no_grad(): decoder_ret_dict = self._sample(batch_size) ret_data = { "symbols": decoder_ret_dict["symbols"] } ret_stat = {} return ret_data, ret_stat def compute_prob(self, sents): """Compute P(sents) Arguments: sents {List [str]} -- sentences in string form """ batch_tokens = [self.tokenizer.convert_string_to_tokens(sent) for sent in sents] batch_token_ids = [self.tokenizer.convert_tokens_to_ids(tokens) for tokens in batch_tokens] batch_token_ids = self.tokenizer.convert_batch_ids_to_tensor(batch_token_ids).to(DEVICE) # [batch_size, len] X_in = batch_token_ids[:, :-1] X_out = batch_token_ids[:, 1:] with torch.no_grad(): word_ll = [] sent_ll = [] sent_probs = [] batch_size = 50 n_batches = math.ceil(X_in.size(0)/batch_size) for batch_idx in range(n_batches): begin = batch_idx * batch_size end = min(begin + batch_size, X_in.size(0)) batch_X_in = X_in[begin:end] batch_X_out = X_out[begin:end] decoder_ret_dict = self._decode( inputs=batch_X_in ) logits = decoder_ret_dict["logits"] # [batch_size, len-1, vocab_size] batch_word_ll = F.log_softmax(logits, dim=2) batch_gathered_word_ll = batch_word_ll.gather(2, batch_X_out.unsqueeze(2)).squeeze(2) # [batch_size, len-1] batch_sent_ll = batch_gathered_word_ll.sum(1) # [batch_size] batch_sent_probs = batch_sent_ll.exp() word_ll.append(batch_gathered_word_ll) sent_ll.append(batch_sent_ll) sent_probs.append(batch_sent_probs) word_ll = torch.cat(word_ll, dim=0) sent_ll = torch.cat(sent_ll, dim=0) sent_probs = torch.cat(sent_probs, dim=0) ret_data = { "word_loglikelihood": word_ll, "sent_loglikelihood": sent_ll, "sent_likelihood": sent_probs } return ret_data
class VHCR(nn.Module): def __init__(self, config, tokenizer): super(VHCR, self).__init__() # Attributes # Attributes from config self.word_embedding_dim = config.word_embedding_dim self.attr_embedding_dim = config.attr_embedding_dim self.sent_encoder_hidden_dim = config.sent_encoder_hidden_dim self.n_sent_encoder_layers = config.n_sent_encoder_layers self.dial_encoder_hidden_dim = config.dial_encoder_hidden_dim self.n_dial_encoder_layers = config.n_dial_encoder_layers self.latent_variable_dim = config.latent_dim self.decoder_hidden_dim = config.decoder_hidden_dim self.n_decoder_layers = config.n_decoder_layers self.use_attention = config.use_attention self.decode_max_len = config.decode_max_len self.tie_weights = config.tie_weights self.rnn_type = config.rnn_type self.gen_type = config.gen_type self.top_k = config.top_k self.top_p = config.top_p self.temp = config.temp self.word_embedding_path = config.word_embedding_path self.floor_encoder_type = config.floor_encoder # Optional attributes from config self.dropout_emb = config.dropout if hasattr(config, "dropout") else 0.0 self.dropout_input = config.dropout if hasattr(config, "dropout") else 0.0 self.dropout_hidden = config.dropout if hasattr(config, "dropout") else 0.0 self.dropout_output = config.dropout if hasattr(config, "dropout") else 0.0 self.use_pretrained_word_embedding = config.use_pretrained_word_embedding if hasattr(config, "use_pretrained_word_embedding") else False self.n_step_annealing = config.n_step_annealing if hasattr(config, "n_step_annealing") else 0 # Other attributes self.vocab_size = len(tokenizer.word2id) self.word2id = tokenizer.word2id self.id2word = tokenizer.id2word self.pad_token_id = tokenizer.pad_token_id self.bos_token_id = tokenizer.bos_token_id self.eos_token_id = tokenizer.eos_token_id self.dropout_sent = 0.25 # Input components self.word_embedding = nn.Embedding( self.vocab_size, self.word_embedding_dim, padding_idx=self.pad_token_id, _weight=init_word_embedding( load_pretrained_word_embedding=self.use_pretrained_word_embedding, pretrained_word_embedding_path=self.word_embedding_path, id2word=self.id2word, word_embedding_dim=self.word_embedding_dim, vocab_size=self.vocab_size, pad_token_id=self.pad_token_id ), ) # Encoding components self.sent_encoder = EncoderRNN( input_dim=self.word_embedding_dim, hidden_dim=self.sent_encoder_hidden_dim, n_layers=self.n_sent_encoder_layers, dropout_emb=self.dropout_emb, dropout_input=self.dropout_input, dropout_hidden=self.dropout_hidden, dropout_output=self.dropout_output, bidirectional=True, embedding=self.word_embedding, rnn_type=self.rnn_type, ) self.dial_encoder = EncoderRNN( input_dim=self.sent_encoder_hidden_dim+self.latent_variable_dim, hidden_dim=self.dial_encoder_hidden_dim, n_layers=self.n_dial_encoder_layers, dropout_emb=self.dropout_emb, dropout_input=self.dropout_input, dropout_hidden=self.dropout_hidden, dropout_output=self.dropout_output, bidirectional=False, rnn_type=self.rnn_type, ) self.dial_infer_encoder = EncoderRNN( input_dim=self.sent_encoder_hidden_dim, hidden_dim=self.dial_encoder_hidden_dim, n_layers=self.n_dial_encoder_layers, dropout_emb=self.dropout_emb, dropout_input=self.dropout_input, dropout_hidden=self.dropout_hidden, dropout_output=self.dropout_output, bidirectional=True, rnn_type=self.rnn_type, ) # Variational components self.dial_post_net = GaussianVariation( input_dim=self.dial_encoder_hidden_dim, z_dim=self.latent_variable_dim ) self.sent_prior_net = GaussianVariation( input_dim=self.dial_encoder_hidden_dim+self.latent_variable_dim, z_dim=self.latent_variable_dim ) self.sent_post_net = GaussianVariation( input_dim=self.sent_encoder_hidden_dim+self.dial_encoder_hidden_dim+self.latent_variable_dim, z_dim=self.latent_variable_dim ) self.unk_sent_vec = nn.Parameter(torch.randn(self.sent_encoder_hidden_dim)).to(DEVICE) # Decoding components self.ctx_fc = nn.Linear( 2*self.latent_variable_dim+self.dial_encoder_hidden_dim, self.dial_encoder_hidden_dim ) self.enc2dec_hidden_fc = nn.Linear( self.dial_encoder_hidden_dim, self.n_decoder_layers*self.decoder_hidden_dim if self.rnn_type == "gru" else self.n_decoder_layers*self.decoder_hidden_dim*2 ) self.decoder = DecoderRNN( vocab_size=len(self.word2id), input_dim=self.word_embedding_dim, hidden_dim=self.decoder_hidden_dim, feat_dim=self.dial_encoder_hidden_dim, n_layers=self.n_decoder_layers, bos_token_id=self.bos_token_id, eos_token_id=self.eos_token_id, pad_token_id=self.pad_token_id, max_len=self.decode_max_len, dropout_emb=self.dropout_emb, dropout_input=self.dropout_input, dropout_hidden=self.dropout_hidden, dropout_output=self.dropout_output, embedding=self.word_embedding, tie_weights=self.tie_weights, rnn_type=self.rnn_type, use_attention=self.use_attention, attn_dim=self.sent_encoder_hidden_dim ) # Extra components # Floor encoding if self.floor_encoder_type == "abs": self.floor_encoder = AbsFloorEmbEncoder( input_dim=self.sent_encoder_hidden_dim, embedding_dim=self.attr_embedding_dim ) elif self.floor_encoder_type == "rel": self.floor_encoder = RelFloorEmbEncoder( input_dim=self.sent_encoder_hidden_dim, embedding_dim=self.attr_embedding_dim ) else: self.floor_encoder = None # Hidden initialization self.dial_z2dial_enc_hidden_fc = nn.Linear( self.latent_variable_dim, self.n_dial_encoder_layers*self.dial_encoder_hidden_dim if self.rnn_type == "gru" else self.n_dial_encoder_layers*self.dial_encoder_hidden_dim*2 ) # Initialization self._init_weights() def _init_weights(self): init_module_weights(self.enc2dec_hidden_fc) init_module_weights(self.ctx_fc) def _init_dec_hiddens(self, context): batch_size = context.size(0) hiddens = self.enc2dec_hidden_fc(context) if self.rnn_type == "gru": hiddens = hiddens.view( batch_size, self.n_decoder_layers, self.decoder_hidden_dim ).transpose(0, 1).contiguous() # (n_layers, batch_size, hidden_dim) elif self.rnn_type == "lstm": hiddens = hiddens.view( batch_size, self.n_decoder_layers, self.decoder_hidden_dim, 2 ) h = hiddens[:, :, :, 0].transpose(0, 1).contiguous() # (n_layers, batch_size, hidden_dim) c = hiddens[:, :, :, 1].transpose(0, 1).contiguous() # (n_layers, batch_size, hidden_dim) hiddens = (h, c) return hiddens def _get_ctx_sent_encodings(self, inputs, input_floors, output_floors): batch_size, history_len, max_x_sent_len = inputs.size() flat_inputs = inputs.view(batch_size*history_len, max_x_sent_len) input_lens = (inputs != self.pad_token_id).sum(-1) flat_input_lens = input_lens.view(batch_size*history_len) word_encodings, _, sent_encodings = self.sent_encoder(flat_inputs, flat_input_lens) word_encodings = word_encodings.view(batch_size, history_len, max_x_sent_len, -1) sent_encodings = sent_encodings.view(batch_size, history_len, -1) if self.floor_encoder is not None: src_floors = input_floors.view(-1) tgt_floors = output_floors.unsqueeze(1).repeat(1, history_len).view(-1) sent_encodings = sent_encodings.view(batch_size*history_len, -1) sent_encodings = self.floor_encoder( sent_encodings, src_floors=src_floors, tgt_floors=tgt_floors ) sent_encodings = sent_encodings.view(batch_size, history_len, -1) if self.training and self.dropout_sent > 0.0: history_len = history_len indices = np.where(np.random.rand(history_len) < self.dropout_sent)[0] if len(indices) > 0: sent_encodings[:, indices, :] = self.unk_sent_vec return word_encodings, sent_encodings def _get_reply_sent_encodings(self, outputs): output_lens = (outputs != self.pad_token_id).sum(-1) word_encodings, _, sent_encodings = self.sent_encoder(outputs, output_lens) return sent_encodings def _get_dial_encodings(self, ctx_dial_lens, ctx_sent_encodings, z_dial): batch_size, history_len, _ = ctx_sent_encodings.size() # Init hidden states of dialog encoder from z_dial hiddens = self.dial_z2dial_enc_hidden_fc(z_dial) if self.rnn_type == "gru": hiddens = hiddens.view( batch_size, self.n_dial_encoder_layers, self.dial_encoder_hidden_dim ).transpose(0, 1).contiguous() # (n_layers, batch_size, hidden_dim) elif self.rnn_type == "lstm": hiddens = hiddens.view( batch_size, self.n_dial_encoder_layers, self.dial_encoder_hidden_dim, 2 ) h = hiddens[:, :, :, 0].transpose(0, 1).contiguous() # (n_layers, batch_size, hidden_dim) c = hiddens[:, :, :, 1].transpose(0, 1).contiguous() # (n_layers, batch_size, hidden_dim) hiddens = (h, c) # Inputs to dialog encoder z_dial = z_dial.unsqueeze(1).repeat(1, history_len, 1) dial_encoder_inputs = torch.cat([ctx_sent_encodings, z_dial], dim=2) dialog_lens = ctx_dial_lens _, _, dialog_encodings = self.dial_encoder(dial_encoder_inputs, dialog_lens, hiddens) # [batch_size, dialog_encoder_dim] return dialog_encodings def _get_full_dial_encodings(self, ctx_dial_lens, ctx_sent_encodings, reply_sent_encodings): batch_size = ctx_sent_encodings.size(0) history_len = ctx_sent_encodings.size(1) full_sent_encodings = [] for batch_idx in range(batch_size): encodings = [] ctx_len = ctx_dial_lens[batch_idx].item() # part 1 - ctx sent encodings for encoding in ctx_sent_encodings[batch_idx][:ctx_len]: encodings.append(encoding) # part 2 - reply encoding encodings.append(reply_sent_encodings[batch_idx]) # part 3 - padding encodings for encoding in ctx_sent_encodings[batch_idx][ctx_len:]: encodings.append(encoding) encodings = torch.stack(encodings, dim=0) full_sent_encodings.append(encodings) full_sent_encodings = torch.stack(full_sent_encodings, dim=0) full_dialog_lens = ctx_dial_lens+1 # equals number of non-padding sents _, _, full_dialog_encodings = self.dial_infer_encoder(full_sent_encodings, full_dialog_lens) # [batch_size, dialog_encoder_dim] return full_dialog_encodings def _get_dial_post(self, full_dialog_encodings): z, mu, var = self.dial_post_net(full_dialog_encodings) return z, mu, var def _get_dial_prior(self, batch_size): mu = torch.FloatTensor([0.0]).to(DEVICE) var = torch.FloatTensor([1.0]).to(DEVICE) z = torch.randn([batch_size, self.latent_variable_dim]).to(DEVICE) return z, mu, var def _get_sent_post(self, reply_sent_encodings, dial_encodings, z_dial): sent_post_net_inputs = torch.cat([reply_sent_encodings, dial_encodings, z_dial], dim=1) z, mu, var = self.sent_post_net(sent_post_net_inputs) return z, mu, var def _get_sent_prior(self, dial_encodings, z_dial): sent_prior_net_inputs = torch.cat([dial_encodings, z_dial], dim=1) z, mu, var = self.sent_prior_net(sent_prior_net_inputs) return z, mu, var def _decode(self, inputs, context, attn_ctx=None, attn_mask=None): batch_size = context.size(0) hiddens = self._init_dec_hiddens(context) feats = None feats = context.unsqueeze(1).repeat(1, inputs.size(1), 1) ret_dict = self.decoder.forward( batch_size=batch_size, inputs=inputs, hiddens=hiddens, feats=feats, attn_ctx=attn_ctx, attn_mask=attn_mask, mode=DecoderRNN.MODE_TEACHER_FORCE ) return ret_dict def _sample(self, context, attn_ctx=None, attn_mask=None): batch_size = context.size(0) hiddens = self._init_dec_hiddens(context) feats = None feats = context.unsqueeze(1).repeat(1, self.decode_max_len, 1) ret_dict = self.decoder.forward( batch_size=batch_size, hiddens=hiddens, feats=feats, attn_ctx=attn_ctx, attn_mask=attn_mask, mode=DecoderRNN.MODE_FREE_RUN, gen_type=self.gen_type, temp=self.temp, top_p=self.top_p, top_k=self.top_k, ) return ret_dict def _get_attn_mask(self, attn_keys): attn_mask = (attn_keys != self.pad_token_id) return attn_mask def _annealing_coef_term(self, step): return min(1.0, 1.0*step/self.n_step_annealing) def load_model(self, model_path): """Load pretrained model weights from model_path Arguments: model_path {str} -- path to pretrained model weights """ pretrained_state_dict = torch.load( model_path, map_location=lambda storage, loc: storage ) self.load_state_dict(pretrained_state_dict) def train_step(self, data, step): """One training step Arguments: data {dict of data} -- required keys and values: 'X' {LongTensor [batch_size, history_len, max_x_sent_len]} -- token ids of context sentences 'X_floor' {LongTensor [batch_size, history_len]} -- floors of context sentences 'Y' {LongTensor [batch_size, max_y_sent_len]} -- token ids of response sentence 'Y_floor' {LongTensor [batch_size]} -- floor of response sentence step {int} -- the n-th optimization step Returns: dict of data -- returned keys and values 'loss' {FloatTensor []} -- loss to backword dict of statistics -- returned keys and values 'ppl' {float} -- perplexity 'sent_kld' {float} -- sentence KLD 'dial_kld' {float} -- dialog KLD 'kld_term' {float} -- KLD annealing coefficient 'loss' {float} -- batch loss """ X, Y = data["X"], data["Y"] X_floor, Y_floor = data["X_floor"], data["Y_floor"] Y_in = Y[:, :-1].contiguous() Y_out = Y[:, 1:].contiguous() batch_size = X.size(0) max_y_len = Y_out.size(1) ctx_dial_lens = ((X != self.pad_token_id).sum(-1) > 0).sum(-1) # Forward # Encode sentences word_encodings, ctx_sent_encodings = self._get_ctx_sent_encodings( inputs=X, input_floors=X_floor, output_floors=Y_floor ) reply_sent_encodings = self._get_reply_sent_encodings( outputs=Y, ) # Encode full dialog for posterior dialog z full_dial_encodings = self._get_full_dial_encodings( ctx_dial_lens=ctx_dial_lens, ctx_sent_encodings=ctx_sent_encodings, reply_sent_encodings=reply_sent_encodings ) # Get dial z z_dial_post, mu_dial_post, var_dial_post = self._get_dial_post(full_dial_encodings) # Encode dialog dial_encodings = self._get_dial_encodings( ctx_dial_lens=ctx_dial_lens, ctx_sent_encodings=ctx_sent_encodings, z_dial=z_dial_post ) # Get sent z z_sent_post, mu_sent_post, var_sent_post = self._get_sent_post( reply_sent_encodings=reply_sent_encodings, dial_encodings=dial_encodings, z_dial=z_dial_post ) z_sent_prior, mu_sent_prior, var_sent_prior = self._get_sent_prior( dial_encodings=dial_encodings, z_dial=z_dial_post ) # Decode ctx_encodings = self.ctx_fc(torch.cat([dial_encodings, z_sent_post, z_dial_post], dim=1)) attn_ctx = word_encodings.view(batch_size, -1, word_encodings.size(-1)) attn_mask = self._get_attn_mask(X.view(batch_size, -1)) decoder_ret_dict = self._decode( inputs=Y_in, context=ctx_encodings, attn_ctx=attn_ctx, attn_mask=attn_mask ) # Loss loss = 0 # Reconstruction logits = decoder_ret_dict["logits"] word_losses = F.cross_entropy( logits.view(-1, self.vocab_size), Y_out.view(-1), ignore_index=self.pad_token_id, reduction="none" ).view(batch_size, max_y_len) sent_loss = word_losses.sum(1).mean(0) loss += sent_loss with torch.no_grad(): ppl = F.cross_entropy( logits.view(-1, self.vocab_size), Y_out.view(-1), ignore_index=self.pad_token_id, reduction="mean" ).exp() # KLD kld_coef = self._annealing_coef_term(step) dial_kld_losses = gaussian_kld(mu_dial_post, var_dial_post) avg_dial_kld = dial_kld_losses.mean() sent_kld_losses = gaussian_kld(mu_sent_post, var_sent_post, mu_sent_prior, var_sent_prior) avg_sent_kld = sent_kld_losses.mean() loss += (avg_dial_kld+avg_sent_kld)*kld_coef # return dicts ret_data = { "loss": loss } ret_stat = { "ppl": ppl.item(), "kld_term": kld_coef, "dial_kld": avg_dial_kld.item(), "sent_kld": avg_sent_kld.item(), "loss": loss.item() } return ret_data, ret_stat def evaluate_step(self, data): """One training step Arguments: data {dict of data} -- required keys and values: 'X' {LongTensor [batch_size, history_len, max_x_sent_len]} -- token ids of context sentences 'X_floor' {LongTensor [batch_size, history_len]} -- floors of context sentences 'Y' {LongTensor [batch_size, max_y_len]} -- token ids of response sentence 'Y_floor' {LongTensor [batch_size]} -- floor of response sentence Returns: dict of data -- returned keys and values dict of statistics -- returned keys and values 'ppl' {float} -- perplexity 'sent_kld' {float} -- sentence KLD 'dial_kld' {float} -- dialog KLD 'monitor' {float} -- a monitor number for learning rate scheduling """ X, Y = data["X"], data["Y"] X_floor, Y_floor = data["X_floor"], data["Y_floor"] Y_in = Y[:, :-1].contiguous() Y_out = Y[:, 1:].contiguous() batch_size = X.size(0) max_y_len = Y_out.size(1) ctx_dial_lens = ((X != self.pad_token_id).sum(-1) > 0).sum(-1) with torch.no_grad(): # Forward # Encode sentences word_encodings, ctx_sent_encodings = self._get_ctx_sent_encodings( inputs=X, input_floors=X_floor, output_floors=Y_floor ) reply_sent_encodings = self._get_reply_sent_encodings( outputs=Y, ) # Encode full dialog for posterior dialog z full_dial_encodings = self._get_full_dial_encodings( ctx_dial_lens=ctx_dial_lens, ctx_sent_encodings=ctx_sent_encodings, reply_sent_encodings=reply_sent_encodings ) # Get dial z z_dial_post, mu_dial_post, var_dial_post = self._get_dial_post(full_dial_encodings) # Encode dialog dial_encodings = self._get_dial_encodings( ctx_dial_lens=ctx_dial_lens, ctx_sent_encodings=ctx_sent_encodings, z_dial=z_dial_post ) # Get sent z z_sent_post, mu_sent_post, var_sent_post = self._get_sent_post( reply_sent_encodings=reply_sent_encodings, dial_encodings=dial_encodings, z_dial=z_dial_post ) z_sent_prior, mu_sent_prior, var_sent_prior = self._get_sent_prior( dial_encodings=dial_encodings, z_dial=z_dial_post ) # Decode ctx_encodings = self.ctx_fc(torch.cat([dial_encodings, z_sent_post, z_dial_post], dim=1)) attn_ctx = word_encodings.view(batch_size, -1, word_encodings.size(-1)) attn_mask = self._get_attn_mask(X.view(batch_size, -1)) decoder_ret_dict = self._decode( inputs=Y_in, context=ctx_encodings, attn_ctx=attn_ctx, attn_mask=attn_mask ) # Loss # Reconstruction logits = decoder_ret_dict["logits"] word_losses = F.cross_entropy( logits.view(-1, self.vocab_size), Y_out.view(-1), ignore_index=self.decoder.pad_token_id, reduction="none" ).view(batch_size, max_y_len) sent_loss = word_losses.sum(1).mean(0) ppl = F.cross_entropy( logits.view(-1, self.vocab_size), Y_out.view(-1), ignore_index=self.decoder.pad_token_id, reduction="mean" ).exp() # KLD dial_kld_losses = gaussian_kld(mu_dial_post, var_dial_post) avg_dial_kld = dial_kld_losses.mean() sent_kld_losses = gaussian_kld(mu_sent_post, var_sent_post, mu_sent_prior, var_sent_prior) avg_sent_kld = sent_kld_losses.mean() # monitor loss monitor_loss = sent_loss + avg_dial_kld + avg_sent_kld # return dicts ret_data = {} ret_stat = { "ppl": ppl.item(), "dial_kld": avg_dial_kld.item(), "sent_kld": avg_sent_kld.item(), "monitor": monitor_loss.item() } return ret_data, ret_stat def test_step(self, data): """One test step Arguments: data {dict of data} -- required keys and values: 'X' {LongTensor [batch_size, history_len, max_x_sent_len]} -- token ids of context sentences 'X_floor' {LongTensor [batch_size, history_len]} -- floors of context sentences 'Y_floor' {LongTensor [batch_size]} -- floor of response sentence Returns: dict of data -- returned keys and values 'symbols' {LongTensor [batch_size, max_decode_len]} -- token ids of response hypothesis dict of statistics -- returned keys and values """ X = data["X"] X_floor, Y_floor = data["X_floor"], data["Y_floor"] batch_size = X.size(0) ctx_dial_lens = ((X != self.pad_token_id).sum(-1) > 0).sum(-1) with torch.no_grad(): # Forward # Encode sentences word_encodings, ctx_sent_encodings = self._get_ctx_sent_encodings( inputs=X, input_floors=X_floor, output_floors=Y_floor ) # Get dial z z_dial_prior, mu_dial_prior, var_dial_prior = self._get_dial_prior(batch_size) # Encode dialog dial_encodings = self._get_dial_encodings( ctx_dial_lens=ctx_dial_lens, ctx_sent_encodings=ctx_sent_encodings, z_dial=z_dial_prior ) # Get sent z z_sent_prior, mu_sent_prior, var_sent_prior = self._get_sent_prior( dial_encodings=dial_encodings, z_dial=z_dial_prior ) # Decode ctx_encodings = self.ctx_fc(torch.cat([dial_encodings, z_sent_prior, z_dial_prior], dim=1)) attn_ctx = word_encodings.view(batch_size, -1, word_encodings.size(-1)) attn_mask = self._get_attn_mask(X.view(batch_size, -1)) decoder_ret_dict = self._sample( context=ctx_encodings, attn_ctx=attn_ctx, attn_mask=attn_mask ) ret_data = { "symbols": decoder_ret_dict["symbols"] } ret_stat = {} return ret_data, ret_stat
class VHRED(nn.Module): def __init__(self, config, tokenizer): super(VHRED, self).__init__() # Attributes # Attributes from config self.word_embedding_dim = config.word_embedding_dim self.attr_embedding_dim = config.attr_embedding_dim self.sent_encoder_hidden_dim = config.sent_encoder_hidden_dim self.n_sent_encoder_layers = config.n_sent_encoder_layers self.dial_encoder_hidden_dim = config.dial_encoder_hidden_dim self.n_dial_encoder_layers = config.n_dial_encoder_layers self.latent_variable_dim = config.latent_dim self.decoder_hidden_dim = config.decoder_hidden_dim self.n_decoder_layers = config.n_decoder_layers self.use_attention = config.use_attention self.decode_max_len = config.decode_max_len self.tie_weights = config.tie_weights self.rnn_type = config.rnn_type self.gen_type = config.gen_type self.top_k = config.top_k self.top_p = config.top_p self.temp = config.temp self.word_embedding_path = config.word_embedding_path self.floor_encoder_type = config.floor_encoder self.gaussian_mix_type = config.gaussian_mix_type # Optional attributes from config self.use_bow_loss = config.use_bow_loss if hasattr( config, "use_bow_loss") else True self.dropout = config.dropout if hasattr(config, "dropout") else 0.0 self.dropout_emb = config.dropout if hasattr(config, "dropout") else 0.0 self.dropout_input = config.dropout if hasattr(config, "dropout") else 0.0 self.dropout_hidden = config.dropout if hasattr(config, "dropout") else 0.0 self.dropout_output = config.dropout if hasattr(config, "dropout") else 0.0 self.use_pretrained_word_embedding = config.use_pretrained_word_embedding if hasattr( config, "use_pretrained_word_embedding") else False self.n_step_annealing = config.n_step_annealing if hasattr( config, "n_step_annealing") else 1 self.n_components = config.n_components if hasattr( config, "n_components") else 1 # Other attributes self.vocab_size = len(tokenizer.word2id) self.word2id = tokenizer.word2id self.id2word = tokenizer.id2word self.pad_token_id = tokenizer.pad_token_id self.bos_token_id = tokenizer.bos_token_id self.eos_token_id = tokenizer.eos_token_id # Input components self.word_embedding = nn.Embedding( self.vocab_size, self.word_embedding_dim, padding_idx=self.pad_token_id, _weight=init_word_embedding( load_pretrained_word_embedding=self. use_pretrained_word_embedding, pretrained_word_embedding_path=self.word_embedding_path, id2word=self.id2word, word_embedding_dim=self.word_embedding_dim, vocab_size=self.vocab_size, pad_token_id=self.pad_token_id), ) # Encoding components self.sent_encoder = EncoderRNN( input_dim=self.word_embedding_dim, hidden_dim=self.sent_encoder_hidden_dim, n_layers=self.n_sent_encoder_layers, dropout_emb=self.dropout_emb, dropout_input=self.dropout_input, dropout_hidden=self.dropout_hidden, dropout_output=self.dropout_output, bidirectional=True, embedding=self.word_embedding, rnn_type=self.rnn_type, ) self.dial_encoder = EncoderRNN( input_dim=self.sent_encoder_hidden_dim, hidden_dim=self.dial_encoder_hidden_dim, n_layers=self.n_dial_encoder_layers, dropout_emb=self.dropout_emb, dropout_input=self.dropout_input, dropout_hidden=self.dropout_hidden, dropout_output=self.dropout_output, bidirectional=False, rnn_type=self.rnn_type, ) # Variational components if config.n_components == 1: self.prior_net = GaussianVariation( input_dim=self.dial_encoder_hidden_dim, z_dim=self.latent_variable_dim, # large_mlp=True ) elif config.n_components > 1: if self.gaussian_mix_type == "gmm": self.prior_net = GMMVariation( input_dim=self.dial_encoder_hidden_dim, z_dim=self.latent_variable_dim, n_components=self.n_components, ) elif self.gaussian_mix_type == "lgm": self.prior_net = LGMVariation( input_dim=self.dial_encoder_hidden_dim, z_dim=self.latent_variable_dim, n_components=self.n_components, ) self.post_net = GaussianVariation( input_dim=self.sent_encoder_hidden_dim + self.dial_encoder_hidden_dim, z_dim=self.latent_variable_dim, ) self.latent_to_bow = nn.Sequential( nn.Linear(self.latent_variable_dim + self.dial_encoder_hidden_dim, self.latent_variable_dim), nn.Tanh(), nn.Dropout(self.dropout), nn.Linear(self.latent_variable_dim, self.vocab_size)) self.ctx_fc = nn.Sequential( nn.Linear( self.latent_variable_dim + self.dial_encoder_hidden_dim, self.dial_encoder_hidden_dim, ), nn.Tanh(), nn.Dropout(self.dropout)) # Decoding components self.enc2dec_hidden_fc = nn.Linear( self.dial_encoder_hidden_dim, self.n_decoder_layers * self.decoder_hidden_dim if self.rnn_type == "gru" else self.n_decoder_layers * self.decoder_hidden_dim * 2) self.decoder = DecoderRNN(vocab_size=len(self.word2id), input_dim=self.word_embedding_dim, hidden_dim=self.decoder_hidden_dim, feat_dim=self.dial_encoder_hidden_dim, n_layers=self.n_decoder_layers, bos_token_id=self.bos_token_id, eos_token_id=self.eos_token_id, pad_token_id=self.pad_token_id, max_len=self.decode_max_len, dropout_emb=self.dropout_emb, dropout_input=self.dropout_input, dropout_hidden=self.dropout_hidden, dropout_output=self.dropout_output, embedding=self.word_embedding, tie_weights=self.tie_weights, rnn_type=self.rnn_type, use_attention=self.use_attention, attn_dim=self.sent_encoder_hidden_dim) # Extra components # Floor encoding if self.floor_encoder_type == "abs": self.floor_encoder = AbsFloorEmbEncoder( input_dim=self.sent_encoder_hidden_dim, embedding_dim=self.attr_embedding_dim) elif self.floor_encoder_type == "rel": self.floor_encoder = RelFloorEmbEncoder( input_dim=self.sent_encoder_hidden_dim, embedding_dim=self.attr_embedding_dim) else: self.floor_encoder = None # Initialization self._init_weights() def _init_weights(self): init_module_weights(self.enc2dec_hidden_fc) init_module_weights(self.latent_to_bow) init_module_weights(self.ctx_fc) def _init_dec_hiddens(self, context): batch_size = context.size(0) hiddens = self.enc2dec_hidden_fc(context) if self.rnn_type == "gru": hiddens = hiddens.view( batch_size, self.n_decoder_layers, self.decoder_hidden_dim).transpose( 0, 1).contiguous() # (n_layers, batch_size, hidden_dim) elif self.rnn_type == "lstm": hiddens = hiddens.view(batch_size, self.n_decoder_layers, self.decoder_hidden_dim, 2) h = hiddens[:, :, :, 0].transpose( 0, 1).contiguous() # (n_layers, batch_size, hidden_dim) c = hiddens[:, :, :, 1].transpose( 0, 1).contiguous() # (n_layers, batch_size, hidden_dim) hiddens = (h, c) return hiddens def _encode_sent(self, outputs): output_lens = (outputs != self.pad_token_id).sum(-1) _, _, sent_encodings = self.sent_encoder(outputs, output_lens) return sent_encodings def _encode_dial(self, inputs, input_floors, output_floors): batch_size, history_len, max_x_sent_len = inputs.size() flat_inputs = inputs.view(batch_size * history_len, max_x_sent_len) input_lens = (inputs != self.pad_token_id).sum(-1) flat_input_lens = input_lens.view(batch_size * history_len) word_encodings, _, sent_encodings = self.sent_encoder( flat_inputs, flat_input_lens) word_encodings = word_encodings.view(batch_size, history_len, max_x_sent_len, -1) sent_encodings = sent_encodings.view(batch_size, history_len, -1) if self.floor_encoder is not None: src_floors = input_floors.view(-1) tgt_floors = output_floors.unsqueeze(1).repeat( 1, history_len).view(-1) sent_encodings = sent_encodings.view(batch_size * history_len, -1) sent_encodings = self.floor_encoder(sent_encodings, src_floors=src_floors, tgt_floors=tgt_floors) sent_encodings = sent_encodings.view(batch_size, history_len, -1) dial_lens = (input_lens > 0).long().sum( 1) # equals number of non-padding sents _, _, dial_encodings = self.dial_encoder( sent_encodings, dial_lens) # [batch_size, dial_encoder_dim] return word_encodings, sent_encodings, dial_encodings def _get_prior_z(self, prior_net_input, assign_k=None, return_pi=False): ret = self.prior_net(context=prior_net_input, assign_k=assign_k, return_pi=return_pi) return ret def _get_post_z(self, post_net_input): z, mu, var = self.post_net(post_net_input) return z, mu, var def _get_ctx_for_decoder(self, z, dial_encodings): ctx_encodings = self.ctx_fc(torch.cat([z, dial_encodings], dim=1)) # ctx_encodings = self.ctx_fc(z) return ctx_encodings def _decode(self, inputs, context, attn_ctx=None, attn_mask=None): batch_size = context.size(0) hiddens = self._init_dec_hiddens(context) feats = None feats = context.unsqueeze(1).repeat(1, inputs.size(1), 1) ret_dict = self.decoder.forward(batch_size=batch_size, inputs=inputs, hiddens=hiddens, feats=feats, attn_ctx=attn_ctx, attn_mask=attn_mask, mode=DecoderRNN.MODE_TEACHER_FORCE) return ret_dict def _sample(self, context, attn_ctx=None, attn_mask=None): batch_size = context.size(0) hiddens = self._init_dec_hiddens(context) feats = None feats = context.unsqueeze(1).repeat(1, self.decode_max_len, 1) ret_dict = self.decoder.forward( batch_size=batch_size, hiddens=hiddens, feats=feats, attn_ctx=attn_ctx, attn_mask=attn_mask, mode=DecoderRNN.MODE_FREE_RUN, gen_type=self.gen_type, temp=self.temp, top_p=self.top_p, top_k=self.top_k, ) return ret_dict def _get_attn_mask(self, attn_keys): attn_mask = (attn_keys != self.pad_token_id) return attn_mask def _annealing_coef_term(self, step): return min(1.0, 1.0 * step / self.n_step_annealing) def compute_active_units(self, data_source, batch_size, delta=0.01): with torch.no_grad(): cnt = 0 data_source.epoch_init() prior_mu_sum = 0 post_mu_sum = 0 while True: batch_data = data_source.next(batch_size) if batch_data is None: break X, Y = batch_data["X"], batch_data["Y"] X_floor, Y_floor = batch_data["X_floor"], batch_data["Y_floor"] # encode word_encodings, sent_encodings, dial_encodings = self._encode_dial( inputs=X, input_floors=X_floor, output_floors=Y_floor) # prior prior_net_input = dial_encodings _, prior_mu, _ = self._get_prior_z(prior_net_input) # post post_sent_encodings = self._encode_sent(Y) post_net_input = torch.cat( [post_sent_encodings, dial_encodings], dim=1) _, post_mu, _ = self._get_post_z(post_net_input) # record cnt += prior_mu.size(0) prior_mu_sum += prior_mu.sum(0) post_mu_sum += post_mu.sum(0) prior_mu_mean = (prior_mu_sum / cnt).unsqueeze( 0) # [1, latent_dim] post_mu_mean = (post_mu_sum / cnt).unsqueeze(0) # [1, latent_dim] data_source.epoch_init() prior_mu_var_sum = 0 post_mu_var_sum = 0 while True: batch_data = data_source.next(batch_size) if batch_data is None: break X, Y = batch_data["X"], batch_data["Y"] X_floor, Y_floor = batch_data["X_floor"], batch_data["Y_floor"] # encode word_encodings, sent_encodings, dial_encodings = self._encode_dial( inputs=X, input_floors=X_floor, output_floors=Y_floor) # prior prior_net_input = dial_encodings _, prior_mu, _ = self._get_prior_z(prior_net_input) # post post_sent_encodings = self._encode_sent(Y) post_net_input = torch.cat( [post_sent_encodings, dial_encodings], dim=1) _, post_mu, _ = self._get_post_z(post_net_input) # record prior_mu_var_sum += ((prior_mu - prior_mu_mean)**2).sum(0) post_mu_var_sum += ((post_mu - post_mu_mean)**2).sum(0) prior_mu_var_mean = prior_mu_var_sum / (cnt - 1) # [latent_dim] post_mu_var_mean = post_mu_var_sum / (cnt - 1) # [latent_dim] prior_au = (prior_mu_var_mean >= delta).sum().item() post_au = (post_mu_var_mean >= delta).sum().item() prior_au_ratio = prior_au / self.latent_variable_dim post_au_ratio = post_au / self.latent_variable_dim return { "prior_au": prior_au, "post_au": post_au, "prior_au_ratio": prior_au_ratio, "post_au_ratio": post_au_ratio, } def load_model(self, model_path): """Load pretrained model weights from model_path Arguments: model_path {str} -- path to pretrained model weights """ pretrained_state_dict = torch.load( model_path, map_location=lambda storage, loc: storage) self.load_state_dict(pretrained_state_dict) def train_step(self, data, step): """One training step Arguments: data {dict of data} -- required keys and values: 'X' {LongTensor [batch_size, history_len, max_x_sent_len]} -- token ids of context sentences 'X_floor' {LongTensor [batch_size, history_len]} -- floors of context sentences 'Y' {LongTensor [batch_size, max_y_sent_len]} -- token ids of response sentence 'Y_floor' {LongTensor [batch_size]} -- floor of response sentence step {int} -- the n-th optimization step Returns: dict of data -- returned keys and values 'loss' {FloatTensor []} -- loss to backword dict of statistics -- returned keys and values 'ppl' {float} -- perplexity 'kld' {float} -- KLD 'kld_term' {float} -- KLD annealing coefficient 'bow_loss' {float} -- bag-of-word loss 'loss' {float} -- batch loss """ X, Y = data["X"], data["Y"] X_floor, Y_floor = data["X_floor"], data["Y_floor"] Y_in = Y[:, :-1].contiguous() Y_out = Y[:, 1:].contiguous() batch_size = X.size(0) max_y_len = Y_out.size(1) # Forward # Get prior z word_encodings, sent_encodings, dial_encodings = self._encode_dial( inputs=X, input_floors=X_floor, output_floors=Y_floor) prior_net_input = dial_encodings prior_z, prior_mu, prior_var = self._get_prior_z(prior_net_input) # Get post z post_sent_encodings = self._encode_sent(Y) post_net_input = torch.cat([post_sent_encodings, dial_encodings], dim=1) post_z, post_mu, post_var = self._get_post_z(post_net_input) # Decode ctx_encodings = self._get_ctx_for_decoder(post_z, dial_encodings) attn_ctx = word_encodings.view(batch_size, -1, word_encodings.size(-1)) attn_mask = self._get_attn_mask(X.view(batch_size, -1)) decoder_ret_dict = self._decode(inputs=Y_in, context=ctx_encodings, attn_ctx=attn_ctx, attn_mask=attn_mask) # Loss loss = 0 # Reconstruction logits = decoder_ret_dict["logits"] word_losses = F.cross_entropy(logits.view(-1, self.vocab_size), Y_out.view(-1), ignore_index=self.pad_token_id, reduction="none").view( batch_size, max_y_len) sent_loss = word_losses.sum(1).mean(0) loss += sent_loss with torch.no_grad(): ppl = F.cross_entropy(logits.view(-1, self.vocab_size), Y_out.view(-1), ignore_index=self.pad_token_id, reduction="mean").exp() # KLD kld_coef = self._annealing_coef_term(step) kld_losses = gaussian_kld( post_mu, post_var, prior_mu, prior_var, ) avg_kld = kld_losses.mean() loss += avg_kld * kld_coef # BOW if self.use_bow_loss: Y_out_mask = (Y_out != self.pad_token_id).float() bow_input = torch.cat([post_z, dial_encodings], dim=1) bow_logits = self.latent_to_bow( bow_input) # [batch_size, vocab_size] bow_loss = -F.log_softmax(bow_logits, dim=1).gather( 1, Y_out) * Y_out_mask bow_loss = bow_loss.sum(1).mean() loss += bow_loss # return dicts ret_data = {"loss": loss} ret_stat = { "ppl": ppl.item(), "kld_term": kld_coef, "kld": avg_kld.item(), "prior_abs_mu_mean": prior_mu.abs().mean().item(), "prior_var_mean": prior_var.mean().item(), "post_abs_mu_mean": post_mu.abs().mean().item(), "post_var_mean": post_var.mean().item(), "loss": loss.item() } if self.use_bow_loss: ret_stat["bow_loss"] = bow_loss.item() return ret_data, ret_stat def evaluate_step(self, data, assign_k=None): """One training step Arguments: data {dict of data} -- required keys and values: 'X' {LongTensor [batch_size, history_len, max_x_sent_len]} -- token ids of context sentences 'X_floor' {LongTensor [batch_size, history_len]} -- floors of context sentences 'Y' {LongTensor [batch_size, max_y_sent_len]} -- token ids of response sentence 'Y_floor' {LongTensor [batch_size]} -- floor of response sentence Returns: dict of data -- returned keys and values dict of statistics -- returned keys and values 'ppl' {float} -- perplexity 'kld' {float} -- KLD 'monitor' {float} -- a monitor number for learning rate scheduling """ X, Y = data["X"], data["Y"] X_floor, Y_floor = data["X_floor"], data["Y_floor"] Y_in = Y[:, :-1].contiguous() Y_out = Y[:, 1:].contiguous() batch_size = X.size(0) max_y_len = Y_out.size(1) with torch.no_grad(): # Forward # Get prior z word_encodings, sent_encodings, dial_encodings = self._encode_dial( inputs=X, input_floors=X_floor, output_floors=Y_floor) prior_net_input = dial_encodings prior_z, prior_mu, prior_var = self._get_prior_z( prior_net_input=prior_net_input, assign_k=assign_k) # Get post z post_sent_encodings = self._encode_sent(Y) post_net_input = torch.cat([post_sent_encodings, dial_encodings], dim=1) post_z, post_mu, post_var = self._get_post_z(post_net_input) # Decode from post z attn_ctx = word_encodings.view(batch_size, -1, word_encodings.size(-1)) attn_mask = self._get_attn_mask(X.view(batch_size, -1)) post_ctx_encodings = self._get_ctx_for_decoder( post_z, dial_encodings) post_decoder_ret_dict = self._decode(inputs=Y_in, context=post_ctx_encodings, attn_ctx=attn_ctx, attn_mask=attn_mask) # Decode from prior z prior_ctx_encodings = self._get_ctx_for_decoder( prior_z, dial_encodings) prior_decoder_ret_dict = self._decode(inputs=Y_in, context=prior_ctx_encodings, attn_ctx=attn_ctx, attn_mask=attn_mask) # Loss # Reconstruction post_word_losses = F.cross_entropy( post_decoder_ret_dict["logits"].view(-1, self.vocab_size), Y_out.view(-1), ignore_index=self.decoder.pad_token_id, reduction="none").view(batch_size, max_y_len) post_sent_loss = post_word_losses.sum(1).mean(0) post_ppl = F.cross_entropy(post_decoder_ret_dict["logits"].view( -1, self.vocab_size), Y_out.view(-1), ignore_index=self.decoder.pad_token_id, reduction="mean").exp() # Generation prior_word_losses = F.cross_entropy( prior_decoder_ret_dict["logits"].view(-1, self.vocab_size), Y_out.view(-1), ignore_index=self.decoder.pad_token_id, reduction="none").view(batch_size, max_y_len) prior_sent_loss = prior_word_losses.sum(1).mean(0) prior_ppl = F.cross_entropy(prior_decoder_ret_dict["logits"].view( -1, self.vocab_size), Y_out.view(-1), ignore_index=self.decoder.pad_token_id, reduction="mean").exp() # KLD kld_losses = gaussian_kld(post_mu, post_var, prior_mu, prior_var) avg_kld = kld_losses.mean() # monitor monitor_loss = post_sent_loss + avg_kld # return dicts ret_data = {} ret_stat = { "post_ppl": post_ppl.item(), "prior_ppl": prior_ppl.item(), "kld": avg_kld.item(), "post_abs_mu_mean": post_mu.abs().mean().item(), "post_var_mean": post_var.mean().item(), "prior_abs_mu_mean": prior_mu.abs().mean().item(), "prior_var_mean": prior_var.mean().item(), "monitor": monitor_loss.item() } return ret_data, ret_stat def test_step(self, data, assign_k=None): """One test step Arguments: data {dict of data} -- required keys and values: 'X' {LongTensor [batch_size, history_len, max_x_sent_len]} -- token ids of context sentences 'X_floor' {LongTensor [batch_size, history_len]} -- floors of context sentences 'Y_floor' {LongTensor [batch_size]} -- floor of response sentence Returns: dict of data -- returned keys and values 'symbols' {LongTensor [batch_size, max_decode_len]} -- token ids of response hypothesis dict of statistics -- returned keys and values """ X = data["X"] X_floor, Y_floor = data["X_floor"], data["Y_floor"] batch_size = X.size(0) with torch.no_grad(): # Forward # Get prior z word_encodings, sent_encodings, dial_encodings = self._encode_dial( inputs=X, input_floors=X_floor, output_floors=Y_floor) prior_net_input = dial_encodings prior_z, prior_mu, prior_var, prior_pi = self._get_prior_z( prior_net_input=prior_net_input, assign_k=assign_k, return_pi=True) # Decode ctx_encodings = self._get_ctx_for_decoder(prior_z, dial_encodings) attn_ctx = word_encodings.view(batch_size, -1, word_encodings.size(-1)) attn_mask = self._get_attn_mask(X.view(batch_size, -1)) decoder_ret_dict = self._sample(context=ctx_encodings, attn_ctx=attn_ctx, attn_mask=attn_mask) ret_data = {"symbols": decoder_ret_dict["symbols"], "pi": prior_pi} ret_stat = {} return ret_data, ret_stat
class Mechanism_HRED(nn.Module): def __init__(self, config, tokenizer): super(Mechanism_HRED, self).__init__() # Attributes # Attributes from config self.word_embedding_dim = config.word_embedding_dim self.attr_embedding_dim = config.attr_embedding_dim self.sent_encoder_hidden_dim = config.sent_encoder_hidden_dim self.n_sent_encoder_layers = config.n_sent_encoder_layers self.dial_encoder_hidden_dim = config.dial_encoder_hidden_dim self.n_dial_encoder_layers = config.n_dial_encoder_layers self.decoder_hidden_dim = config.decoder_hidden_dim self.n_decoder_layers = config.n_decoder_layers self.n_mechanisms = config.n_mechanisms self.latent_dim = config.latent_dim self.use_attention = config.use_attention self.decode_max_len = config.decode_max_len self.tie_weights = config.tie_weights self.rnn_type = config.rnn_type self.gen_type = config.gen_type self.top_k = config.top_k self.top_p = config.top_p self.temp = config.temp self.word_embedding_path = config.word_embedding_path self.floor_encoder_type = config.floor_encoder # Optional attributes from config self.dropout_emb = config.dropout if hasattr(config, "dropout") else 0.0 self.dropout_input = config.dropout if hasattr(config, "dropout") else 0.0 self.dropout_hidden = config.dropout if hasattr(config, "dropout") else 0.0 self.dropout_output = config.dropout if hasattr(config, "dropout") else 0.0 self.use_pretrained_word_embedding = config.use_pretrained_word_embedding if hasattr( config, "use_pretrained_word_embedding") else False # Other attributes self.word2id = tokenizer.word2id self.id2word = tokenizer.id2word self.vocab_size = len(tokenizer.word2id) self.pad_token_id = tokenizer.pad_token_id self.bos_token_id = tokenizer.bos_token_id self.eos_token_id = tokenizer.eos_token_id # Input components self.word_embedding = nn.Embedding( self.vocab_size, self.word_embedding_dim, padding_idx=self.pad_token_id, _weight=init_word_embedding( load_pretrained_word_embedding=self. use_pretrained_word_embedding, pretrained_word_embedding_path=self.word_embedding_path, id2word=self.id2word, word_embedding_dim=self.word_embedding_dim, vocab_size=self.vocab_size, pad_token_id=self.pad_token_id), ) # Encoding components self.sent_encoder = EncoderRNN( input_dim=self.word_embedding_dim, hidden_dim=self.sent_encoder_hidden_dim, n_layers=self.n_sent_encoder_layers, dropout_emb=self.dropout_emb, dropout_input=self.dropout_input, dropout_hidden=self.dropout_hidden, dropout_output=self.dropout_output, bidirectional=True, embedding=self.word_embedding, rnn_type=self.rnn_type, ) self.dial_encoder = EncoderRNN( input_dim=self.sent_encoder_hidden_dim, hidden_dim=self.dial_encoder_hidden_dim, n_layers=self.n_dial_encoder_layers, dropout_emb=self.dropout_emb, dropout_input=self.dropout_input, dropout_hidden=self.dropout_hidden, dropout_output=self.dropout_output, bidirectional=False, rnn_type=self.rnn_type, ) # Mechanism components self.mechanism_embeddings = nn.Embedding( self.n_mechanisms, self.latent_dim, ) self.ctx2mech_fc = nn.Linear(self.dial_encoder_hidden_dim, self.latent_dim) self.score_bilinear = nn.Parameter( torch.FloatTensor(self.latent_dim, self.latent_dim)) self.ctx_mech_combine_fc = nn.Linear( self.dial_encoder_hidden_dim + self.latent_dim, self.dial_encoder_hidden_dim) # Decoding components self.enc2dec_hidden_fc = nn.Linear( self.dial_encoder_hidden_dim, self.n_decoder_layers * self.decoder_hidden_dim if self.rnn_type == "gru" else self.n_decoder_layers * self.decoder_hidden_dim * 2) self.decoder = DecoderRNN(vocab_size=len(self.word2id), input_dim=self.word_embedding_dim, hidden_dim=self.decoder_hidden_dim, feat_dim=self.dial_encoder_hidden_dim, n_layers=self.n_decoder_layers, bos_token_id=self.bos_token_id, eos_token_id=self.eos_token_id, pad_token_id=self.pad_token_id, max_len=self.decode_max_len, dropout_emb=self.dropout_emb, dropout_input=self.dropout_input, dropout_hidden=self.dropout_hidden, dropout_output=self.dropout_output, embedding=self.word_embedding, tie_weights=self.tie_weights, rnn_type=self.rnn_type, use_attention=self.use_attention, attn_dim=self.sent_encoder_hidden_dim) # Extra components # floor encoding if self.floor_encoder_type == "abs": self.floor_encoder = AbsFloorEmbEncoder( input_dim=self.sent_encoder_hidden_dim, embedding_dim=self.attr_embedding_dim) elif self.floor_encoder_type == "rel": self.floor_encoder = RelFloorEmbEncoder( input_dim=self.sent_encoder_hidden_dim, embedding_dim=self.attr_embedding_dim) else: self.floor_encoder = None # Initialization self._init_weights() def _init_weights(self): init_module_weights(self.mechanism_embeddings) init_module_weights(self.ctx2mech_fc) init_module_weights(self.score_bilinear) init_module_weights(self.ctx_mech_combine_fc) init_module_weights(self.enc2dec_hidden_fc) def _init_dec_hiddens(self, context): batch_size = context.size(0) hiddens = self.enc2dec_hidden_fc(context) if self.rnn_type == "gru": hiddens = hiddens.view( batch_size, self.n_decoder_layers, self.decoder_hidden_dim).transpose( 0, 1).contiguous() # (n_layers, batch_size, hidden_dim) elif self.rnn_type == "lstm": hiddens = hiddens.view(batch_size, self.n_decoder_layers, self.decoder_hidden_dim, 2) h = hiddens[:, :, :, 0].transpose( 0, 1).contiguous() # (n_layers, batch_size, hidden_dim) c = hiddens[:, :, :, 1].transpose( 0, 1).contiguous() # (n_layers, batch_size, hidden_dim) hiddens = (h, c) return hiddens def _encode(self, inputs, input_floors, output_floors): batch_size, history_len, max_x_sent_len = inputs.size() flat_inputs = inputs.view(batch_size * history_len, max_x_sent_len) input_lens = (inputs != self.pad_token_id).sum(-1) flat_input_lens = input_lens.view(batch_size * history_len) word_encodings, _, sent_encodings = self.sent_encoder( flat_inputs, flat_input_lens) word_encodings = word_encodings.view(batch_size, history_len, max_x_sent_len, -1) sent_encodings = sent_encodings.view(batch_size, history_len, -1) if self.floor_encoder is not None: src_floors = input_floors.view(-1) tgt_floors = output_floors.unsqueeze(1).repeat( 1, history_len).view(-1) sent_encodings = sent_encodings.view(batch_size * history_len, -1) sent_encodings = self.floor_encoder(sent_encodings, src_floors=src_floors, tgt_floors=tgt_floors) sent_encodings = sent_encodings.view(batch_size, history_len, -1) dial_lens = (input_lens > 0).long().sum( 1) # equals number of non-padding sents _, _, dial_encodings = self.dial_encoder( sent_encodings, dial_lens) # [batch_size, dial_encoder_dim] return word_encodings, sent_encodings, dial_encodings def _compute_mechanism_probs(self, ctx_encodings): ctx_mech = self.ctx2mech_fc(ctx_encodings) # [batch_size, latent_dim] mech_scores = torch.matmul( torch.matmul(ctx_mech, self.score_bilinear), self.mechanism_embeddings.weight.T) # [batch_size, n_mechanisms] mech_probs = F.softmax(mech_scores, dim=1) return mech_probs def _decode(self, inputs, context, attn_ctx=None, attn_mask=None): batch_size = context.size(0) hiddens = self._init_dec_hiddens(context) feats = None feats = context.unsqueeze(1).repeat(1, inputs.size(1), 1) ret_dict = self.decoder.forward(batch_size=batch_size, inputs=inputs, hiddens=hiddens, feats=feats, attn_ctx=attn_ctx, attn_mask=attn_mask, mode=DecoderRNN.MODE_TEACHER_FORCE) return ret_dict def _sample(self, context, attn_ctx=None, attn_mask=None): batch_size = context.size(0) hiddens = self._init_dec_hiddens(context) feats = None feats = context.unsqueeze(1).repeat(1, self.decode_max_len, 1) ret_dict = self.decoder.forward( batch_size=batch_size, hiddens=hiddens, feats=feats, attn_ctx=attn_ctx, attn_mask=attn_mask, mode=DecoderRNN.MODE_FREE_RUN, gen_type=self.gen_type, temp=self.temp, top_p=self.top_p, top_k=self.top_k, ) return ret_dict def _get_attn_mask(self, attn_keys): attn_mask = (attn_keys != self.pad_token_id) return attn_mask def load_model(self, model_path): """Load pretrained model weights from model_path Arguments: model_path {str} -- path to pretrained model weights """ pretrained_state_dict = torch.load( model_path, map_location=lambda storage, loc: storage) self.load_state_dict(pretrained_state_dict) def train_step(self, data): """One training step Arguments: data {dict of data} -- required keys and values: 'X' {LongTensor [batch_size, history_len, max_x_sent_len]} -- token ids of context sentences 'X_floor' {LongTensor [batch_size, history_len]} -- floors of context sentences 'Y' {LongTensor [batch_size, max_y_sent_len]} -- token ids of response sentence 'Y_floor' {LongTensor [batch_size]} -- floor of response sentence Returns: dict of data -- returned keys and values 'loss' {FloatTensor []} -- loss to backword dict of statistics -- returned keys and values 'ppl' {float} -- perplexity 'loss' {float} -- batch loss """ X, Y = data["X"], data["Y"] X_floor, Y_floor = data["X_floor"], data["Y_floor"] Y_in = Y[:, :-1].contiguous() Y_out = Y[:, 1:].contiguous() batch_size = X.size(0) # Forward # -- encode word_encodings, sent_encodings, dial_encodings = self._encode( inputs=X, input_floors=X_floor, output_floors=Y_floor) attn_ctx = word_encodings.view(batch_size, -1, word_encodings.size(-1)) attn_mask = self._get_attn_mask(X.view(batch_size, -1)) mech_probs = self._compute_mechanism_probs( dial_encodings) # [batch_size, n_mechanisms] mech_embed_inputs = torch.LongTensor(list(range( self.n_mechanisms))).to(DEVICE) # [n_mechanisms] repeated_mech_embed_inputs = mech_embed_inputs.unsqueeze(0).repeat( batch_size, 1).view(-1) # [batch_size*n_mechanisms] repeated_mech_embeds = self.mechanism_embeddings( repeated_mech_embed_inputs ) # [batch_size*n_mechanisms, latent_dim] # -- repeat for each mechanism repeated_Y_in = Y_in.unsqueeze(1).repeat( 1, self.n_mechanisms, 1) # [batch_size, n_mechanisms, len] repeated_Y_in = repeated_Y_in.view(batch_size * self.n_mechanisms, -1) repeated_Y_out = Y_out.unsqueeze(1).repeat( 1, self.n_mechanisms, 1) # [batch_size, n_mechanisms, len] repeated_Y_out = repeated_Y_out.view(batch_size * self.n_mechanisms, -1) dial_encodings = dial_encodings.unsqueeze(1).repeat( 1, self.n_mechanisms, 1) # [batch_size, n_mechanisms, hidden_dim] dial_encodings = dial_encodings.view(batch_size * self.n_mechanisms, self.dial_encoder_hidden_dim) attn_ctx = attn_ctx.unsqueeze(1).repeat(1, self.n_mechanisms, 1, 1) attn_ctx = attn_ctx.view(batch_size * self.n_mechanisms, -1, attn_ctx.size(-1)) attn_mask = attn_mask.unsqueeze(1).repeat(1, self.n_mechanisms, 1) attn_mask = attn_mask.view(batch_size * self.n_mechanisms, -1) # -- decode dec_ctx = self.ctx_mech_combine_fc( torch.cat([dial_encodings, repeated_mech_embeds], dim=1)) decoder_ret_dict = self._decode(inputs=repeated_Y_in, context=dec_ctx, attn_ctx=attn_ctx, attn_mask=attn_mask) # Calculate loss loss = 0 word_neglogll = F.cross_entropy( decoder_ret_dict["logits"].view(-1, self.vocab_size), repeated_Y_out.view(-1), ignore_index=self.decoder.pad_token_id, reduction="none").view(batch_size, self.n_mechanisms, -1) sent_logll = word_neglogll.sum(2) * (-1) mech_logll = (mech_probs + 1e-10).log() sent_mech_logll = sent_logll + mech_logll target_logll = torch.logsumexp(sent_mech_logll, dim=1) target_neglogll = target_logll * (-1) loss = target_neglogll.mean() with torch.no_grad(): ppl = F.cross_entropy(decoder_ret_dict["logits"].view( -1, self.vocab_size), repeated_Y_out.view(-1), ignore_index=self.decoder.pad_token_id, reduction="mean").exp() # return dicts ret_data = {"loss": loss} ret_stat = { "ppl": ppl.item(), "loss": loss.item(), "mech_prob_std": mech_probs.std(1).mean().item(), "mech_prob_max": mech_probs.max(1)[0].mean().item() } return ret_data, ret_stat def evaluate_step(self, data): """One evaluation step Arguments: data {dict of data} -- required keys and values: 'X' {LongTensor [batch_size, history_len, max_x_sent_len]} -- token ids of context sentences 'X_floor' {LongTensor [batch_size, history_len]} -- floors of context sentences 'Y' {LongTensor [batch_size, max_y_sent_len]} -- token ids of response sentence 'Y_floor' {LongTensor [batch_size]} -- floor of response sentence Returns: dict of data -- returned keys and values dict of statistics -- returned keys and values 'ppl' {float} -- perplexity 'monitor' {float} -- a monitor number for learning rate scheduling """ X, Y = data["X"], data["Y"] X_floor, Y_floor = data["X_floor"], data["Y_floor"] Y_in = Y[:, :-1].contiguous() Y_out = Y[:, 1:].contiguous() batch_size = X.size(0) with torch.no_grad(): # Forward word_encodings, sent_encodings, dial_encodings = self._encode( inputs=X, input_floors=X_floor, output_floors=Y_floor) attn_ctx = word_encodings.view(batch_size, -1, word_encodings.size(-1)) attn_mask = self._get_attn_mask(X.view(batch_size, -1)) mech_probs = self._compute_mechanism_probs( dial_encodings) # [batch_size, n_mechanisms] mech_embed_inputs = mech_probs.argmax(1) # [batch_size] mech_embeds = self.mechanism_embeddings(mech_embed_inputs) dec_ctx = self.ctx_mech_combine_fc( torch.cat([dial_encodings, mech_embeds], dim=1)) decoder_ret_dict = self._decode(inputs=Y_in, context=dec_ctx, attn_ctx=attn_ctx, attn_mask=attn_mask) # Loss word_loss = F.cross_entropy(decoder_ret_dict["logits"].view( -1, self.vocab_size), Y_out.view(-1), ignore_index=self.decoder.pad_token_id, reduction="mean") ppl = torch.exp(word_loss) # return dicts ret_data = {} ret_stat = {"ppl": ppl.item(), "monitor": ppl.item()} return ret_data, ret_stat def test_step(self, data, sample_from="dist"): """One test step Arguments: data {dict of data} -- required keys and values: 'X' {LongTensor [batch_size, history_len, max_x_sent_len]} -- token ids of context sentences 'X_floor' {LongTensor [batch_size, history_len]} -- floors of context sentences 'Y_floor' {LongTensor [batch_size]} -- floor of response sentence sample_from {str} -- "dist": sample mechanism from computed probabilities "random": sample mechanisms uniformly Returns: dict of data -- returned keys and values 'symbols' {LongTensor [batch_size, max_decode_len]} -- token ids of response hypothesis dict of statistics -- returned keys and values """ X, Y = data["X"], data["Y"] X_floor, Y_floor = data["X_floor"], data["Y_floor"] batch_size = X.size(0) with torch.no_grad(): # Forward word_encodings, sent_encodings, dial_encodings = self._encode( inputs=X, input_floors=X_floor, output_floors=Y_floor) attn_ctx = word_encodings.view(batch_size, -1, word_encodings.size(-1)) attn_mask = self._get_attn_mask(X.view(batch_size, -1)) if sample_from == "dist": mech_probs = self._compute_mechanism_probs( dial_encodings) # [batch_size, n_mechanisms] mech_dist = torch.distributions.Categorical(mech_probs) mech_embed_inputs = mech_dist.sample() # [batch_size] elif sample_from == "random": mech_embed_inputs = [ random.randint(0, self.n_mechanisms - 1) for _ in range(batch_size) ] mech_embed_inputs = torch.LongTensor(mech_embed_inputs).to( DEVICE) # [batch_size] mech_embeds = self.mechanism_embeddings(mech_embed_inputs) dec_ctx = self.ctx_mech_combine_fc( torch.cat([dial_encodings, mech_embeds], dim=1)) decoder_ret_dict = self._sample(context=dec_ctx, attn_ctx=attn_ctx, attn_mask=attn_mask) ret_data = {"symbols": decoder_ret_dict["symbols"]} ret_stat = {} return ret_data, ret_stat