def __init__(self, config, tokenizer): super(RUBER, self).__init__() # Attributes # Attributes from config self.word_embedding_dim = config.word_embedding_dim self.attr_embedding_dim = config.attr_embedding_dim self.sent_encoder_hidden_dim = config.sent_encoder_hidden_dim self.n_sent_encoder_layers = config.n_sent_encoder_layers self.dial_encoder_hidden_dim = config.dial_encoder_hidden_dim self.n_dial_encoder_layers = config.n_dial_encoder_layers self.rnn_type = config.rnn_type self.word_embedding_path = config.word_embedding_path self.floor_encoder_type = config.floor_encoder self.metric = config.metric_type # Optional attributes from config self.dropout = config.dropout if hasattr(config, "dropout") else 0.0 self.use_pretrained_word_embedding = config.use_pretrained_word_embedding if hasattr(config, "use_pretrained_word_embedding") else False # Other attributes self.id2word = tokenizer.id2word self.vocab_size = len(tokenizer.word2id) self.pad_token_id = tokenizer.pad_token_id self.margin = 0.5 # Input components self.word_embedding = nn.Embedding( self.vocab_size, self.word_embedding_dim, padding_idx=self.pad_token_id, _weight=init_word_embedding( load_pretrained_word_embedding=self.use_pretrained_word_embedding, pretrained_word_embedding_path=self.word_embedding_path, id2word=self.id2word, word_embedding_dim=self.word_embedding_dim, vocab_size=self.vocab_size, pad_token_id=self.pad_token_id ), ) # Encoding components self.sent_encoder = EncoderRNN( input_dim=self.word_embedding_dim, hidden_dim=self.sent_encoder_hidden_dim, n_layers=self.n_sent_encoder_layers, dropout_emb=self.dropout, dropout_input=self.dropout, dropout_hidden=self.dropout, dropout_output=self.dropout, bidirectional=True, embedding=self.word_embedding, rnn_type=self.rnn_type, ) self.dial_encoder = EncoderRNN( input_dim=self.sent_encoder_hidden_dim, hidden_dim=self.dial_encoder_hidden_dim, n_layers=self.n_dial_encoder_layers, dropout_emb=self.dropout, dropout_input=self.dropout, dropout_hidden=self.dropout, dropout_output=self.dropout, bidirectional=True, rnn_type=self.rnn_type, ) # Output components # regressor for unsupervised training of is-next-sentence prediction self.M = nn.Parameter( torch.FloatTensor(self.dial_encoder_hidden_dim, self.sent_encoder_hidden_dim) ) self.unref_fc = nn.ModuleList( [ nn.Dropout(self.dropout), nn.Linear(self.dial_encoder_hidden_dim+self.sent_encoder_hidden_dim+1, self.dial_encoder_hidden_dim), nn.Tanh(), nn.Dropout(self.dropout), nn.Linear(self.dial_encoder_hidden_dim, 1), nn.Sigmoid() ] ) # Extra components # floor encoding if self.floor_encoder_type == "abs": self.floor_encoder = AbsFloorEmbEncoder( input_dim=self.sent_encoder_hidden_dim, embedding_dim=self.attr_embedding_dim ) elif self.floor_encoder_type == "rel": self.floor_encoder = RelFloorEmbEncoder( input_dim=self.sent_encoder_hidden_dim, embedding_dim=self.attr_embedding_dim ) else: self.floor_encoder = None # Initialization self._init_weights()
def __init__(self, config, tokenizer): super(HRED, self).__init__() # Attributes # Attributes from config self.word_embedding_dim = config.word_embedding_dim self.attr_embedding_dim = config.attr_embedding_dim self.sent_encoder_hidden_dim = config.sent_encoder_hidden_dim self.n_sent_encoder_layers = config.n_sent_encoder_layers self.dial_encoder_hidden_dim = config.dial_encoder_hidden_dim self.n_dial_encoder_layers = config.n_dial_encoder_layers self.decoder_hidden_dim = config.decoder_hidden_dim self.n_decoder_layers = config.n_decoder_layers self.use_attention = config.use_attention self.decode_max_len = config.decode_max_len self.tie_weights = config.tie_weights self.rnn_type = config.rnn_type self.gen_type = config.gen_type self.top_k = config.top_k self.top_p = config.top_p self.temp = config.temp self.word_embedding_path = config.word_embedding_path self.floor_encoder_type = config.floor_encoder # Optional attributes from config self.dropout_emb = config.dropout if hasattr(config, "dropout") else 0.0 self.dropout_input = config.dropout if hasattr(config, "dropout") else 0.0 self.dropout_hidden = config.dropout if hasattr(config, "dropout") else 0.0 self.dropout_output = config.dropout if hasattr(config, "dropout") else 0.0 self.use_pretrained_word_embedding = config.use_pretrained_word_embedding if hasattr( config, "use_pretrained_word_embedding") else False # Other attributes self.word2id = tokenizer.word2id self.id2word = tokenizer.id2word self.vocab_size = len(tokenizer.word2id) self.pad_token_id = tokenizer.pad_token_id self.bos_token_id = tokenizer.bos_token_id self.eos_token_id = tokenizer.eos_token_id # Input components self.word_embedding = nn.Embedding( self.vocab_size, self.word_embedding_dim, padding_idx=self.pad_token_id, _weight=init_word_embedding( load_pretrained_word_embedding=self. use_pretrained_word_embedding, pretrained_word_embedding_path=self.word_embedding_path, id2word=self.id2word, word_embedding_dim=self.word_embedding_dim, vocab_size=self.vocab_size, pad_token_id=self.pad_token_id), ) # Encoding components self.sent_encoder = EncoderRNN( input_dim=self.word_embedding_dim, hidden_dim=self.sent_encoder_hidden_dim, n_layers=self.n_sent_encoder_layers, dropout_emb=self.dropout_emb, dropout_input=self.dropout_input, dropout_hidden=self.dropout_hidden, dropout_output=self.dropout_output, bidirectional=True, embedding=self.word_embedding, rnn_type=self.rnn_type, ) self.dial_encoder = EncoderRNN( input_dim=self.sent_encoder_hidden_dim, hidden_dim=self.dial_encoder_hidden_dim, n_layers=self.n_dial_encoder_layers, dropout_emb=self.dropout_emb, dropout_input=self.dropout_input, dropout_hidden=self.dropout_hidden, dropout_output=self.dropout_output, bidirectional=False, rnn_type=self.rnn_type, ) # Decoding components self.enc2dec_hidden_fc = nn.Linear( self.dial_encoder_hidden_dim, self.n_decoder_layers * self.decoder_hidden_dim if self.rnn_type == "gru" else self.n_decoder_layers * self.decoder_hidden_dim * 2) self.decoder = DecoderRNN(vocab_size=len(self.word2id), input_dim=self.word_embedding_dim, hidden_dim=self.decoder_hidden_dim, feat_dim=self.dial_encoder_hidden_dim, n_layers=self.n_decoder_layers, bos_token_id=self.bos_token_id, eos_token_id=self.eos_token_id, pad_token_id=self.pad_token_id, max_len=self.decode_max_len, dropout_emb=self.dropout_emb, dropout_input=self.dropout_input, dropout_hidden=self.dropout_hidden, dropout_output=self.dropout_output, embedding=self.word_embedding, tie_weights=self.tie_weights, rnn_type=self.rnn_type, use_attention=self.use_attention, attn_dim=self.sent_encoder_hidden_dim) # Extra components # floor encoding if self.floor_encoder_type == "abs": self.floor_encoder = AbsFloorEmbEncoder( input_dim=self.sent_encoder_hidden_dim, embedding_dim=self.attr_embedding_dim) elif self.floor_encoder_type == "rel": self.floor_encoder = RelFloorEmbEncoder( input_dim=self.sent_encoder_hidden_dim, embedding_dim=self.attr_embedding_dim) else: self.floor_encoder = None # Initialization self._init_weights()
def __init__(self, config, tokenizer): super(RNNLM, self).__init__() # Attributes # Attributes from config self.word_embedding_dim = config.word_embedding_dim self.decoder_hidden_dim = config.decoder_hidden_dim self.n_decoder_layers = config.n_decoder_layers self.decode_max_len = config.decode_max_len self.tie_weights = config.tie_weights self.rnn_type = config.rnn_type # Optional attributes from config self.gen_type = config.gen_type if hasattr(config, "gen_type") else "greedy" self.top_k = config.top_k if hasattr(config, "top_k") else 0 self.top_p = config.top_p if hasattr(config, "top_p") else 0.0 self.temp = config.temp if hasattr(config, "temp") else 1.0 self.dropout_emb = config.dropout if hasattr(config, "dropout") else 0.0 self.dropout_input = config.dropout if hasattr(config, "dropout") else 0.0 self.dropout_hidden = config.dropout if hasattr(config, "dropout") else 0.0 self.dropout_output = config.dropout if hasattr(config, "dropout") else 0.0 self.use_pretrained_word_embedding = config.use_pretrained_word_embedding if hasattr(config, "use_pretrained_word_embedding") else False self.word_embedding_path = config.word_embedding_path if hasattr(config, "word_embedding_path") else None # Other attributes self.tokenizer = tokenizer self.word2id = tokenizer.word2id self.id2word = tokenizer.id2word self.vocab_size = len(tokenizer.word2id) self.pad_token_id = tokenizer.pad_token_id self.bos_token_id = tokenizer.bos_token_id self.eos_token_id = tokenizer.eos_token_id # Components self.word_embedding = nn.Embedding( self.vocab_size, self.word_embedding_dim, padding_idx=self.pad_token_id, _weight=init_word_embedding( load_pretrained_word_embedding=self.use_pretrained_word_embedding, pretrained_word_embedding_path=self.word_embedding_path, id2word=self.id2word, word_embedding_dim=self.word_embedding_dim, vocab_size=self.vocab_size, pad_token_id=self.pad_token_id ), ) self.decoder = DecoderRNN( vocab_size=len(self.word2id), input_dim=self.word_embedding_dim, hidden_dim=self.decoder_hidden_dim, n_layers=self.n_decoder_layers, bos_token_id=self.bos_token_id, eos_token_id=self.eos_token_id, pad_token_id=self.pad_token_id, max_len=self.decode_max_len, dropout_emb=self.dropout_emb, dropout_input=self.dropout_input, dropout_hidden=self.dropout_hidden, dropout_output=self.dropout_output, embedding=self.word_embedding, tie_weights=self.tie_weights, rnn_type=self.rnn_type, ) # Initialization self._init_weights()
def __init__(self, config, tokenizer): super(VHRED, self).__init__() # Attributes # Attributes from config self.word_embedding_dim = config.word_embedding_dim self.attr_embedding_dim = config.attr_embedding_dim self.sent_encoder_hidden_dim = config.sent_encoder_hidden_dim self.n_sent_encoder_layers = config.n_sent_encoder_layers self.dial_encoder_hidden_dim = config.dial_encoder_hidden_dim self.n_dial_encoder_layers = config.n_dial_encoder_layers self.latent_variable_dim = config.latent_dim self.decoder_hidden_dim = config.decoder_hidden_dim self.n_decoder_layers = config.n_decoder_layers self.use_attention = config.use_attention self.decode_max_len = config.decode_max_len self.tie_weights = config.tie_weights self.rnn_type = config.rnn_type self.gen_type = config.gen_type self.top_k = config.top_k self.top_p = config.top_p self.temp = config.temp self.word_embedding_path = config.word_embedding_path self.floor_encoder_type = config.floor_encoder self.gaussian_mix_type = config.gaussian_mix_type # Optional attributes from config self.use_bow_loss = config.use_bow_loss if hasattr( config, "use_bow_loss") else True self.dropout = config.dropout if hasattr(config, "dropout") else 0.0 self.dropout_emb = config.dropout if hasattr(config, "dropout") else 0.0 self.dropout_input = config.dropout if hasattr(config, "dropout") else 0.0 self.dropout_hidden = config.dropout if hasattr(config, "dropout") else 0.0 self.dropout_output = config.dropout if hasattr(config, "dropout") else 0.0 self.use_pretrained_word_embedding = config.use_pretrained_word_embedding if hasattr( config, "use_pretrained_word_embedding") else False self.n_step_annealing = config.n_step_annealing if hasattr( config, "n_step_annealing") else 1 self.n_components = config.n_components if hasattr( config, "n_components") else 1 # Other attributes self.vocab_size = len(tokenizer.word2id) self.word2id = tokenizer.word2id self.id2word = tokenizer.id2word self.pad_token_id = tokenizer.pad_token_id self.bos_token_id = tokenizer.bos_token_id self.eos_token_id = tokenizer.eos_token_id # Input components self.word_embedding = nn.Embedding( self.vocab_size, self.word_embedding_dim, padding_idx=self.pad_token_id, _weight=init_word_embedding( load_pretrained_word_embedding=self. use_pretrained_word_embedding, pretrained_word_embedding_path=self.word_embedding_path, id2word=self.id2word, word_embedding_dim=self.word_embedding_dim, vocab_size=self.vocab_size, pad_token_id=self.pad_token_id), ) # Encoding components self.sent_encoder = EncoderRNN( input_dim=self.word_embedding_dim, hidden_dim=self.sent_encoder_hidden_dim, n_layers=self.n_sent_encoder_layers, dropout_emb=self.dropout_emb, dropout_input=self.dropout_input, dropout_hidden=self.dropout_hidden, dropout_output=self.dropout_output, bidirectional=True, embedding=self.word_embedding, rnn_type=self.rnn_type, ) self.dial_encoder = EncoderRNN( input_dim=self.sent_encoder_hidden_dim, hidden_dim=self.dial_encoder_hidden_dim, n_layers=self.n_dial_encoder_layers, dropout_emb=self.dropout_emb, dropout_input=self.dropout_input, dropout_hidden=self.dropout_hidden, dropout_output=self.dropout_output, bidirectional=False, rnn_type=self.rnn_type, ) # Variational components if config.n_components == 1: self.prior_net = GaussianVariation( input_dim=self.dial_encoder_hidden_dim, z_dim=self.latent_variable_dim, # large_mlp=True ) elif config.n_components > 1: if self.gaussian_mix_type == "gmm": self.prior_net = GMMVariation( input_dim=self.dial_encoder_hidden_dim, z_dim=self.latent_variable_dim, n_components=self.n_components, ) elif self.gaussian_mix_type == "lgm": self.prior_net = LGMVariation( input_dim=self.dial_encoder_hidden_dim, z_dim=self.latent_variable_dim, n_components=self.n_components, ) self.post_net = GaussianVariation( input_dim=self.sent_encoder_hidden_dim + self.dial_encoder_hidden_dim, z_dim=self.latent_variable_dim, ) self.latent_to_bow = nn.Sequential( nn.Linear(self.latent_variable_dim + self.dial_encoder_hidden_dim, self.latent_variable_dim), nn.Tanh(), nn.Dropout(self.dropout), nn.Linear(self.latent_variable_dim, self.vocab_size)) self.ctx_fc = nn.Sequential( nn.Linear( self.latent_variable_dim + self.dial_encoder_hidden_dim, self.dial_encoder_hidden_dim, ), nn.Tanh(), nn.Dropout(self.dropout)) # Decoding components self.enc2dec_hidden_fc = nn.Linear( self.dial_encoder_hidden_dim, self.n_decoder_layers * self.decoder_hidden_dim if self.rnn_type == "gru" else self.n_decoder_layers * self.decoder_hidden_dim * 2) self.decoder = DecoderRNN(vocab_size=len(self.word2id), input_dim=self.word_embedding_dim, hidden_dim=self.decoder_hidden_dim, feat_dim=self.dial_encoder_hidden_dim, n_layers=self.n_decoder_layers, bos_token_id=self.bos_token_id, eos_token_id=self.eos_token_id, pad_token_id=self.pad_token_id, max_len=self.decode_max_len, dropout_emb=self.dropout_emb, dropout_input=self.dropout_input, dropout_hidden=self.dropout_hidden, dropout_output=self.dropout_output, embedding=self.word_embedding, tie_weights=self.tie_weights, rnn_type=self.rnn_type, use_attention=self.use_attention, attn_dim=self.sent_encoder_hidden_dim) # Extra components # Floor encoding if self.floor_encoder_type == "abs": self.floor_encoder = AbsFloorEmbEncoder( input_dim=self.sent_encoder_hidden_dim, embedding_dim=self.attr_embedding_dim) elif self.floor_encoder_type == "rel": self.floor_encoder = RelFloorEmbEncoder( input_dim=self.sent_encoder_hidden_dim, embedding_dim=self.attr_embedding_dim) else: self.floor_encoder = None # Initialization self._init_weights()
def __init__(self, config, tokenizer): super(HRESepUttrEnc, self).__init__() # Attributes # Attributes from config self.num_labels = len(config.dialog_acts) self.word_embedding_dim = config.word_embedding_dim self.attr_embedding_dim = config.attr_embedding_dim self.sent_encoder_hidden_dim = config.sent_encoder_hidden_dim self.n_sent_encoder_layers = config.n_sent_encoder_layers self.dial_encoder_hidden_dim = config.dial_encoder_hidden_dim self.n_dial_encoder_layers = config.n_dial_encoder_layers self.rnn_type = config.rnn_type self.word_embedding_path = config.word_embedding_path self.floor_encoder_type = config.floor_encoder # Optional attributes from config self.dropout_emb = config.dropout if hasattr(config, "dropout") else 0.0 self.dropout_input = config.dropout if hasattr(config, "dropout") else 0.0 self.dropout_hidden = config.dropout if hasattr(config, "dropout") else 0.0 self.dropout_output = config.dropout if hasattr(config, "dropout") else 0.0 self.use_pretrained_word_embedding = config.use_pretrained_word_embedding if hasattr( config, "use_pretrained_word_embedding") else False # Other attributes self.word2id = tokenizer.word2id self.id2word = tokenizer.id2word self.vocab_size = len(tokenizer.word2id) self.pad_token_id = tokenizer.pad_token_id # Input components self.word_embedding = nn.Embedding( self.vocab_size, self.word_embedding_dim, padding_idx=self.pad_token_id, _weight=init_word_embedding( load_pretrained_word_embedding=self. use_pretrained_word_embedding, pretrained_word_embedding_path=self.word_embedding_path, id2word=self.id2word, word_embedding_dim=self.word_embedding_dim, vocab_size=self.vocab_size, pad_token_id=self.pad_token_id), ) # Encoding components self.own_sent_encoder = EncoderRNN( input_dim=self.word_embedding_dim, hidden_dim=self.sent_encoder_hidden_dim, n_layers=self.n_sent_encoder_layers, dropout_emb=self.dropout_emb, dropout_input=self.dropout_input, dropout_hidden=self.dropout_hidden, dropout_output=self.dropout_output, bidirectional=True, embedding=self.word_embedding, rnn_type=self.rnn_type, ) self.oth_sent_encoder = EncoderRNN( input_dim=self.word_embedding_dim, hidden_dim=self.sent_encoder_hidden_dim, n_layers=self.n_sent_encoder_layers, dropout_emb=self.dropout_emb, dropout_input=self.dropout_input, dropout_hidden=self.dropout_hidden, dropout_output=self.dropout_output, bidirectional=True, embedding=self.word_embedding, rnn_type=self.rnn_type, ) self.dial_encoder = EncoderRNN( input_dim=self.sent_encoder_hidden_dim, hidden_dim=self.dial_encoder_hidden_dim, n_layers=self.n_dial_encoder_layers, dropout_emb=self.dropout_emb, dropout_input=self.dropout_input, dropout_hidden=self.dropout_hidden, dropout_output=self.dropout_output, bidirectional=False, rnn_type=self.rnn_type, ) # Classification components self.output_fc = nn.Linear(self.dial_encoder_hidden_dim, self.num_labels) # Initialization self._init_weights()
def __init__(self, config, tokenizer): super(VHCR, self).__init__() # Attributes # Attributes from config self.word_embedding_dim = config.word_embedding_dim self.attr_embedding_dim = config.attr_embedding_dim self.sent_encoder_hidden_dim = config.sent_encoder_hidden_dim self.n_sent_encoder_layers = config.n_sent_encoder_layers self.dial_encoder_hidden_dim = config.dial_encoder_hidden_dim self.n_dial_encoder_layers = config.n_dial_encoder_layers self.latent_variable_dim = config.latent_dim self.decoder_hidden_dim = config.decoder_hidden_dim self.n_decoder_layers = config.n_decoder_layers self.use_attention = config.use_attention self.decode_max_len = config.decode_max_len self.tie_weights = config.tie_weights self.rnn_type = config.rnn_type self.gen_type = config.gen_type self.top_k = config.top_k self.top_p = config.top_p self.temp = config.temp self.word_embedding_path = config.word_embedding_path self.floor_encoder_type = config.floor_encoder # Optional attributes from config self.dropout_emb = config.dropout if hasattr(config, "dropout") else 0.0 self.dropout_input = config.dropout if hasattr(config, "dropout") else 0.0 self.dropout_hidden = config.dropout if hasattr(config, "dropout") else 0.0 self.dropout_output = config.dropout if hasattr(config, "dropout") else 0.0 self.use_pretrained_word_embedding = config.use_pretrained_word_embedding if hasattr(config, "use_pretrained_word_embedding") else False self.n_step_annealing = config.n_step_annealing if hasattr(config, "n_step_annealing") else 0 # Other attributes self.vocab_size = len(tokenizer.word2id) self.word2id = tokenizer.word2id self.id2word = tokenizer.id2word self.pad_token_id = tokenizer.pad_token_id self.bos_token_id = tokenizer.bos_token_id self.eos_token_id = tokenizer.eos_token_id self.dropout_sent = 0.25 # Input components self.word_embedding = nn.Embedding( self.vocab_size, self.word_embedding_dim, padding_idx=self.pad_token_id, _weight=init_word_embedding( load_pretrained_word_embedding=self.use_pretrained_word_embedding, pretrained_word_embedding_path=self.word_embedding_path, id2word=self.id2word, word_embedding_dim=self.word_embedding_dim, vocab_size=self.vocab_size, pad_token_id=self.pad_token_id ), ) # Encoding components self.sent_encoder = EncoderRNN( input_dim=self.word_embedding_dim, hidden_dim=self.sent_encoder_hidden_dim, n_layers=self.n_sent_encoder_layers, dropout_emb=self.dropout_emb, dropout_input=self.dropout_input, dropout_hidden=self.dropout_hidden, dropout_output=self.dropout_output, bidirectional=True, embedding=self.word_embedding, rnn_type=self.rnn_type, ) self.dial_encoder = EncoderRNN( input_dim=self.sent_encoder_hidden_dim+self.latent_variable_dim, hidden_dim=self.dial_encoder_hidden_dim, n_layers=self.n_dial_encoder_layers, dropout_emb=self.dropout_emb, dropout_input=self.dropout_input, dropout_hidden=self.dropout_hidden, dropout_output=self.dropout_output, bidirectional=False, rnn_type=self.rnn_type, ) self.dial_infer_encoder = EncoderRNN( input_dim=self.sent_encoder_hidden_dim, hidden_dim=self.dial_encoder_hidden_dim, n_layers=self.n_dial_encoder_layers, dropout_emb=self.dropout_emb, dropout_input=self.dropout_input, dropout_hidden=self.dropout_hidden, dropout_output=self.dropout_output, bidirectional=True, rnn_type=self.rnn_type, ) # Variational components self.dial_post_net = GaussianVariation( input_dim=self.dial_encoder_hidden_dim, z_dim=self.latent_variable_dim ) self.sent_prior_net = GaussianVariation( input_dim=self.dial_encoder_hidden_dim+self.latent_variable_dim, z_dim=self.latent_variable_dim ) self.sent_post_net = GaussianVariation( input_dim=self.sent_encoder_hidden_dim+self.dial_encoder_hidden_dim+self.latent_variable_dim, z_dim=self.latent_variable_dim ) self.unk_sent_vec = nn.Parameter(torch.randn(self.sent_encoder_hidden_dim)).to(DEVICE) # Decoding components self.ctx_fc = nn.Linear( 2*self.latent_variable_dim+self.dial_encoder_hidden_dim, self.dial_encoder_hidden_dim ) self.enc2dec_hidden_fc = nn.Linear( self.dial_encoder_hidden_dim, self.n_decoder_layers*self.decoder_hidden_dim if self.rnn_type == "gru" else self.n_decoder_layers*self.decoder_hidden_dim*2 ) self.decoder = DecoderRNN( vocab_size=len(self.word2id), input_dim=self.word_embedding_dim, hidden_dim=self.decoder_hidden_dim, feat_dim=self.dial_encoder_hidden_dim, n_layers=self.n_decoder_layers, bos_token_id=self.bos_token_id, eos_token_id=self.eos_token_id, pad_token_id=self.pad_token_id, max_len=self.decode_max_len, dropout_emb=self.dropout_emb, dropout_input=self.dropout_input, dropout_hidden=self.dropout_hidden, dropout_output=self.dropout_output, embedding=self.word_embedding, tie_weights=self.tie_weights, rnn_type=self.rnn_type, use_attention=self.use_attention, attn_dim=self.sent_encoder_hidden_dim ) # Extra components # Floor encoding if self.floor_encoder_type == "abs": self.floor_encoder = AbsFloorEmbEncoder( input_dim=self.sent_encoder_hidden_dim, embedding_dim=self.attr_embedding_dim ) elif self.floor_encoder_type == "rel": self.floor_encoder = RelFloorEmbEncoder( input_dim=self.sent_encoder_hidden_dim, embedding_dim=self.attr_embedding_dim ) else: self.floor_encoder = None # Hidden initialization self.dial_z2dial_enc_hidden_fc = nn.Linear( self.latent_variable_dim, self.n_dial_encoder_layers*self.dial_encoder_hidden_dim if self.rnn_type == "gru" else self.n_dial_encoder_layers*self.dial_encoder_hidden_dim*2 ) # Initialization self._init_weights()
def __init__(self, config, tokenizer): super(ADEM, self).__init__() # Attributes # Attributes from config self.word_embedding_dim = config.word_embedding_dim self.attr_embedding_dim = config.attr_embedding_dim self.sent_encoder_hidden_dim = config.sent_encoder_hidden_dim self.n_sent_encoder_layers = config.n_sent_encoder_layers self.dial_encoder_hidden_dim = config.dial_encoder_hidden_dim self.n_dial_encoder_layers = config.n_dial_encoder_layers assert self.sent_encoder_hidden_dim == self.dial_encoder_hidden_dim self.latent_variable_dim = config.latent_dim self.tie_weights = config.tie_weights self.rnn_type = config.rnn_type self.word_embedding_path = config.word_embedding_path self.floor_encoder_type = config.floor_encoder self.metric = config.metric_type # Optional attributes from config self.dropout_emb = config.dropout if hasattr(config, "dropout") else 0.0 self.dropout_input = config.dropout if hasattr(config, "dropout") else 0.0 self.dropout_hidden = config.dropout if hasattr(config, "dropout") else 0.0 self.dropout_output = config.dropout if hasattr(config, "dropout") else 0.0 self.use_pretrained_word_embedding = config.use_pretrained_word_embedding if hasattr( config, "use_pretrained_word_embedding") else False # Other attributes self.id2word = tokenizer.id2word self.vocab_size = len(tokenizer.word2id) self.pad_token_id = tokenizer.pad_token_id self.n_pca_components = 50 # Input components self.word_embedding = nn.Embedding( self.vocab_size, self.word_embedding_dim, padding_idx=self.pad_token_id, _weight=init_word_embedding( load_pretrained_word_embedding=self. use_pretrained_word_embedding, pretrained_word_embedding_path=self.word_embedding_path, id2word=self.id2word, word_embedding_dim=self.word_embedding_dim, vocab_size=self.vocab_size, pad_token_id=self.pad_token_id), ) # Encoding components self.sent_encoder = EncoderRNN( input_dim=self.word_embedding_dim, hidden_dim=self.sent_encoder_hidden_dim, n_layers=self.n_sent_encoder_layers, dropout_emb=self.dropout_emb, dropout_input=self.dropout_input, dropout_hidden=self.dropout_hidden, dropout_output=self.dropout_output, bidirectional=True, embedding=self.word_embedding, rnn_type=self.rnn_type, ) self.dial_encoder = EncoderRNN( input_dim=self.sent_encoder_hidden_dim, hidden_dim=self.dial_encoder_hidden_dim, n_layers=self.n_dial_encoder_layers, dropout_emb=self.dropout_emb, dropout_input=self.dropout_input, dropout_hidden=self.dropout_hidden, dropout_output=self.dropout_output, bidirectional=True, rnn_type=self.rnn_type, ) self.pca = PCA(n_components=self.n_pca_components) # Scoring components self.M = nn.Linear(self.n_pca_components, self.n_pca_components) self.N = nn.Linear(self.n_pca_components, self.n_pca_components) self.alpha = None self.beta = None # Extra components # floor encoding if self.floor_encoder_type == "abs": self.floor_encoder = AbsFloorEmbEncoder( input_dim=self.sent_encoder_hidden_dim, embedding_dim=self.attr_embedding_dim) elif self.floor_encoder_type == "rel": self.floor_encoder = RelFloorEmbEncoder( input_dim=self.sent_encoder_hidden_dim, embedding_dim=self.attr_embedding_dim) else: self.floor_encoder = None # Initialization self._init_weights()