Example #1
0
            def CellFn(theta, state0, inputs):
                """A cell fn is exectued inside of StackedRecurrent."""
                del state0
                fprop_inputs = []
                for input_idx in range(len(state_shapes[i])):
                    name = 's{}'.format(input_idx)
                    if state_shapes[i][input_idx] is not None:
                        inputs[name].set_shape(state_shapes[i][input_idx])
                        fprop_inputs.append(inputs[name])
                    else:
                        fprop_inputs.append(None)

                with py_utils.RemoveAssertContext(remove=True):
                    with CellFnFPropOpReplacementWrapper():
                        tf.logging.info('cell {} input {}'.format(
                            i, fprop_inputs))
                        mb_tensor = inputs[_MICRO_BATCH_STATE_NAME]
                        SetOverWriteGlobalStep(mb_tensor)
                        _, cell = self._cells[i]
                        outputs = cell.FProp(theta, *fprop_inputs)

                state1 = py_utils.NestedMap()
                state1[_MICRO_BATCH_STATE_NAME] = mb_tensor
                outputs = _ToTuple(outputs)
                assert len(outputs) == len(state_shapes[i + 1])
                for output_idx in range(len(outputs)):
                    if outputs[output_idx] is not None:
                        name = 's{}'.format(output_idx)
                        state1[name] = outputs[output_idx]
                return state1, py_utils.NestedMap()
Example #2
0
            def CellFn(theta, state0, inputs):
                """A cell fn is exectued inside of StackedRecurrent."""
                del state0

                def _FPropInputSetShape(name, t_shape):
                    if t_shape is None:
                        return None
                    inputs[name].set_shape(t_shape.ToTensorShape().as_list())
                    return inputs[name]

                if p.nested_map_fprop:
                    # pylint: disable=protected-access
                    fprop_inputs = state_shapes[i]._RecursiveMap(
                        _FPropInputSetShape)
                    # pylint: enable=protected-access
                else:
                    fprop_inputs = []
                    for input_idx, input_shape in enumerate(state_shapes[i]):
                        name = 's{}'.format(input_idx)
                        fprop_inputs.append(
                            _FPropInputSetShape(name, input_shape))

                with py_utils.RemoveAssertContext(remove=True):
                    with CellFnFPropOpReplacementWrapper():
                        tf.logging.info('cell {} input {}'.format(
                            i, fprop_inputs))
                        mb_tensor = inputs[_MICRO_BATCH_STATE_NAME]
                        SetOverWriteGlobalStep(mb_tensor)
                        _, cell = self._cells[i]
                        fprop_inputs = _ToTuple(fprop_inputs)
                        outputs = cell.FProp(theta, *fprop_inputs)

                if p.nested_map_fprop:
                    assert py_utils.IsCompatible(outputs, state_shapes[i + 1])
                    state1 = outputs.Filter(lambda x: x is not None)
                else:
                    state1 = py_utils.NestedMap()
                    outputs = _ToTuple(outputs)
                    assert len(outputs) == len(state_shapes[i + 1])
                    for output_idx in range(len(outputs)):
                        if outputs[output_idx] is not None:
                            name = 's{}'.format(output_idx)
                            state1[name] = outputs[output_idx]
                state1[_MICRO_BATCH_STATE_NAME] = mb_tensor
                return state1, py_utils.NestedMap()