Пример #1
0
class MultiProjectorManager(Module):
    def __init__(self, config):

        super().__init__()

        self.nametag = 'MultiProjectorsManager'

        self.setup_classes_factory()

        self.classes = self.get_classes(config)

        self.projectors = ModuleDict(
            {k: cl(config)
             for k, cl in self.classes.items()})
        # print('projectors: {}'.format(self.projectors))

    def setup_classes_factory(self):
        self.factory_dict = {
            "HyperbolicProjector": NickelProjector,
            'CosineProjector': Type2VecProjector
        }

    def get_classes(self, conf):
        self.names = conf[self.nametag]['PROJECTOR_CONFIGS'].split(' ')
        classes = {}
        for name in self.names:
            classes[conf[name]['NAME']] = self.factory_dict[conf[name]
                                                            ['Class']]
        return classes

    def forward(self, vec):
        projections = {}
        for k, projector in self.projectors.items():
            projections[k] = projector(vec)
        return projections
Пример #2
0
    def step(self, batch: Any, batch_idx: int, metrics: nn.ModuleDict) -> Any:
        """The training/validation/test step.

        Override for custom behavior.
        """
        x, y = batch
        y_hat = self(x)
        y, y_hat = self.apply_filtering(y, y_hat)
        output = {"y_hat": y_hat}
        y_hat = self.to_loss_format(output["y_hat"])
        losses = {name: l_fn(y_hat, y) for name, l_fn in self.loss_fn.items()}

        y_hat = self.to_metrics_format(output["y_hat"])

        logs = {}

        for name, metric in metrics.items():
            if isinstance(metric, torchmetrics.metric.Metric):
                metric(y_hat, y)
                logs[
                    name] = metric  # log the metric itself if it is of type Metric
            else:
                logs[name] = metric(y_hat, y)

        if len(losses.values()) > 1:
            logs["total_loss"] = sum(losses.values())
            return logs["total_loss"], logs

        output["loss"] = self.compute_loss(losses)
        output["logs"] = self.compute_logs(logs, losses)
        output["y"] = y
        return output
Пример #3
0
class Plugin(Module):
    """Module used to compute prop values that are passed to the decoder."""
    def __init__(self,
                 model_dim,
                 ff_dim,
                 nheads,
                 nlayers,
                 dropout=0.0,
                 mem_att_dropout=0.0,
                 mem_dim=None,
                 len_prop=True,
                 rating_prop=True,
                 rouge_prop=True,
                 pov_prop=True):
        super(Plugin, self).__init__()
        assert len_prop or rating_prop or rouge_prop or pov_prop
        self.model_dim = model_dim
        self.tr_stack = TransformerStack(model_dim=model_dim,
                                         mem_dim=mem_dim,
                                         num_layers=nlayers,
                                         dropout=dropout,
                                         mem_att_dropout=mem_att_dropout,
                                         ff_dim=ff_dim,
                                         nheads=nheads)
        self.fns = ModuleDict()
        if len_prop:
            self.fns[ModelF.LEN_PROP] = Linear(model_dim, 1)
        if rating_prop:
            self.fns[ModelF.RATING_PROP] = Linear(model_dim, 1)
        if rouge_prop:
            self.fns[ModelF.ROUGE_PROP] = Sequential()
            self.fns[ModelF.ROUGE_PROP].add_module('lin', Linear(model_dim, 3))
            self.fns[ModelF.ROUGE_PROP].add_module('sigmoid', Sigmoid())
        if pov_prop:
            self.fns[ModelF.POV_PROP] = Sequential()
            self.fns[ModelF.POV_PROP].add_module("lin", Linear(model_dim, 4))
            self.fns[ModelF.POV_PROP].add_module('softmax', Softmax(dim=-1))

    def forward(self, mem, mem_bin_mask):
        """Computes props by running one-step transformer.

        Args:
            mem: [batch_size, seq_len, model_dim]
            mem_bin_mask: [batch_size, seq_len]

        Returns:
            dict mapping prop names to computed values [batch_size, dim].
        """
        bs = mem.size(0)
        dummy_inp = T.ones((bs, 1, self.model_dim), device=mem.device)
        out, _, tr_arts = self.tr_stack(tgt=dummy_inp,
                                        mem=mem,
                                        mem_key_padding_mask=mem_bin_mask)
        out = out.squeeze(1)
        props = {n: fn(out).squeeze(-1) for n, fn in self.fns.items()}
        mem_att_wts = tr_arts[MEM_ATT_WTS]
        return props, mem_att_wts
 def compute_and_log_metrics(self, logits: torch.Tensor,
                             targets: torch.Tensor, subject_ids: List[str],
                             is_training: bool, metrics: ModuleDict,
                             logger: DataframeLogger, current_epoch: int,
                             data_split: ModelExecutionMode) -> None:
     """
     Computes all the metrics for a given (logits, labels) pair, and writes them to the loggers.
     :param logits: The model output before normalization.
     :param targets: The expected model outputs.
     :param subject_ids: The subject IDs for the present minibatch.
     :param is_training: If True, write the metrics as training metrics, otherwise as validation metrics.
     :param metrics: A dictionary mapping from names of prediction targets to a list of metric computers,
     as returned by create_metric_computers.
     :param logger: An object of type DataframeLogger which can be be used for logging within this function.
     :param current_epoch: Current epoch number.
     :param data_split: ModelExecutionMode object indicating if this is the train or validation split.
     :return:
     """
     per_subject_outputs: List[Tuple[str, str, torch.Tensor,
                                     torch.Tensor]] = []
     for i, (prediction_target, metric_list) in enumerate(metrics.items()):
         # mask the model outputs and labels if required
         masked = get_masked_model_outputs_and_labels(
             logits[:, i, ...], targets[:, i, ...], subject_ids)
         # compute metrics on valid masked tensors only
         if masked is not None:
             _logits = masked.model_outputs.data
             _posteriors = self.get_post_loss_logits_normalization_function(
             )(_logits)
             # Classification metrics expect labels as integers, but they are float throughout the rest of the code
             labels_dtype = torch.int if self.is_classification_model else _posteriors.dtype
             _labels = masked.labels.data.to(dtype=labels_dtype)
             _subject_ids = masked.subject_ids
             assert _subject_ids is not None
             for metric in metric_list:
                 if isinstance(
                         metric,
                         ScalarMetricsBase) and metric.compute_from_logits:
                     metric(_logits, _labels)
                 else:
                     metric(_posteriors, _labels)
             per_subject_outputs.extend(
                 zip(_subject_ids, [prediction_target] * len(_subject_ids),
                     _posteriors.tolist(), _labels.tolist()))
     # Write a full breakdown of per-subject predictions and labels to a file. These files are local to the current
     # rank in distributed training, and will be aggregated after training.
     for subject, prediction_target, model_output, label in per_subject_outputs:
         logger.add_record({
             LoggingColumns.Epoch.value: current_epoch,
             LoggingColumns.Patient.value: subject,
             LoggingColumns.Hue.value: prediction_target,
             LoggingColumns.ModelOutput.value: model_output,
             LoggingColumns.Label.value: label,
             LoggingColumns.DataSplit.value: data_split.value
         })
Пример #5
0
    def _log_epoch(self, ep: int, modules: nn.ModuleDict) -> None:
        r"""
        Log the results of the epoch
        :param ep: Epoch number
        :param modules: Modules to log
        """
        flds = []
        for _, module in modules.items():
            flds.extend(module.epoch_log_fields())

        flds.append(time.time() - self._train_start)
        self._logger.log(ep, flds)
Пример #6
0
    def step(self, batch: Any, batch_idx: int, metrics: nn.ModuleDict) -> Any:
        """Implement the core logic for the training/validation/test step. By default this includes:

            - Inference on the current batch
            - Calculating the loss
            - Calculating relevant metrics

        Override for custom behavior.

        Args:
            batch: The output of your dataloader. Can either be a single Tensor or a list of Tensors.
            batch_idx: Integer displaying index of this batch
            metrics: A module dict containing metrics for calculating relevant training statitics

        Returns:
            A dict containing both the loss and relevant metrics
        """
        x, y = batch
        y_hat = self(x)
        y, y_hat = self.apply_filtering(y, y_hat)
        output = {OutputKeys.OUTPUT: y_hat}
        y_hat = self.to_loss_format(output[OutputKeys.OUTPUT])
        losses = {name: l_fn(y_hat, y) for name, l_fn in self.loss_fn.items()}

        y_hat = self.to_metrics_format(output[OutputKeys.OUTPUT])

        logs = {}

        for name, metric in metrics.items():
            if isinstance(metric, torchmetrics.metric.Metric):
                metric(y_hat, y)
                # PL 1.4.0 -> 1.4.9 tries to deepcopy the metric.
                # Sometimes _forward_cache is not a leaf, so we convert it to one.
                if not metric._forward_cache.is_leaf and not _PL_GREATER_EQUAL_1_5_0:
                    metric._forward_cache = metric._forward_cache.clone(
                    ).detach()
                logs[
                    name] = metric  # log the metric itself if it is of type Metric
            else:
                logs[name] = metric(y_hat, y)

        if len(losses.values()) > 1:
            logs["total_loss"] = sum(losses.values())
            return logs["total_loss"], logs

        output[OutputKeys.LOSS] = self.compute_loss(losses)
        output[OutputKeys.LOGS] = self.compute_logs(logs, losses)
        output[OutputKeys.TARGET] = y
        output[OutputKeys.BATCH_SIZE] = y.shape[0] if isinstance(
            y, torch.Tensor) else None
        return output
Пример #7
0
    def _fit(self, modules: nn.ModuleDict, train_dl: DeviceDataLoader,
             valid_dl: DeviceDataLoader):
        r""" Fits \p modules' learners to the training and validation \p DataLoader objects """
        self._configure_fit_vars(modules)

        for mod_name, module in modules.items():
            lr = config.get_learner_val(mod_name,
                                        LearnerParams.Attribute.LEARNING_RATE)
            wd = config.get_learner_val(mod_name,
                                        LearnerParams.Attribute.WEIGHT_DECAY)
            is_lin_ff = config.DATASET.is_synthetic(
            ) and module.module.num_hidden_layers == 0
            if is_lin_ff:
                module.optim = LBFGS(module.parameters(), lr=lr)
            else:
                module.optim = AdamW(module.parameters(),
                                     lr=lr,
                                     weight_decay=wd,
                                     amsgrad=True)
            logging.debug(
                f"{mod_name} Optimizer: {module.optim.__class__.__name__}")

        for ep in range(1, config.NUM_EPOCH + 1):
            # noinspection PyUnresolvedReferences
            for _, module in modules.items():
                module.epoch_start()

            for batch in train_dl:
                for _, module in modules.items():
                    module.process_batch(batch)

            for _, module in modules.items():
                module.calc_valid_loss(valid_dl)
            self._log_epoch(ep, modules)
        self._restore_best_model(modules)
        self.eval()
Пример #8
0
 def _configure_fit_vars(self, modules: nn.ModuleDict):
     r""" Set initial values/construct all variables used in a fit method """
     # Fields that apply regardless of loss method
     tb_dir = apu.utils.BASE_DIR / "tb"
     TrainingLogger.create_tensorboard(tb_dir)
     names, sizes = [], []
     for _, module in modules.items():
         _name, _size = module.logger_field_info()
         names.extend(_name)
         sizes.extend(_size)
     # Always log the time in number of seconds
     names.append("Time")
     sizes.append(10)
     self._logger = TrainingLogger(names,
                                   sizes,
                                   logger_name=apu.utils.LOGGER_NAME,
                                   tb_grp_name=self._name)
Пример #9
0
class SentenceEmbeddings(Module):
    @dataclass
    class Options(OptionsBase):
        dim_word: "word embedding dim" = 100
        dim_postag: "postag embedding dim. 0 for not using postag" = 100
        dim_char_input: "character embedding input dim" = 100
        dim_char: "character embedding dim. 0 for not using character" = 100
        word_dropout: "word embedding dropout" = 0.4
        postag_dropout: "postag embedding dropout" = 0.2
        character_embedding: CharacterEmbedding.Options = field(
            default_factory=CharacterEmbedding.Options)
        input_layer_norm: "Use layer norm on input embeddings" = True
        mode: str = argfield("concat", choices=["add", "concat"])
        replace_unk_with_chars: bool = False

    def __init__(self,
                 hparams: "SentenceEmbeddings.Options",
                 statistics,
                 plugins=None):

        super().__init__()
        self.hparams = hparams
        self.mode = hparams.mode
        self.plugins = ModuleDict(plugins) if plugins is not None else {}

        # embedding
        input_dims = {}
        if hparams.dim_word != 0:
            self.word_embeddings = Embedding(len(statistics.words),
                                             hparams.dim_word,
                                             padding_idx=0)
            self.word_dropout = FeatureDropout2(hparams.word_dropout)
            input_dims["word"] = hparams.dim_word

        if hparams.dim_postag != 0:
            self.pos_embeddings = Embedding(len(statistics.postags),
                                            hparams.dim_postag,
                                            padding_idx=0)
            self.pos_dropout = FeatureDropout2(hparams.postag_dropout)
            input_dims["postag"] = hparams.dim_postag

        if hparams.dim_char > 0:
            self.bilm = None
            self.character_lookup = Embedding(len(statistics.characters),
                                              hparams.dim_char_input)
            self.char_embeded = CharacterEmbedding.get(
                hparams.character_embedding,
                dim_char_input=hparams.dim_char_input,
                input_size=hparams.dim_char)
            if not hparams.replace_unk_with_chars:
                input_dims["char"] = hparams.dim_char
            else:
                assert hparams.dim_word == hparams.dim_char
        else:
            self.character_lookup = None

        for name, plugin in self.plugins.items():
            input_dims[name] = plugin.output_dim

        if hparams.mode == "concat":
            self.output_dim = sum(input_dims.values())
        else:
            assert hparams.mode == "add"
            uniq_input_dims = list(set(input_dims.values()))
            if len(uniq_input_dims) != 1:
                raise ValueError(f"Different input dims: {input_dims}")
            print(input_dims)
            self.output_dim = uniq_input_dims[0]

        self.input_layer_norm = LayerNorm(self.output_dim, eps=1e-6) \
            if hparams.input_layer_norm else None

    def reset_parameters(self):
        torch.nn.init.xavier_normal_(self.word_embeddings.weight.data)
        if self.hparams.dim_postag != 0:
            torch.nn.init.xavier_normal_(self.pos_embeddings.weight.data)
        if self.character_lookup is not None:
            torch.nn.init.xavier_normal_(self.character_lookup.weight.data)

    def forward(self, inputs, unk_idx=1):
        all_features = []

        if self.character_lookup is not None:
            # use character embedding instead
            # batch_size, bucket_size, word_length, embedding_dims
            char_embeded_4d = self.character_lookup(inputs.chars)
            word_embeded_by_char = self.char_embeded(inputs.word_lengths,
                                                     char_embeded_4d)
            if not self.hparams.replace_unk_with_chars:
                all_features.append(word_embeded_by_char)

        if self.hparams.dim_word != 0:
            word_embedding = self.word_dropout(
                self.word_embeddings(inputs.words))
            if self.hparams.dim_char and self.hparams.replace_unk_with_chars:
                unk = inputs.words.eq(unk_idx)
                # noinspection PyUnboundLocalVariable
                unk_word_embeded_by_char = word_embeded_by_char[unk]
                word_embedding[unk] = unk_word_embeded_by_char
            all_features.append(word_embedding)

        if self.hparams.dim_postag != 0:
            all_features.append(
                self.pos_dropout(self.pos_embeddings(inputs.postags)))

        for plugin in self.plugins.values():
            plugin_output = plugin(inputs)
            # FIXME: remove these two ugly tweak
            if plugin_output.shape[1] == inputs.words.shape[1] + 2:
                plugin_output = plugin_output[:, 1:-1]
            # pad external embedding to dim_word
            # if self.mode == "add" and plugin_output.shape[-1] < self.hparams.dim_word:
            #     plugin_output = torch.cat(
            #         [plugin_output,
            #          plugin_output.new_zeros(
            #              (*inputs.words.shape, self.hparams.dim_word - plugin_output.shape[-1]))], -1)
            all_features.append(plugin_output)

        if self.mode == "concat":
            total_input_embeded = torch.cat(all_features, -1)
        else:
            total_input_embeded = sum(all_features)

        if self.input_layer_norm is not None:
            total_input_embeded = self.input_layer_norm(total_input_embeded)

        return total_input_embeded
Пример #10
0
def process_state_temporal_kv(state: Dict, observation_net: nn.ModuleDict):
    x: List[torch.Tensor] = []
    for key, net in observation_net.items():
        x.append(process_state_temporal(state[key], net))
    x = torch.cat(x, dim=-1)
    return x
Пример #11
0
class QEOutputs(MetaModule):
    class Config(BaseConfig):
        word_level: WordLevelConfig = WordLevelConfig()
        sentence_level: SentenceLevelConfig = SentenceLevelConfig()
        sentence_loss_weight: float = 1.0
        'Multiplier for sentence_level loss weight.'

        dropout: float = 0.0
        last_activation: bool = False
        n_layers_output: int = 3

    def __init__(self, inputs_dims, vocabs: Dict[str, Vocabulary],
                 config: Config):
        super().__init__(config=config)

        self.inputs_dims = inputs_dims
        self.config = config
        self.vocabs = OrderedDict()
        self._metrics = None

        self.word_outputs = ModuleDict()

        tags_config = [
            (self.config.word_level.target, const.TARGET_TAGS),
            (self.config.word_level.source, const.SOURCE_TAGS),
            (self.config.word_level.gaps, const.GAP_TAGS),
        ]
        tags_sides = [tag for predict_tag, tag in tags_config if predict_tag]
        for tag_side in tags_sides:
            if tag_side not in vocabs:
                raise KeyError(
                    f'Asked to output {tag_side} but there is no vocabulary for it.'
                )
        if const.TARGET_TAGS in vocabs and self.config.word_level.target:
            class_weights = make_classes_loss_weights(
                vocab=vocabs[const.TARGET_TAGS],
                label_weights=self.config.word_level.class_weights[
                    const.TARGET_TAGS],
            )
            self.word_outputs[const.TARGET_TAGS] = WordLevelOutput(
                input_size=self.inputs_dims[const.TARGET],
                output_size=vocabs[const.TARGET_TAGS].net_length(),
                pad_idx=vocabs[const.TARGET_TAGS].pad_id,
                class_weights=class_weights,
                remove_first=vocabs[const.TARGET].bos_id,
                remove_last=vocabs[const.TARGET].eos_id,
            )
            self.vocabs[const.TARGET_TAGS] = vocabs[const.TARGET_TAGS]
        if const.GAP_TAGS in vocabs and self.config.word_level.gaps:
            class_weights = make_classes_loss_weights(
                vocab=vocabs[const.GAP_TAGS],
                label_weights=self.config.word_level.class_weights[
                    const.GAP_TAGS],
            )
            self.word_outputs[const.GAP_TAGS] = GapTagsOutput(
                input_size=self.inputs_dims[const.TARGET],
                output_size=vocabs[const.GAP_TAGS].net_length(),
                pad_idx=vocabs[const.GAP_TAGS].pad_id,
                class_weights=class_weights,
                remove_first=vocabs[const.TARGET].bos_id,
                remove_last=vocabs[const.TARGET].eos_id,
            )
            self.vocabs[const.GAP_TAGS] = vocabs[const.GAP_TAGS]
        if const.SOURCE_TAGS in vocabs and self.config.word_level.source:
            class_weights = make_classes_loss_weights(
                vocab=vocabs[const.SOURCE_TAGS],
                label_weights=self.config.word_level.class_weights[
                    const.SOURCE_TAGS],
            )
            self.word_outputs[const.SOURCE_TAGS] = WordLevelOutput(
                input_size=self.inputs_dims[const.SOURCE],
                output_size=vocabs[const.SOURCE_TAGS].net_length(),
                pad_idx=vocabs[const.SOURCE_TAGS].pad_id,
                class_weights=class_weights,
                remove_first=vocabs[const.SOURCE].bos_id,
                remove_last=vocabs[const.SOURCE].eos_id,
            )
            self.vocabs[const.SOURCE_TAGS] = vocabs[const.SOURCE_TAGS]

        # Sentence level
        self.sentence_outputs = ModuleDict()

        if self.config.sentence_level.hter:
            if False:  # FIXME: add flag for regressing over average of word predictions
                self.sentence_outputs[
                    const.SENTENCE_SCORES] = SentenceFromLogits()
            else:
                if const.TARGET_SENTENCE in self.inputs_dims:
                    input_size = self.inputs_dims[const.TARGET_SENTENCE]
                else:
                    input_size = self.inputs_dims[const.TARGET]
                if self.config.sentence_level.use_distribution:
                    sentence_scores = SentenceScoreDistribution(
                        input_size=input_size)
                else:
                    sentence_scores = SentenceScoreRegression(
                        input_size=input_size,
                        num_layers=self.config.n_layers_output,
                        final_activation=self.config.last_activation,
                    )
                self.sentence_outputs[const.SENTENCE_SCORES] = sentence_scores
        # Binary sentence level
        if self.config.sentence_level.binary:
            if const.TARGET_SENTENCE in self.inputs_dims:
                input_size = self.inputs_dims[const.TARGET_SENTENCE]
            else:
                input_size = self.inputs_dims[const.TARGET]
            self.sentence_outputs[const.BINARY] = BinarySentenceScore(
                input_size=input_size)

    def forward(self, features: Dict[str, Tensor],
                batch_inputs: MultiFieldBatch) -> Dict[str, Tensor]:
        outputs = OrderedDict()

        if self.config.word_level.target:
            if const.TARGET_TAGS in self.word_outputs and const.TARGET in features:
                outputs[const.TARGET_TAGS] = self.word_outputs[
                    const.TARGET_TAGS](features[const.TARGET], batch_inputs)
            elif const.TARGET_TAGS not in self.word_outputs:
                logger.warning(
                    f'Asked to output {const.TARGET_TAGS} but model has no layers for '
                    f'it; turning it off now.')
                self.config.word_level.target = False
            else:
                logger.warning(
                    f'Asked to output {const.TARGET_TAGS} but no features for '
                    f'{const.TARGET} were provided')
        if self.config.word_level.gaps:
            if const.GAP_TAGS in self.word_outputs and const.TARGET in features:
                outputs[const.GAP_TAGS] = self.word_outputs[const.GAP_TAGS](
                    features[const.TARGET], batch_inputs)
            elif const.GAP_TAGS not in self.word_outputs:
                logger.warning(
                    f'Asked to output {const.GAP_TAGS} but model has no layers for it; '
                    f'turning if off now.')
                self.config.word_level.gaps = False
            else:
                logger.warning(
                    f'Asked to output {const.GAP_TAGS} but no features for '
                    f'{const.TARGET} were provided')
        if self.config.word_level.source:
            if const.SOURCE_TAGS in self.word_outputs and const.SOURCE in features:
                outputs[const.SOURCE_TAGS] = self.word_outputs[
                    const.SOURCE_TAGS](features[const.SOURCE], batch_inputs)
            elif const.SOURCE_TAGS not in self.word_outputs:
                logger.warning(
                    f'Asked to output {const.SOURCE_TAGS} but model has no layers for '
                    f'it; turning it off now.')
                self.config.word_level.source = False
            else:
                logger.warning(
                    f'Asked to output {const.SOURCE_TAGS} but no features for '
                    f'{const.SOURCE} were provided.')

        # Sentence score and binary score prediction
        if self.config.sentence_level.hter:
            if False:  # FIXME: add flag for predicting from logits average
                _, lengths, *_ = batch_inputs[const.TARGET]
                sentence_score = self.sentence_pred(outputs[const.TARGET_TAGS],
                                                    lengths)
                outputs[const.SENTENCE_SCORES] = sentence_score
            else:
                if const.SENTENCE_SCORES in self.sentence_outputs and (
                        const.TARGET_SENTENCE in features
                        or const.TARGET in features):
                    sentence_features = features.get(const.TARGET_SENTENCE)
                    if sentence_features is None:
                        sentence_features = features[const.TARGET][:, 0]
                    sentence_scores = self.sentence_outputs[
                        const.SENTENCE_SCORES](sentence_features, batch_inputs)
                    outputs[const.SENTENCE_SCORES] = sentence_scores
                elif const.SENTENCE_SCORES not in self.sentence_outputs:
                    logger.warning(
                        f'Asked to output {const.SENTENCE_SCORES} but model has no '
                        f'layers for it; turning it off now.')
                    self.config.sentence_level.hter = False
                else:
                    logger.warning(
                        f'Asked to output {const.SENTENCE_SCORES} but no features for '
                        f'{const.TARGET_SENTENCE} or for {const.TARGET} were provided.'
                    )

        if self.config.sentence_level.binary:
            if const.BINARY in self.sentence_outputs and (
                    const.TARGET_SENTENCE in features
                    or const.TARGET in features):
                sentence_features = features.get(const.TARGET_SENTENCE)
                if sentence_features is None:
                    sentence_features = features[const.TARGET][:, 0]
                outputs[const.BINARY] = self.sentence_outputs[const.BINARY](
                    sentence_features, batch_inputs)
            elif const.BINARY not in self.sentence_outputs:
                logger.warning(
                    f'Asked to output {const.BINARY} but model has no layers for it; '
                    f'turning it off now.')
                self.config.sentence_level.binary = False
            else:
                logger.warning(
                    f'Asked to output {const.BINARY} but no features for '
                    f'{const.TARGET_SENTENCE} or for {const.TARGET} were provided.'
                )

        return outputs

    def loss(self, model_out: Dict[str, Tensor],
             batch: MultiFieldBatch) -> Dict[str, Tensor]:
        loss_dict = self.word_losses(model_out, batch)

        loss_sent_dict = self.sentence_losses(model_out, batch)
        for name, loss_value in loss_sent_dict.items():
            loss_dict[name] = loss_value * self.config.sentence_loss_weight

        loss_dict[const.LOSS] = sum(loss.sum()
                                    for _, loss in loss_dict.items())
        return loss_dict

    def word_losses(self, model_out: Dict[str, Tensor],
                    batch_outputs: MultiFieldBatch):
        """Compute sequence tagging loss."""
        word_loss = OrderedDict()
        for tag, layer in self.word_outputs.items():
            if tag in model_out:
                if tag not in batch_outputs:
                    raise ValueError(
                        f'Model predicted {tag} but true target is not in the batch.'
                    )
                logits = model_out[tag]
                logits = logits.transpose(1, 2)
                y = batch_outputs[tag].tensor
                try:
                    word_loss[tag] = layer.loss_fn(logits, y)
                except ValueError as e:
                    raise ValueError(f'with {tag}: {e}')
        return word_loss

    def sentence_losses(self, model_out: Dict[str, Tensor],
                        batch_outputs: MultiFieldBatch):
        """Compute sentence score loss."""
        sent_loss = OrderedDict()
        for label, layer in self.sentence_outputs.items():
            if label in model_out:
                if label not in batch_outputs:
                    raise ValueError(
                        f'Model predicted {label} but true target is not in the batch.'
                    )
                prediction = model_out[label]
                y = batch_outputs[label]
                sent_loss[label] = layer.loss_fn(prediction, y)
        return sent_loss

    def metrics_step(
        self,
        batch: MultiFieldBatch,
        model_out: Dict[str, Tensor],
        loss_dict: Dict[str, Tensor],
    ) -> Dict[str, Tensor]:
        metrics_dict = {}
        for metric in self.metrics:
            metrics_dict[metric.name] = metric.step(model_out=model_out,
                                                    batch=batch,
                                                    losses=loss_dict)
        return metrics_dict

    def metrics_end(self, steps: List[Dict[str, Tensor]], prefix=''):
        metrics_steps = defaultdict(list)
        for step in steps:
            for name, output in step.items():
                metrics_steps[name].append(output)
        metrics_steps = dict(metrics_steps)

        summary = {}
        for metric in self.metrics:
            summary.update(
                metric.compute(metrics_steps[metric.name], prefix=prefix))
        return summary

    @property
    def metrics(self) -> List[Metric]:
        if self._metrics is None:
            metrics = []
            if self.config.word_level.target and self.config.word_level.gaps:
                metrics += tag_metrics(
                    const.TARGET_TAGS,
                    const.GAP_TAGS,
                    prefix='WMT19_',
                    labels=self.labels(const.TARGET_TAGS),
                )
            if self.config.word_level.target:
                metrics += tag_metrics(const.TARGET_TAGS,
                                       labels=self.labels(const.TARGET_TAGS))
            if self.config.word_level.gaps:
                metrics += tag_metrics(const.GAP_TAGS,
                                       labels=self.labels(const.GAP_TAGS))
            if self.config.word_level.source:
                metrics += tag_metrics(const.SOURCE_TAGS,
                                       labels=self.labels(const.SOURCE_TAGS))

            if self.config.sentence_level.hter:
                metrics.append(PearsonMetric(const.SENTENCE_SCORES, prefix=''))
                metrics.append(SpearmanMetric(const.SENTENCE_SCORES,
                                              prefix=''))
                metrics.append(RMSEMetric(const.SENTENCE_SCORES, prefix=''))
            if self.config.sentence_level.binary:
                metrics.append(
                    CorrectMetric(
                        const.BINARY,
                        prefix='binary_',
                        labels=self.labels(const.TARGET_TAGS),
                    ))
            # metrics.append(LogMetric(log_targets=[(const.LOSS, const.LOSS)]))
            self._metrics = metrics
        return self._metrics

    def labels(self, field: str) -> List[str]:
        return [
            label for label in self.vocabs[field].itos
            if label not in self.vocabs[field].specials
        ]

    def decode_outputs(
        self,
        model_out: Dict[str, Tensor],
        batch_inputs: MultiFieldBatch,
        positive_class_label: str = const.BAD,
    ) -> Dict[str, List]:
        outputs = self.decode_word_outputs(model_out, batch_inputs,
                                           positive_class_label)
        outputs.update(self.decode_sentence_outputs(model_out))
        return outputs

    def decode_word_outputs(
        self,
        model_out: Dict[str, Tensor],
        batch_inputs: MultiFieldBatch,
        positive_class_label: str = const.BAD,
    ) -> Dict[str, List]:
        outputs = {}

        tags_config = [
            (const.TARGET_TAGS, 'target'),
            (const.SOURCE_TAGS, 'source'),
            (const.GAP_TAGS, 'target'),
        ]
        for key, input_side in tags_config:
            if key in model_out:
                # Models are assumed to return logits, not probabilities
                logits = model_out[key]
                probs = torch.softmax(logits, dim=-1)

                # Get string labels
                predicted_labels = probs.argmax(dim=-1, keepdim=False).tolist()

                # Get BAD probability
                class_index = torch.tensor(
                    [self.vocabs[key].token_to_id(positive_class_label)],
                    device=probs.device,
                    dtype=torch.long,
                )
                class_probs = probs.index_select(-1, class_index)
                class_probs = class_probs.squeeze(-1).tolist()

                # Convert into the right number of tokens per sample
                # Get lengths so we can unmask predictions and get rid of pads
                lengths = batch_inputs[input_side].number_of_tokens.clone()
                if key == const.GAP_TAGS:
                    lengths += 1  # Append one extra token

                for i, sample in enumerate(class_probs):
                    class_probs[i] = sample[:lengths[i]]
                for i, sample in enumerate(predicted_labels):
                    predicted_labels[i] = [
                        self.vocabs[key].id_to_token(x)
                        for x in sample[:lengths[i]]
                    ]

                outputs[key] = class_probs
                outputs[f'{key}_labels'] = predicted_labels

        return outputs

    @staticmethod
    def decode_sentence_outputs(
            model_out: Dict[str, Tensor]) -> Dict[str, List]:
        outputs = {}

        if const.SENTENCE_SCORES in model_out:
            sentence_scores = model_out[const.SENTENCE_SCORES]
            if isinstance(sentence_scores, tuple):
                # By convention, first element are scores, the rest are extra data
                #  specific to that layer.
                #  E.g., here the rest are mean and std)
                extras = torch.stack(sentence_scores[1]).T.tolist()
                outputs[f'{const.SENTENCE_SCORES}_extras'] = extras
                sentence_scores = sentence_scores[0]
            outputs[const.SENTENCE_SCORES] = sentence_scores.tolist()
        if const.BINARY in model_out:
            logits = model_out[const.BINARY]
            probs = torch.softmax(logits, dim=-1)
            class_probs = probs[...,
                                0]  # FIXME: shouldn't this 0 be 1, for BAD?
            outputs[const.BINARY] = class_probs.tolist()

        return outputs
Пример #12
0
def get_trunk_forward_outputs(
    feat: torch.Tensor,
    out_feat_keys: List[str],
    feature_blocks: nn.ModuleDict,
    feature_mapping: Dict[str, str] = None,
    use_checkpointing: bool = False,
    checkpointing_splits: int = 2,
) -> List[torch.Tensor]:
    """
    Args:
        feat: model input.
        out_feat_keys: a list/tuple with the feature names of the features that
            the function should return. By default the last feature of the network
            is returned.
        feature_blocks: ModuleDict containing feature blocks in the model
        feature_mapping: an optional correspondence table in between the requested
            feature names and the model's.

    Returns:
        out_feats: a list with the asked output features placed in the same order as in
        `out_feat_keys`.
    """

    # Sanitize inputs
    if feature_mapping is not None:
        out_feat_keys = [feature_mapping[f] for f in out_feat_keys]

    out_feat_keys, max_out_feat = parse_out_keys_arg(
        out_feat_keys, list(feature_blocks.keys())
    )

    # Forward pass over the trunk
    unique_out_feats = {}
    unique_out_feat_keys = list(set(out_feat_keys))

    # FIXME: Ideally this should only be done once at construction time
    if use_checkpointing:
        feature_blocks = checkpoint_trunk(
            feature_blocks, unique_out_feat_keys, checkpointing_splits
        )

        # If feat is the first input to the network, it doesn't have requires_grad,
        # which will make checkpoint's backward function not being called. So we need
        # to set it to true here.
        feat.requires_grad = True

    # Go through the blocks, and save the features as we go
    # NOTE: we are not doing several forward passes but instead just checking
    # whether the feature is requested to be returned.
    for i, (feature_name, feature_block) in enumerate(feature_blocks.items()):
        # The last chunk has to be non-volatile
        if use_checkpointing and i < len(feature_blocks) - 1:
            # Un-freeze the running stats in any BN layer
            for m in filter(lambda x: isinstance(x, _bn_cls), feature_block.modules()):
                m.track_running_stats = m.training

            feat = checkpoint(feature_block, feat)

            # Freeze the running stats in any BN layer
            # the checkpointing process will have to do another FW pass
            for m in filter(lambda x: isinstance(x, _bn_cls), feature_block.modules()):
                m.track_running_stats = False
        else:
            feat = feature_block(feat)

        # This feature is requested, store. If the same feature is requested several
        # times, we return the feature several times.
        if feature_name in unique_out_feat_keys:
            unique_out_feats[feature_name] = feat

        # Early exit if all the features have been collected
        if i == max_out_feat and not use_checkpointing:
            break

    # now return the features as requested by the user. If there are no duplicate keys,
    # return as is.
    if len(unique_out_feat_keys) == len(out_feat_keys):
        return list(unique_out_feats.values())

    output_feats = []
    for key_name in out_feat_keys:
        output_feats.append(unique_out_feats[key_name])
    return output_feats
Пример #13
0
class LightningMaskRCNN(pl.LightningModule):
    """Lightning version of the torchvision Mask R-CNN architecture with Stochastic Gradient
        Descent, mAP validation and DropLROnPlateau learning rate scheduler.

    :param num_classes: Number of classes of the Mask R-CNN (including the background, so the
        minimum is 2).
    :param learning_rate: Learning rate of the SGD optimizer.
    :param drop_lr_on_plateau_patience: Patience of the `DropLROnPlateau` learning rate scheduler,
        until it drops the learning rate by a factor of 10.
    :param model_kwargs: Keyword arguments which are given to
        `torchvision.models.detection.maskrcnn_resnet50_fpn` during model creation.
    """
    def __init__(
        self,
        num_classes: int = 2,
        learning_rate: float = 0.005,
        drop_lr_on_plateau_patience: int = 10,
        model_kwargs: Optional[Dict] = None,
    ):
        super().__init__()
        self.save_hyperparameters()

        self.num_classes = num_classes
        self.learning_rate = learning_rate
        self.drop_lr_on_plateau_patience = drop_lr_on_plateau_patience

        if model_kwargs is None:
            self.model_kwargs = {}
        else:
            self.model_kwargs = model_kwargs

        self.model = self.build_model()

        self.validation_metrics = ModuleDict({
            # "AP50": AveragePrecision(
            #     num_foreground_classes=num_classes - 1,
            #     iou_thresholds=[0.5],
            #     iou_type="mask",
            #     ap_calculation_type="COCO",
            # ),
            # "AP75": AveragePrecision(
            #     num_foreground_classes=num_classes - 1,
            #     iou_thresholds=[0.75],
            #     iou_type="mask",
            #     ap_calculation_type="COCO",
            # ),
            "mAP":
            AveragePrecision(
                num_foreground_classes=num_classes - 1,
                iou_thresholds=np.arange(0.5, 1, 0.05),
                iou_type="mask",
                ap_calculation_type="COCO",
            ),
            "confusion_matrix":
            ConfusionMatrix(
                num_classes,
                iou_type="mask",
                iou_threshold=0.5,
                score_threshold=0.5,
            ),
        })

        self.main_validation_metric_name = "mAP"

    def build_model(self) -> torchvision.models.detection.MaskRCNN:
        """Builds the Mask R-CNN model. Based on
            `torchvision.models.detection.maskrcnn_resnet50_fpn`.

        :return: Mask R-CNN model
        """
        # Load an instance segmentation model pre-trained on COCO.
        model = torchvision.models.detection.maskrcnn_resnet50_fpn(
            pretrained=True,
            **self.model_kwargs,
        )

        # Replace the pretrained box and the mask heads.
        in_features_box = model.roi_heads.box_predictor.cls_score.in_features
        model.roi_heads.box_predictor = torchvision.models.detection.faster_rcnn.FastRCNNPredictor(
            in_features_box, self.num_classes)

        in_features_mask = model.roi_heads.mask_predictor.conv5_mask.in_channels
        hidden_layer = 256
        model.roi_heads.mask_predictor = torchvision.models.detection.mask_rcnn.MaskRCNNPredictor(
            in_features_mask, hidden_layer, self.num_classes)

        return model

    def forward(
        self,
        images: Tuple[Image, ...],
        targets: Optional[Tuple[Annotation, ...]] = None
    ) -> Union[List[Annotation], PartialLosses]:
        """Forward pass through the Mask R-CNN.

        :param images: Input images.
        :param targets: Ground truth annotations.
        :return:
            During training: partial losses of the Mask R-CNN heads
            During validation and test: predictions
        """
        return self.model(images, targets)

    def training_step(self, batch: Batch, batch_idx: int) -> Loss:
        """Takes a batch, inputs it into the model and retrieves and logs losses of the prediction heads.
            Calculate the sum of losses and return it.

        :param batch: Batch of images and ground truths.
        :param batch_idx: Index of the current batch.
        :return: Sum of the losses of the prediction heads.
        """
        images, targets = batch
        partial_losses = self(images, targets)

        # TODO: Implement loss weights.
        loss = sum(partial_loss for partial_loss in partial_losses.values())

        self.log("training/loss", loss, on_epoch=True, on_step=False)

        for key, value in partial_losses.items():
            self.log("training/" + key, value, on_epoch=True, on_step=False)

        return loss

    def validation_step(self, batch: Batch,
                        batch_idx: int) -> ValidationOutput:
        """Take a batch from the validation data set and input its images into the model to
            retrieve the associated predictions. Return predictions and and ground truths for later use.

        :param batch: Batch of images and ground truths.
        :param batch_idx: Index of the current batch.
        :return: Predictions and ground truths.
        """
        images, targets = batch
        predictions = self(images)

        return {"predictions": predictions, "targets": targets}

    def validation_step_end(self, output: ValidationOutput) -> None:
        """Calculate and log the validation_metrics.

        :param output: Outputs of the validation step.
        """
        for metric_name, metric in self.validation_metrics.items():
            metric(output["predictions"], output["targets"])

            tag = f"validation/{metric_name}"

            if isinstance(metric, ConfusionMatrix):
                # TODO: Replace when https://github.com/PyTorchLightning/pytorch-lightning/pull/6227
                #  has been merged.
                self.logger.experiment.add_figure(
                    tag=tag,
                    figure=metric.plot(
                        self.trainer.val_dataloaders[0].dataset.class_names),
                    global_step=self.global_step,
                    close=True,
                )

            else:
                self.log(tag, metric)

            if metric_name == self.main_validation_metric_name:
                self.log("hp_metric", metric)

    def test_step(self, batch: Batch, batch_idx: int) -> TestOutput:
        """Take a batch from the test data set and input its images into the model to
            retrieve the associated predictions. Return predictions and and ground truths for later
            use.

        :param batch: Batch of images and ground truths.
        :param batch_idx: Index of the current batch.
        :return: Predictions, ground truths and input images.
        """
        images, _ = batch
        predictions = self(images)

        return {"predictions": predictions}

    def configure_optimizers(self, ) -> OptimizerConfiguration:
        """Configure the SGD optimizer and the ReduceLROnPlateau learning rate scheduler.

        :return: Dictionary with the optimizer, the learning rate scheduler and the name of the
            metric monitored by the learning rate scheduler.
        """

        # TODO: Test adding weight_decay= 0.0005.
        optimizer = torch.optim.SGD(self.parameters(), lr=self.learning_rate)

        return {
            "optimizer":
            optimizer,
            "lr_scheduler":
            ReduceLROnPlateau(optimizer,
                              mode="max",
                              patience=self.drop_lr_on_plateau_patience),
            "monitor":
            "validation/mAP",
        }
Пример #14
0
 def _restore_best_model(modules: nn.ModuleDict):
     r""" Restores the best trained model from disk """
     for _, module in modules.items():
         module.restore_best()
class ShallowShapleyNetwork(ShapleyNetwork, ABC):
    r"""
    The architecture of Shallow Shapley Networks.

    How it works:
        For each of the Shapley modules, we have a corresponding index list,
        which indicates which inputs to take into account in that Shapley
        module. From there we could just iterate through the dictionary
    """

    def __init__(
            self,
            module_dict: ModuleDict or dict,
            reference_values: torch.Tensor = None,
            dimensions: ModuleDimensions = None
    ):
        r"""
        Initialization

        Args:
            module_dict (): Only support the following form of ``modules``:
                ModuleDict[str:ShapleyModule]: this contains already
                initialized torch
                    modules, each of which will be wrapped in a
                    :class:`~{\sc ShapNet}.ShapleyModule`.
            reference_values (): the reference values for each of the variables
                Defaults to zeros
        """
        # input validation: correct the type of the module_dict argument
        module_dict = module_dict if not isinstance(module_dict, ModuleDict) \
            else ModuleDict(module_dict)
        module_dim = list(module_dict.values())[0].dimensions
        # input validation: correct the input for dimensions
        if dimensions is None:
            dimensions = ModuleDimensions(
                features=self._get_input_features(module_dict),
                in_channel=module_dim.in_channel,
                out_channel=module_dim.out_channel
            )

        super(ShallowShapleyNetwork, self).__init__(
            dimensions=dimensions,
            reference_values=reference_values,
        )
        if isinstance(module_dict, ModuleDict):
            self.module_dict = module_dict
        else:
            self.module_dict = ModuleDict(module_dict)
        # align the reference values of the entire model with those of the 
        # Shapley modules
        self.align_reference_values()

        assert len(self.module_dict) != 0
        assert isinstance(self.module_dict, ModuleDict)
        assert isinstance(list(module_dict.values())[0], ShapleyModule)

    def _get_input_features(self, module_dict: ModuleDict) -> int:
        r"""
        Get the number of input features by the ModuleDict input instance

        Args:
            module_dict (): the dictionary which contains the keys and 
                their corresponding modules.

        Returns:
            the number of features used.

        """
        indices = [self._get_keys(key) for key in module_dict.keys()]
        return len(set(flatten(indices)))

    def explained_forward(self, inputs: torch.Tensor) \
            -> Tuple[torch.Tensor, torch.Tensor]:
        r"""
        The forward method that performs explanation
        Args:
            inputs (): of shape (*, features, input_channels)

        Returns:
            the exact Shapley values for each input features

        """
        # setup place holders
        out = [0] * self.dimensions.features
        biases = 0
        for _, (key, module) in enumerate(self.module_dict.items()):
            indices = self._get_keys(key)
            bias, output = module.explained_forward(
                named_tensor_vector_indexing_single_dim(
                    inputs, dim=NAME_FEATURES, indices=indices))
            output = output.rename(None)

            biases = bias + biases
            for i, index in enumerate(indices):
                out[index] = out[index] + output[..., i, :]

        zeros = torch.zeros_like(output[..., 0, :])
        out = [o if isinstance(o, torch.Tensor) else zeros for o in out]
        shapley_values = torch.stack(
            out, -2
        ).refine_names(
            NAME_BATCH_SIZE, ..., NAME_FEATURES, NAME_META_CHANNELS)

        return shapley_values, biases

    def unexplained_forward(self, inputs: torch.Tensor):
        r"""
        The forward operation that is not explained, this is the usual
        forward method for the underlying function. In other words, this 
        works exactly as the underlying function does.

        Args:
            inputs: the input samples.
        """
        output = 0
        for _, (key, module) in enumerate(self.module_dict.items()):
            indices = self._get_keys(key)
            out = module.unexplained_forward(
                named_tensor_vector_indexing_single_dim(
                    inputs, dim=NAME_FEATURES, indices=indices))
            output = output + out

        return output

    @staticmethod
    def _get_keys(key_string: str) -> List[int]:
        r"""
        Get feature index from the key string
        :param key_string: the key string from the dictionary
        :return: the first feature index and the second feature index
        """
        return [int(key) for key in re.findall(r"\d+", key_string)]

    def align_reference_values(self, reference_values: torch.Tensor = None):
        """
        put the overall reference values to the first stage-reference values
        Returns:

        """
        if reference_values is None:
            reference_values = self.reference_values
        for key, module in self.module_dict.items():
            module.reference_values = named_tensor_vector_indexing_single_dim(
                reference_values, NAME_FEATURES,
                self._get_indices_from_key(key))
Пример #16
0
class AuTopologyReadOut(nn.Module):

    """
    Class for reading out results from a convolution using AuTopology.
    Attributes:
        terms (dict): dictionary of the types of AuTopology potentials used
            for each kind of topology (e.g. Morse for harmonic, LJ for pairs,
            etc.)
        auto_modules (torch.nn.ModuleDict): module dictionary for all the topology
            nets associated with each energy state. E.g. of the form {"energy_0":
            {"bond": BondNet0, "angle": AngletNet0}, "energy_1": {"bond": BondNet1,
            "angle": AngletNet1} }.
    """

    def __init__(self, multitaskdict):

        """
        Args:
            multitaskdict (dict): dictionary of items used for setting up the networks.
        Returns:
            None
        """

        super(AuTopologyReadOut, self).__init__()

        trainable = multitaskdict["trainable_prior"]
        Fr = multitaskdict["Fr"]
        Lh = multitaskdict["Lh"]
        # bond_terms = multitaskdict.get("bond_terms", ["morse"])
        # angle_terms =  multitaskdict.get("angle_terms", ['harmonic'])  # harmonic and/or cubic and/or quartic
        # dihedral_terms = multitaskdict.get("dihedral_terms", ['OPLS'])  # OPLS and/or multiharmonic
        # improper_terms = multitaskdict.get("improper_terms", ['harmonic'])  # harmonic
        # pair_terms = multitaskdict.get("pair_terms", ['LJ'])  # coulomb and/or LJ and/or induced_dipole
        autopology_keys = multitaskdict["output_keys"]

        default_terms_dict = {
            "bond_terms": ["morse"],
            "angle_terms": ["harmonic"],
            "dihedral_terms": ["OPLS"],
            "improper_terms": ["harmonic"],
            "pair_terms": ["LJ", "coulombs"]
        }

        # self.terms = {
        #     'bond': bond_terms,
        #     'angle': angle_terms,
        #     'dihedral': dihedral_terms,
        #     'improper': improper_terms,
        #     'pair': pair_terms
        # }
        self.terms = {}

        # remove terms that is not included 
        for top in ['bond', 'angle', 'dihedral', 'improper', 'pair']:
            if top + '_terms' in multitaskdict.keys():
                self.terms[top] = multitaskdict.get(top + '_terms', default_terms_dict[top + '_terms'])


        topologynet = {key: {} for key in autopology_keys}
        for key in autopology_keys:
            for top in self.terms.keys():
                if top + '_terms' in multitaskdict:
                    topologynet[key][top] = TopologyNet[top](Fr, Lh, self.terms[top], trainable=trainable)


        # module dictionary of the form {"energy_0": {"bond": BondNet0, "angle": AngletNet0},
        # "energy_1": {"bond": BondNet1, "angle": AngletNet1} }
        self.auto_modules = ModuleDict({key: ModuleDict({top: topologynet[key][top] for top in
            self.terms.keys()}) for key in autopology_keys})

        # energy offset for each state
        self.offset = ModuleDict({key: ParameterPredictor(Fr, Lh, 1)
            for key in autopology_keys})


    def forward(self, r, batch, xyz, take_grad=True):

        output = dict()

        # loop through output keys (e.g. energy_0 and energy_1)
        for output_key, top_set in self.auto_modules.items():
            E = {key: 0.0 for key in list(self.terms.keys()) + ['total']}
            learned_params = {}
            # loop through associated topology nets (e.g. BondNet0 and AngletNet0 or
            # BondNet1 and AngletNet1)
            for top, top_net in top_set.items():
                E[top] = top_net(r, batch, xyz)
                learned_params[top] = top_net.learned_params
                E['total'] += E[top]

            N = batch["num_atoms"].cpu().numpy().tolist()
            offset = torch.split(self.offset[output_key](r), N)
            offset = (torch.stack([torch.sum(item) for item in offset])).reshape(-1, 1)

            output[output_key] = E["total"] + offset

            if take_grad:
                grad = compute_grad(inputs=xyz, output=E["total"])
                output[output_key + "_grad"] = grad

        return output
Пример #17
0
def get_tunk_forward_interpolated_outputs(
    input_type: str,  # bgr or rgb or lab
    interpolate_out_feat_key_name: str,
    remove_padding_before_feat_key_name: str,
    feat: MultiDimensionalTensor,
    feature_blocks: nn.ModuleDict,
    feature_mapping: Dict[str, str] = None,
    use_checkpointing: bool = False,
    checkpointing_splits: int = 2,
) -> List[torch.Tensor]:
    """
    Args:
        input_type (AttrDict): whether the model input should be RGB or BGR or LAB
        interpolate_out_feat_key_name (str): what feature dimensions should be
            used to interpolate the mask
        remove_padding_before_feat_key_name (str): name of the feature block for which
            the input should have padding removed using the interpolated mask
        feat (MultiDimensionalTensor): model input
        feature_blocks (nn.ModuleDict): ModuleDict containing feature blocks in the model
        feature_mapping (Dict[str, str]): an optional correspondence table in between
            the requested feature names and the model's.

    Returns:
        out_feats: a list with the asked output features placed in the same order as in
            `out_feat_keys`.
    """
    if feature_mapping is not None:
        interpolate_out_feat_key_name = feature_mapping[interpolate_out_feat_key_name]

    model_input = transform_model_input_data_type(feat.tensor, input_type)
    out = get_trunk_forward_outputs(
        feat=model_input,
        out_feat_keys=[interpolate_out_feat_key_name],
        feature_blocks=feature_blocks,
        use_checkpointing=use_checkpointing,
        checkpointing_splits=checkpointing_splits,
    )
    # mask is of shape N x H x W and has 1.0 value for places with padding
    # we interpolate the mask spatially to N x out.shape[-2] x out.shape[-1].
    interp_mask = F.interpolate(feat.mask[None].float(), size=out[0].shape[-2:]).to(
        torch.bool
    )[0]

    # we want to iterate over the rest of the feature blocks now
    _, max_out_feat = parse_out_keys_arg(
        [interpolate_out_feat_key_name], list(feature_blocks.keys())
    )
    for i, (feature_name, feature_block) in enumerate(feature_blocks.items()):
        # We have already done the forward till the max_out_feat.
        # we forward through rest of the blocks now.
        if i >= (max_out_feat + 1):
            if remove_padding_before_feat_key_name and (
                feature_name == remove_padding_before_feat_key_name
            ):
                # negate the mask so that the padded locations have 0.0 and the non-padded
                # locations have 1.0. This is used to extract the h, w of the original tensors.
                interp_mask = (~interp_mask).chunk(len(feat.image_sizes))
                tensors = out[0].chunk(len(feat.image_sizes))
                res = []
                for i, tensor in enumerate(tensors):
                    w = torch.sum(interp_mask[i][0], dim=0)[0]
                    h = torch.sum(interp_mask[i][0], dim=1)[0]
                    res.append(feature_block(tensor[:, :, :w, :h]))
                out[0] = torch.cat(res)
            else:
                out[0] = feature_block(out[0])
    return out
Пример #18
0
class TLMOutputs(MetaModule):
    class Config(BaseConfig):
        fine_tune: bool = False
        """Continue training an encoder on the post-edited text.
        Recommended if you have access to PE.
        Requires setting `system.data.train.input.pe`, `system.data.valid.input.pe`"""

        # pretrain: bool = False
        # """Train an encoder from scratch on parallel corpora.
        # Used to pretrain TLM models (like the Predictor).
        # """

    def __init__(
        self,
        inputs_dims: Dict[str, int],
        vocabs: Dict[str, Vocabulary],
        config: Config,
        pretraining: bool = False,
    ):
        super().__init__(config=config)

        self.inputs_dims = inputs_dims
        self.vocabs = OrderedDict()
        self.config = config
        self.pretraining = pretraining
        self._metrics = None

        self.masked_word_outputs = ModuleDict()

        if self.pretraining:
            if const.TARGET not in vocabs:
                raise ValueError(
                    f'Asked to pretrain the encoder (`pretrain`) but no '
                    f'vocabulary was provided for {const.TARGET}')
            if const.TARGET_LOGITS not in self.inputs_dims:
                raise ValueError(
                    'Asked to pretrain the encoder (`pretrain`) but no '
                    'target data was provided')
            self.masked_word_outputs[const.TARGET] = MaskedWordOutput(
                input_size=self.inputs_dims[const.TARGET_LOGITS],
                pad_idx=vocabs[const.TARGET].pad_id,
                start_idx=vocabs[const.TARGET].bos_id,
                stop_idx=vocabs[const.TARGET].eos_id,
            )
            self.vocabs[const.TARGET] = vocabs[const.TARGET]

        if self.config.fine_tune:
            # Target side; use PE for fine-tuning
            if const.PE not in vocabs:
                raise ValueError(
                    f'Asked to fine-tune the encoder (`fine_tune`) but no '
                    f'vocabulary was provided for {const.PE}')
            if const.PE_LOGITS not in self.inputs_dims:
                raise ValueError(
                    'Asked to fine-tune the encoder (`fine_tune`) but no '
                    'post-edit (PE) data was provided')
            self.masked_word_outputs[const.PE] = MaskedWordOutput(
                input_size=self.inputs_dims[const.PE_LOGITS],
                pad_idx=vocabs[const.PE].pad_id,
                start_idx=vocabs[const.PE].bos_id,
                stop_idx=vocabs[const.PE].eos_id,
            )
            self.vocabs[const.PE] = vocabs[const.PE]

    def forward(self, features, batch_inputs):
        outputs = OrderedDict()

        if const.PE_LOGITS in features and const.PE in self.masked_word_outputs:
            outputs[const.PE] = self.masked_word_outputs[const.PE](
                features[const.PE_LOGITS])
        if const.TARGET_LOGITS in features and const.TARGET in self.masked_word_outputs:
            outputs[const.TARGET] = self.masked_word_outputs[const.TARGET](
                features[const.TARGET_LOGITS])

        return outputs

    def loss(self, model_out, batch_outputs):
        loss_dict = OrderedDict()
        for output_side, layer in self.masked_word_outputs.items():
            if output_side in model_out:
                if output_side not in batch_outputs:
                    raise ValueError(
                        f'Model predicted {output_side} but true target is not in the '
                        f'batch.')
                target = batch_outputs[output_side].tensor
                # There are predictions for first/last element, so we want to
                #  mask them out if they are BOS and EOS tokens.
                target = replace_token(target, layer.start_idx, layer.pad_idx)
                target = replace_token(target, layer.stop_idx, layer.pad_idx)
                # Predicted Class must be in dim 1 for xentropyloss
                logits = model_out[output_side]
                logits = logits.transpose(1, 2)

                loss_dict[output_side] = layer.loss_fn(logits, target)

        loss_dict[const.LOSS] = sum(loss.sum()
                                    for _, loss in loss_dict.items())
        return loss_dict

    def metrics_step(self, batch, model_out, loss_dict):
        metrics_dict = {}
        for metric in self.metrics:
            metrics_dict[metric.name] = metric.step(model_out=model_out,
                                                    batch=batch,
                                                    losses=loss_dict)
        return metrics_dict

    def metrics_end(self, steps, prefix=''):
        metrics_steps = defaultdict(list)
        for step in steps:
            for name, output in step.items():
                metrics_steps[name].append(output)
        metrics_steps = dict(metrics_steps)

        summary = {}
        for metric in self.metrics:
            summary.update(
                metric.compute(metrics_steps[metric.name], prefix=prefix))
        return summary

    @property
    def metrics(self):
        if self._metrics is None:
            metrics = []
            for output_side, layer in self.masked_word_outputs.items():
                metrics.append(PerplexityMetric(output_side))
                metrics.append(
                    ExpectedErrorMetric(output_side,
                                        labels=self.labels(output_side)))
                metrics.append(
                    CorrectMetric(output_side,
                                  labels=self.labels(output_side)))
            self._metrics = metrics
        return self._metrics

    def labels(self, field):
        return [
            label for label in self.vocabs[field].itos
            if label not in self.vocabs[field].specials
        ]
Пример #19
0
class Event2Mind(Model):
    """
    This ``Event2Mind`` class is a :class:`Model` which takes an event
    sequence, encodes it, and then uses the encoded representation to decode
    several mental state sequences.

    It is based on `the paper by Rashkin et al.
    <https://www.semanticscholar.org/paper/Event2Mind/b89f8a9b2192a8f2018eead6b135ed30a1f2144d>`_

    Parameters
    ----------
    vocab : ``Vocabulary``, required
        Vocabulary containing source and target vocabularies. They may be under the same namespace
        (``tokens``) or the target tokens can have a different namespace, in which case it needs to
        be specified as ``target_namespace``.
    source_embedder : ``TextFieldEmbedder``, required
        Embedder for source side sequences.
    embedding_dropout: float, required
        The amount of dropout to apply after the source tokens have been embedded.
    encoder : ``Seq2VecEncoder``, required
        The encoder of the "encoder/decoder" model.
    max_decoding_steps : int, required
        Length of decoded sequences.
    beam_size : int, optional (default = 10)
        The width of the beam search.
    target_names: ``List[str]``, optional, (default = ['xintent', 'xreact', 'oreact'])
        Names of the target fields matching those in the ``Instance`` objects.
    target_namespace : str, optional (default = 'tokens')
        If the target side vocabulary is different from the source side's, you need to specify the
        target's namespace here. If not, we'll assume it is "tokens", which is also the default
        choice for the source side, and this might cause them to share vocabularies.
    target_embedding_dim : int, optional (default = source_embedding_dim)
        You can specify an embedding dimensionality for the target side. If not, we'll use the same
        value as the source embedder's.
    """
    def __init__(self,
                 vocab: Vocabulary,
                 source_embedder: TextFieldEmbedder,
                 embedding_dropout: float,
                 encoder: Seq2VecEncoder,
                 max_decoding_steps: int,
                 beam_size: int = 10,
                 target_names: List[str] = None,
                 target_namespace: str = "tokens",
                 target_embedding_dim: int = None) -> None:
        super().__init__(vocab)
        target_names = target_names or ["xintent", "xreact", "oreact"]

        # Note: The original tweaks the embeddings for "personx" to be the mean
        # across the embeddings for "he", "she", "him" and "her". Similarly for
        # "personx's" and so forth. We could consider that here as a well.
        self._source_embedder = source_embedder
        self._embedding_dropout = nn.Dropout(embedding_dropout)
        self._encoder = encoder
        self._max_decoding_steps = max_decoding_steps
        self._target_namespace = target_namespace

        # We need the start symbol to provide as the input at the first timestep of decoding, and
        # end symbol as a way to indicate the end of the decoded sequence.
        self._start_index = self.vocab.get_token_index(START_SYMBOL,
                                                       self._target_namespace)
        self._end_index = self.vocab.get_token_index(END_SYMBOL,
                                                     self._target_namespace)
        # Warning: The different decoders share a vocabulary! This may be
        # counterintuitive, but consider the case of xreact and oreact. A
        # reaction of "happy" could easily apply to both the subject of the
        # event and others. This could become less appropriate as more decoders
        # are added.
        num_classes = self.vocab.get_vocab_size(self._target_namespace)
        # Decoder output dim needs to be the same as the encoder output dim since we initialize the
        # hidden state of the decoder with that of the final hidden states of the encoder.
        self._decoder_output_dim = self._encoder.get_output_dim()
        target_embedding_dim = target_embedding_dim or self._source_embedder.get_output_dim(
        )

        self._states = ModuleDict()
        for name in target_names:
            self._states[name] = StateDecoder(num_classes,
                                              target_embedding_dim,
                                              self._decoder_output_dim)

        self._beam_search = BeamSearch(self._end_index,
                                       beam_size=beam_size,
                                       max_steps=max_decoding_steps)

    def _update_recall(self, all_top_k_predictions: torch.Tensor,
                       target_tokens: Dict[str, torch.LongTensor],
                       target_recall: UnigramRecall) -> None:
        targets = target_tokens["tokens"]
        target_mask = get_text_field_mask(target_tokens)
        # See comment in _get_loss.
        # TODO(brendanr): Do we need contiguous here?
        relevant_targets = targets[:, 1:].contiguous()
        relevant_mask = target_mask[:, 1:].contiguous()
        target_recall(all_top_k_predictions, relevant_targets, relevant_mask,
                      self._end_index)

    def _get_num_decoding_steps(
            self, target_tokens: Optional[Dict[str, torch.LongTensor]]) -> int:
        if target_tokens:
            targets = target_tokens["tokens"]
            target_sequence_length = targets.size()[1]
            # The last input from the target is either padding or the end
            # symbol.  Either way, we don't have to process it. (To be clear,
            # we do still output and compare against the end symbol, but there
            # is no need to take the end symbol as input to the decoder.)
            return target_sequence_length - 1
        else:
            return self._max_decoding_steps

    @overrides
    def forward(
        self,  # type: ignore
        source: Dict[str, torch.LongTensor],
        **target_tokens: Dict[str, Dict[str, torch.LongTensor]]
    ) -> Dict[str, torch.Tensor]:
        # pylint: disable=arguments-differ
        """
        Decoder logic for producing the target sequences.

        Parameters
        ----------
        source : ``Dict[str, torch.LongTensor]``
            The output of ``TextField.as_array()`` applied on the source
            ``TextField``. This will be passed through a ``TextFieldEmbedder``
            and then through an encoder.
        target_tokens : ``Dict[str, Dict[str, torch.LongTensor]]``:
            Dictionary from name to output of ``Textfield.as_array()`` applied
            on target ``TextField``. We assume that the target tokens are also
            represented as a ``TextField``.
        """
        # (batch_size, input_sequence_length, embedding_dim)
        embedded_input = self._embedding_dropout(self._source_embedder(source))
        source_mask = get_text_field_mask(source)
        # (batch_size, encoder_output_dim)
        final_encoder_output = self._encoder(embedded_input, source_mask)
        output_dict = {}

        # Perform greedy search so we can get the loss.
        if target_tokens:
            if target_tokens.keys() != self._states.keys():
                target_only = target_tokens.keys() - self._states.keys()
                states_only = self._states.keys() - target_tokens.keys()
                raise Exception(
                    "Mismatch between target_tokens and self._states. Keys in "
                    +
                    f"targets only: {target_only} Keys in states only: {states_only}"
                )
            total_loss = 0
            for name, state in self._states.items():
                loss = self.greedy_search(
                    final_encoder_output=final_encoder_output,
                    target_tokens=target_tokens[name],
                    target_embedder=state.embedder,
                    decoder_cell=state.decoder_cell,
                    output_projection_layer=state.output_projection_layer)
                total_loss += loss
                output_dict[f"{name}_loss"] = loss

            # Use mean loss (instead of the sum of the losses) to be comparable to the paper.
            output_dict["loss"] = total_loss / len(self._states)

        # Perform beam search to obtain the predictions.
        if not self.training:
            batch_size = final_encoder_output.size()[0]
            for name, state in self._states.items():
                start_predictions = final_encoder_output.new_full(
                    (batch_size, ),
                    fill_value=self._start_index,
                    dtype=torch.long)
                start_state = {"decoder_hidden": final_encoder_output}

                # (batch_size, 10, num_decoding_steps)
                all_top_k_predictions, log_probabilities = self._beam_search.search(
                    start_predictions, start_state, state.take_step)

                if target_tokens:
                    self._update_recall(all_top_k_predictions,
                                        target_tokens[name], state.recall)
                output_dict[
                    f"{name}_top_k_predictions"] = all_top_k_predictions
                output_dict[
                    f"{name}_top_k_log_probabilities"] = log_probabilities

        return output_dict

    def greedy_search(self, final_encoder_output: torch.LongTensor,
                      target_tokens: Dict[str, torch.LongTensor],
                      target_embedder: Embedding, decoder_cell: GRUCell,
                      output_projection_layer: Linear) -> torch.FloatTensor:
        """
        Greedily produces a sequence using the provided ``decoder_cell``.
        Returns the cross entropy between this sequence and ``target_tokens``.

        Parameters
        ----------
        final_encoder_output : ``torch.LongTensor``, required
            Vector produced by ``self._encoder``.
        target_tokens : ``Dict[str, torch.LongTensor]``, required
            The output of ``TextField.as_array()`` applied on some target ``TextField``.
        target_embedder : ``Embedding``, required
            Used to embed the target tokens.
        decoder_cell: ``GRUCell``, required
            The recurrent cell used at each time step.
        output_projection_layer: ``Linear``, required
            Linear layer mapping to the desired number of classes.
        """
        num_decoding_steps = self._get_num_decoding_steps(target_tokens)
        targets = target_tokens["tokens"]
        decoder_hidden = final_encoder_output
        step_logits = []
        for timestep in range(num_decoding_steps):
            # See https://github.com/allenai/allennlp/issues/1134.
            input_choices = targets[:, timestep]
            decoder_input = target_embedder(input_choices)
            decoder_hidden = decoder_cell(decoder_input, decoder_hidden)
            # (batch_size, num_classes)
            output_projections = output_projection_layer(decoder_hidden)
            # list of (batch_size, 1, num_classes)
            step_logits.append(output_projections.unsqueeze(1))
        # (batch_size, num_decoding_steps, num_classes)
        logits = torch.cat(step_logits, 1)
        target_mask = get_text_field_mask(target_tokens)
        return self._get_loss(logits, targets, target_mask)

    def greedy_predict(self, final_encoder_output: torch.LongTensor,
                       target_embedder: Embedding, decoder_cell: GRUCell,
                       output_projection_layer: Linear) -> torch.Tensor:
        """
        Greedily produces a sequence using the provided ``decoder_cell``.
        Returns the predicted sequence.

        Parameters
        ----------
        final_encoder_output : ``torch.LongTensor``, required
            Vector produced by ``self._encoder``.
        target_embedder : ``Embedding``, required
            Used to embed the target tokens.
        decoder_cell: ``GRUCell``, required
            The recurrent cell used at each time step.
        output_projection_layer: ``Linear``, required
            Linear layer mapping to the desired number of classes.
        """
        num_decoding_steps = self._max_decoding_steps
        decoder_hidden = final_encoder_output
        batch_size = final_encoder_output.size()[0]
        predictions = [
            final_encoder_output.new_full((batch_size, ),
                                          fill_value=self._start_index,
                                          dtype=torch.long)
        ]
        for _ in range(num_decoding_steps):
            input_choices = predictions[-1]
            decoder_input = target_embedder(input_choices)
            decoder_hidden = decoder_cell(decoder_input, decoder_hidden)
            # (batch_size, num_classes)
            output_projections = output_projection_layer(decoder_hidden)
            class_probabilities = F.softmax(output_projections, dim=-1)
            _, predicted_classes = torch.max(class_probabilities, 1)
            predictions.append(predicted_classes)
        all_predictions = torch.cat([ps.unsqueeze(1) for ps in predictions], 1)
        # Drop start symbol and return.
        return all_predictions[:, 1:]

    @staticmethod
    def _get_loss(logits: torch.LongTensor, targets: torch.LongTensor,
                  target_mask: torch.LongTensor) -> torch.FloatTensor:
        """
        Takes logits (unnormalized outputs from the decoder) of size (batch_size,
        num_decoding_steps, num_classes), target indices of size (batch_size, num_decoding_steps+1)
        and corresponding masks of size (batch_size, num_decoding_steps+1) steps and computes cross
        entropy loss while taking the mask into account.

        The length of ``targets`` is expected to be greater than that of ``logits`` because the
        decoder does not need to compute the output corresponding to the last timestep of
        ``targets``. This method aligns the inputs appropriately to compute the loss.

        During training, we want the logit corresponding to timestep i to be similar to the target
        token from timestep i + 1. That is, the targets should be shifted by one timestep for
        appropriate comparison.  Consider a single example where the target has 3 words, and
        padding is to 7 tokens.
           The complete sequence would correspond to <S> w1  w2  w3  <E> <P> <P>
           and the mask would be                     1   1   1   1   1   0   0
           and let the logits be                     l1  l2  l3  l4  l5  l6
        We actually need to compare:
           the sequence           w1  w2  w3  <E> <P> <P>
           with masks             1   1   1   1   0   0
           against                l1  l2  l3  l4  l5  l6
           (where the input was)  <S> w1  w2  w3  <E> <P>
        """
        relevant_targets = targets[:, 1:].contiguous(
        )  # (batch_size, num_decoding_steps)
        relevant_mask = target_mask[:, 1:].contiguous(
        )  # (batch_size, num_decoding_steps)
        loss = sequence_cross_entropy_with_logits(logits, relevant_targets,
                                                  relevant_mask)
        return loss

    def decode_all(self, predicted_indices: torch.Tensor) -> List[List[str]]:
        if not isinstance(predicted_indices, numpy.ndarray):
            predicted_indices = predicted_indices.detach().cpu().numpy()
        all_predicted_tokens = []
        for indices in predicted_indices:
            indices = list(indices)
            # Collect indices till the first end_symbol
            if self._end_index in indices:
                indices = indices[:indices.index(self._end_index)]
            predicted_tokens = [
                self.vocab.get_token_from_index(
                    x, namespace=self._target_namespace) for x in indices
            ]
            all_predicted_tokens.append(predicted_tokens)
        return all_predicted_tokens

    @overrides
    def decode(
            self,
            output_dict: Dict[str,
                              torch.Tensor]) -> Dict[str, List[List[str]]]:
        """
        This method overrides ``Model.decode``, which gets called after ``Model.forward``, at test
        time, to finalize predictions. The logic for the decoder part of the encoder-decoder lives
        within the ``forward`` method.

        This method trims the output predictions to the first end symbol, replaces indices with
        corresponding tokens, and adds fields for the tokens to the ``output_dict``.
        """
        for name in self._states:
            top_k_predicted_indices = output_dict[f"{name}_top_k_predictions"][
                0]
            output_dict[f"{name}_top_k_predicted_tokens"] = [
                self.decode_all(top_k_predicted_indices)
            ]

        return output_dict

    @overrides
    def get_metrics(self, reset: bool = False) -> Dict[str, float]:
        all_metrics = {}
        # Recall@10 needs beam search which doesn't happen during training.
        if not self.training:
            for name, state in self._states.items():
                all_metrics[name] = state.recall.get_metric(reset=reset)
        return all_metrics
Пример #20
0
class Event2Mind(Model):
    """
    This ``Event2Mind`` class is a :class:`Model` which takes an event
    sequence, encodes it, and then uses the encoded representation to decode
    several mental state sequences.

    It is based on `the paper by Rashkin et al.
    <https://www.semanticscholar.org/paper/Event2Mind/b89f8a9b2192a8f2018eead6b135ed30a1f2144d>`_

    Parameters
    ----------
    vocab : ``Vocabulary``, required
        Vocabulary containing source and target vocabularies. They may be under the same namespace
        (``tokens``) or the target tokens can have a different namespace, in which case it needs to
        be specified as ``target_namespace``.
    source_embedder : ``TextFieldEmbedder``, required
        Embedder for source side sequences.
    embedding_dropout: float, required
        The amount of dropout to apply after the source tokens have been embedded.
    encoder : ``Seq2VecEncoder``, required
        The encoder of the "encoder/decoder" model.
    max_decoding_steps : int, required
        Length of decoded sequences.
    beam_size : int, optional (default = 10)
        The width of the beam search.
    target_names: ``List[str]``, optional, (default = ['xintent', 'xreact', 'oreact'])
        Names of the target fields matching those in the ``Instance`` objects.
    target_namespace : str, optional (default = 'tokens')
        If the target side vocabulary is different from the source side's, you need to specify the
        target's namespace here. If not, we'll assume it is "tokens", which is also the default
        choice for the source side, and this might cause them to share vocabularies.
    target_embedding_dim : int, optional (default = source_embedding_dim)
        You can specify an embedding dimensionality for the target side. If not, we'll use the same
        value as the source embedder's.
    """
    def __init__(self,
                 vocab: Vocabulary,
                 source_embedder: TextFieldEmbedder,
                 embedding_dropout: float,
                 encoder: Seq2VecEncoder,
                 max_decoding_steps: int,
                 beam_size: int = 10,
                 target_names: List[str] = None,
                 target_namespace: str = "tokens",
                 target_embedding_dim: int = None) -> None:
        super().__init__(vocab)
        target_names = target_names or ["xintent", "xreact", "oreact"]

        # Note: The original tweaks the embeddings for "personx" to be the mean
        # across the embeddings for "he", "she", "him" and "her". Similarly for
        # "personx's" and so forth. We could consider that here as a well.
        self._source_embedder = source_embedder
        self._embedding_dropout = nn.Dropout(embedding_dropout)
        self._encoder = encoder
        self._max_decoding_steps = max_decoding_steps
        self._target_namespace = target_namespace

        # We need the start symbol to provide as the input at the first timestep of decoding, and
        # end symbol as a way to indicate the end of the decoded sequence.
        self._start_index = self.vocab.get_token_index(START_SYMBOL, self._target_namespace)
        self._end_index = self.vocab.get_token_index(END_SYMBOL, self._target_namespace)
        # Warning: The different decoders share a vocabulary! This may be
        # counterintuitive, but consider the case of xreact and oreact. A
        # reaction of "happy" could easily apply to both the subject of the
        # event and others. This could become less appropriate as more decoders
        # are added.
        num_classes = self.vocab.get_vocab_size(self._target_namespace)
        # Decoder output dim needs to be the same as the encoder output dim since we initialize the
        # hidden state of the decoder with that of the final hidden states of the encoder.
        self._decoder_output_dim = self._encoder.get_output_dim()
        target_embedding_dim = target_embedding_dim or self._source_embedder.get_output_dim()

        self._states = ModuleDict()
        for name in target_names:
            self._states[name] = StateDecoder(
                    num_classes,
                    target_embedding_dim,
                    self._decoder_output_dim
            )

        self._beam_search = BeamSearch(
                self._end_index,
                beam_size=beam_size,
                max_steps=max_decoding_steps
        )

    def _update_recall(self,
                       all_top_k_predictions: torch.Tensor,
                       target_tokens: Dict[str, torch.LongTensor],
                       target_recall: UnigramRecall) -> None:
        targets = target_tokens["tokens"]
        target_mask = get_text_field_mask(target_tokens)
        # See comment in _get_loss.
        # TODO(brendanr): Do we need contiguous here?
        relevant_targets = targets[:, 1:].contiguous()
        relevant_mask = target_mask[:, 1:].contiguous()
        target_recall(
                all_top_k_predictions,
                relevant_targets,
                relevant_mask,
                self._end_index
        )

    def _get_num_decoding_steps(self,
                                target_tokens: Optional[Dict[str, torch.LongTensor]]) -> int:
        if target_tokens:
            targets = target_tokens["tokens"]
            target_sequence_length = targets.size()[1]
            # The last input from the target is either padding or the end
            # symbol.  Either way, we don't have to process it. (To be clear,
            # we do still output and compare against the end symbol, but there
            # is no need to take the end symbol as input to the decoder.)
            return target_sequence_length - 1
        else:
            return self._max_decoding_steps

    @overrides
    def forward(self,  # type: ignore
                source: Dict[str, torch.LongTensor],
                **target_tokens: Dict[str, Dict[str, torch.LongTensor]]) -> Dict[str, torch.Tensor]:
        # pylint: disable=arguments-differ
        """
        Decoder logic for producing the target sequences.

        Parameters
        ----------
        source : ``Dict[str, torch.LongTensor]``
            The output of ``TextField.as_array()`` applied on the source
            ``TextField``. This will be passed through a ``TextFieldEmbedder``
            and then through an encoder.
        target_tokens : ``Dict[str, Dict[str, torch.LongTensor]]``:
            Dictionary from name to output of ``Textfield.as_array()`` applied
            on target ``TextField``. We assume that the target tokens are also
            represented as a ``TextField``.
        """
        # (batch_size, input_sequence_length, embedding_dim)
        embedded_input = self._embedding_dropout(self._source_embedder(source))
        source_mask = get_text_field_mask(source)
        # (batch_size, encoder_output_dim)
        final_encoder_output = self._encoder(embedded_input, source_mask)
        output_dict = {}

        # Perform greedy search so we can get the loss.
        if target_tokens:
            if target_tokens.keys() != self._states.keys():
                target_only = target_tokens.keys() - self._states.keys()
                states_only = self._states.keys() - target_tokens.keys()
                raise Exception("Mismatch between target_tokens and self._states. Keys in " +
                                f"targets only: {target_only} Keys in states only: {states_only}")
            total_loss = 0
            for name, state in self._states.items():
                loss = self.greedy_search(
                        final_encoder_output=final_encoder_output,
                        target_tokens=target_tokens[name],
                        target_embedder=state.embedder,
                        decoder_cell=state.decoder_cell,
                        output_projection_layer=state.output_projection_layer
                )
                total_loss += loss
                output_dict[f"{name}_loss"] = loss

            # Use mean loss (instead of the sum of the losses) to be comparable to the paper.
            output_dict["loss"] = total_loss / len(self._states)

        # Perform beam search to obtain the predictions.
        if not self.training:
            batch_size = final_encoder_output.size()[0]
            for name, state in self._states.items():
                start_predictions = final_encoder_output.new_full(
                        (batch_size,), fill_value=self._start_index, dtype=torch.long)
                start_state = {"decoder_hidden": final_encoder_output}

                # (batch_size, 10, num_decoding_steps)
                all_top_k_predictions, log_probabilities = self._beam_search.search(
                        start_predictions, start_state, state.take_step)

                if target_tokens:
                    self._update_recall(all_top_k_predictions, target_tokens[name], state.recall)
                output_dict[f"{name}_top_k_predictions"] = all_top_k_predictions
                output_dict[f"{name}_top_k_log_probabilities"] = log_probabilities

        return output_dict

    def greedy_search(self,
                      final_encoder_output: torch.LongTensor,
                      target_tokens: Dict[str, torch.LongTensor],
                      target_embedder: Embedding,
                      decoder_cell: GRUCell,
                      output_projection_layer: Linear) -> torch.FloatTensor:
        """
        Greedily produces a sequence using the provided ``decoder_cell``.
        Returns the cross entropy between this sequence and ``target_tokens``.

        Parameters
        ----------
        final_encoder_output : ``torch.LongTensor``, required
            Vector produced by ``self._encoder``.
        target_tokens : ``Dict[str, torch.LongTensor]``, required
            The output of ``TextField.as_array()`` applied on some target ``TextField``.
        target_embedder : ``Embedding``, required
            Used to embed the target tokens.
        decoder_cell: ``GRUCell``, required
            The recurrent cell used at each time step.
        output_projection_layer: ``Linear``, required
            Linear layer mapping to the desired number of classes.
        """
        num_decoding_steps = self._get_num_decoding_steps(target_tokens)
        targets = target_tokens["tokens"]
        decoder_hidden = final_encoder_output
        step_logits = []
        for timestep in range(num_decoding_steps):
            # See https://github.com/allenai/allennlp/issues/1134.
            input_choices = targets[:, timestep]
            decoder_input = target_embedder(input_choices)
            decoder_hidden = decoder_cell(decoder_input, decoder_hidden)
            # (batch_size, num_classes)
            output_projections = output_projection_layer(decoder_hidden)
            # list of (batch_size, 1, num_classes)
            step_logits.append(output_projections.unsqueeze(1))
        # (batch_size, num_decoding_steps, num_classes)
        logits = torch.cat(step_logits, 1)
        target_mask = get_text_field_mask(target_tokens)
        return self._get_loss(logits, targets, target_mask)

    def greedy_predict(self,
                       final_encoder_output: torch.LongTensor,
                       target_embedder: Embedding,
                       decoder_cell: GRUCell,
                       output_projection_layer: Linear) -> torch.Tensor:
        """
        Greedily produces a sequence using the provided ``decoder_cell``.
        Returns the predicted sequence.

        Parameters
        ----------
        final_encoder_output : ``torch.LongTensor``, required
            Vector produced by ``self._encoder``.
        target_embedder : ``Embedding``, required
            Used to embed the target tokens.
        decoder_cell: ``GRUCell``, required
            The recurrent cell used at each time step.
        output_projection_layer: ``Linear``, required
            Linear layer mapping to the desired number of classes.
        """
        num_decoding_steps = self._max_decoding_steps
        decoder_hidden = final_encoder_output
        batch_size = final_encoder_output.size()[0]
        predictions = [final_encoder_output.new_full(
                (batch_size,), fill_value=self._start_index, dtype=torch.long
        )]
        for _ in range(num_decoding_steps):
            input_choices = predictions[-1]
            decoder_input = target_embedder(input_choices)
            decoder_hidden = decoder_cell(decoder_input, decoder_hidden)
            # (batch_size, num_classes)
            output_projections = output_projection_layer(decoder_hidden)
            class_probabilities = F.softmax(output_projections, dim=-1)
            _, predicted_classes = torch.max(class_probabilities, 1)
            predictions.append(predicted_classes)
        all_predictions = torch.cat([ps.unsqueeze(1) for ps in predictions], 1)
        # Drop start symbol and return.
        return all_predictions[:, 1:]

    @staticmethod
    def _get_loss(logits: torch.LongTensor,
                  targets: torch.LongTensor,
                  target_mask: torch.LongTensor) -> torch.FloatTensor:
        """
        Takes logits (unnormalized outputs from the decoder) of size (batch_size,
        num_decoding_steps, num_classes), target indices of size (batch_size, num_decoding_steps+1)
        and corresponding masks of size (batch_size, num_decoding_steps+1) steps and computes cross
        entropy loss while taking the mask into account.

        The length of ``targets`` is expected to be greater than that of ``logits`` because the
        decoder does not need to compute the output corresponding to the last timestep of
        ``targets``. This method aligns the inputs appropriately to compute the loss.

        During training, we want the logit corresponding to timestep i to be similar to the target
        token from timestep i + 1. That is, the targets should be shifted by one timestep for
        appropriate comparison.  Consider a single example where the target has 3 words, and
        padding is to 7 tokens.
           The complete sequence would correspond to <S> w1  w2  w3  <E> <P> <P>
           and the mask would be                     1   1   1   1   1   0   0
           and let the logits be                     l1  l2  l3  l4  l5  l6
        We actually need to compare:
           the sequence           w1  w2  w3  <E> <P> <P>
           with masks             1   1   1   1   0   0
           against                l1  l2  l3  l4  l5  l6
           (where the input was)  <S> w1  w2  w3  <E> <P>
        """
        relevant_targets = targets[:, 1:].contiguous()  # (batch_size, num_decoding_steps)
        relevant_mask = target_mask[:, 1:].contiguous()  # (batch_size, num_decoding_steps)
        loss = sequence_cross_entropy_with_logits(logits, relevant_targets, relevant_mask)
        return loss

    def decode_all(self, predicted_indices: torch.Tensor) -> List[List[str]]:
        if not isinstance(predicted_indices, numpy.ndarray):
            predicted_indices = predicted_indices.detach().cpu().numpy()
        all_predicted_tokens = []
        for indices in predicted_indices:
            indices = list(indices)
            # Collect indices till the first end_symbol
            if self._end_index in indices:
                indices = indices[:indices.index(self._end_index)]
            predicted_tokens = [self.vocab.get_token_from_index(x, namespace=self._target_namespace)
                                for x in indices]
            all_predicted_tokens.append(predicted_tokens)
        return all_predicted_tokens

    @overrides
    def decode(self, output_dict: Dict[str, torch.Tensor]) -> Dict[str, List[List[str]]]:
        """
        This method overrides ``Model.decode``, which gets called after ``Model.forward``, at test
        time, to finalize predictions. The logic for the decoder part of the encoder-decoder lives
        within the ``forward`` method.

        This method trims the output predictions to the first end symbol, replaces indices with
        corresponding tokens, and adds fields for the tokens to the ``output_dict``.
        """
        for name in self._states:
            top_k_predicted_indices = output_dict[f"{name}_top_k_predictions"][0]
            output_dict[f"{name}_top_k_predicted_tokens"] = [self.decode_all(top_k_predicted_indices)]

        return output_dict

    @overrides
    def get_metrics(self, reset: bool = False) -> Dict[str, float]:
        all_metrics = {}
        # Recall@10 needs beam search which doesn't happen during training.
        if not self.training:
            for name, state in self._states.items():
                all_metrics[name] = state.recall.get_metric(reset=reset)
        return all_metrics