def testSimpleStacked(self): g = tf.Graph() with g.as_default(): devices = ['/cpu:0'] * 3 cell_fns = [self.Poly, self.Identity, self.Identity] cell_grads = [None] * 3 cell_outs = [lambda x: x] * 3 cell_out_grads = [lambda x: x] * 3 w0 = tf.constant(2.) w1 = tf.constant(0.) w2 = tf.constant(0.) thetas = [ py_utils.NestedMap(x=w0), py_utils.NestedMap(x=w1), py_utils.NestedMap(x=w2) ] init_states = [py_utils.NestedMap(s=tf.constant(0.))] * 3 inputs = py_utils.NestedMap(c=tf.constant([1., 2., 1., 0.]), padding=tf.constant([0., 0., 0., 1.])) output, _ = recurrent.StackedRecurrent( devices=devices, cell_fns=cell_fns, cell_grads=cell_grads, cell_outs=cell_outs, cell_out_grads=cell_out_grads, thetas=thetas, init_states=init_states, inputs=inputs) dw0, dw1, dw2 = tf.gradients(tf.reduce_sum(output.s), [w0, w1, w2]) with self.session(graph=g) as sess: (output, dw0, dw1, dw2) = sess.run([output.s, dw0, dw1, dw2]) self.assertAllClose(output, [1., 4., 9., 0.]) self.assertAllClose(dw2, 0.) self.assertAllClose(dw1, 0.) self.assertAllClose(dw0, 7.)
def FProp(self, theta, *args): """Run multiple cells in different devices in a pipelining manner. Args: theta: A NestedMap object containing weights' values of this layer and its children layers. *args: Non-keyworded variable length argument list of input tensors. Returns: A list of output tensors """ # TODO(huangyp): handle optional None inputs. p = self.params if p.is_eval: outputs = _ToTuple(args) for (name, l) in self._before_layers: outputs = _ToTuple(outputs) outputs = l.FProp(theta[name], *outputs) for (name, l) in self._cells: outputs = _ToTuple(outputs) outputs = l.FProp(theta[name], *outputs) return outputs num_cells = len(p.cell_tpl) cluster = self.cluster # Compute shapes of input and output tenors. input_tenors = _ToTuple(args) mini_batch_size = input_tenors[0].get_shape().as_list()[p.batch_dim] if p.state_dtype: state_dtype = p.state_dtype else: state_dtype = input_tenors[0].dtype if p.num_micro_batches > mini_batch_size: p.num_micro_batches = mini_batch_size micro_batch_size = mini_batch_size // p.num_micro_batches input_shapes = () for input_tensor in input_tenors: if input_tensor is not None: input_shape = input_tensor.get_shape().as_list() input_shape[p.batch_dim] = micro_batch_size input_shapes += (tf.TensorShape(input_shape),) else: input_shapes += (None,) state_shapes = self._CalculateOutputShapes(input_shapes) def GetCellFn(i): """Get the ith feature extraction layer.""" def CellFn(theta, state0, inputs): """A cell fn is exectued inside of StackedRecurrent.""" del state0 frop_inputs = [] for input_idx in range(len(state_shapes[i])): name = 's{}'.format(input_idx) if state_shapes[i][input_idx] is not None: inputs[name].set_shape(state_shapes[i][input_idx]) frop_inputs.append(inputs[name]) else: frop_inputs.append(None) with CellFnFropOpReplacementWrapper(): tf.logging.info('cell {} input {}'.format(i, frop_inputs)) mb_tensor = inputs[_MICRO_BATCH_STATE_NAME] SetOverWriteGlobalStep(mb_tensor) _, cell = self._cells[i] outputs = cell.FProp(theta, *frop_inputs) state1 = py_utils.NestedMap() state1[_MICRO_BATCH_STATE_NAME] = mb_tensor outputs = _ToTuple(outputs) assert len(outputs) == len(state_shapes[i + 1]) for output_idx in range(len(outputs)): if outputs[output_idx] is not None: name = 's{}'.format(output_idx) state1[name] = outputs[output_idx] return state1, py_utils.NestedMap() return CellFn cell_fns = [] accumulator_layers = [] thetas = [] init_states = [] devices = [] for cell_idx in range(num_cells): cell_name, cell = self._cells[cell_idx] accumulator_layers.append(cell) cell_fns.append(GetCellFn(cell_idx)) thetas.append(theta[cell_name]) init_state = py_utils.NestedMap() init_state[_MICRO_BATCH_STATE_NAME] = tf.cast(0, dtype=state_dtype) for output_idx in range(len(state_shapes[cell_idx + 1])): name = 's{}'.format(output_idx) if state_shapes[cell_idx + 1][output_idx] is not None: init_state[name] = tf.zeros( state_shapes[cell_idx + 1][output_idx], dtype=state_dtype) init_states.append(init_state) devices.append(cluster.WorkerDeviceInModelSplit(cell_idx)) cell_grads = [None] * num_cells cell_outs = [lambda x: x] * num_cells cell_out_grads = [lambda x: x] * num_cells with tf.device(devices[0]): previous = input_tenors for (name, l) in self._before_layers: previous = l.FProp(theta[name], *previous) previous = _ToTuple(previous) inputs = py_utils.NestedMap() gs_tensor = py_utils.GetGlobalStep() inputs[_MICRO_BATCH_STATE_NAME] = tf.stack([ tf.cast(gs_tensor * p.num_micro_batches + t, dtype=state_dtype) for t in range(p.num_micro_batches) ]) # TODO(huangyp, dehao): apply dehao's trick to reshape the input tensor # to [p.num_micro_batches, -1, 128]. for output_idx, output_tenor in enumerate(previous): name = 's{}'.format(output_idx) if output_tenor is not None: output_tenor = tf.stack( tf.split(output_tenor, p.num_micro_batches, axis=p.batch_dim)) inputs[name] = output_tenor output, _ = recurrent.StackedRecurrent( devices=devices, cell_fns=cell_fns, cell_grads=cell_grads, cell_outs=cell_outs, cell_out_grads=cell_out_grads, thetas=thetas, init_states=init_states, inputs=inputs, accumulator_layers=accumulator_layers, unused_acc_state=True) with tf.device(devices[-1]): output_tensors = [] for output_idx in range(len(state_shapes[-1])): state_shape = state_shapes[-1][output_idx] if state_shape is None: output_tensors.append(None) continue output_name = 's{}'.format(output_idx) output_tensor = output[output_name] if p.batch_dim != 0: perm = list(range(1, p.batch_dim + 1)) + [0] perm += list(range(p.batch_dim + 1, len(state_shape) + 1)) output_tensor = tf.transpose(output_tensor, perm=perm) state_shape[p.batch_dim] *= p.num_micro_batches output_tensor = tf.reshape(output_tensor, state_shape) output_tensors.append(output_tensor) tf.logging.info('pipeline output = {}'.format(output_tensors)) if len(output_tensors) == 1: return output_tensors[0] return tuple(output_tensors)
def FProp(self, theta, *args): """Run multiple cells in different devices in a pipelining manner. Args: theta: A NestedMap object containing weights' values of this layer and its children layers. *args: Non-keyworded variable length argument list of input tensors. Returns: A list of output tensors """ # TODO(huangyp): handle optional None inputs. p = self.params if self.do_eval: outputs = copy.copy(args) for (name, l) in self._before_layers + self._cells: outputs = _ToTuple(outputs) outputs = l.FProp(theta[name], *outputs) return outputs num_cells = len(p.cell_tpl) cluster = self.cluster # Compute shapes of input and output tensors. input_shapes = self._get_input_shapes(*args) state_dtype = self._get_state_dtype(*args) state_shapes = self._CalculateOutputShapes(input_shapes) tf.logging.info('state_shapes={}'.format(state_shapes)) def GetCellFn(i): """Get the ith feature extraction layer.""" def CellFn(theta, state0, inputs): """A cell fn is exectued inside of StackedRecurrent.""" del state0 def _FPropInputSetShape(name, t_shape): if t_shape is None: return None inputs[name].set_shape(t_shape.ToTensorShape().as_list()) return inputs[name] if p.nested_map_fprop: # pylint: disable=protected-access fprop_inputs = state_shapes[i]._RecursiveMap(_FPropInputSetShape) # pylint: enable=protected-access else: fprop_inputs = [] for input_idx, input_shape in enumerate(state_shapes[i]): name = 's{}'.format(input_idx) fprop_inputs.append(_FPropInputSetShape(name, input_shape)) with py_utils.RemoveAssertContext(remove=True): with CellFnFPropOpReplacementWrapper(): tf.logging.info('cell {} input {}'.format(i, fprop_inputs)) mb_tensor = inputs[_MICRO_BATCH_STATE_NAME] SetOverWriteGlobalStep(mb_tensor) _, cell = self._cells[i] fprop_inputs = _ToTuple(fprop_inputs) outputs = cell.FProp(theta, *fprop_inputs) if p.nested_map_fprop: assert py_utils.IsCompatible(outputs, state_shapes[i + 1]) state1 = outputs.Filter(lambda x: x is not None) else: state1 = py_utils.NestedMap() outputs = _ToTuple(outputs) assert len(outputs) == len(state_shapes[i + 1]) for output_idx in range(len(outputs)): if outputs[output_idx] is not None: name = 's{}'.format(output_idx) state1[name] = outputs[output_idx] state1[_MICRO_BATCH_STATE_NAME] = mb_tensor return state1, py_utils.NestedMap() return CellFn cell_fns = [] accumulator_layers = [] thetas = [] init_states = [] devices = [] for cell_idx in range(num_cells): cell_name, cell = self._cells[cell_idx] accumulator_layers.append(cell) cell_fns.append(GetCellFn(cell_idx)) thetas.append(theta[cell_name]) def _TfZeros(t_shape): if t_shape is None: return None return tf.zeros(t_shape.ToTensorShape().as_list(), dtype=state_dtype) if p.nested_map_fprop: init_state = py_utils.Transform(_TfZeros, state_shapes[cell_idx + 1]) init_state = init_state.Filter(lambda x: x is not None) else: init_state = py_utils.NestedMap() for output_idx, state in enumerate(state_shapes[cell_idx + 1]): state = _TfZeros(state) if state is not None: name = 's{}'.format(output_idx) init_state[name] = state init_state[_MICRO_BATCH_STATE_NAME] = tf.cast(0, dtype=state_dtype) init_states.append(init_state) devices.append(cluster.WorkerDeviceInModelSplit(cell_idx)) cell_grads = [None] * num_cells cell_outs = [lambda x: x] * num_cells cell_out_grads = [lambda x: x] * num_cells with tf.device(devices[0]): previous = _ToTuple(args) for (name, l) in self._before_layers: previous = l.FProp(theta[name], *previous) previous = _ToTuple(previous) def _StackAndSplit(x): # Split tensors into microbatches. if x is None: return None return tf.stack(tf.split(x, p.num_micro_batches, axis=p.batch_dim)) if p.nested_map_fprop: inputs = py_utils.Transform(_StackAndSplit, previous[0]) inputs = inputs.Filter(lambda x: x is not None) else: inputs = py_utils.NestedMap() for output_idx, output_tensor in enumerate(previous): output_tensor = _StackAndSplit(output_tensor) if output_tensor is not None: name = 's{}'.format(output_idx) inputs[name] = output_tensor gs_tensor = py_utils.GetGlobalStep() inputs[_MICRO_BATCH_STATE_NAME] = tf.stack([ tf.cast(gs_tensor * p.num_micro_batches + t, dtype=state_dtype) for t in range(p.num_micro_batches) ]) tf.logging.info('pipeline input = {}'.format(inputs)) output_state, _ = recurrent.StackedRecurrent( devices=devices, cell_fns=cell_fns, cell_grads=cell_grads, cell_outs=cell_outs, cell_out_grads=cell_out_grads, thetas=thetas, init_states=init_states, inputs=inputs, accumulator_layers=accumulator_layers, unused_acc_state=True) with tf.device(devices[-1]): def _ReshapeRetVal(name, t_shape): """Restore shape for tensors in microbatches.""" if t_shape is None: return None output_tensor = output_state[name] if p.batch_dim != 0: perm = list(range(1, p.batch_dim + 1)) + [0] perm += list(range(p.batch_dim + 1, t_shape.rank + 1)) output_tensor = tf.transpose(output_tensor, perm=perm) output_shape = t_shape.ToTensorShape().as_list() output_shape[p.batch_dim] *= p.num_micro_batches output_tensor = tf.reshape(output_tensor, output_shape) return output_tensor # Construct the final return values from output_state. if p.nested_map_fprop: # pylint: disable=protected-access output_tensors = state_shapes[-1]._RecursiveMap(_ReshapeRetVal) # pylint: enable=protected-access else: output_tensors = [] for output_idx, state_shape in enumerate(state_shapes[-1]): output_name = 's{}'.format(output_idx) output_tensor = _ReshapeRetVal(output_name, state_shape) output_tensors.append(output_tensor) if len(output_tensors) == 1: output_tensors = output_tensors[0] else: output_tensors = tuple(output_tensors) tf.logging.info('pipeline output = {}'.format(output_tensors)) return output_tensors
def _BuildStackedRecurrentElman(self, seqlen, trailing_pad_len, batch, dims, layers): tf.set_random_seed(342462) np.random.seed(32540) seqlen += trailing_pad_len dtype = tf.float64 def CreateTheta(): return py_utils.NestedMap( w=tf.constant(np.random.uniform(0, 0.2, (2 * dims, dims)), dtype=dtype), b=tf.constant(np.random.uniform(0, 0.2, (dims, )), dtype=dtype)) def CreateState0(): return py_utils.NestedMap(h=tf.constant(np.random.uniform( 0, 0.2, (batch, dims)), dtype=dtype), padding=tf.constant([[0]] * batch, dtype=dtype)) devices = ['/cpu:0'] * layers cell_fns = [self.Elman] * layers cell_grads = [self.ElmanGrad] * layers cell_outs = [self.ElmanOut] * layers cell_out_grads = [self.ElmanOutGrad] * layers thetas = [CreateTheta() for _ in range(layers)] init_states = [CreateState0() for _ in range(layers)] padding = np.zeros((seqlen, batch, 1)) padding[-trailing_pad_len:, :, :] = 1. padding[-trailing_pad_len - 3:-trailing_pad_len - 1, :, :] = 1. inputs = py_utils.NestedMap(x=tf.constant(np.random.uniform( 0, 0.2, (seqlen, batch, dims)), dtype=dtype), padding=tf.constant(padding, dtype=dtype)) output, _ = recurrent.StackedRecurrent(devices=devices, cell_fns=cell_fns, cell_grads=cell_grads, cell_outs=cell_outs, cell_out_grads=cell_out_grads, thetas=thetas, init_states=init_states, inputs=inputs) o = output.x if 'padding' in inputs: o *= (1 - inputs.padding) loss = tf.reduce_sum(tf.square(o)) xs = recurrent.Flatten(thetas + [py_utils.NestedMap(x=inputs.x)]) dxs = tf.gradients(ys=loss, xs=xs) # Reference implementation using Recurrent(). ref = inputs for i in range(layers): ref = self.ElmanOut( recurrent.Recurrent(cell_fn=cell_fns[i], cell_grad=cell_grads[i], theta=thetas[i], state0=init_states[i], inputs=ref)[0]) return ref.x, output.x, loss, xs, dxs