Ejemplo n.º 1
0
    def __init__(self, workflow, **kwargs):
        super(LSTM, self).__init__(workflow, **kwargs)
        self.simple = kwargs.pop("simple", True)

        # Create units
        self.ij = InputJoiner(self)
        self.input_gate = All2AllSigmoid(self, name="input_gate", **kwargs)
        self.forget_gate = All2AllSigmoid(self, name="forget_gate", **kwargs)
        self.memory_maker = All2AllTanh(self, name="memory_maker", **kwargs)

        if not self.simple:
            self.ij_output = InputJoiner(self)
        self.output_gate = All2AllSigmoid(self, name="output_gate", **kwargs)
        self.output_activation = ForwardTanh(self, name="output_activation",
                                             **kwargs)

        self.input_mul = Multiplier(self, name="input_mul")
        self.forget_mul = Multiplier(self, name="forget_mul")
        self.summator = Summator(self, name="memory_cell")
        self.output_mul = Multiplier(self, name="output_mul")

        # Link control flow
        self.ij.link_from(self.start_point)
        self.input_gate.link_from(self.ij)
        self.forget_gate.link_from(self.ij)
        self.memory_maker.link_from(self.ij)
        self.input_mul.link_from(self.input_gate, self.memory_maker)
        self.forget_mul.link_from(self.forget_gate)
        self.summator.link_from(self.input_mul, self.forget_mul)

        if not self.simple:
            self.ij_output.link_from(self.summator, self.ij)
            self.output_gate.link_from(self.ij_output)
        else:
            self.output_gate.link_from(self.ij)

        self.output_activation.link_from(self.summator)
        self.output_mul.link_from(self.output_activation, self.output_gate)
        self.end_point.link_from(self.output_mul)

        # Link unit attributes
        self.ij.link_inputs(self, "input", "prev_output")
        self.input_gate.link_attrs(self.ij, ("input", "output"))
        self.forget_gate.link_attrs(self.ij, ("input", "output"))
        self.memory_maker.link_attrs(self.ij, ("input", "output"))
        self.input_mul.link_attrs(self.input_gate, ("x", "output"))
        self.input_mul.link_attrs(self.memory_maker, ("y", "output"))
        self.forget_mul.link_attrs(self.forget_gate, ("x", "output"))
        self.forget_mul.link_attrs(self, ("y", "prev_memory"))
        self.summator.link_attrs(self.input_mul, ("x", "output"))
        self.summator.link_attrs(self.forget_mul, ("y", "output"))
        self.output_activation.link_attrs(self.summator, ("input", "output"))

        if not self.simple:
            self.ij_output.link_inputs(self.ij, "output")
            self.ij_output.link_inputs(self.summator, "output")
            self.output_gate.link_attrs(self.ij_output, ("input", "output"))
        else:
            self.output_gate.link_attrs(self.ij, ("input", "output"))

        self.output_mul.link_attrs(self.output_gate, ("x", "output"))
        self.output_mul.link_attrs(self.output_activation, ("y", "output"))
        self.link_attrs(self.output_mul, "output")
        self.link_attrs(self.summator, ("memory", "output"))

        self.demand("input", "prev_output", "prev_memory")
Ejemplo n.º 2
0
class LSTM(FullyConnectedOutput, AcceleratedWorkflow):
    """LSTM block.

    Must be assigned before initialize():
        input: current input vector
        prev_output: output from the previous LSTM unit (hidden state)
        prev_memory: value of memory cell from the previous LSTM unit

    Updates after run():
        output: current output (hidden state)
        memory: current value of memory cell

    Attributes:
        simple: do not connect memory cell to an output gate.
    """

    MAPPING = {"LSTM"}

    def __init__(self, workflow, **kwargs):
        super(LSTM, self).__init__(workflow, **kwargs)
        self.simple = kwargs.pop("simple", True)

        # Create units
        self.ij = InputJoiner(self)
        self.input_gate = All2AllSigmoid(self, name="input_gate", **kwargs)
        self.forget_gate = All2AllSigmoid(self, name="forget_gate", **kwargs)
        self.memory_maker = All2AllTanh(self, name="memory_maker", **kwargs)

        if not self.simple:
            self.ij_output = InputJoiner(self)
        self.output_gate = All2AllSigmoid(self, name="output_gate", **kwargs)
        self.output_activation = ForwardTanh(self, name="output_activation",
                                             **kwargs)

        self.input_mul = Multiplier(self, name="input_mul")
        self.forget_mul = Multiplier(self, name="forget_mul")
        self.summator = Summator(self, name="memory_cell")
        self.output_mul = Multiplier(self, name="output_mul")

        # Link control flow
        self.ij.link_from(self.start_point)
        self.input_gate.link_from(self.ij)
        self.forget_gate.link_from(self.ij)
        self.memory_maker.link_from(self.ij)
        self.input_mul.link_from(self.input_gate, self.memory_maker)
        self.forget_mul.link_from(self.forget_gate)
        self.summator.link_from(self.input_mul, self.forget_mul)

        if not self.simple:
            self.ij_output.link_from(self.summator, self.ij)
            self.output_gate.link_from(self.ij_output)
        else:
            self.output_gate.link_from(self.ij)

        self.output_activation.link_from(self.summator)
        self.output_mul.link_from(self.output_activation, self.output_gate)
        self.end_point.link_from(self.output_mul)

        # Link unit attributes
        self.ij.link_inputs(self, "input", "prev_output")
        self.input_gate.link_attrs(self.ij, ("input", "output"))
        self.forget_gate.link_attrs(self.ij, ("input", "output"))
        self.memory_maker.link_attrs(self.ij, ("input", "output"))
        self.input_mul.link_attrs(self.input_gate, ("x", "output"))
        self.input_mul.link_attrs(self.memory_maker, ("y", "output"))
        self.forget_mul.link_attrs(self.forget_gate, ("x", "output"))
        self.forget_mul.link_attrs(self, ("y", "prev_memory"))
        self.summator.link_attrs(self.input_mul, ("x", "output"))
        self.summator.link_attrs(self.forget_mul, ("y", "output"))
        self.output_activation.link_attrs(self.summator, ("input", "output"))

        if not self.simple:
            self.ij_output.link_inputs(self.ij, "output")
            self.ij_output.link_inputs(self.summator, "output")
            self.output_gate.link_attrs(self.ij_output, ("input", "output"))
        else:
            self.output_gate.link_attrs(self.ij, ("input", "output"))

        self.output_mul.link_attrs(self.output_gate, ("x", "output"))
        self.output_mul.link_attrs(self.output_activation, ("y", "output"))
        self.link_attrs(self.output_mul, "output")
        self.link_attrs(self.summator, ("memory", "output"))

        self.demand("input", "prev_output", "prev_memory")

    def link_weights(self, src):
        """Links this weights to the weights of src.
        """
        for attr in ("input_gate", "forget_gate", "memory_maker",
                     "output_gate"):
            getattr(self, attr).link_attrs(
                getattr(src, attr), "weights", "bias")