Пример #1
0
    def _get_step_inputs(self, dec_inputs: ModelIO) -> ModelIO:

        if hasattr(dec_inputs, 'h'):
            h = dec_inputs.h
        elif hasattr(dec_inputs, 'enc_hidden'):
            h = dec_inputs.enc_hidden
        else:
            h = dec_inputs.enc_outputs[-1]
            # log.error(f"I don't have any hidden state to use for the step from {dec_inputs}.")
            # raise SystemError

        batch_size = h.shape[0]

        if hasattr(dec_inputs, 'x'):
            # Not the first step. Use outputs from previous step instead
            x = dec_inputs.x
        elif hasattr(dec_inputs, 'transform'):
            # Use the transformation token from the input
            x = dec_inputs.transform[1:-1]  # strip <sos> and <eos> tokens
        else:
            log.error(
                f"I don't have any input to use for the step from {dec_inputs}."
            )
            raise SystemError

        dec_step_input = ModelIO({"x": x, "h": h})

        if hasattr(dec_inputs, 'enc_outputs'):
            dec_step_input.set_attribute('enc_outputs', dec_inputs.enc_outputs)

        return dec_step_input
Пример #2
0
    def forward_step(self,
                     step_input: ModelIO,
                     src_mask: Tensor = None) -> ModelIO:

        unit_input = F.relu(self._embedding(step_input.x))
        h, c = step_input.h, step_input.c

        if len(unit_input.shape) == 2:
            unit_input = unit_input.unsqueeze(0)

        if len(h.shape) == 2:
            h = h.unsqueeze(0)

        hidden = (h, c)

        if src_mask is not None:
            unit_input, attn = self.compute_attention(unit_input,
                                                      step_input.enc_outputs,
                                                      h, src_mask)

        _, state = self._unit(unit_input, hidden)
        y = self._out(state[0][-1])

        step_result = ModelIO({"y": y, "h": state[0], "c": state[1]})
        if src_mask is not None:
            step_result.set_attribute("attn", attn)

        return step_result
Пример #3
0
    def forward_step(self,
                     step_input: ModelIO,
                     src_mask: Tensor = None) -> ModelIO:

        h = step_input.h
        h = h.unsqueeze(0) if len(h.shape) == 2 else h
        unit_input = F.relu(self._embedding(step_input.x))
        unit_input = unit_input.unsqueeze(0) if len(
            unit_input.shape) == 2 else unit_input
        if src_mask is not None:
            unit_input, attn = self.compute_attention(unit_input,
                                                      step_input.enc_outputs,
                                                      h, src_mask)

        # print("unit_input", unit_input.shape)
        # print("attn", attn.shape)

        _, state = self._unit(unit_input, h)
        y = self._out(state[-1])

        step_result = ModelIO({"y": y, "h": state})
        if src_mask is not None:
            step_result.set_attribute("attn", attn)

        return step_result
Пример #4
0
    def forward_expression(self, expressions):
        """
    ...
    """

        representations = []
        sources = []

        # Compute forward pass on each sub-expression
        for term in expressions:
            if isinstance(term, str):
                representations.append(term)
            else:
                enc_input = ModelIO({"source": term.source})
                sources.append(term.source)
                enc_output = self._encoder(enc_input)
                representations.append(enc_output)

        # Perform arithmetic operation on representations
        buffer = {"enc_hidden": None, "enc_outputs": None}
        should_operate = False
        for r in representations:
            if isinstance(r, str):
                should_operate = torch.add if r == "+" else torch.subtract
            else:
                if not should_operate:
                    for key in buffer:
                        try:
                            buffer[key] = getattr(r, key)
                        except:
                            buffer[key] = None
                else:
                    for key in buffer:
                        try:
                            buffer[key] = should_operate(
                                buffer[key], getattr(r, key))
                        except:
                            buffer[key] = None

                    should_operate = False

        dec_inputs = ModelIO({
            "source": sources[0],
            "transform": expressions[0].annotation
        })

        for key in buffer:
            if buffer[key] is not None:
                dec_inputs.set_attribute(key, buffer[key])

        dec_output = self._decoder(dec_inputs, tf_ratio=0.0)
        return dec_output.dec_outputs
Пример #5
0
    def forward(self,
                batch: Batch,
                tf_ratio: float = 0.0,
                plot_trajectories=False):
        """
    Runs the forward pass.

    batch (torchtext Batch): batch of [source, annotation, target]
    tf_ratio (float in range [0, 1]): if present, probability of using teacher
      forcing.
    """

        enc_input = ModelIO({"source": batch.source})
        enc_output = self._encoder(enc_input)

        enc_output.set_attributes({
            "source": batch.source,
            "transform": batch.annotation
        })

        if hasattr(batch, 'target'):
            enc_output.set_attribute("target", batch.target)

        dec_output = self._decoder(enc_output, tf_ratio=tf_ratio)

        return dec_output.dec_outputs
Пример #6
0
    def _get_step_inputs(self, dec_inputs: ModelIO) -> ModelIO:

        # TODO: This is hacky....make the logic nicer here.
        if hasattr(dec_inputs, 'enc_hidden') and isinstance(
                dec_inputs.enc_hidden, tuple):
            dec_inputs.c = dec_inputs.enc_hidden[1]
            dec_inputs.enc_hidden = dec_inputs.enc_hidden[0]

        # Get default implementation
        step_input = super()._get_step_inputs(dec_inputs)
        batch_size = step_input.h.shape[0]

        if hasattr(dec_inputs, 'c'):
            # We're @ timestep t>0, and the decoder has already produced a 'c'
            step_input.set_attribute('c', dec_inputs.c)
        else:
            # We're @ timestep t=0, so create an initial 'c' of all zeros.
            c = torch.zeros(self.num_layers, batch_size,
                            self.hidden_size).to(avd)
            step_input.set_attribute('c', c)

        return step_input
Пример #7
0
    def forward(self, dec_input: ModelIO, tf_ratio: float) -> ModelIO:
        """
    Try and keep the same signature as SequenceDecoder.
    """

        seq_len, batch_size, _ = dec_input.enc_outputs.shape

        teacher_forcing = random.random() < tf_ratio
        if teacher_forcing and not hasattr(dec_input, 'target'):
            log.error("You must provide a 'target' to use teacher forcing.")
            raise SystemError

        if hasattr(dec_input, 'target'):
            gen_len = dec_input.target.shape[0]
        else:
            gen_len = self.max_length

        # tgt = inputs to the decoder, starting with TRANS token and appending
        #       dec_input.target[i] or predicted token
        # mem = encoder outputs, used for multi-headed attention
        tgt = dec_input.transform[1:-1]  # strip <sos> and <eos> tokens
        mem = dec_input.enc_outputs

        has_finished = torch.zeros(batch_size, dtype=torch.bool).to(avd)

        for i in range(gen_len):

            # Re-embed every time since we need positional encoding to take into
            # account the new tokens in context of the old ones.
            tgt_emb = self._embedding(tgt)

            # Ensures that once a model outputs token <t> @ position i, it will
            # always output <t> @ i even for further timesteps
            tgt_mask = self._generate_square_subsequent_mask(
                tgt.shape[0]).to(avd)

            # Calculate the next predicted token from output
            out = self._out(
                self._unit(tgt=tgt_emb, memory=mem, tgt_mask=tgt_mask))
            predicted = out[-1].argmax(dim=1)

            has_finished[predicted == self.EOS_IDX] = True
            if all(has_finished):
                break
            else:
                new_tgt = dec_input.target[i] if teacher_forcing else predicted
                tgt = torch.cat((tgt, new_tgt.unsqueeze(0)), dim=0).to(avd)

        return ModelIO({"dec_outputs": out})
Пример #8
0
    def forward(self, enc_input: ModelIO) -> ModelIO:
        """
      Compute the forward pass.
    """
        enc = self.module(enc_input.source)

        output = ModelIO()
        if isinstance(enc, tuple):
            enc_outputs, enc_hidden = enc
            output.set_attributes({
                "enc_outputs": enc_outputs,
                "enc_hidden": enc_hidden
            })
        else:
            enc_outputs = enc
            output.set_attributes({"enc_outputs": enc_outputs})

        return output
Пример #9
0
    def forward(self, dec_input: ModelIO, tf_ratio: float) -> ModelIO:
        """
    Computes the forward pass of the decoder.

    Paramters:
      - dec_input: wrapper object for the various inputs to the decoder. This
          allows for variadic parameters to account for various units' different
          input requirements (i.e., LSTMs require a `cell`)
      - tf_ratio (float in range [0.0, 1.0]): chance that teacher_forcing is
          used for a given batch. If tf_ratio is not `None`, a `target` must
          be present in `dec_input`.
    """

        seq_len, batch_size, _ = dec_input.enc_outputs.shape

        teacher_forcing = random.random() < tf_ratio
        if teacher_forcing and not hasattr(dec_input, 'target'):
            log.error("You must provide a 'target' to use teacher forcing.")
            raise SystemError

        # Okay so we still need this, but should we be padding
        # the outputs or something when not using teacher forcing?
        if hasattr(dec_input, 'target'):
            gen_len = dec_input.target.shape[0]
        else:
            gen_len = self.max_length

        # Get input to decoder unit
        dec_step_input = self._get_step_inputs(dec_input)

        # Skeletons for the decoder outputs
        has_finished = torch.zeros(batch_size, dtype=torch.bool).to(avd)
        dec_outputs = torch.zeros(gen_len, batch_size, self.vocab_size).to(avd)
        dec_outputs[:, :, self.PAD_IDX] = 1.0
        dec_hiddens = torch.zeros(gen_len, batch_size,
                                  self.hidden_size).to(avd)

        # Attention
        if self.attention_type is not None:
            attention = torch.zeros(gen_len, batch_size, seq_len).to(avd)
            src_mask = create_mask(dec_input.source, self._src_vocab)
        else:
            src_mask = None

        for i in range(gen_len):

            # Get forward_step pass
            step_result = self.forward_step(dec_step_input, src_mask)
            step_prediction = step_result.y.argmax(dim=1)

            # Update results
            dec_outputs[i] = step_result.y
            dec_hiddens[i] = step_result.h[-1]
            if self.attention_type is not None:
                attention[i] = step_result.attn

            # Check if we're done
            has_finished[step_prediction == self.EOS_IDX] = True
            if all(has_finished):
                break
            else:
                # Otherwise, iterate x, h and repeat
                x = dec_input.target[i] if teacher_forcing else step_prediction

                # A little hacky, but we want to use every value from the step result
                # For the next step EXCEPT x, which should come from either the target
                # or the step_prediction.
                step_result.set_attribute('x', x)
                step_result.set_attribute('enc_outputs', dec_input.enc_outputs)
                dec_step_input = self._get_step_inputs(step_result)

        output = ModelIO({
            "dec_outputs": dec_outputs,
            "dec_hiddens": dec_hiddens
        })

        if self.attention_type is not None:
            output.set_attribute("attention", attention)

        return output
Пример #10
0
    def forward(self, enc_input: ModelIO) -> ModelIO:

        embedded = self.module(enc_input.source)
        encoded = self.unit(embedded.last_hidden_state)
        return ModelIO({"enc_outputs": encoded})