def from_options(vocabs, opts):
        """

        Args:
            vocabs:
            opts:
                predict_target (bool): Predict target tags
                predict_source (bool): Predict source tags
                predict_gaps (bool): Predict gap tags
                token_level (bool): Train predictor using PE field.
                sentence_level (bool): Predict Sentence Scores
                sentence_ll (bool): Use likelihood loss for sentence scores
                                    (instead of squared error)
                binary_level: Predict binary sentence labels
                target_bad_weight: Weight for target tags bad class. Default 3.0
                source_bad_weight: Weight for source tags bad class. Default 3.0
                gaps_bad_weight: Weight for gap tags bad class. Default 3.0

        Returns:

        """
        predictor_src = predictor_tgt = None
        if opts.load_pred_source:
            predictor_src = Predictor.from_file(opts.load_pred_source)
        if opts.load_pred_target:
            predictor_tgt = Predictor.from_file(opts.load_pred_target)

        model = Estimator(
            vocabs,
            predictor_tgt=predictor_tgt,
            predictor_src=predictor_src,
            hidden_est=opts.hidden_est,
            rnn_layers_est=opts.rnn_layers_est,
            mlp_est=opts.mlp_est,
            dropout_est=opts.dropout_est,
            start_stop=opts.start_stop,
            predict_target=opts.predict_target,
            predict_gaps=opts.predict_gaps,
            predict_source=opts.predict_source,
            token_level=opts.token_level,
            sentence_level=opts.sentence_level,
            sentence_ll=opts.sentence_ll,
            binary_level=opts.binary_level,
            target_bad_weight=opts.target_bad_weight,
            source_bad_weight=opts.source_bad_weight,
            gaps_bad_weight=opts.gaps_bad_weight,
            hidden_pred=opts.hidden_pred,
            rnn_layers_pred=opts.rnn_layers_pred,
            dropout_pred=opts.dropout_pred,
            share_embeddings=opts.dropout_est,
            embedding_sizes=opts.embedding_sizes,
            target_embeddings_size=opts.target_embeddings_size,
            source_embeddings_size=opts.source_embeddings_size,
            out_embeddings_size=opts.out_embeddings_size,
            predict_inverse=opts.predict_inverse,
        )
        return model
Beispiel #2
0
def test_extend_vocabs(extend_vocab):
    options = extend_vocab
    OOV_SRC = 'oov_word_src'
    OOV_TGT = 'oov_word_tgt'

    fieldset = Predictor.fieldset(wmt18_format=options.wmt18_format)
    vocabs_fieldset = extend_vocabs_fieldset.build_fieldset(fieldset)

    dataset, _ = build_training_datasets(
        fieldset, extend_vocabs_fieldset=vocabs_fieldset, **vars(options))
    assert OOV_SRC in dataset.fields[constants.SOURCE].vocab.stoi
    assert OOV_TGT in dataset.fields[constants.TARGET].vocab.stoi

    fieldset = Predictor.fieldset(wmt18_format=options.wmt18_format)
    options.extend_source_vocab = None
    options.extend_target_vocab = None
    dataset, _ = build_training_datasets(fieldset, **vars(options))
    assert OOV_SRC not in dataset.fields[constants.SOURCE].vocab.stoi
    assert OOV_TGT not in dataset.fields[constants.TARGET].vocab.stoi
    def __init__(self,
                 vocabs,
                 predictor_tgt=None,
                 predictor_src=None,
                 **kwargs):

        super().__init__(vocabs=vocabs, ConfigCls=EstimatorConfig, **kwargs)

        if predictor_src:
            self.config.update(predictor_src.config)
        elif predictor_tgt:
            self.config.update(predictor_tgt.config)

        # Predictor Settings #
        predict_tgt = (self.config.predict_target or self.config.predict_gaps
                       or self.config.sentence_level)
        if predict_tgt and not predictor_tgt:
            predictor_tgt = Predictor(
                vocabs=vocabs,
                predict_inverse=False,
                hidden_pred=self.config.hidden_pred,
                rnn_layers_pred=self.config.rnn_layers_pred,
                dropout_pred=self.config.dropout_pred,
                target_embeddings_size=self.config.target_embeddings_size,
                source_embeddings_size=self.config.source_embeddings_size,
                out_embeddings_size=self.config.out_embeddings_size,
            )
        if self.config.predict_source and not predictor_src:
            predictor_src = Predictor(
                vocabs=vocabs,
                predict_inverse=True,
                hidden_pred=self.config.hidden_pred,
                rnn_layers_pred=self.config.rnn_layers_pred,
                dropout_pred=self.config.dropout_pred,
                target_embeddings_size=self.config.target_embeddings_size,
                source_embeddings_size=self.config.source_embeddings_size,
                out_embeddings_size=self.config.out_embeddings_size,
            )

        # Update the predictor vocabs if token level == True
        # Required by `get_mask` call in predictor forward with `pe` side
        # to determine padding IDs.
        if self.config.token_level:
            if predictor_src:
                predictor_src.vocabs = vocabs
            if predictor_tgt:
                predictor_tgt.vocabs = vocabs

        self.predictor_tgt = predictor_tgt
        self.predictor_src = predictor_src

        predictor_hidden = self.config.hidden_pred
        embedding_size = self.config.out_embeddings_size
        input_size = 2 * predictor_hidden + embedding_size

        self.nb_classes = len(const.LABELS)
        self.lstm_input_size = input_size

        self.mlp = None
        self.sentence_pred = None
        self.sentence_sigma = None
        self.binary_pred = None
        self.binary_scale = None

        # Build Model #

        if self.config.start_stop:
            self.start_PreQEFV = nn.Parameter(torch.zeros(
                1, 1, embedding_size))
            self.end_PreQEFV = nn.Parameter(torch.zeros(1, 1, embedding_size))

        if self.config.mlp_est:
            self.mlp = nn.Sequential(
                nn.Linear(input_size, self.config.hidden_est), nn.Tanh())
            self.lstm_input_size = self.config.hidden_est

        self.lstm = nn.LSTM(
            input_size=self.lstm_input_size,
            hidden_size=self.config.hidden_est,
            num_layers=self.config.rnn_layers_est,
            batch_first=True,
            dropout=self.config.dropout_est,
            bidirectional=True,
        )
        self.embedding_out = nn.Linear(2 * self.config.hidden_est,
                                       self.nb_classes)
        if self.config.predict_gaps:
            self.embedding_out_gaps = nn.Linear(4 * self.config.hidden_est,
                                                self.nb_classes)
        self.dropout = None
        if self.config.dropout_est:
            self.dropout = nn.Dropout(self.config.dropout_est)

        # Multitask Learning Objectives #
        sentence_input_size = (2 * self.config.rnn_layers_est *
                               self.config.hidden_est)
        if self.config.sentence_level:
            self.sentence_pred = nn.Sequential(
                nn.Linear(sentence_input_size, sentence_input_size // 2),
                nn.Sigmoid(),
                nn.Linear(sentence_input_size // 2, sentence_input_size // 4),
                nn.Sigmoid(),
                nn.Linear(sentence_input_size // 4, 1),
            )
            self.sentence_sigma = None
            if self.config.sentence_ll:
                # Predict truncated Gaussian distribution
                self.sentence_sigma = nn.Sequential(
                    nn.Linear(sentence_input_size, sentence_input_size // 2),
                    nn.Sigmoid(),
                    nn.Linear(sentence_input_size // 2,
                              sentence_input_size // 4),
                    nn.Sigmoid(),
                    nn.Linear(sentence_input_size // 4, 1),
                    nn.Sigmoid(),
                )
        if self.config.binary_level:
            self.binary_pred = nn.Sequential(
                nn.Linear(sentence_input_size, sentence_input_size // 2),
                nn.Tanh(),
                nn.Linear(sentence_input_size // 2, sentence_input_size // 4),
                nn.Tanh(),
                nn.Linear(sentence_input_size // 4, 2),
            )

        # Build Losses #

        # FIXME: Remove dependency on magic numbers
        self.xents = nn.ModuleDict()
        weight = make_loss_weights(self.nb_classes, const.BAD_ID,
                                   self.config.target_bad_weight)

        self.xents[const.TARGET_TAGS] = nn.CrossEntropyLoss(
            reduction='sum', ignore_index=const.PAD_TAGS_ID, weight=weight)
        if self.config.predict_source:
            weight = make_loss_weights(self.nb_classes, const.BAD_ID,
                                       self.config.source_bad_weight)
            self.xents[const.SOURCE_TAGS] = nn.CrossEntropyLoss(
                reduction='sum', ignore_index=const.PAD_TAGS_ID, weight=weight)
        if self.config.predict_gaps:
            weight = make_loss_weights(self.nb_classes, const.BAD_ID,
                                       self.config.gaps_bad_weight)
            self.xents[const.GAP_TAGS] = nn.CrossEntropyLoss(
                reduction='sum', ignore_index=const.PAD_TAGS_ID, weight=weight)
        if self.config.sentence_level and not self.config.sentence_ll:
            self.mse_loss = nn.MSELoss(reduction='sum')
        if self.config.binary_level:
            self.xent_binary = nn.CrossEntropyLoss(reduction='sum')