def __init__(self, workflow, **kwargs): super(LSTM, self).__init__(workflow, **kwargs) self.simple = kwargs.pop("simple", True) # Create units self.ij = InputJoiner(self) self.input_gate = All2AllSigmoid(self, name="input_gate", **kwargs) self.forget_gate = All2AllSigmoid(self, name="forget_gate", **kwargs) self.memory_maker = All2AllTanh(self, name="memory_maker", **kwargs) if not self.simple: self.ij_output = InputJoiner(self) self.output_gate = All2AllSigmoid(self, name="output_gate", **kwargs) self.output_activation = ForwardTanh(self, name="output_activation", **kwargs) self.input_mul = Multiplier(self, name="input_mul") self.forget_mul = Multiplier(self, name="forget_mul") self.summator = Summator(self, name="memory_cell") self.output_mul = Multiplier(self, name="output_mul") # Link control flow self.ij.link_from(self.start_point) self.input_gate.link_from(self.ij) self.forget_gate.link_from(self.ij) self.memory_maker.link_from(self.ij) self.input_mul.link_from(self.input_gate, self.memory_maker) self.forget_mul.link_from(self.forget_gate) self.summator.link_from(self.input_mul, self.forget_mul) if not self.simple: self.ij_output.link_from(self.summator, self.ij) self.output_gate.link_from(self.ij_output) else: self.output_gate.link_from(self.ij) self.output_activation.link_from(self.summator) self.output_mul.link_from(self.output_activation, self.output_gate) self.end_point.link_from(self.output_mul) # Link unit attributes self.ij.link_inputs(self, "input", "prev_output") self.input_gate.link_attrs(self.ij, ("input", "output")) self.forget_gate.link_attrs(self.ij, ("input", "output")) self.memory_maker.link_attrs(self.ij, ("input", "output")) self.input_mul.link_attrs(self.input_gate, ("x", "output")) self.input_mul.link_attrs(self.memory_maker, ("y", "output")) self.forget_mul.link_attrs(self.forget_gate, ("x", "output")) self.forget_mul.link_attrs(self, ("y", "prev_memory")) self.summator.link_attrs(self.input_mul, ("x", "output")) self.summator.link_attrs(self.forget_mul, ("y", "output")) self.output_activation.link_attrs(self.summator, ("input", "output")) if not self.simple: self.ij_output.link_inputs(self.ij, "output") self.ij_output.link_inputs(self.summator, "output") self.output_gate.link_attrs(self.ij_output, ("input", "output")) else: self.output_gate.link_attrs(self.ij, ("input", "output")) self.output_mul.link_attrs(self.output_gate, ("x", "output")) self.output_mul.link_attrs(self.output_activation, ("y", "output")) self.link_attrs(self.output_mul, "output") self.link_attrs(self.summator, ("memory", "output")) self.demand("input", "prev_output", "prev_memory")
class LSTM(FullyConnectedOutput, AcceleratedWorkflow): """LSTM block. Must be assigned before initialize(): input: current input vector prev_output: output from the previous LSTM unit (hidden state) prev_memory: value of memory cell from the previous LSTM unit Updates after run(): output: current output (hidden state) memory: current value of memory cell Attributes: simple: do not connect memory cell to an output gate. """ MAPPING = {"LSTM"} def __init__(self, workflow, **kwargs): super(LSTM, self).__init__(workflow, **kwargs) self.simple = kwargs.pop("simple", True) # Create units self.ij = InputJoiner(self) self.input_gate = All2AllSigmoid(self, name="input_gate", **kwargs) self.forget_gate = All2AllSigmoid(self, name="forget_gate", **kwargs) self.memory_maker = All2AllTanh(self, name="memory_maker", **kwargs) if not self.simple: self.ij_output = InputJoiner(self) self.output_gate = All2AllSigmoid(self, name="output_gate", **kwargs) self.output_activation = ForwardTanh(self, name="output_activation", **kwargs) self.input_mul = Multiplier(self, name="input_mul") self.forget_mul = Multiplier(self, name="forget_mul") self.summator = Summator(self, name="memory_cell") self.output_mul = Multiplier(self, name="output_mul") # Link control flow self.ij.link_from(self.start_point) self.input_gate.link_from(self.ij) self.forget_gate.link_from(self.ij) self.memory_maker.link_from(self.ij) self.input_mul.link_from(self.input_gate, self.memory_maker) self.forget_mul.link_from(self.forget_gate) self.summator.link_from(self.input_mul, self.forget_mul) if not self.simple: self.ij_output.link_from(self.summator, self.ij) self.output_gate.link_from(self.ij_output) else: self.output_gate.link_from(self.ij) self.output_activation.link_from(self.summator) self.output_mul.link_from(self.output_activation, self.output_gate) self.end_point.link_from(self.output_mul) # Link unit attributes self.ij.link_inputs(self, "input", "prev_output") self.input_gate.link_attrs(self.ij, ("input", "output")) self.forget_gate.link_attrs(self.ij, ("input", "output")) self.memory_maker.link_attrs(self.ij, ("input", "output")) self.input_mul.link_attrs(self.input_gate, ("x", "output")) self.input_mul.link_attrs(self.memory_maker, ("y", "output")) self.forget_mul.link_attrs(self.forget_gate, ("x", "output")) self.forget_mul.link_attrs(self, ("y", "prev_memory")) self.summator.link_attrs(self.input_mul, ("x", "output")) self.summator.link_attrs(self.forget_mul, ("y", "output")) self.output_activation.link_attrs(self.summator, ("input", "output")) if not self.simple: self.ij_output.link_inputs(self.ij, "output") self.ij_output.link_inputs(self.summator, "output") self.output_gate.link_attrs(self.ij_output, ("input", "output")) else: self.output_gate.link_attrs(self.ij, ("input", "output")) self.output_mul.link_attrs(self.output_gate, ("x", "output")) self.output_mul.link_attrs(self.output_activation, ("y", "output")) self.link_attrs(self.output_mul, "output") self.link_attrs(self.summator, ("memory", "output")) self.demand("input", "prev_output", "prev_memory") def link_weights(self, src): """Links this weights to the weights of src. """ for attr in ("input_gate", "forget_gate", "memory_maker", "output_gate"): getattr(self, attr).link_attrs( getattr(src, attr), "weights", "bias")