Exemplo n.º 1
0
 def __init__(self,
              model_dim,
              ff_dim,
              nheads,
              nlayers,
              dropout=0.0,
              mem_att_dropout=0.0,
              mem_dim=None,
              len_prop=True,
              rating_prop=True,
              rouge_prop=True,
              pov_prop=True):
     super(Plugin, self).__init__()
     assert len_prop or rating_prop or rouge_prop or pov_prop
     self.model_dim = model_dim
     self.tr_stack = TransformerStack(model_dim=model_dim,
                                      mem_dim=mem_dim,
                                      num_layers=nlayers,
                                      dropout=dropout,
                                      mem_att_dropout=mem_att_dropout,
                                      ff_dim=ff_dim,
                                      nheads=nheads)
     self.fns = ModuleDict()
     if len_prop:
         self.fns[ModelF.LEN_PROP] = Linear(model_dim, 1)
     if rating_prop:
         self.fns[ModelF.RATING_PROP] = Linear(model_dim, 1)
     if rouge_prop:
         self.fns[ModelF.ROUGE_PROP] = Sequential()
         self.fns[ModelF.ROUGE_PROP].add_module('lin', Linear(model_dim, 3))
         self.fns[ModelF.ROUGE_PROP].add_module('sigmoid', Sigmoid())
     if pov_prop:
         self.fns[ModelF.POV_PROP] = Sequential()
         self.fns[ModelF.POV_PROP].add_module("lin", Linear(model_dim, 4))
         self.fns[ModelF.POV_PROP].add_module('softmax', Softmax(dim=-1))
Exemplo n.º 2
0
    def _build_cycle_layers(self, hidden_channels, num_layers, cycle_depth):
        """
        Build layers that convolve over the cyclic activations.
        """
        # Construct embedding layers for input
        self.cycle_embedders = ModuleDict()
        for k, key in zip(self.used_cycles, self._cycle_keys):
            self.cycle_embedders[key] = Embedding(4, hidden_channels)

        self.cycle_blocks = ModuleList()
        self.a2c_blocks = ModuleList()
        self.c2a_blocks = ModuleList()

        # Construct layers that convolve over cycles and move to/from atom activations
        for _ in range(num_layers):
            cb_k_dict = ModuleDict()  # Dict of intracycle blocks
            a2c_k_dict = ModuleDict()  # Dict of linears for conversions
            c2a_k_dict = ModuleDict()  # Dict of linears for conversions
            for k, key in zip(self.used_cycles, self._cycle_keys):
                cb_k_dict[key] = blocks.PathBlock(hidden_channels,
                                                  k,
                                                  num_resid_blocks=cycle_depth,
                                                  dropout=self.dropout)
                a2c_k_dict[key] = Linear(hidden_channels, hidden_channels)
                c2a_k_dict[key] = Linear(hidden_channels, hidden_channels)
            self.cycle_blocks.append(cb_k_dict)
            self.a2c_blocks.append(a2c_k_dict)
            self.c2a_blocks.append(c2a_k_dict)
Exemplo n.º 3
0
 def __init__(self, nfft, cfg):
     super().__init__()
     params_lookup = {'peak': ['w0', 'bw', 'g'], 'notch': ['w0', 'bw']}
     self.features = {f'{t}{i}': params_lookup[t] for i, t in enumerate(cfg['spectral_features'])}
     input_size = nfft // 2 + 1
     # add first stack
     frontend_sizes = cfg['frontend']['sizes']
     frontend_sizes.insert(0, input_size)
     frontend_act = cfg['frontend']['activation_type']
     frontend_do = cfg['frontend']['dropout_rate']
     self.frontend = self.dense_stack(frontend_sizes, frontend_act, frontend_do)
     # add a stack for each spectral feature to predict (peak or notch)
     specfeat_sizes = cfg['features_block']['sizes']
     specfeat_sizes.insert(0, frontend_sizes[-1])
     specfeat_act = cfg['features_block']['activation_type']
     specfeat_do = cfg['features_block']['dropout_rate']
     specfeat_dict = {k: self.dense_stack(specfeat_sizes, specfeat_act, specfeat_do) for k in self.features}
     self.spectral_features = ModuleDict(specfeat_dict)
     # add a stack for each parameter of each spectral features (w0, bw, g)
     specparams_sizes = cfg['parameters_block']['sizes']
     specparams_sizes.insert(0, specfeat_sizes[-1])
     specparams_sizes.append(1)
     specparams_act = cfg['parameters_block']['activation_type']
     specparams_do = cfg['parameters_block']['dropout_rate']
     specparams_dict = {f'{k}_{p}': self.dense_stack(specparams_sizes, specparams_act, specparams_do) for k in self.features for p in self.features[k]}
     self.spectral_parameters = ModuleDict(specparams_dict)
Exemplo n.º 4
0
    def __init__(self,
                 vocabs: Dict[str, Vocabulary],
                 config: Config,
                 pre_load_model: bool = True):
        super().__init__(config=config)

        self.embeddings = ModuleDict()
        self.embeddings[const.TARGET] = TokenEmbeddings(
            num_embeddings=len(vocabs[const.TARGET]),
            pad_idx=vocabs[const.TARGET].pad_id,
            config=config.embeddings.target,
            vectors=vocabs[const.TARGET].vectors,
        )
        self.embeddings[const.SOURCE] = TokenEmbeddings(
            num_embeddings=len(vocabs[const.SOURCE]),
            pad_idx=vocabs[const.SOURCE].pad_id,
            config=config.embeddings.source,
            vectors=vocabs[const.SOURCE].vectors,
        )

        total_size = sum(emb.size() for emb in self.embeddings.values())
        self._sizes = {
            const.TARGET: total_size * self.config.window_size,
            const.SOURCE: total_size * self.config.window_size,
        }
Exemplo n.º 5
0
    def __init__(
        self,
        n_stages,
        nf=128,
        nf_out=3,
        n_rnb=2,
        conv_layer=NormConv2d,
        spatial_size=256,
        final_act=True,
        dropout_prob=0.0,
    ):
        super().__init__()
        assert (2 ** (n_stages - 1)) == spatial_size
        self.final_act = final_act
        self.blocks = ModuleDict()
        self.ups = ModuleDict()
        self.n_stages = n_stages
        self.n_rnb = n_rnb
        for i_s in range(self.n_stages - 2, 0, -1):
            # for final stage, bisect number of filters
            if i_s == 1:
                # upsampling operations
                self.ups.update(
                    {
                        f"s{i_s+1}": Upsample(
                            in_channels=nf, out_channels=nf // 2, conv_layer=conv_layer,
                        )
                    }
                )
                nf = nf // 2
            else:
                # upsampling operations
                self.ups.update(
                    {
                        f"s{i_s+1}": Upsample(
                            in_channels=nf, out_channels=nf, conv_layer=conv_layer,
                        )
                    }
                )

            # resnet blocks
            for ir in range(self.n_rnb, 0, -1):
                stage = f"s{i_s}_{ir}"
                self.blocks.update(
                    {
                        stage: VUnetResnetBlock(
                            nf,
                            use_skip=True,
                            conv_layer=conv_layer,
                            dropout_prob=dropout_prob,
                        )
                    }
                )

        # final 1x1 convolution
        self.final_layer = conv_layer(nf, nf_out, kernel_size=1)

        # conditionally: set final activation
        if self.final_act:
            self.final_act = nn.Tanh()
Exemplo n.º 6
0
    def __init__(
        self,
        encoder_output_channels: int,
        scales: List[int],
    ):
        super().__init__()
        self.scales = scales
        self.decoder_channels = [16, 32, 64, 128, 256]

        self.convs: ModuleDict[str, Module] = ModuleDict()
        self.scaled_convs: ModuleDict[str, Module] = ModuleDict()
        for i in range(4, -1, -1):
            # upconv_0
            input_channels = (encoder_output_channels
                              if i == 4 else self.decoder_channels[i + 1])
            self.convs[f"{i}0"] = ConvBlock(
                input_channels,
                self.decoder_channels[i],
            )
            # upconv_1
            input_channels = self.decoder_channels[i]
            self.convs[f"{i}1"] = ConvBlock(
                input_channels,
                self.decoder_channels[i],
            )
        for s in self.scales:
            self.scaled_convs[f"{s}"] = Conv3x3(self.decoder_channels[s], 1)
        self.sigmoid = Sigmoid()
Exemplo n.º 7
0
    def __init__(
        self,
        n_stages,
        nf_in=3,
        nf_start=64,
        nf_max=128,
        n_rnb=2,
        conv_layer=NormConv2d,
    ):
        super().__init__()
        self.in_op = conv_layer(nf_in, nf_start, kernel_size=1)
        nf = nf_start
        self.blocks = ModuleDict()
        self.downs = ModuleDict()
        self.n_rnb = n_rnb
        self.n_stages = n_stages
        for i_s in range(self.n_stages):
            # prepare resnet blocks per stage
            if i_s > 0:
                self.downs.update(
                    {
                        f"s{i_s+1}": Downsample(
                            nf, min(2 * nf, nf_max), conv_layer=conv_layer
                        )
                    }
                )
                nf = min(2 * nf, nf_max)

            for ir in range(self.n_rnb):
                stage = f"s{i_s+1}_{ir+1}"
                self.blocks.update({stage: VUnetResnetBlock(nf, conv_layer=conv_layer)})
Exemplo n.º 8
0
    def __init__(
        self,
        encoder_channels: List[int],
        scales: List[int],
        use_skips: bool = True,
    ):
        super().__init__()
        self.use_skips = use_skips
        self.scales = scales

        self.encoder_channels = encoder_channels
        self.decoder_channels = array([16, 32, 64, 128, 256])

        self.convs: ModuleDict[str, Module] = ModuleDict()
        self.scaled_convs: ModuleDict[str, Module] = ModuleDict()
        for i in range(4, -1, -1):
            # upconv_0
            input_channels = (self.encoder_channels[-1]
                              if i == 4 else self.decoder_channels[i + 1])
            self.convs[f"{i}0"] = ConvBlock(
                input_channels,
                self.decoder_channels[i],
            )
            # upconv_1
            input_channels = self.decoder_channels[i]
            if self.use_skips and i > 0:
                input_channels += self.encoder_channels[i - 1]
            self.convs[f"{i}1"] = ConvBlock(
                input_channels,
                self.decoder_channels[i],
            )
        for s in self.scales:
            self.scaled_convs[f"{s}"] = Conv3x3(self.decoder_channels[s], 1)
        self.sigmoid = Sigmoid()
Exemplo n.º 9
0
    def __init__(self, multitaskdict):

        """
        Args:
            multitaskdict (dict): dictionary of items used for setting up the networks.
        Returns:
            None
        """

        super(AuTopologyReadOut, self).__init__()

        trainable = multitaskdict["trainable_prior"]
        Fr = multitaskdict["Fr"]
        Lh = multitaskdict["Lh"]
        # bond_terms = multitaskdict.get("bond_terms", ["morse"])
        # angle_terms =  multitaskdict.get("angle_terms", ['harmonic'])  # harmonic and/or cubic and/or quartic
        # dihedral_terms = multitaskdict.get("dihedral_terms", ['OPLS'])  # OPLS and/or multiharmonic
        # improper_terms = multitaskdict.get("improper_terms", ['harmonic'])  # harmonic
        # pair_terms = multitaskdict.get("pair_terms", ['LJ'])  # coulomb and/or LJ and/or induced_dipole
        autopology_keys = multitaskdict["output_keys"]

        default_terms_dict = {
            "bond_terms": ["morse"],
            "angle_terms": ["harmonic"],
            "dihedral_terms": ["OPLS"],
            "improper_terms": ["harmonic"],
            "pair_terms": ["LJ", "coulombs"]
        }

        # self.terms = {
        #     'bond': bond_terms,
        #     'angle': angle_terms,
        #     'dihedral': dihedral_terms,
        #     'improper': improper_terms,
        #     'pair': pair_terms
        # }
        self.terms = {}

        # remove terms that is not included 
        for top in ['bond', 'angle', 'dihedral', 'improper', 'pair']:
            if top + '_terms' in multitaskdict.keys():
                self.terms[top] = multitaskdict.get(top + '_terms', default_terms_dict[top + '_terms'])


        topologynet = {key: {} for key in autopology_keys}
        for key in autopology_keys:
            for top in self.terms.keys():
                if top + '_terms' in multitaskdict:
                    topologynet[key][top] = TopologyNet[top](Fr, Lh, self.terms[top], trainable=trainable)


        # module dictionary of the form {"energy_0": {"bond": BondNet0, "angle": AngletNet0},
        # "energy_1": {"bond": BondNet1, "angle": AngletNet1} }
        self.auto_modules = ModuleDict({key: ModuleDict({top: topologynet[key][top] for top in
            self.terms.keys()}) for key in autopology_keys})

        # energy offset for each state
        self.offset = ModuleDict({key: ParameterPredictor(Fr, Lh, 1)
            for key in autopology_keys})
 def __init__(self):
     super().__init__()
     self.metrics_list = ModuleList([DummyMetric() for _ in range(2)])
     self.metrics_dict = ModuleDict({"a": DummyMetric(), "b": DummyMetric()})
     self.metrics_collection_dict = MetricCollection({"a": DummyMetric(), "b": DummyMetric()})
     self.metrics_collection_dict_nested = ModuleDict(
         {"a": ModuleList([ModuleDict({"b": DummyMetric()}), DummyMetric()])}
     )
Exemplo n.º 11
0
 def __init__(self, model_dict, mode='sum'):
     """Summary
     
     Args:
         model_dict (TYPE): Description
         mode (str, optional): Description
     """
     super().__init__()
     self.models = ModuleDict(model_dict)
Exemplo n.º 12
0
class VUnetEncoder(nn.Module):
    def __init__(
        self,
        n_stages,
        nf_in=3,
        nf_start=64,
        nf_max=128,
        n_rnb=2,
        conv_layer=NormConv2d,
        dropout_prob=0.0,
    ):
        super().__init__()
        self.in_op = conv_layer(nf_in, nf_start, kernel_size=1)
        nf = nf_start
        self.blocks = ModuleDict()
        self.downs = ModuleDict()
        self.n_rnb = n_rnb
        self.n_stages = n_stages
        for i_s in range(self.n_stages):
            # prepare resnet blocks per stage
            if i_s > 0:
                self.downs.update(
                    {
                        f"s{i_s+1}": Downsample(
                            nf, min(2 * nf, nf_max), conv_layer=conv_layer
                        )
                    }
                )
                nf = min(2 * nf, nf_max)

            for ir in range(self.n_rnb):
                stage = f"s{i_s+1}_{ir+1}"
                self.blocks.update(
                    {
                        stage: VUnetResnetBlock(
                            nf, conv_layer=conv_layer, dropout_prob=dropout_prob
                        )
                    }
                )

    def forward(self, x):
        out = {}
        h = self.in_op(x)
        for ir in range(self.n_rnb):
            h = self.blocks[f"s1_{ir+1}"](h)
            out[f"s1_{ir+1}"] = h

        for i_s in range(1, self.n_stages):

            h = self.downs[f"s{i_s+1}"](h)

            for ir in range(self.n_rnb):
                stage = f"s{i_s+1}_{ir+1}"
                h = self.blocks[stage](h)
                out[stage] = h

        return out
Exemplo n.º 13
0
    def __init__(self,
                 hparams: "SentenceEmbeddings.Options",
                 statistics,
                 plugins=None):

        super().__init__()
        self.hparams = hparams
        self.mode = hparams.mode
        self.plugins = ModuleDict(plugins) if plugins is not None else {}

        # embedding
        input_dims = {}
        if hparams.dim_word != 0:
            self.word_embeddings = Embedding(len(statistics.words),
                                             hparams.dim_word,
                                             padding_idx=0)
            self.word_dropout = FeatureDropout2(hparams.word_dropout)
            input_dims["word"] = hparams.dim_word

        if hparams.dim_postag != 0:
            self.pos_embeddings = Embedding(len(statistics.postags),
                                            hparams.dim_postag,
                                            padding_idx=0)
            self.pos_dropout = FeatureDropout2(hparams.postag_dropout)
            input_dims["postag"] = hparams.dim_postag

        if hparams.dim_char > 0:
            self.bilm = None
            self.character_lookup = Embedding(len(statistics.characters),
                                              hparams.dim_char_input)
            self.char_embeded = CharacterEmbedding.get(
                hparams.character_embedding,
                dim_char_input=hparams.dim_char_input,
                input_size=hparams.dim_char)
            if not hparams.replace_unk_with_chars:
                input_dims["char"] = hparams.dim_char
            else:
                assert hparams.dim_word == hparams.dim_char
        else:
            self.character_lookup = None

        for name, plugin in self.plugins.items():
            input_dims[name] = plugin.output_dim

        if hparams.mode == "concat":
            self.output_dim = sum(input_dims.values())
        else:
            assert hparams.mode == "add"
            uniq_input_dims = list(set(input_dims.values()))
            if len(uniq_input_dims) != 1:
                raise ValueError(f"Different input dims: {input_dims}")
            print(input_dims)
            self.output_dim = uniq_input_dims[0]

        self.input_layer_norm = LayerNorm(self.output_dim, eps=1e-6) \
            if hparams.input_layer_norm else None
Exemplo n.º 14
0
    def __init__(
        self,
        n_stages,
        nf,
        device,
        n_rnb=2,
        n_auto_groups=4,
        conv_layer=NormConv2d,
        dropout_prob=0.0,
    ):
        super().__init__()
        self.device = device
        self.blocks = ModuleDict()
        self.channel_norm = ModuleDict()
        self.conv1x1 = conv_layer(nf, nf, 1)
        self.up = Upsample(in_channels=nf, out_channels=nf, conv_layer=conv_layer)
        self.depth_to_space = DepthToSpace(block_size=2)
        self.space_to_depth = SpaceToDepth(block_size=2)
        self.n_stages = n_stages
        self.n_rnb = n_rnb
        # number of autoregressively modeled groups
        self.n_auto_groups = n_auto_groups
        for i_s in range(self.n_stages, self.n_stages - 2, -1):
            self.channel_norm.update({f"s{i_s}": conv_layer(2 * nf, nf, 1)})
            for ir in range(self.n_rnb):
                self.blocks.update(
                    {
                        f"s{i_s}_{ir+1}": VUnetResnetBlock(
                            nf,
                            use_skip=True,
                            conv_layer=conv_layer,
                            dropout_prob=dropout_prob,
                        )
                    }
                )

        self.auto_blocks = ModuleList()
        # model the autoregressively groups rnb
        for i_a in range(4):
            if i_a < 1:
                self.auto_blocks.append(
                    VUnetResnetBlock(
                        nf, conv_layer=conv_layer, dropout_prob=dropout_prob
                    )
                )
                self.param_converter = conv_layer(4 * nf, nf, kernel_size=1)
            else:
                self.auto_blocks.append(
                    VUnetResnetBlock(
                        nf,
                        use_skip=True,
                        conv_layer=conv_layer,
                        dropout_prob=dropout_prob,
                    )
                )
Exemplo n.º 15
0
    def __init__(self, model_dict, mode='sum'):
        super().__init__()
        implemented_mode = ['sum', 'mean']

        if mode not in implemented_mode:
            raise NotImplementedError(
                '{} mode is not implemented for Stack'.format(mode))

        # to implement a check for readout keys

        self.models = ModuleDict(model_dict)
        self.mode = mode
Exemplo n.º 16
0
    def __init__(
        self,
        vocab: Vocabulary,
        source_embedder: TextFieldEmbedder,
        embedding_dropout: float,
        encoder: Seq2VecEncoder,
        max_decoding_steps: int,
        beam_size: int = 10,
        target_names: List[str] = None,
        target_namespace: str = "tokens",
        target_embedding_dim: int = None,
        regularizer: Optional[RegularizerApplicator] = None,
    ) -> None:
        super().__init__(vocab, regularizer)
        target_names = target_names or ["xintent", "xreact", "oreact"]

        # Note: The original tweaks the embeddings for "personx" to be the mean
        # across the embeddings for "he", "she", "him" and "her". Similarly for
        # "personx's" and so forth. We could consider that here as a well.
        self._source_embedder = source_embedder
        self._embedding_dropout = nn.Dropout(embedding_dropout)
        self._encoder = encoder
        self._max_decoding_steps = max_decoding_steps
        self._target_namespace = target_namespace

        # We need the start symbol to provide as the input at the first timestep of decoding, and
        # end symbol as a way to indicate the end of the decoded sequence.
        self._start_index = self.vocab.get_token_index(START_SYMBOL,
                                                       self._target_namespace)
        self._end_index = self.vocab.get_token_index(END_SYMBOL,
                                                     self._target_namespace)
        # Warning: The different decoders share a vocabulary! This may be
        # counterintuitive, but consider the case of xreact and oreact. A
        # reaction of "happy" could easily apply to both the subject of the
        # event and others. This could become less appropriate as more decoders
        # are added.
        num_classes = self.vocab.get_vocab_size(self._target_namespace)
        # Decoder output dim needs to be the same as the encoder output dim since we initialize the
        # hidden state of the decoder with that of the final hidden states of the encoder.
        self._decoder_output_dim = self._encoder.get_output_dim()
        target_embedding_dim = target_embedding_dim or self._source_embedder.get_output_dim(
        )

        self._states = ModuleDict()
        for name in target_names:
            self._states[name] = StateDecoder(num_classes,
                                              target_embedding_dim,
                                              self._decoder_output_dim)

        self._beam_search = BeamSearch(self._end_index,
                                       beam_size=beam_size,
                                       max_steps=max_decoding_steps)
Exemplo n.º 17
0
    def __init__(self, config):

        super().__init__()

        self.nametag = 'MultiProjectorsManager'

        self.setup_classes_factory()

        self.classes = self.get_classes(config)

        self.projectors = ModuleDict(
            {k: cl(config)
             for k, cl in self.classes.items()})
Exemplo n.º 18
0
    def __init__(
        self,
        num_classes: int = 2,
        learning_rate: float = 0.005,
        drop_lr_on_plateau_patience: int = 10,
        model_kwargs: Optional[Dict] = None,
    ):
        super().__init__()
        self.save_hyperparameters()

        self.num_classes = num_classes
        self.learning_rate = learning_rate
        self.drop_lr_on_plateau_patience = drop_lr_on_plateau_patience

        if model_kwargs is None:
            self.model_kwargs = {}
        else:
            self.model_kwargs = model_kwargs

        self.model = self.build_model()

        self.validation_metrics = ModuleDict({
            # "AP50": AveragePrecision(
            #     num_foreground_classes=num_classes - 1,
            #     iou_thresholds=[0.5],
            #     iou_type="mask",
            #     ap_calculation_type="COCO",
            # ),
            # "AP75": AveragePrecision(
            #     num_foreground_classes=num_classes - 1,
            #     iou_thresholds=[0.75],
            #     iou_type="mask",
            #     ap_calculation_type="COCO",
            # ),
            "mAP":
            AveragePrecision(
                num_foreground_classes=num_classes - 1,
                iou_thresholds=np.arange(0.5, 1, 0.05),
                iou_type="mask",
                ap_calculation_type="COCO",
            ),
            "confusion_matrix":
            ConfusionMatrix(
                num_classes,
                iou_type="mask",
                iou_threshold=0.5,
                score_threshold=0.5,
            ),
        })

        self.main_validation_metric_name = "mAP"
Exemplo n.º 19
0
    def __init__(
        self,
        inputs_dims: Dict[str, int],
        vocabs: Dict[str, Vocabulary],
        config: Config,
        pretraining: bool = False,
    ):
        super().__init__(config=config)

        self.inputs_dims = inputs_dims
        self.vocabs = OrderedDict()
        self.config = config
        self.pretraining = pretraining
        self._metrics = None

        self.masked_word_outputs = ModuleDict()

        if self.pretraining:
            if const.TARGET not in vocabs:
                raise ValueError(
                    f'Asked to pretrain the encoder (`pretrain`) but no '
                    f'vocabulary was provided for {const.TARGET}')
            if const.TARGET_LOGITS not in self.inputs_dims:
                raise ValueError(
                    'Asked to pretrain the encoder (`pretrain`) but no '
                    'target data was provided')
            self.masked_word_outputs[const.TARGET] = MaskedWordOutput(
                input_size=self.inputs_dims[const.TARGET_LOGITS],
                pad_idx=vocabs[const.TARGET].pad_id,
                start_idx=vocabs[const.TARGET].bos_id,
                stop_idx=vocabs[const.TARGET].eos_id,
            )
            self.vocabs[const.TARGET] = vocabs[const.TARGET]

        if self.config.fine_tune:
            # Target side; use PE for fine-tuning
            if const.PE not in vocabs:
                raise ValueError(
                    f'Asked to fine-tune the encoder (`fine_tune`) but no '
                    f'vocabulary was provided for {const.PE}')
            if const.PE_LOGITS not in self.inputs_dims:
                raise ValueError(
                    'Asked to fine-tune the encoder (`fine_tune`) but no '
                    'post-edit (PE) data was provided')
            self.masked_word_outputs[const.PE] = MaskedWordOutput(
                input_size=self.inputs_dims[const.PE_LOGITS],
                pad_idx=vocabs[const.PE].pad_id,
                start_idx=vocabs[const.PE].bos_id,
                stop_idx=vocabs[const.PE].eos_id,
            )
            self.vocabs[const.PE] = vocabs[const.PE]
Exemplo n.º 20
0
    def _init_attention(self):
        '''
            Initialises the attention map vector/matrix

            Takes attention_type-Span, Sentence, Span-param, Sentence-param
            as a parameter to decide the size of the attention matrix
        '''

        self.att_map_repr = ModuleDict({})
        self.att_map_W = ModuleDict({})
        self.att_map_V = ModuleDict({})
        self.att_map_context = ModuleDict({})
        for prot in self.attention_type.keys():
            # Token representation
            if self.attention_type[prot]['repr'] == "span":
                repr_dim = 2 * self.reduced_embedding_dim
                self.att_map_repr[prot] = Linear(self.reduced_embedding_dim,
                                                 1,
                                                 bias=False)
                self.att_map_W[prot] = Linear(self.reduced_embedding_dim,
                                              self.reduced_embedding_dim)
                self.att_map_V[prot] = Linear(self.reduced_embedding_dim,
                                              1,
                                              bias=False)
            elif self.attention_type[prot]['repr'] == "param":
                repr_dim = 2 * self.reduced_embedding_dim
                self.att_map_repr[prot] = Linear(self.reduced_embedding_dim,
                                                 self.reduced_embedding_dim,
                                                 bias=False)
                self.att_map_W[prot] = Linear(2 * self.reduced_embedding_dim,
                                              self.reduced_embedding_dim)
                self.att_map_V[prot] = Linear(self.reduced_embedding_dim,
                                              1,
                                              bias=False)
            else:
                repr_dim = self.reduced_embedding_dim

            # Context representation
            # There is no attention for argument davidsonian
            if self.attention_type[prot]['context'] == 'param':
                self.att_map_context[prot] = Linear(repr_dim,
                                                    self.reduced_embedding_dim,
                                                    bias=False)
            elif self.attention_type[prot][
                    'context'] == 'david' and prot == 'arg':
                self.att_map_context[prot] = Linear(repr_dim,
                                                    self.reduced_embedding_dim,
                                                    bias=False)
Exemplo n.º 21
0
class Stack(torch.nn.Module):
    """Summary
    
    Attributes:
        models (TYPE): Description
    """
    def __init__(self, model_dict, mode='sum'):
        """Summary
        
        Args:
            model_dict (TYPE): Description
            mode (str, optional): Description
        """
        super().__init__()
        self.models = ModuleDict(model_dict)

    def forward(self, x):
        """Summary
        
        Args:
            x (TYPE): Description
        
        Returns:
            TYPE: Description
        """
        for i, key in enumerate(self.models.keys()):
            if i == 0:
                result = self.models[key](x).sum().reshape(-1)
            else:
                new_result = self.models[key](x).sum().reshape(-1)
                result += new_result

        return result
Exemplo n.º 22
0
    def __init__(
        self, decomposition: Dict[str, List[int]], batch_shape: torch.Size
    ) -> None:
        super().__init__(batch_shape=batch_shape)
        self.decomposition = decomposition

        num_param = len(next(iter(decomposition.values())))
        for active_parameters in decomposition.values():
            # check number of parameters are same in each decomp
            if len(active_parameters) != num_param:
                raise ValueError(
                    "num of parameters needs to be same across all contexts"
                )

        self._indexers = {
            context: torch.tensor(active_params)
            for context, active_params in self.decomposition.items()
        }

        self.base_kernel = MaternKernel(
            nu=2.5,
            ard_num_dims=num_param,
            batch_shape=batch_shape,
            lengthscale_prior=GammaPrior(3.0, 6.0),
        )

        self.kernel_dict = {}  # scaled kernel for each parameter space partition
        for context in list(decomposition.keys()):
            self.kernel_dict[context] = ScaleKernel(
                base_kernel=self.base_kernel, outputscale_prior=GammaPrior(2.0, 15.0)
            )
        self.kernel_dict = ModuleDict(self.kernel_dict)
Exemplo n.º 23
0
    def _init_regression(self):
        '''
            Define the linear maps
        '''

        # Output regression parameters
        self.linmaps = ModuleDict(
            {prot: ModuleList([])
             for prot in self.all_attributes.keys()})

        for prot in self.all_attributes.keys():
            last_size = self.reduced_embedding_dim
            # Handle varying size of dimension depending on representation
            if self.attention_type[prot]['repr'] == "root":
                if self.attention_type[prot]['context'] != "none":
                    last_size *= 2
            else:
                if self.attention_type[prot]['context'] == "none":
                    last_size *= 2
                else:
                    last_size *= 3
            # self.layer_norm[prot] = torch.nn.LayerNorm(last_size)
            last_size += self.hand_feat_dim
            for out_size in self.layers:
                linmap = Linear(last_size, out_size)
                self.linmaps[prot].append(linmap)
                last_size = out_size
            final_linmap = Linear(last_size, self.output_size)
            self.linmaps[prot].append(final_linmap)

        # Dropout layer
        self.dropout = Dropout()
Exemplo n.º 24
0
    def __init__(self,
                 convs: Dict[EdgeType, Module],
                 aggr: Optional[str] = "sum"):
        super().__init__()

        src_node_types = set([key[0] for key in convs.keys()])
        dst_node_types = set([key[-1] for key in convs.keys()])
        if len(src_node_types - dst_node_types) > 0:
            warnings.warn(
                f"There exist node types ({src_node_types - dst_node_types}) "
                f"whose representations do not get updated during message "
                f"passing as they do not occur as destination type in any "
                f"edge type. This may lead to unexpected behaviour.")

        self.convs = ModuleDict({'__'.join(k): v for k, v in convs.items()})
        self.aggr = aggr
Exemplo n.º 25
0
    def __init__(
        self,
        name: Optional[str] = None,
        tasks: Optional[Union[EmmentalTask, List[EmmentalTask]]] = None,
    ) -> None:
        super().__init__()
        self.name = name if name is not None else type(self).__name__

        # Initiate the model attributes
        self.module_pool: ModuleDict = ModuleDict()
        self.task_names: Set[str] = set()
        self.task_flows: Dict[str, Any] = dict()  # TODO: make it concrete
        self.loss_funcs: Dict[str, Callable] = dict()
        self.output_funcs: Dict[str, Callable] = dict()
        self.scorers: Dict[str, Scorer] = dict()
        self.weights: Dict[str, float] = dict()

        # Build network with given tasks
        if tasks is not None:
            self._build_network(tasks)

        if Meta.config["meta_config"]["verbose"]:
            logger.info(f"Created emmental model {self.name} that contains "
                        f"task {self.task_names}.")

        # Move model to specified device
        self._move_to_device()
Exemplo n.º 26
0
 def __init__(self, backbone: nn.Module, decoders: nn.ModuleDict,
              tasks: list):
     super(MultiTaskModel, self).__init__()
     assert (set(decoders.keys()) == set(tasks))
     self.backbone = backbone
     self.decoders = decoders
     self.tasks = tasks
Exemplo n.º 27
0
def savemodel(model: nn.ModuleDict, path) -> None:
    """
    Save AffinityModel.

    Parameters
    ----------
    model: torch.nn.ModuleDict
        Model
    path:
        Save path
    """
    torch.save(
        {
            "args": {
                "n_species": model.n_species,
                "aev_length": model.aev_length,
                # Drop first layer size which is n_species
                "layers_sizes": model.layers_sizes[1:],  # type: ignore
                "dropp": model.dropp,
            },
            "state_dict": model.state_dict(),
        },
        path,
    )

    mlflow.log_artifact(path)
Exemplo n.º 28
0
    def step(self, batch: Any, batch_idx: int, metrics: nn.ModuleDict) -> Any:
        """The training/validation/test step.

        Override for custom behavior.
        """
        x, y = batch
        y_hat = self(x)
        y, y_hat = self.apply_filtering(y, y_hat)
        output = {"y_hat": y_hat}
        y_hat = self.to_loss_format(output["y_hat"])
        losses = {name: l_fn(y_hat, y) for name, l_fn in self.loss_fn.items()}

        y_hat = self.to_metrics_format(output["y_hat"])

        logs = {}

        for name, metric in metrics.items():
            if isinstance(metric, torchmetrics.metric.Metric):
                metric(y_hat, y)
                logs[
                    name] = metric  # log the metric itself if it is of type Metric
            else:
                logs[name] = metric(y_hat, y)

        if len(losses.values()) > 1:
            logs["total_loss"] = sum(losses.values())
            return logs["total_loss"], logs

        output["loss"] = self.compute_loss(losses)
        output["logs"] = self.compute_logs(logs, losses)
        output["y"] = y
        return output
Exemplo n.º 29
0
 def __init__(self,
              n_atom_basis,
              n_filters,
              n_gaussians,
              cutoff,
              trainable_gauss,
              ):
     super(SchNetConv, self).__init__()
     self.moduledict = ModuleDict({
         'message_edge_filter': Sequential(
             GaussianSmearing(
                 start=0.0,
                 stop=cutoff,
                 n_gaussians=n_gaussians,
                 trainable=trainable_gauss
             ),
             Dense(in_features=n_gaussians, out_features=n_gaussians),
             shifted_softplus(),
             Dense(in_features=n_gaussians, out_features=n_filters)
         ),
         'message_node_filter': Dense(in_features=n_atom_basis, out_features=n_filters),
         'update_function': Sequential(
             Dense(in_features=n_filters, out_features=n_atom_basis),
             shifted_softplus(),
             Dense(in_features=n_atom_basis, out_features=n_atom_basis)
         )
     })
Exemplo n.º 30
0
    def __init__(
        self,
        name: Optional[str] = None,
        tasks: Optional[Union[EmmentalTask, List[EmmentalTask]]] = None,
    ) -> None:
        """Initialize EmmentalModel."""
        super().__init__()
        self.name = name if name is not None else type(self).__name__

        # Initiate the model attributes
        self.module_pool: ModuleDict = ModuleDict()
        self.task_names: Set[str] = set()
        self.task_flows: Dict[str, Any] = dict()  # TODO: make it concrete
        self.loss_funcs: Dict[str, Callable] = dict()
        self.output_funcs: Dict[str, Callable] = dict()
        self.scorers: Dict[str, Scorer] = dict()
        self.action_outputs: Dict[
            str, Optional[List[Union[Tuple[str, str], Tuple[str, int]]]]
        ] = dict()
        self.module_device: Dict[str, Union[int, str, torch.device]] = {}
        self.task_weights: Dict[str, float] = dict()
        self.require_prob_for_evals: Dict[str, bool] = dict()
        self.require_pred_for_evals: Dict[str, bool] = dict()

        # Build network with given tasks
        if tasks is not None:
            self.add_tasks(tasks)

        if Meta.config["meta_config"]["verbose"]:
            logger.info(
                f"Created emmental model {self.name} that contains "
                f"task {self.task_names}."
            )
Exemplo n.º 31
0
    def __init__(self,
                 vocab: Vocabulary,
                 source_embedder: TextFieldEmbedder,
                 embedding_dropout: float,
                 encoder: Seq2VecEncoder,
                 max_decoding_steps: int,
                 beam_size: int = 10,
                 target_names: List[str] = None,
                 target_namespace: str = "tokens",
                 target_embedding_dim: int = None) -> None:
        super().__init__(vocab)
        target_names = target_names or ["xintent", "xreact", "oreact"]

        # Note: The original tweaks the embeddings for "personx" to be the mean
        # across the embeddings for "he", "she", "him" and "her". Similarly for
        # "personx's" and so forth. We could consider that here as a well.
        self._source_embedder = source_embedder
        self._embedding_dropout = nn.Dropout(embedding_dropout)
        self._encoder = encoder
        self._max_decoding_steps = max_decoding_steps
        self._target_namespace = target_namespace

        # We need the start symbol to provide as the input at the first timestep of decoding, and
        # end symbol as a way to indicate the end of the decoded sequence.
        self._start_index = self.vocab.get_token_index(START_SYMBOL, self._target_namespace)
        self._end_index = self.vocab.get_token_index(END_SYMBOL, self._target_namespace)
        # Warning: The different decoders share a vocabulary! This may be
        # counterintuitive, but consider the case of xreact and oreact. A
        # reaction of "happy" could easily apply to both the subject of the
        # event and others. This could become less appropriate as more decoders
        # are added.
        num_classes = self.vocab.get_vocab_size(self._target_namespace)
        # Decoder output dim needs to be the same as the encoder output dim since we initialize the
        # hidden state of the decoder with that of the final hidden states of the encoder.
        self._decoder_output_dim = self._encoder.get_output_dim()
        target_embedding_dim = target_embedding_dim or self._source_embedder.get_output_dim()

        self._states = ModuleDict()
        for name in target_names:
            self._states[name] = StateDecoder(
                    num_classes,
                    target_embedding_dim,
                    self._decoder_output_dim
            )

        self._beam_search = BeamSearch(
                self._end_index,
                beam_size=beam_size,
                max_steps=max_decoding_steps
        )
Exemplo n.º 32
0
class Event2Mind(Model):
    """
    This ``Event2Mind`` class is a :class:`Model` which takes an event
    sequence, encodes it, and then uses the encoded representation to decode
    several mental state sequences.

    It is based on `the paper by Rashkin et al.
    <https://www.semanticscholar.org/paper/Event2Mind/b89f8a9b2192a8f2018eead6b135ed30a1f2144d>`_

    Parameters
    ----------
    vocab : ``Vocabulary``, required
        Vocabulary containing source and target vocabularies. They may be under the same namespace
        (``tokens``) or the target tokens can have a different namespace, in which case it needs to
        be specified as ``target_namespace``.
    source_embedder : ``TextFieldEmbedder``, required
        Embedder for source side sequences.
    embedding_dropout: float, required
        The amount of dropout to apply after the source tokens have been embedded.
    encoder : ``Seq2VecEncoder``, required
        The encoder of the "encoder/decoder" model.
    max_decoding_steps : int, required
        Length of decoded sequences.
    beam_size : int, optional (default = 10)
        The width of the beam search.
    target_names: ``List[str]``, optional, (default = ['xintent', 'xreact', 'oreact'])
        Names of the target fields matching those in the ``Instance`` objects.
    target_namespace : str, optional (default = 'tokens')
        If the target side vocabulary is different from the source side's, you need to specify the
        target's namespace here. If not, we'll assume it is "tokens", which is also the default
        choice for the source side, and this might cause them to share vocabularies.
    target_embedding_dim : int, optional (default = source_embedding_dim)
        You can specify an embedding dimensionality for the target side. If not, we'll use the same
        value as the source embedder's.
    """
    def __init__(self,
                 vocab: Vocabulary,
                 source_embedder: TextFieldEmbedder,
                 embedding_dropout: float,
                 encoder: Seq2VecEncoder,
                 max_decoding_steps: int,
                 beam_size: int = 10,
                 target_names: List[str] = None,
                 target_namespace: str = "tokens",
                 target_embedding_dim: int = None) -> None:
        super().__init__(vocab)
        target_names = target_names or ["xintent", "xreact", "oreact"]

        # Note: The original tweaks the embeddings for "personx" to be the mean
        # across the embeddings for "he", "she", "him" and "her". Similarly for
        # "personx's" and so forth. We could consider that here as a well.
        self._source_embedder = source_embedder
        self._embedding_dropout = nn.Dropout(embedding_dropout)
        self._encoder = encoder
        self._max_decoding_steps = max_decoding_steps
        self._target_namespace = target_namespace

        # We need the start symbol to provide as the input at the first timestep of decoding, and
        # end symbol as a way to indicate the end of the decoded sequence.
        self._start_index = self.vocab.get_token_index(START_SYMBOL, self._target_namespace)
        self._end_index = self.vocab.get_token_index(END_SYMBOL, self._target_namespace)
        # Warning: The different decoders share a vocabulary! This may be
        # counterintuitive, but consider the case of xreact and oreact. A
        # reaction of "happy" could easily apply to both the subject of the
        # event and others. This could become less appropriate as more decoders
        # are added.
        num_classes = self.vocab.get_vocab_size(self._target_namespace)
        # Decoder output dim needs to be the same as the encoder output dim since we initialize the
        # hidden state of the decoder with that of the final hidden states of the encoder.
        self._decoder_output_dim = self._encoder.get_output_dim()
        target_embedding_dim = target_embedding_dim or self._source_embedder.get_output_dim()

        self._states = ModuleDict()
        for name in target_names:
            self._states[name] = StateDecoder(
                    num_classes,
                    target_embedding_dim,
                    self._decoder_output_dim
            )

        self._beam_search = BeamSearch(
                self._end_index,
                beam_size=beam_size,
                max_steps=max_decoding_steps
        )

    def _update_recall(self,
                       all_top_k_predictions: torch.Tensor,
                       target_tokens: Dict[str, torch.LongTensor],
                       target_recall: UnigramRecall) -> None:
        targets = target_tokens["tokens"]
        target_mask = get_text_field_mask(target_tokens)
        # See comment in _get_loss.
        # TODO(brendanr): Do we need contiguous here?
        relevant_targets = targets[:, 1:].contiguous()
        relevant_mask = target_mask[:, 1:].contiguous()
        target_recall(
                all_top_k_predictions,
                relevant_targets,
                relevant_mask,
                self._end_index
        )

    def _get_num_decoding_steps(self,
                                target_tokens: Optional[Dict[str, torch.LongTensor]]) -> int:
        if target_tokens:
            targets = target_tokens["tokens"]
            target_sequence_length = targets.size()[1]
            # The last input from the target is either padding or the end
            # symbol.  Either way, we don't have to process it. (To be clear,
            # we do still output and compare against the end symbol, but there
            # is no need to take the end symbol as input to the decoder.)
            return target_sequence_length - 1
        else:
            return self._max_decoding_steps

    @overrides
    def forward(self,  # type: ignore
                source: Dict[str, torch.LongTensor],
                **target_tokens: Dict[str, Dict[str, torch.LongTensor]]) -> Dict[str, torch.Tensor]:
        # pylint: disable=arguments-differ
        """
        Decoder logic for producing the target sequences.

        Parameters
        ----------
        source : ``Dict[str, torch.LongTensor]``
            The output of ``TextField.as_array()`` applied on the source
            ``TextField``. This will be passed through a ``TextFieldEmbedder``
            and then through an encoder.
        target_tokens : ``Dict[str, Dict[str, torch.LongTensor]]``:
            Dictionary from name to output of ``Textfield.as_array()`` applied
            on target ``TextField``. We assume that the target tokens are also
            represented as a ``TextField``.
        """
        # (batch_size, input_sequence_length, embedding_dim)
        embedded_input = self._embedding_dropout(self._source_embedder(source))
        source_mask = get_text_field_mask(source)
        # (batch_size, encoder_output_dim)
        final_encoder_output = self._encoder(embedded_input, source_mask)
        output_dict = {}

        # Perform greedy search so we can get the loss.
        if target_tokens:
            if target_tokens.keys() != self._states.keys():
                target_only = target_tokens.keys() - self._states.keys()
                states_only = self._states.keys() - target_tokens.keys()
                raise Exception("Mismatch between target_tokens and self._states. Keys in " +
                                f"targets only: {target_only} Keys in states only: {states_only}")
            total_loss = 0
            for name, state in self._states.items():
                loss = self.greedy_search(
                        final_encoder_output=final_encoder_output,
                        target_tokens=target_tokens[name],
                        target_embedder=state.embedder,
                        decoder_cell=state.decoder_cell,
                        output_projection_layer=state.output_projection_layer
                )
                total_loss += loss
                output_dict[f"{name}_loss"] = loss

            # Use mean loss (instead of the sum of the losses) to be comparable to the paper.
            output_dict["loss"] = total_loss / len(self._states)

        # Perform beam search to obtain the predictions.
        if not self.training:
            batch_size = final_encoder_output.size()[0]
            for name, state in self._states.items():
                start_predictions = final_encoder_output.new_full(
                        (batch_size,), fill_value=self._start_index, dtype=torch.long)
                start_state = {"decoder_hidden": final_encoder_output}

                # (batch_size, 10, num_decoding_steps)
                all_top_k_predictions, log_probabilities = self._beam_search.search(
                        start_predictions, start_state, state.take_step)

                if target_tokens:
                    self._update_recall(all_top_k_predictions, target_tokens[name], state.recall)
                output_dict[f"{name}_top_k_predictions"] = all_top_k_predictions
                output_dict[f"{name}_top_k_log_probabilities"] = log_probabilities

        return output_dict

    def greedy_search(self,
                      final_encoder_output: torch.LongTensor,
                      target_tokens: Dict[str, torch.LongTensor],
                      target_embedder: Embedding,
                      decoder_cell: GRUCell,
                      output_projection_layer: Linear) -> torch.FloatTensor:
        """
        Greedily produces a sequence using the provided ``decoder_cell``.
        Returns the cross entropy between this sequence and ``target_tokens``.

        Parameters
        ----------
        final_encoder_output : ``torch.LongTensor``, required
            Vector produced by ``self._encoder``.
        target_tokens : ``Dict[str, torch.LongTensor]``, required
            The output of ``TextField.as_array()`` applied on some target ``TextField``.
        target_embedder : ``Embedding``, required
            Used to embed the target tokens.
        decoder_cell: ``GRUCell``, required
            The recurrent cell used at each time step.
        output_projection_layer: ``Linear``, required
            Linear layer mapping to the desired number of classes.
        """
        num_decoding_steps = self._get_num_decoding_steps(target_tokens)
        targets = target_tokens["tokens"]
        decoder_hidden = final_encoder_output
        step_logits = []
        for timestep in range(num_decoding_steps):
            # See https://github.com/allenai/allennlp/issues/1134.
            input_choices = targets[:, timestep]
            decoder_input = target_embedder(input_choices)
            decoder_hidden = decoder_cell(decoder_input, decoder_hidden)
            # (batch_size, num_classes)
            output_projections = output_projection_layer(decoder_hidden)
            # list of (batch_size, 1, num_classes)
            step_logits.append(output_projections.unsqueeze(1))
        # (batch_size, num_decoding_steps, num_classes)
        logits = torch.cat(step_logits, 1)
        target_mask = get_text_field_mask(target_tokens)
        return self._get_loss(logits, targets, target_mask)

    def greedy_predict(self,
                       final_encoder_output: torch.LongTensor,
                       target_embedder: Embedding,
                       decoder_cell: GRUCell,
                       output_projection_layer: Linear) -> torch.Tensor:
        """
        Greedily produces a sequence using the provided ``decoder_cell``.
        Returns the predicted sequence.

        Parameters
        ----------
        final_encoder_output : ``torch.LongTensor``, required
            Vector produced by ``self._encoder``.
        target_embedder : ``Embedding``, required
            Used to embed the target tokens.
        decoder_cell: ``GRUCell``, required
            The recurrent cell used at each time step.
        output_projection_layer: ``Linear``, required
            Linear layer mapping to the desired number of classes.
        """
        num_decoding_steps = self._max_decoding_steps
        decoder_hidden = final_encoder_output
        batch_size = final_encoder_output.size()[0]
        predictions = [final_encoder_output.new_full(
                (batch_size,), fill_value=self._start_index, dtype=torch.long
        )]
        for _ in range(num_decoding_steps):
            input_choices = predictions[-1]
            decoder_input = target_embedder(input_choices)
            decoder_hidden = decoder_cell(decoder_input, decoder_hidden)
            # (batch_size, num_classes)
            output_projections = output_projection_layer(decoder_hidden)
            class_probabilities = F.softmax(output_projections, dim=-1)
            _, predicted_classes = torch.max(class_probabilities, 1)
            predictions.append(predicted_classes)
        all_predictions = torch.cat([ps.unsqueeze(1) for ps in predictions], 1)
        # Drop start symbol and return.
        return all_predictions[:, 1:]

    @staticmethod
    def _get_loss(logits: torch.LongTensor,
                  targets: torch.LongTensor,
                  target_mask: torch.LongTensor) -> torch.FloatTensor:
        """
        Takes logits (unnormalized outputs from the decoder) of size (batch_size,
        num_decoding_steps, num_classes), target indices of size (batch_size, num_decoding_steps+1)
        and corresponding masks of size (batch_size, num_decoding_steps+1) steps and computes cross
        entropy loss while taking the mask into account.

        The length of ``targets`` is expected to be greater than that of ``logits`` because the
        decoder does not need to compute the output corresponding to the last timestep of
        ``targets``. This method aligns the inputs appropriately to compute the loss.

        During training, we want the logit corresponding to timestep i to be similar to the target
        token from timestep i + 1. That is, the targets should be shifted by one timestep for
        appropriate comparison.  Consider a single example where the target has 3 words, and
        padding is to 7 tokens.
           The complete sequence would correspond to <S> w1  w2  w3  <E> <P> <P>
           and the mask would be                     1   1   1   1   1   0   0
           and let the logits be                     l1  l2  l3  l4  l5  l6
        We actually need to compare:
           the sequence           w1  w2  w3  <E> <P> <P>
           with masks             1   1   1   1   0   0
           against                l1  l2  l3  l4  l5  l6
           (where the input was)  <S> w1  w2  w3  <E> <P>
        """
        relevant_targets = targets[:, 1:].contiguous()  # (batch_size, num_decoding_steps)
        relevant_mask = target_mask[:, 1:].contiguous()  # (batch_size, num_decoding_steps)
        loss = sequence_cross_entropy_with_logits(logits, relevant_targets, relevant_mask)
        return loss

    def decode_all(self, predicted_indices: torch.Tensor) -> List[List[str]]:
        if not isinstance(predicted_indices, numpy.ndarray):
            predicted_indices = predicted_indices.detach().cpu().numpy()
        all_predicted_tokens = []
        for indices in predicted_indices:
            indices = list(indices)
            # Collect indices till the first end_symbol
            if self._end_index in indices:
                indices = indices[:indices.index(self._end_index)]
            predicted_tokens = [self.vocab.get_token_from_index(x, namespace=self._target_namespace)
                                for x in indices]
            all_predicted_tokens.append(predicted_tokens)
        return all_predicted_tokens

    @overrides
    def decode(self, output_dict: Dict[str, torch.Tensor]) -> Dict[str, List[List[str]]]:
        """
        This method overrides ``Model.decode``, which gets called after ``Model.forward``, at test
        time, to finalize predictions. The logic for the decoder part of the encoder-decoder lives
        within the ``forward`` method.

        This method trims the output predictions to the first end symbol, replaces indices with
        corresponding tokens, and adds fields for the tokens to the ``output_dict``.
        """
        for name in self._states:
            top_k_predicted_indices = output_dict[f"{name}_top_k_predictions"][0]
            output_dict[f"{name}_top_k_predicted_tokens"] = [self.decode_all(top_k_predicted_indices)]

        return output_dict

    @overrides
    def get_metrics(self, reset: bool = False) -> Dict[str, float]:
        all_metrics = {}
        # Recall@10 needs beam search which doesn't happen during training.
        if not self.training:
            for name, state in self._states.items():
                all_metrics[name] = state.recall.get_metric(reset=reset)
        return all_metrics