Beispiel #1
0
    def __init__(self, config, tokenizer):
        super(RUBER, self).__init__()

        # Attributes
        # Attributes from config
        self.word_embedding_dim = config.word_embedding_dim
        self.attr_embedding_dim = config.attr_embedding_dim
        self.sent_encoder_hidden_dim = config.sent_encoder_hidden_dim
        self.n_sent_encoder_layers = config.n_sent_encoder_layers
        self.dial_encoder_hidden_dim = config.dial_encoder_hidden_dim
        self.n_dial_encoder_layers = config.n_dial_encoder_layers
        self.rnn_type = config.rnn_type
        self.word_embedding_path = config.word_embedding_path
        self.floor_encoder_type = config.floor_encoder
        self.metric = config.metric_type
        # Optional attributes from config
        self.dropout = config.dropout if hasattr(config, "dropout") else 0.0
        self.use_pretrained_word_embedding = config.use_pretrained_word_embedding if hasattr(config, "use_pretrained_word_embedding") else False
        # Other attributes
        self.id2word = tokenizer.id2word
        self.vocab_size = len(tokenizer.word2id)
        self.pad_token_id = tokenizer.pad_token_id
        self.margin = 0.5

        # Input components
        self.word_embedding = nn.Embedding(
            self.vocab_size,
            self.word_embedding_dim,
            padding_idx=self.pad_token_id,
            _weight=init_word_embedding(
                load_pretrained_word_embedding=self.use_pretrained_word_embedding,
                pretrained_word_embedding_path=self.word_embedding_path,
                id2word=self.id2word,
                word_embedding_dim=self.word_embedding_dim,
                vocab_size=self.vocab_size,
                pad_token_id=self.pad_token_id
            ),
        )

        # Encoding components
        self.sent_encoder = EncoderRNN(
            input_dim=self.word_embedding_dim,
            hidden_dim=self.sent_encoder_hidden_dim,
            n_layers=self.n_sent_encoder_layers,
            dropout_emb=self.dropout,
            dropout_input=self.dropout,
            dropout_hidden=self.dropout,
            dropout_output=self.dropout,
            bidirectional=True,
            embedding=self.word_embedding,
            rnn_type=self.rnn_type,
        )
        self.dial_encoder = EncoderRNN(
            input_dim=self.sent_encoder_hidden_dim,
            hidden_dim=self.dial_encoder_hidden_dim,
            n_layers=self.n_dial_encoder_layers,
            dropout_emb=self.dropout,
            dropout_input=self.dropout,
            dropout_hidden=self.dropout,
            dropout_output=self.dropout,
            bidirectional=True,
            rnn_type=self.rnn_type,
        )

        # Output components
        # regressor for unsupervised training of is-next-sentence prediction
        self.M = nn.Parameter(
            torch.FloatTensor(self.dial_encoder_hidden_dim, self.sent_encoder_hidden_dim)
        )
        self.unref_fc = nn.ModuleList(
            [
                nn.Dropout(self.dropout),
                nn.Linear(self.dial_encoder_hidden_dim+self.sent_encoder_hidden_dim+1, self.dial_encoder_hidden_dim),
                nn.Tanh(),
                nn.Dropout(self.dropout),
                nn.Linear(self.dial_encoder_hidden_dim, 1),
                nn.Sigmoid()
            ]
        )

        # Extra components
        # floor encoding
        if self.floor_encoder_type == "abs":
            self.floor_encoder = AbsFloorEmbEncoder(
                input_dim=self.sent_encoder_hidden_dim,
                embedding_dim=self.attr_embedding_dim
            )
        elif self.floor_encoder_type == "rel":
            self.floor_encoder = RelFloorEmbEncoder(
                input_dim=self.sent_encoder_hidden_dim,
                embedding_dim=self.attr_embedding_dim
            )
        else:
            self.floor_encoder = None

        # Initialization
        self._init_weights()
Beispiel #2
0
    def __init__(self, config, tokenizer):
        super(HRED, self).__init__()

        # Attributes
        # Attributes from config
        self.word_embedding_dim = config.word_embedding_dim
        self.attr_embedding_dim = config.attr_embedding_dim
        self.sent_encoder_hidden_dim = config.sent_encoder_hidden_dim
        self.n_sent_encoder_layers = config.n_sent_encoder_layers
        self.dial_encoder_hidden_dim = config.dial_encoder_hidden_dim
        self.n_dial_encoder_layers = config.n_dial_encoder_layers
        self.decoder_hidden_dim = config.decoder_hidden_dim
        self.n_decoder_layers = config.n_decoder_layers
        self.use_attention = config.use_attention
        self.decode_max_len = config.decode_max_len
        self.tie_weights = config.tie_weights
        self.rnn_type = config.rnn_type
        self.gen_type = config.gen_type
        self.top_k = config.top_k
        self.top_p = config.top_p
        self.temp = config.temp
        self.word_embedding_path = config.word_embedding_path
        self.floor_encoder_type = config.floor_encoder
        # Optional attributes from config
        self.dropout_emb = config.dropout if hasattr(config,
                                                     "dropout") else 0.0
        self.dropout_input = config.dropout if hasattr(config,
                                                       "dropout") else 0.0
        self.dropout_hidden = config.dropout if hasattr(config,
                                                        "dropout") else 0.0
        self.dropout_output = config.dropout if hasattr(config,
                                                        "dropout") else 0.0
        self.use_pretrained_word_embedding = config.use_pretrained_word_embedding if hasattr(
            config, "use_pretrained_word_embedding") else False
        # Other attributes
        self.word2id = tokenizer.word2id
        self.id2word = tokenizer.id2word
        self.vocab_size = len(tokenizer.word2id)
        self.pad_token_id = tokenizer.pad_token_id
        self.bos_token_id = tokenizer.bos_token_id
        self.eos_token_id = tokenizer.eos_token_id

        # Input components
        self.word_embedding = nn.Embedding(
            self.vocab_size,
            self.word_embedding_dim,
            padding_idx=self.pad_token_id,
            _weight=init_word_embedding(
                load_pretrained_word_embedding=self.
                use_pretrained_word_embedding,
                pretrained_word_embedding_path=self.word_embedding_path,
                id2word=self.id2word,
                word_embedding_dim=self.word_embedding_dim,
                vocab_size=self.vocab_size,
                pad_token_id=self.pad_token_id),
        )

        # Encoding components
        self.sent_encoder = EncoderRNN(
            input_dim=self.word_embedding_dim,
            hidden_dim=self.sent_encoder_hidden_dim,
            n_layers=self.n_sent_encoder_layers,
            dropout_emb=self.dropout_emb,
            dropout_input=self.dropout_input,
            dropout_hidden=self.dropout_hidden,
            dropout_output=self.dropout_output,
            bidirectional=True,
            embedding=self.word_embedding,
            rnn_type=self.rnn_type,
        )
        self.dial_encoder = EncoderRNN(
            input_dim=self.sent_encoder_hidden_dim,
            hidden_dim=self.dial_encoder_hidden_dim,
            n_layers=self.n_dial_encoder_layers,
            dropout_emb=self.dropout_emb,
            dropout_input=self.dropout_input,
            dropout_hidden=self.dropout_hidden,
            dropout_output=self.dropout_output,
            bidirectional=False,
            rnn_type=self.rnn_type,
        )

        # Decoding components
        self.enc2dec_hidden_fc = nn.Linear(
            self.dial_encoder_hidden_dim,
            self.n_decoder_layers * self.decoder_hidden_dim if self.rnn_type
            == "gru" else self.n_decoder_layers * self.decoder_hidden_dim * 2)
        self.decoder = DecoderRNN(vocab_size=len(self.word2id),
                                  input_dim=self.word_embedding_dim,
                                  hidden_dim=self.decoder_hidden_dim,
                                  feat_dim=self.dial_encoder_hidden_dim,
                                  n_layers=self.n_decoder_layers,
                                  bos_token_id=self.bos_token_id,
                                  eos_token_id=self.eos_token_id,
                                  pad_token_id=self.pad_token_id,
                                  max_len=self.decode_max_len,
                                  dropout_emb=self.dropout_emb,
                                  dropout_input=self.dropout_input,
                                  dropout_hidden=self.dropout_hidden,
                                  dropout_output=self.dropout_output,
                                  embedding=self.word_embedding,
                                  tie_weights=self.tie_weights,
                                  rnn_type=self.rnn_type,
                                  use_attention=self.use_attention,
                                  attn_dim=self.sent_encoder_hidden_dim)

        # Extra components
        # floor encoding
        if self.floor_encoder_type == "abs":
            self.floor_encoder = AbsFloorEmbEncoder(
                input_dim=self.sent_encoder_hidden_dim,
                embedding_dim=self.attr_embedding_dim)
        elif self.floor_encoder_type == "rel":
            self.floor_encoder = RelFloorEmbEncoder(
                input_dim=self.sent_encoder_hidden_dim,
                embedding_dim=self.attr_embedding_dim)
        else:
            self.floor_encoder = None

        # Initialization
        self._init_weights()
Beispiel #3
0
    def __init__(self, config, tokenizer):
        super(RNNLM, self).__init__()

        # Attributes
        # Attributes from config
        self.word_embedding_dim = config.word_embedding_dim
        self.decoder_hidden_dim = config.decoder_hidden_dim
        self.n_decoder_layers = config.n_decoder_layers
        self.decode_max_len = config.decode_max_len
        self.tie_weights = config.tie_weights
        self.rnn_type = config.rnn_type
        # Optional attributes from config
        self.gen_type = config.gen_type if hasattr(config, "gen_type") else "greedy"
        self.top_k = config.top_k if hasattr(config, "top_k") else 0
        self.top_p = config.top_p if hasattr(config, "top_p") else 0.0
        self.temp = config.temp if hasattr(config, "temp") else 1.0
        self.dropout_emb = config.dropout if hasattr(config, "dropout") else 0.0
        self.dropout_input = config.dropout if hasattr(config, "dropout") else 0.0
        self.dropout_hidden = config.dropout if hasattr(config, "dropout") else 0.0
        self.dropout_output = config.dropout if hasattr(config, "dropout") else 0.0
        self.use_pretrained_word_embedding = config.use_pretrained_word_embedding if hasattr(config, "use_pretrained_word_embedding") else False
        self.word_embedding_path = config.word_embedding_path if hasattr(config, "word_embedding_path") else None
        # Other attributes
        self.tokenizer = tokenizer
        self.word2id = tokenizer.word2id
        self.id2word = tokenizer.id2word
        self.vocab_size = len(tokenizer.word2id)
        self.pad_token_id = tokenizer.pad_token_id
        self.bos_token_id = tokenizer.bos_token_id
        self.eos_token_id = tokenizer.eos_token_id

        # Components
        self.word_embedding = nn.Embedding(
            self.vocab_size,
            self.word_embedding_dim,
            padding_idx=self.pad_token_id,
            _weight=init_word_embedding(
                load_pretrained_word_embedding=self.use_pretrained_word_embedding,
                pretrained_word_embedding_path=self.word_embedding_path,
                id2word=self.id2word,
                word_embedding_dim=self.word_embedding_dim,
                vocab_size=self.vocab_size,
                pad_token_id=self.pad_token_id
            ),
        )
        self.decoder = DecoderRNN(
            vocab_size=len(self.word2id),
            input_dim=self.word_embedding_dim,
            hidden_dim=self.decoder_hidden_dim,
            n_layers=self.n_decoder_layers,
            bos_token_id=self.bos_token_id,
            eos_token_id=self.eos_token_id,
            pad_token_id=self.pad_token_id,
            max_len=self.decode_max_len,
            dropout_emb=self.dropout_emb,
            dropout_input=self.dropout_input,
            dropout_hidden=self.dropout_hidden,
            dropout_output=self.dropout_output,
            embedding=self.word_embedding,
            tie_weights=self.tie_weights,
            rnn_type=self.rnn_type,
        )

        # Initialization
        self._init_weights()
Beispiel #4
0
    def __init__(self, config, tokenizer):
        super(VHRED, self).__init__()

        # Attributes
        # Attributes from config
        self.word_embedding_dim = config.word_embedding_dim
        self.attr_embedding_dim = config.attr_embedding_dim
        self.sent_encoder_hidden_dim = config.sent_encoder_hidden_dim
        self.n_sent_encoder_layers = config.n_sent_encoder_layers
        self.dial_encoder_hidden_dim = config.dial_encoder_hidden_dim
        self.n_dial_encoder_layers = config.n_dial_encoder_layers
        self.latent_variable_dim = config.latent_dim
        self.decoder_hidden_dim = config.decoder_hidden_dim
        self.n_decoder_layers = config.n_decoder_layers
        self.use_attention = config.use_attention
        self.decode_max_len = config.decode_max_len
        self.tie_weights = config.tie_weights
        self.rnn_type = config.rnn_type
        self.gen_type = config.gen_type
        self.top_k = config.top_k
        self.top_p = config.top_p
        self.temp = config.temp
        self.word_embedding_path = config.word_embedding_path
        self.floor_encoder_type = config.floor_encoder
        self.gaussian_mix_type = config.gaussian_mix_type
        # Optional attributes from config
        self.use_bow_loss = config.use_bow_loss if hasattr(
            config, "use_bow_loss") else True
        self.dropout = config.dropout if hasattr(config, "dropout") else 0.0
        self.dropout_emb = config.dropout if hasattr(config,
                                                     "dropout") else 0.0
        self.dropout_input = config.dropout if hasattr(config,
                                                       "dropout") else 0.0
        self.dropout_hidden = config.dropout if hasattr(config,
                                                        "dropout") else 0.0
        self.dropout_output = config.dropout if hasattr(config,
                                                        "dropout") else 0.0
        self.use_pretrained_word_embedding = config.use_pretrained_word_embedding if hasattr(
            config, "use_pretrained_word_embedding") else False
        self.n_step_annealing = config.n_step_annealing if hasattr(
            config, "n_step_annealing") else 1
        self.n_components = config.n_components if hasattr(
            config, "n_components") else 1
        # Other attributes
        self.vocab_size = len(tokenizer.word2id)
        self.word2id = tokenizer.word2id
        self.id2word = tokenizer.id2word
        self.pad_token_id = tokenizer.pad_token_id
        self.bos_token_id = tokenizer.bos_token_id
        self.eos_token_id = tokenizer.eos_token_id

        # Input components
        self.word_embedding = nn.Embedding(
            self.vocab_size,
            self.word_embedding_dim,
            padding_idx=self.pad_token_id,
            _weight=init_word_embedding(
                load_pretrained_word_embedding=self.
                use_pretrained_word_embedding,
                pretrained_word_embedding_path=self.word_embedding_path,
                id2word=self.id2word,
                word_embedding_dim=self.word_embedding_dim,
                vocab_size=self.vocab_size,
                pad_token_id=self.pad_token_id),
        )

        # Encoding components
        self.sent_encoder = EncoderRNN(
            input_dim=self.word_embedding_dim,
            hidden_dim=self.sent_encoder_hidden_dim,
            n_layers=self.n_sent_encoder_layers,
            dropout_emb=self.dropout_emb,
            dropout_input=self.dropout_input,
            dropout_hidden=self.dropout_hidden,
            dropout_output=self.dropout_output,
            bidirectional=True,
            embedding=self.word_embedding,
            rnn_type=self.rnn_type,
        )
        self.dial_encoder = EncoderRNN(
            input_dim=self.sent_encoder_hidden_dim,
            hidden_dim=self.dial_encoder_hidden_dim,
            n_layers=self.n_dial_encoder_layers,
            dropout_emb=self.dropout_emb,
            dropout_input=self.dropout_input,
            dropout_hidden=self.dropout_hidden,
            dropout_output=self.dropout_output,
            bidirectional=False,
            rnn_type=self.rnn_type,
        )

        # Variational components
        if config.n_components == 1:
            self.prior_net = GaussianVariation(
                input_dim=self.dial_encoder_hidden_dim,
                z_dim=self.latent_variable_dim,
                # large_mlp=True
            )
        elif config.n_components > 1:
            if self.gaussian_mix_type == "gmm":
                self.prior_net = GMMVariation(
                    input_dim=self.dial_encoder_hidden_dim,
                    z_dim=self.latent_variable_dim,
                    n_components=self.n_components,
                )
            elif self.gaussian_mix_type == "lgm":
                self.prior_net = LGMVariation(
                    input_dim=self.dial_encoder_hidden_dim,
                    z_dim=self.latent_variable_dim,
                    n_components=self.n_components,
                )
        self.post_net = GaussianVariation(
            input_dim=self.sent_encoder_hidden_dim +
            self.dial_encoder_hidden_dim,
            z_dim=self.latent_variable_dim,
        )
        self.latent_to_bow = nn.Sequential(
            nn.Linear(self.latent_variable_dim + self.dial_encoder_hidden_dim,
                      self.latent_variable_dim), nn.Tanh(),
            nn.Dropout(self.dropout),
            nn.Linear(self.latent_variable_dim, self.vocab_size))
        self.ctx_fc = nn.Sequential(
            nn.Linear(
                self.latent_variable_dim + self.dial_encoder_hidden_dim,
                self.dial_encoder_hidden_dim,
            ), nn.Tanh(), nn.Dropout(self.dropout))

        # Decoding components
        self.enc2dec_hidden_fc = nn.Linear(
            self.dial_encoder_hidden_dim,
            self.n_decoder_layers * self.decoder_hidden_dim if self.rnn_type
            == "gru" else self.n_decoder_layers * self.decoder_hidden_dim * 2)
        self.decoder = DecoderRNN(vocab_size=len(self.word2id),
                                  input_dim=self.word_embedding_dim,
                                  hidden_dim=self.decoder_hidden_dim,
                                  feat_dim=self.dial_encoder_hidden_dim,
                                  n_layers=self.n_decoder_layers,
                                  bos_token_id=self.bos_token_id,
                                  eos_token_id=self.eos_token_id,
                                  pad_token_id=self.pad_token_id,
                                  max_len=self.decode_max_len,
                                  dropout_emb=self.dropout_emb,
                                  dropout_input=self.dropout_input,
                                  dropout_hidden=self.dropout_hidden,
                                  dropout_output=self.dropout_output,
                                  embedding=self.word_embedding,
                                  tie_weights=self.tie_weights,
                                  rnn_type=self.rnn_type,
                                  use_attention=self.use_attention,
                                  attn_dim=self.sent_encoder_hidden_dim)

        # Extra components
        # Floor encoding
        if self.floor_encoder_type == "abs":
            self.floor_encoder = AbsFloorEmbEncoder(
                input_dim=self.sent_encoder_hidden_dim,
                embedding_dim=self.attr_embedding_dim)
        elif self.floor_encoder_type == "rel":
            self.floor_encoder = RelFloorEmbEncoder(
                input_dim=self.sent_encoder_hidden_dim,
                embedding_dim=self.attr_embedding_dim)
        else:
            self.floor_encoder = None

        # Initialization
        self._init_weights()
Beispiel #5
0
    def __init__(self, config, tokenizer):
        super(HRESepUttrEnc, self).__init__()

        # Attributes
        # Attributes from config
        self.num_labels = len(config.dialog_acts)
        self.word_embedding_dim = config.word_embedding_dim
        self.attr_embedding_dim = config.attr_embedding_dim
        self.sent_encoder_hidden_dim = config.sent_encoder_hidden_dim
        self.n_sent_encoder_layers = config.n_sent_encoder_layers
        self.dial_encoder_hidden_dim = config.dial_encoder_hidden_dim
        self.n_dial_encoder_layers = config.n_dial_encoder_layers
        self.rnn_type = config.rnn_type
        self.word_embedding_path = config.word_embedding_path
        self.floor_encoder_type = config.floor_encoder
        # Optional attributes from config
        self.dropout_emb = config.dropout if hasattr(config,
                                                     "dropout") else 0.0
        self.dropout_input = config.dropout if hasattr(config,
                                                       "dropout") else 0.0
        self.dropout_hidden = config.dropout if hasattr(config,
                                                        "dropout") else 0.0
        self.dropout_output = config.dropout if hasattr(config,
                                                        "dropout") else 0.0
        self.use_pretrained_word_embedding = config.use_pretrained_word_embedding if hasattr(
            config, "use_pretrained_word_embedding") else False
        # Other attributes
        self.word2id = tokenizer.word2id
        self.id2word = tokenizer.id2word
        self.vocab_size = len(tokenizer.word2id)
        self.pad_token_id = tokenizer.pad_token_id

        # Input components
        self.word_embedding = nn.Embedding(
            self.vocab_size,
            self.word_embedding_dim,
            padding_idx=self.pad_token_id,
            _weight=init_word_embedding(
                load_pretrained_word_embedding=self.
                use_pretrained_word_embedding,
                pretrained_word_embedding_path=self.word_embedding_path,
                id2word=self.id2word,
                word_embedding_dim=self.word_embedding_dim,
                vocab_size=self.vocab_size,
                pad_token_id=self.pad_token_id),
        )

        # Encoding components
        self.own_sent_encoder = EncoderRNN(
            input_dim=self.word_embedding_dim,
            hidden_dim=self.sent_encoder_hidden_dim,
            n_layers=self.n_sent_encoder_layers,
            dropout_emb=self.dropout_emb,
            dropout_input=self.dropout_input,
            dropout_hidden=self.dropout_hidden,
            dropout_output=self.dropout_output,
            bidirectional=True,
            embedding=self.word_embedding,
            rnn_type=self.rnn_type,
        )
        self.oth_sent_encoder = EncoderRNN(
            input_dim=self.word_embedding_dim,
            hidden_dim=self.sent_encoder_hidden_dim,
            n_layers=self.n_sent_encoder_layers,
            dropout_emb=self.dropout_emb,
            dropout_input=self.dropout_input,
            dropout_hidden=self.dropout_hidden,
            dropout_output=self.dropout_output,
            bidirectional=True,
            embedding=self.word_embedding,
            rnn_type=self.rnn_type,
        )
        self.dial_encoder = EncoderRNN(
            input_dim=self.sent_encoder_hidden_dim,
            hidden_dim=self.dial_encoder_hidden_dim,
            n_layers=self.n_dial_encoder_layers,
            dropout_emb=self.dropout_emb,
            dropout_input=self.dropout_input,
            dropout_hidden=self.dropout_hidden,
            dropout_output=self.dropout_output,
            bidirectional=False,
            rnn_type=self.rnn_type,
        )

        # Classification components
        self.output_fc = nn.Linear(self.dial_encoder_hidden_dim,
                                   self.num_labels)

        # Initialization
        self._init_weights()
Beispiel #6
0
    def __init__(self, config, tokenizer):
        super(VHCR, self).__init__()

        # Attributes
        # Attributes from config
        self.word_embedding_dim = config.word_embedding_dim
        self.attr_embedding_dim = config.attr_embedding_dim
        self.sent_encoder_hidden_dim = config.sent_encoder_hidden_dim
        self.n_sent_encoder_layers = config.n_sent_encoder_layers
        self.dial_encoder_hidden_dim = config.dial_encoder_hidden_dim
        self.n_dial_encoder_layers = config.n_dial_encoder_layers
        self.latent_variable_dim = config.latent_dim
        self.decoder_hidden_dim = config.decoder_hidden_dim
        self.n_decoder_layers = config.n_decoder_layers
        self.use_attention = config.use_attention
        self.decode_max_len = config.decode_max_len
        self.tie_weights = config.tie_weights
        self.rnn_type = config.rnn_type
        self.gen_type = config.gen_type
        self.top_k = config.top_k
        self.top_p = config.top_p
        self.temp = config.temp
        self.word_embedding_path = config.word_embedding_path
        self.floor_encoder_type = config.floor_encoder
        # Optional attributes from config
        self.dropout_emb = config.dropout if hasattr(config, "dropout") else 0.0
        self.dropout_input = config.dropout if hasattr(config, "dropout") else 0.0
        self.dropout_hidden = config.dropout if hasattr(config, "dropout") else 0.0
        self.dropout_output = config.dropout if hasattr(config, "dropout") else 0.0
        self.use_pretrained_word_embedding = config.use_pretrained_word_embedding if hasattr(config, "use_pretrained_word_embedding") else False
        self.n_step_annealing = config.n_step_annealing if hasattr(config, "n_step_annealing") else 0
        # Other attributes
        self.vocab_size = len(tokenizer.word2id)
        self.word2id = tokenizer.word2id
        self.id2word = tokenizer.id2word
        self.pad_token_id = tokenizer.pad_token_id
        self.bos_token_id = tokenizer.bos_token_id
        self.eos_token_id = tokenizer.eos_token_id
        self.dropout_sent = 0.25

        # Input components
        self.word_embedding = nn.Embedding(
            self.vocab_size,
            self.word_embedding_dim,
            padding_idx=self.pad_token_id,
            _weight=init_word_embedding(
                load_pretrained_word_embedding=self.use_pretrained_word_embedding,
                pretrained_word_embedding_path=self.word_embedding_path,
                id2word=self.id2word,
                word_embedding_dim=self.word_embedding_dim,
                vocab_size=self.vocab_size,
                pad_token_id=self.pad_token_id
            ),
        )

        # Encoding components
        self.sent_encoder = EncoderRNN(
            input_dim=self.word_embedding_dim,
            hidden_dim=self.sent_encoder_hidden_dim,
            n_layers=self.n_sent_encoder_layers,
            dropout_emb=self.dropout_emb,
            dropout_input=self.dropout_input,
            dropout_hidden=self.dropout_hidden,
            dropout_output=self.dropout_output,
            bidirectional=True,
            embedding=self.word_embedding,
            rnn_type=self.rnn_type,
        )
        self.dial_encoder = EncoderRNN(
            input_dim=self.sent_encoder_hidden_dim+self.latent_variable_dim,
            hidden_dim=self.dial_encoder_hidden_dim,
            n_layers=self.n_dial_encoder_layers,
            dropout_emb=self.dropout_emb,
            dropout_input=self.dropout_input,
            dropout_hidden=self.dropout_hidden,
            dropout_output=self.dropout_output,
            bidirectional=False,
            rnn_type=self.rnn_type,
        )
        self.dial_infer_encoder = EncoderRNN(
            input_dim=self.sent_encoder_hidden_dim,
            hidden_dim=self.dial_encoder_hidden_dim,
            n_layers=self.n_dial_encoder_layers,
            dropout_emb=self.dropout_emb,
            dropout_input=self.dropout_input,
            dropout_hidden=self.dropout_hidden,
            dropout_output=self.dropout_output,
            bidirectional=True,
            rnn_type=self.rnn_type,
        )

        # Variational components
        self.dial_post_net = GaussianVariation(
            input_dim=self.dial_encoder_hidden_dim,
            z_dim=self.latent_variable_dim
        )
        self.sent_prior_net = GaussianVariation(
            input_dim=self.dial_encoder_hidden_dim+self.latent_variable_dim,
            z_dim=self.latent_variable_dim
        )
        self.sent_post_net = GaussianVariation(
            input_dim=self.sent_encoder_hidden_dim+self.dial_encoder_hidden_dim+self.latent_variable_dim,
            z_dim=self.latent_variable_dim
        )
        self.unk_sent_vec = nn.Parameter(torch.randn(self.sent_encoder_hidden_dim)).to(DEVICE)

        # Decoding components
        self.ctx_fc = nn.Linear(
            2*self.latent_variable_dim+self.dial_encoder_hidden_dim,
            self.dial_encoder_hidden_dim
        )
        self.enc2dec_hidden_fc = nn.Linear(
            self.dial_encoder_hidden_dim,
            self.n_decoder_layers*self.decoder_hidden_dim if self.rnn_type == "gru" else self.n_decoder_layers*self.decoder_hidden_dim*2
        )
        self.decoder = DecoderRNN(
            vocab_size=len(self.word2id),
            input_dim=self.word_embedding_dim,
            hidden_dim=self.decoder_hidden_dim,
            feat_dim=self.dial_encoder_hidden_dim,
            n_layers=self.n_decoder_layers,
            bos_token_id=self.bos_token_id,
            eos_token_id=self.eos_token_id,
            pad_token_id=self.pad_token_id,
            max_len=self.decode_max_len,
            dropout_emb=self.dropout_emb,
            dropout_input=self.dropout_input,
            dropout_hidden=self.dropout_hidden,
            dropout_output=self.dropout_output,
            embedding=self.word_embedding,
            tie_weights=self.tie_weights,
            rnn_type=self.rnn_type,
            use_attention=self.use_attention,
            attn_dim=self.sent_encoder_hidden_dim
        )

        # Extra components
        # Floor encoding
        if self.floor_encoder_type == "abs":
            self.floor_encoder = AbsFloorEmbEncoder(
                input_dim=self.sent_encoder_hidden_dim,
                embedding_dim=self.attr_embedding_dim
            )
        elif self.floor_encoder_type == "rel":
            self.floor_encoder = RelFloorEmbEncoder(
                input_dim=self.sent_encoder_hidden_dim,
                embedding_dim=self.attr_embedding_dim
            )
        else:
            self.floor_encoder = None
        # Hidden initialization
        self.dial_z2dial_enc_hidden_fc = nn.Linear(
            self.latent_variable_dim,
            self.n_dial_encoder_layers*self.dial_encoder_hidden_dim if self.rnn_type == "gru" else self.n_dial_encoder_layers*self.dial_encoder_hidden_dim*2
        )

        # Initialization
        self._init_weights()
Beispiel #7
0
    def __init__(self, config, tokenizer):
        super(ADEM, self).__init__()

        # Attributes
        # Attributes from config
        self.word_embedding_dim = config.word_embedding_dim
        self.attr_embedding_dim = config.attr_embedding_dim
        self.sent_encoder_hidden_dim = config.sent_encoder_hidden_dim
        self.n_sent_encoder_layers = config.n_sent_encoder_layers
        self.dial_encoder_hidden_dim = config.dial_encoder_hidden_dim
        self.n_dial_encoder_layers = config.n_dial_encoder_layers
        assert self.sent_encoder_hidden_dim == self.dial_encoder_hidden_dim
        self.latent_variable_dim = config.latent_dim
        self.tie_weights = config.tie_weights
        self.rnn_type = config.rnn_type
        self.word_embedding_path = config.word_embedding_path
        self.floor_encoder_type = config.floor_encoder
        self.metric = config.metric_type
        # Optional attributes from config
        self.dropout_emb = config.dropout if hasattr(config,
                                                     "dropout") else 0.0
        self.dropout_input = config.dropout if hasattr(config,
                                                       "dropout") else 0.0
        self.dropout_hidden = config.dropout if hasattr(config,
                                                        "dropout") else 0.0
        self.dropout_output = config.dropout if hasattr(config,
                                                        "dropout") else 0.0
        self.use_pretrained_word_embedding = config.use_pretrained_word_embedding if hasattr(
            config, "use_pretrained_word_embedding") else False
        # Other attributes
        self.id2word = tokenizer.id2word
        self.vocab_size = len(tokenizer.word2id)
        self.pad_token_id = tokenizer.pad_token_id
        self.n_pca_components = 50

        # Input components
        self.word_embedding = nn.Embedding(
            self.vocab_size,
            self.word_embedding_dim,
            padding_idx=self.pad_token_id,
            _weight=init_word_embedding(
                load_pretrained_word_embedding=self.
                use_pretrained_word_embedding,
                pretrained_word_embedding_path=self.word_embedding_path,
                id2word=self.id2word,
                word_embedding_dim=self.word_embedding_dim,
                vocab_size=self.vocab_size,
                pad_token_id=self.pad_token_id),
        )

        # Encoding components
        self.sent_encoder = EncoderRNN(
            input_dim=self.word_embedding_dim,
            hidden_dim=self.sent_encoder_hidden_dim,
            n_layers=self.n_sent_encoder_layers,
            dropout_emb=self.dropout_emb,
            dropout_input=self.dropout_input,
            dropout_hidden=self.dropout_hidden,
            dropout_output=self.dropout_output,
            bidirectional=True,
            embedding=self.word_embedding,
            rnn_type=self.rnn_type,
        )
        self.dial_encoder = EncoderRNN(
            input_dim=self.sent_encoder_hidden_dim,
            hidden_dim=self.dial_encoder_hidden_dim,
            n_layers=self.n_dial_encoder_layers,
            dropout_emb=self.dropout_emb,
            dropout_input=self.dropout_input,
            dropout_hidden=self.dropout_hidden,
            dropout_output=self.dropout_output,
            bidirectional=True,
            rnn_type=self.rnn_type,
        )
        self.pca = PCA(n_components=self.n_pca_components)

        # Scoring components
        self.M = nn.Linear(self.n_pca_components, self.n_pca_components)
        self.N = nn.Linear(self.n_pca_components, self.n_pca_components)
        self.alpha = None
        self.beta = None

        # Extra components
        # floor encoding
        if self.floor_encoder_type == "abs":
            self.floor_encoder = AbsFloorEmbEncoder(
                input_dim=self.sent_encoder_hidden_dim,
                embedding_dim=self.attr_embedding_dim)
        elif self.floor_encoder_type == "rel":
            self.floor_encoder = RelFloorEmbEncoder(
                input_dim=self.sent_encoder_hidden_dim,
                embedding_dim=self.attr_embedding_dim)
        else:
            self.floor_encoder = None

        # Initialization
        self._init_weights()