Пример #1
0
    def testSimpleStacked(self):
        g = tf.Graph()
        with g.as_default():
            devices = ['/cpu:0'] * 3
            cell_fns = [self.Poly, self.Identity, self.Identity]
            cell_grads = [None] * 3
            cell_outs = [lambda x: x] * 3
            cell_out_grads = [lambda x: x] * 3
            w0 = tf.constant(2.)
            w1 = tf.constant(0.)
            w2 = tf.constant(0.)
            thetas = [
                py_utils.NestedMap(x=w0),
                py_utils.NestedMap(x=w1),
                py_utils.NestedMap(x=w2)
            ]
            init_states = [py_utils.NestedMap(s=tf.constant(0.))] * 3
            inputs = py_utils.NestedMap(c=tf.constant([1., 2., 1., 0.]),
                                        padding=tf.constant([0., 0., 0., 1.]))
            output, _ = recurrent.StackedRecurrent(
                devices=devices,
                cell_fns=cell_fns,
                cell_grads=cell_grads,
                cell_outs=cell_outs,
                cell_out_grads=cell_out_grads,
                thetas=thetas,
                init_states=init_states,
                inputs=inputs)
            dw0, dw1, dw2 = tf.gradients(tf.reduce_sum(output.s), [w0, w1, w2])

        with self.session(graph=g) as sess:
            (output, dw0, dw1, dw2) = sess.run([output.s, dw0, dw1, dw2])

        self.assertAllClose(output, [1., 4., 9., 0.])
        self.assertAllClose(dw2, 0.)
        self.assertAllClose(dw1, 0.)
        self.assertAllClose(dw0, 7.)
Пример #2
0
  def FProp(self, theta, *args):
    """Run multiple cells in different devices in a pipelining manner.

    Args:
      theta: A NestedMap object containing weights' values of this layer and its
        children layers.
      *args: Non-keyworded variable length argument list of input tensors.

    Returns:
      A list of output tensors
    """
    # TODO(huangyp): handle optional None inputs.
    p = self.params
    if p.is_eval:
      outputs = _ToTuple(args)
      for (name, l) in self._before_layers:
        outputs = _ToTuple(outputs)
        outputs = l.FProp(theta[name], *outputs)
      for (name, l) in self._cells:
        outputs = _ToTuple(outputs)
        outputs = l.FProp(theta[name], *outputs)
      return outputs

    num_cells = len(p.cell_tpl)
    cluster = self.cluster

    # Compute shapes of input and output tenors.
    input_tenors = _ToTuple(args)
    mini_batch_size = input_tenors[0].get_shape().as_list()[p.batch_dim]
    if p.state_dtype:
      state_dtype = p.state_dtype
    else:
      state_dtype = input_tenors[0].dtype
    if p.num_micro_batches > mini_batch_size:
      p.num_micro_batches = mini_batch_size
    micro_batch_size = mini_batch_size // p.num_micro_batches

    input_shapes = ()
    for input_tensor in input_tenors:
      if input_tensor is not None:
        input_shape = input_tensor.get_shape().as_list()
        input_shape[p.batch_dim] = micro_batch_size
        input_shapes += (tf.TensorShape(input_shape),)
      else:
        input_shapes += (None,)

    state_shapes = self._CalculateOutputShapes(input_shapes)

    def GetCellFn(i):
      """Get the ith feature extraction layer."""

      def CellFn(theta, state0, inputs):
        """A cell fn is exectued inside of StackedRecurrent."""
        del state0
        frop_inputs = []
        for input_idx in range(len(state_shapes[i])):
          name = 's{}'.format(input_idx)
          if state_shapes[i][input_idx] is not None:
            inputs[name].set_shape(state_shapes[i][input_idx])
            frop_inputs.append(inputs[name])
          else:
            frop_inputs.append(None)

        with CellFnFropOpReplacementWrapper():
          tf.logging.info('cell {} input {}'.format(i, frop_inputs))
          mb_tensor = inputs[_MICRO_BATCH_STATE_NAME]
          SetOverWriteGlobalStep(mb_tensor)
          _, cell = self._cells[i]
          outputs = cell.FProp(theta, *frop_inputs)

        state1 = py_utils.NestedMap()
        state1[_MICRO_BATCH_STATE_NAME] = mb_tensor
        outputs = _ToTuple(outputs)
        assert len(outputs) == len(state_shapes[i + 1])
        for output_idx in range(len(outputs)):
          if outputs[output_idx] is not None:
            name = 's{}'.format(output_idx)
            state1[name] = outputs[output_idx]
        return state1, py_utils.NestedMap()

      return CellFn

    cell_fns = []
    accumulator_layers = []
    thetas = []
    init_states = []
    devices = []
    for cell_idx in range(num_cells):
      cell_name, cell = self._cells[cell_idx]
      accumulator_layers.append(cell)
      cell_fns.append(GetCellFn(cell_idx))
      thetas.append(theta[cell_name])
      init_state = py_utils.NestedMap()
      init_state[_MICRO_BATCH_STATE_NAME] = tf.cast(0, dtype=state_dtype)
      for output_idx in range(len(state_shapes[cell_idx + 1])):
        name = 's{}'.format(output_idx)
        if state_shapes[cell_idx + 1][output_idx] is not None:
          init_state[name] = tf.zeros(
              state_shapes[cell_idx + 1][output_idx], dtype=state_dtype)
      init_states.append(init_state)
      devices.append(cluster.WorkerDeviceInModelSplit(cell_idx))

    cell_grads = [None] * num_cells
    cell_outs = [lambda x: x] * num_cells
    cell_out_grads = [lambda x: x] * num_cells

    with tf.device(devices[0]):
      previous = input_tenors
      for (name, l) in self._before_layers:
        previous = l.FProp(theta[name], *previous)
        previous = _ToTuple(previous)
      inputs = py_utils.NestedMap()
      gs_tensor = py_utils.GetGlobalStep()
      inputs[_MICRO_BATCH_STATE_NAME] = tf.stack([
          tf.cast(gs_tensor * p.num_micro_batches + t, dtype=state_dtype)
          for t in range(p.num_micro_batches)
      ])

      # TODO(huangyp, dehao): apply dehao's trick to reshape the input tensor
      # to [p.num_micro_batches, -1, 128].
      for output_idx, output_tenor in enumerate(previous):
        name = 's{}'.format(output_idx)
        if output_tenor is not None:
          output_tenor = tf.stack(
              tf.split(output_tenor, p.num_micro_batches, axis=p.batch_dim))
          inputs[name] = output_tenor

    output, _ = recurrent.StackedRecurrent(
        devices=devices,
        cell_fns=cell_fns,
        cell_grads=cell_grads,
        cell_outs=cell_outs,
        cell_out_grads=cell_out_grads,
        thetas=thetas,
        init_states=init_states,
        inputs=inputs,
        accumulator_layers=accumulator_layers,
        unused_acc_state=True)

    with tf.device(devices[-1]):
      output_tensors = []
      for output_idx in range(len(state_shapes[-1])):
        state_shape = state_shapes[-1][output_idx]
        if state_shape is None:
          output_tensors.append(None)
          continue
        output_name = 's{}'.format(output_idx)
        output_tensor = output[output_name]
        if p.batch_dim != 0:
          perm = list(range(1, p.batch_dim + 1)) + [0]
          perm += list(range(p.batch_dim + 1, len(state_shape) + 1))
          output_tensor = tf.transpose(output_tensor, perm=perm)
        state_shape[p.batch_dim] *= p.num_micro_batches
        output_tensor = tf.reshape(output_tensor, state_shape)
        output_tensors.append(output_tensor)
      tf.logging.info('pipeline output = {}'.format(output_tensors))
      if len(output_tensors) == 1:
        return output_tensors[0]
      return tuple(output_tensors)
Пример #3
0
  def FProp(self, theta, *args):
    """Run multiple cells in different devices in a pipelining manner.

    Args:
      theta: A NestedMap object containing weights' values of this layer and its
        children layers.
      *args: Non-keyworded variable length argument list of input tensors.

    Returns:
      A list of output tensors
    """
    # TODO(huangyp): handle optional None inputs.
    p = self.params
    if self.do_eval:
      outputs = copy.copy(args)
      for (name, l) in self._before_layers + self._cells:
        outputs = _ToTuple(outputs)
        outputs = l.FProp(theta[name], *outputs)
      return outputs

    num_cells = len(p.cell_tpl)
    cluster = self.cluster

    # Compute shapes of input and output tensors.
    input_shapes = self._get_input_shapes(*args)
    state_dtype = self._get_state_dtype(*args)
    state_shapes = self._CalculateOutputShapes(input_shapes)
    tf.logging.info('state_shapes={}'.format(state_shapes))

    def GetCellFn(i):
      """Get the ith feature extraction layer."""

      def CellFn(theta, state0, inputs):
        """A cell fn is exectued inside of StackedRecurrent."""
        del state0

        def _FPropInputSetShape(name, t_shape):
          if t_shape is None:
            return None
          inputs[name].set_shape(t_shape.ToTensorShape().as_list())
          return inputs[name]

        if p.nested_map_fprop:
          # pylint: disable=protected-access
          fprop_inputs = state_shapes[i]._RecursiveMap(_FPropInputSetShape)
          # pylint: enable=protected-access
        else:
          fprop_inputs = []
          for input_idx, input_shape in enumerate(state_shapes[i]):
            name = 's{}'.format(input_idx)
            fprop_inputs.append(_FPropInputSetShape(name, input_shape))

        with py_utils.RemoveAssertContext(remove=True):
          with CellFnFPropOpReplacementWrapper():
            tf.logging.info('cell {} input {}'.format(i, fprop_inputs))
            mb_tensor = inputs[_MICRO_BATCH_STATE_NAME]
            SetOverWriteGlobalStep(mb_tensor)
            _, cell = self._cells[i]
            fprop_inputs = _ToTuple(fprop_inputs)
            outputs = cell.FProp(theta, *fprop_inputs)

        if p.nested_map_fprop:
          assert py_utils.IsCompatible(outputs, state_shapes[i + 1])
          state1 = outputs.Filter(lambda x: x is not None)
        else:
          state1 = py_utils.NestedMap()
          outputs = _ToTuple(outputs)
          assert len(outputs) == len(state_shapes[i + 1])
          for output_idx in range(len(outputs)):
            if outputs[output_idx] is not None:
              name = 's{}'.format(output_idx)
              state1[name] = outputs[output_idx]
        state1[_MICRO_BATCH_STATE_NAME] = mb_tensor
        return state1, py_utils.NestedMap()

      return CellFn

    cell_fns = []
    accumulator_layers = []
    thetas = []
    init_states = []
    devices = []
    for cell_idx in range(num_cells):
      cell_name, cell = self._cells[cell_idx]
      accumulator_layers.append(cell)
      cell_fns.append(GetCellFn(cell_idx))
      thetas.append(theta[cell_name])

      def _TfZeros(t_shape):
        if t_shape is None:
          return None
        return tf.zeros(t_shape.ToTensorShape().as_list(), dtype=state_dtype)

      if p.nested_map_fprop:
        init_state = py_utils.Transform(_TfZeros, state_shapes[cell_idx + 1])
        init_state = init_state.Filter(lambda x: x is not None)
      else:
        init_state = py_utils.NestedMap()
        for output_idx, state in enumerate(state_shapes[cell_idx + 1]):
          state = _TfZeros(state)
          if state is not None:
            name = 's{}'.format(output_idx)
            init_state[name] = state
      init_state[_MICRO_BATCH_STATE_NAME] = tf.cast(0, dtype=state_dtype)
      init_states.append(init_state)

      devices.append(cluster.WorkerDeviceInModelSplit(cell_idx))

    cell_grads = [None] * num_cells
    cell_outs = [lambda x: x] * num_cells
    cell_out_grads = [lambda x: x] * num_cells

    with tf.device(devices[0]):
      previous = _ToTuple(args)
      for (name, l) in self._before_layers:
        previous = l.FProp(theta[name], *previous)
        previous = _ToTuple(previous)

      def _StackAndSplit(x):
        # Split tensors into microbatches.
        if x is None:
          return None
        return tf.stack(tf.split(x, p.num_micro_batches, axis=p.batch_dim))

      if p.nested_map_fprop:
        inputs = py_utils.Transform(_StackAndSplit, previous[0])
        inputs = inputs.Filter(lambda x: x is not None)
      else:
        inputs = py_utils.NestedMap()
        for output_idx, output_tensor in enumerate(previous):
          output_tensor = _StackAndSplit(output_tensor)
          if output_tensor is not None:
            name = 's{}'.format(output_idx)
            inputs[name] = output_tensor
      gs_tensor = py_utils.GetGlobalStep()
      inputs[_MICRO_BATCH_STATE_NAME] = tf.stack([
          tf.cast(gs_tensor * p.num_micro_batches + t, dtype=state_dtype)
          for t in range(p.num_micro_batches)
      ])
    tf.logging.info('pipeline input = {}'.format(inputs))
    output_state, _ = recurrent.StackedRecurrent(
        devices=devices,
        cell_fns=cell_fns,
        cell_grads=cell_grads,
        cell_outs=cell_outs,
        cell_out_grads=cell_out_grads,
        thetas=thetas,
        init_states=init_states,
        inputs=inputs,
        accumulator_layers=accumulator_layers,
        unused_acc_state=True)

    with tf.device(devices[-1]):

      def _ReshapeRetVal(name, t_shape):
        """Restore shape for tensors in microbatches."""
        if t_shape is None:
          return None
        output_tensor = output_state[name]
        if p.batch_dim != 0:
          perm = list(range(1, p.batch_dim + 1)) + [0]
          perm += list(range(p.batch_dim + 1, t_shape.rank + 1))
          output_tensor = tf.transpose(output_tensor, perm=perm)
        output_shape = t_shape.ToTensorShape().as_list()
        output_shape[p.batch_dim] *= p.num_micro_batches
        output_tensor = tf.reshape(output_tensor, output_shape)
        return output_tensor

      # Construct the final return values from output_state.
      if p.nested_map_fprop:
        # pylint: disable=protected-access
        output_tensors = state_shapes[-1]._RecursiveMap(_ReshapeRetVal)
        # pylint: enable=protected-access
      else:
        output_tensors = []
        for output_idx, state_shape in enumerate(state_shapes[-1]):
          output_name = 's{}'.format(output_idx)
          output_tensor = _ReshapeRetVal(output_name, state_shape)
          output_tensors.append(output_tensor)
        if len(output_tensors) == 1:
          output_tensors = output_tensors[0]
        else:
          output_tensors = tuple(output_tensors)
      tf.logging.info('pipeline output = {}'.format(output_tensors))
      return output_tensors
Пример #4
0
    def _BuildStackedRecurrentElman(self, seqlen, trailing_pad_len, batch,
                                    dims, layers):
        tf.set_random_seed(342462)
        np.random.seed(32540)

        seqlen += trailing_pad_len
        dtype = tf.float64

        def CreateTheta():
            return py_utils.NestedMap(
                w=tf.constant(np.random.uniform(0, 0.2, (2 * dims, dims)),
                              dtype=dtype),
                b=tf.constant(np.random.uniform(0, 0.2, (dims, )),
                              dtype=dtype))

        def CreateState0():
            return py_utils.NestedMap(h=tf.constant(np.random.uniform(
                0, 0.2, (batch, dims)),
                                                    dtype=dtype),
                                      padding=tf.constant([[0]] * batch,
                                                          dtype=dtype))

        devices = ['/cpu:0'] * layers
        cell_fns = [self.Elman] * layers
        cell_grads = [self.ElmanGrad] * layers
        cell_outs = [self.ElmanOut] * layers
        cell_out_grads = [self.ElmanOutGrad] * layers
        thetas = [CreateTheta() for _ in range(layers)]
        init_states = [CreateState0() for _ in range(layers)]
        padding = np.zeros((seqlen, batch, 1))
        padding[-trailing_pad_len:, :, :] = 1.
        padding[-trailing_pad_len - 3:-trailing_pad_len - 1, :, :] = 1.
        inputs = py_utils.NestedMap(x=tf.constant(np.random.uniform(
            0, 0.2, (seqlen, batch, dims)),
                                                  dtype=dtype),
                                    padding=tf.constant(padding, dtype=dtype))
        output, _ = recurrent.StackedRecurrent(devices=devices,
                                               cell_fns=cell_fns,
                                               cell_grads=cell_grads,
                                               cell_outs=cell_outs,
                                               cell_out_grads=cell_out_grads,
                                               thetas=thetas,
                                               init_states=init_states,
                                               inputs=inputs)
        o = output.x
        if 'padding' in inputs:
            o *= (1 - inputs.padding)
        loss = tf.reduce_sum(tf.square(o))

        xs = recurrent.Flatten(thetas + [py_utils.NestedMap(x=inputs.x)])
        dxs = tf.gradients(ys=loss, xs=xs)

        # Reference implementation using Recurrent().
        ref = inputs
        for i in range(layers):
            ref = self.ElmanOut(
                recurrent.Recurrent(cell_fn=cell_fns[i],
                                    cell_grad=cell_grads[i],
                                    theta=thetas[i],
                                    state0=init_states[i],
                                    inputs=ref)[0])
        return ref.x, output.x, loss, xs, dxs