def _ProcessMASSInput(self, source_id, src): """Perform MASS input processing.""" # TODO(yuancao): By doing so we assume that right now for monolingual # eval/dev sets (xx->xx) are in double-column format (since it bypasses # the Mass op). Ideally we should add a dedicated eval/dev processing # procedure for unsupervised MT cases, so that single-column eval/devs sets # are also supported. This should not be handled by any specific ops like # Mass, but inside the TextPackedInput class. assert not self.do_eval, 'MASS input can only be used for training.' _, labels, paddings = self.StringsToIds(tf.reshape(src, [1]), is_source=True, key=self._src_tokenizer_key) weights = 1 - paddings actual_seq_len = tf.cast(tf.reduce_sum(weights, 1), tf.int32) src_lang_ids, tgt_lang_ids = self._GetTaskIds(source_id) mass_out = self.mass_layer.Mask(labels, weights, actual_seq_len) features = py_utils.NestedMap() features.src = py_utils.NestedMap() features.src.ids = mass_out.src.ids features.src.paddings = paddings features.src.weights = weights features.src.task_ids = tf.cast(features.src.weights, dtype=tf.int32) * src_lang_ids features.src.ids_indicator = weights features.tgt = py_utils.NestedMap() features.tgt.ids = mass_out.tgt.ids features.tgt.labels = mass_out.tgt.labels features.tgt.paddings = paddings features.tgt.weights = mass_out.tgt.weights features.tgt.task_ids = tf.ones_like(features.src.task_ids, dtype=tf.int32) * tgt_lang_ids features.tgt.ids_indicator = weights if not py_utils.use_tpu(): features.src.strs = src features.tgt.strs = src return features.Transform(tf.squeeze)
def FProp(self, theta, input_batch): p = self.params with tf.name_scope(p.name): inputs = py_utils.with_dependencies([ py_utils.assert_shape_match(tf.shape(input_batch.ids), [-1, -1]), py_utils.assert_shape_match(tf.shape(input_batch.ids), tf.shape(input_batch.paddings)) ], tf.transpose(input_batch.ids)) paddings = tf.expand_dims(tf.transpose(input_batch.paddings), 2) if p.packed_input: src_segment_id = tf.expand_dims( tf.transpose(input_batch.segment_ids), 2) else: src_segment_id = None xs = self.emb.EmbLookup(theta.emb, inputs) xs = self.ApplyClipping(theta, xs) summary_utils.histogram('input_emb', xs) xs = self.dropout.FProp(theta.dropout, xs) ps = paddings # Now the rnn layers. outputs_list = [] for i in range(0, p.num_lstm_layers): layer = self.rnn[i] ys = layer.FProp(theta.rnn[i], xs, ps, segment_id=src_segment_id) ys = self.dropout.FProp(theta.dropout, ys) if i >= p.residual_start: xs += ys # Residual skip xs = self.ApplyClipping(theta, xs) else: xs = ys outputs_list.append(xs) summary_utils.histogram('layer_out_%s' % i, xs) if p.is_transparent: xs = self.transparent_merger.FProp(theta.transparent_merger, outputs_list) if p.lstm_cell_size * 2 != p.encoder_out_dim: # Project to the right depth. xs = self.final_proj.FProp(theta.final_proj, xs, ps) summary_utils.histogram('final_proj_out', xs) if src_segment_id is not None: src_segment_id = tf.squeeze(src_segment_id, [2]) return py_utils.NestedMap(encoded=xs, padding=tf.squeeze(ps, [2]), segment_id=src_segment_id)
def FProp(self, theta, prepared_inputs, step_inputs, padding, state0): """Produces a context vector from the attention algorithm. The context vector is a summary of the inputs from external_inputs which the attention algorithm has determined would be useful for decoding the next output. Args: theta: A NestedMap containing weights' values of this layer and its children layers. prepared_inputs: A set of encoded tensors that have been pre-processed by PrepareExternalInputs. step_inputs: A NestedMap containing an 'inputs' tensor with the query vector to use. padding: A [batch, 1] 0/1 float tensor, where 1.0 means that this batch slot is not used. state0: A NestedMap of state, either produced by ZeroState or a previous invocation of this graph. Returns: output, state1, defined as follows: - output: a NestedMap containing a query tensor, a context tensor, and cum_atten_probs, the log of attention probabilities for each input vector. - state1: a NestedMap of state to be used in subsequent invocations of this graph. """ (new_atten_context, new_atten_probs, new_atten_states) = self.atten.ComputeContextVectorWithSource( theta.atten, prepared_inputs.packed_src, tf.concat(step_inputs.inputs, axis=1), attention_state=state0.atten_state) new_atten_probs = py_utils.ApplyPadding(padding, new_atten_probs) output = py_utils.NestedMap( context=new_atten_context, probs=new_atten_probs) state1 = py_utils.NestedMap( atten_context=new_atten_context, atten_state=new_atten_states) return output, state1
def FnMeta(*shapes): """A lambda tuple(tshape.Shape) -> NestedMap{flops, out_shapes}.""" if fn_out: out_shapes = fn_out(*shapes) if isinstance(out_shapes, tshape.Shape): out_shapes = (out_shapes,) else: out_shapes = shapes if fn_flops: flops = fn_flops(*shapes) else: flops = sum([s.size for s in shapes]) return py_utils.NestedMap(flops=flops, out_shapes=out_shapes)
def ParseAndProcess(*cols): """Parses a Tensorflow example into features.""" # Assume either one or two column input. If one, then the record is # assumed to be that column. If 2, then it is assumed to be a KV store # and the record is the second column. assert len(cols) in [ 1, 2 ], ('BaseExampleInputGenerator supports one or two column input') record = cols[-1] feature_spec = self.GetFeatureSpec() features = py_utils.NestedMap( tf.io.parse_example(record, feature_spec)) return self._PreprocessInputBatch(features)
def FProp(self, theta, prepared_inputs, step_inputs, padding, state0): """Produces a query vector and a context vector for the next decoder step. Args: theta: A `.NestedMap` object containing weights' values of this layer and its children layers. prepared_inputs: A set of encoded tensors that have been pre-processed by PrepareExternalInputs. step_inputs: Unused. All of the input for this step comes from external_inputs and previous step state. padding: A [batch, 1] 0/1 float tensor, where 1.0 means that this batch slot is not used. state0: A NestedMap of state, either produced by ZeroState or a previous invocation of this graph. Returns: output, state1, are defined as follows. output, a NestedMap containing an atten_query tensor, an atten_context tensor, and atten_probs, attention probabilities for each input vector. state1, a NestedMap of state to be used in subsequent invocations of this graph. """ query_output, query_state1 = self.query_generator.FProp( theta.query_generator, prepared_inputs.query_generator, py_utils.NestedMap(inputs=[state0.atten_state.atten_context]), padding, state0.query_state) atten_input = py_utils.NestedMap(inputs=[query_output.output]) atten_output, atten_state1 = self.attention.FProp(theta.attention, prepared_inputs.attention, atten_input, padding, state0.atten_state) state1 = py_utils.NestedMap( atten_state=atten_state1, query_state=query_state1) return py_utils.NestedMap( atten_context=atten_output.context, atten_query=query_output.output, atten_probs=atten_output.probs), state1
def FPropMeta(cls, p, *args): py_utils.CheckShapes(args) input_shapes = [ None if arg is None else tshape.Shape(arg.get_shape().as_list()[1:]) for arg in args ] meta = p.body.cls.FPropMeta(p.body, *input_shapes) py_utils.CheckShapes(meta.out_shapes) total = meta.flops * p.repeat out_shapes = [ None if s is None else tshape.Shape([p.repeat] + s[:]) for s in meta.out_shapes ] return py_utils.NestedMap(flops=total, out_shapes=tuple(out_shapes))
def _ConsumeMap(self): """Return the NestedMap that starts at the current position, and increment. Returns: The NestedMap that starts at the current token position. """ if self._i >= len(self._tokens): raise ValueError('Ran out of tokens while looking for a NestedMap.') if self._tokens[self._i] != '(': raise ValueError('Expected ( at token position %d' % (self._i)) self._i += 1 if self._MaybeConsumeSymbol(')'): # Empty NestedMaps are allowed. return py_utils.NestedMap() result = py_utils.NestedMap() while self._i < len(self._tokens): name = self._ConsumeKey() self._ConsumeSymbol('=') result[name] = self._ConsumeItem() if self._MaybeConsumeSymbol(')'): return result self._ConsumeSymbol(',') raise ValueError('Ran out of tokens while looking for end of NestedMap.')
def _BuildDataSourceWithMetadata(self, task_id=None): """Read and return input batch from `p.file_pattern`. `p.file_pattern` may be a string file_pattern or a list of (file_pattern, weight, [bprop_variable_filter]) tuples. bprop_variable_filter is optional. When bprop_variable_filter is used, batches will always contain the examples from the same source. Otherwise, examples from different sources may be mixed together. Args: task_id: Host index for partitioning input shards. Returns: A `.NestedMap` containing - data: a tuple of tf.Tensor or `.NestedMap` of tf.Tensor same as `self._DataSourceFromFilePattern()` - source_selected: a tensor of size [batch_size, number of data sources] or None. - selected_bprop: a tensor of size [number of data sources] or None. - bprop_variable_filters: a list of bprop_variable filters for each source or None. Raises: ValueError: If file_datasource is not set """ p = self.params if not p.file_datasource and p.file_pattern: # This is a workaround for subclasses which have defined # their own data source-like functionality. tf.logging.info( 'Creating data source-like output from class %s using ' 'file_pattern %s', self, p.file_pattern) ret = py_utils.NestedMap() ret.data = self._DataSourceFromFilePattern(p.file_pattern, task_id=task_id) else: tf.logging.info( 'Building data source %s with params %s and ' 'file_pattern %s', self.datasource, self.datasource.params, p.file_pattern) ret = self.datasource.BuildDataSource( self._DataSourceFromFilePattern, task_id=task_id) #, task_id=task_id) if 'selected_bprop' in ret: self._bprop_onehot = ret.selected_bprop if 'bprop_variable_filters' in ret: self._bprop_variable_filters = ret.bprop_variable_filters if 'source_selected' not in ret: ret.source_selected = None return ret
def FPropWithProjectedInput(self, theta, state0, inputs): """FProp with inputs already projected. This method is for parallelizing the input projection across time steps to accelerate training. The following are equivalent: >>> inputs = <a tensor of [T, B, D]> >>> paddings = tf.zeros([T, B]) >>> theta = cell.theta >>> state = cell.zero_state(theta, B) # a. Use FProp(). >>> for i in range(T): ... state, _ = cell.FProp(theta, inputs[i, :, :], paddings, state) # b. Use FPropWithProjectedInput(). >>> proj_inputs = cell.ProjectInputSequence(theta, inputs) >>> for i in range(T): ... state, _ = cell.FPropWithProjectedInputs( ... theta, proj_inputs[i, :, :], paddings, state) Args: theta: a NestedMap of layer weights. Notably, it's expected to contain separate weight tensors for input and hidden state projections, for performance reasons, under the key 'wm_i' (input) and 'wm_h' (hidden state). state0: A NestedMap with the same structure as return value of `self.zero_state()`. inputs: A NestedMap with the following fields: - proj_inputs: A single Tensors of shape [batch, 4 * hidden_dim]. - padding: A Tensor of shape [batch, 1]. - reset_mask: A Tensor of shape [batch, 1]. Returns: state1: A NestedMap of the same structure as `state0`. extras: Intermediate results to facilitate backprop. A NestedMap. """ if self.params.reset_cell_state: state0_modified = self._ResetState(state0.DeepCopy(), inputs) else: state0_modified = state0 xmw = self._MixWithProjectedInput(theta, state0_modified, inputs.proj_inputs) gates_input = inputs.copy() gates_input.act = [inputs.proj_inputs] state1 = self._Gates(xmw, theta, state0_modified, gates_input) return state1, py_utils.NestedMap()
def PrepareExternalInputs(self, theta, external_inputs): """Returns the prepared external inputs, e.g., packed_src for attention.""" if not external_inputs: external_inputs = py_utils.NestedMap() packed = external_inputs.DeepCopy() for name, child in six.iteritems(self.children): child_external_inputs = external_inputs.get( name, py_utils.NestedMap()) if isinstance(child, (tuple, list)): output = [] for i, sub in enumerate(child): if isinstance(sub, Step): output.append( sub.PrepareExternalInputs(theta[name][i], child_external_inputs)) if output: if len(output) != len(child): raise ValueError( 'Expecting child list to be instances of Step.') packed[name] = type(child)(output) elif isinstance(child, Step): packed[name] = child.PrepareExternalInputs( theta[name], child_external_inputs) return packed
def _common_gpipe_transformer_fprop_meta(p, inputs, *args): """GPipe FPropMeta function.""" # TODO(huangyp): return accurate estimate of flops. py_utils.CheckShapes((inputs, )) flops_per_element = 5 src_time, source_batch, dim = inputs flops = flops_per_element * src_time * src_time * source_batch * dim args = args if isinstance(args, tuple) else (args, ) if not p.has_aux_atten and p.is_transparent: # Transparent Encoder FPropMeta if p.transparent_merger_tpl is not None: args = args[:5] + ( inputs, tshape.Shape([p.transparent_merger_tpl.num_sources])) args = args[:6] + (tshape.Shape([args[6][0] - 1]), ) if p.final_enc_layer: args = args[:5] + (None, None) return py_utils.NestedMap(flops=flops, out_shapes=(inputs, ) + args)
def _DecodeStep(): """Decode call to be compiled for TPU.""" with py_utils.OpportunisticVariableReuseScope(True): with cluster_factory.SetEval(True): self._decode_model = self._decode_task_params.Instantiate() self._decode_model_task = self._decode_model.GetTask() self._decode_model_task.AddChild('input', self._decode_input) input_batch = self._decode_model_task.input_generator.TpuDequeueBatch( ) metrics_dict = self._decode_model_task.Decode(input_batch) self.metrics_nm = py_utils.NestedMap(metrics_dict) device = tpu.core(0) if self.spmd else '' with tf.device(device): outfeed_enqueue = tpu_ops.outfeed_enqueue_tuple( self.metrics_nm.Flatten()) return [outfeed_enqueue]
def FPropMeta(cls, p, *args): assert len(args) > p.num_act_inputs seq_args = args[:-p.num_act_inputs] if p.num_act_inputs > 0 else args extra_args = args[-p.num_act_inputs:] if p.num_act_inputs > 0 else () total = 0 act_fetch_metas = {} for sub in p.sub: meta = sub.cls.FPropMeta(sub, *seq_args) if sub.name in p.act_fetch_layers: act_fetch_metas[sub.name] = meta.out_shapes[0] total += meta.flops seq_args = meta.out_shapes for fetch_layer in p.act_fetch_layers: extra_args += (act_fetch_metas[fetch_layer], ) return py_utils.NestedMap(flops=total, out_shapes=seq_args + extra_args)
def ZeroState(self, theta, prepared_inputs, batch_size): """Creates a zero state NestedMap for this step. Args: theta: variables used by sub-steps. prepared_inputs: Output from a call to PrepareExternalInputs. batch_size: The number of items in the batch that FProp will process. Returns: A NestedMap of ZeroState results for each sub-step. """ state0 = py_utils.NestedMap() with tf.name_scope(self.params.name): for seq in self._seq: state0[seq.name] = seq.step.ZeroState( theta[seq.name], prepared_inputs[seq.name], batch_size) return state0
def BuildDataSource(self, data_source_from_file_pattern_fn): """Read and return input batch from p.file_patterns list weighted by p.weights. Examples in the batch will be mixed together from different file_pattern source proportionally to the weights. Args: data_source_from_file_pattern_fn: a function that takes file_pattern and input_source_weights as arguments and returns an input batch from a string file_pattern. Returns: A NestedMap containing: data: a tuple of tf.Tensor or `.NestedMap` of tf.Tensor Raises: ValueError: If unknown token type. """ p = self.params if not isinstance(p.file_patterns, list): raise ValueError('Expected a list, got %s' % (p.file_patterns, )) if not isinstance(p.weights, list): raise ValueError('Expected a list, got %s' % (p.weights, )) if len(p.file_patterns) != len(p.weights): raise ValueError( 'Expected p.file_patterns and p.weights to be the same length. ' 'Found %d file_patterns, and %d weights' % (len(p.file_patterns), len(p.weights))) # TODO(rosenberg) confirm that weights are numeric if not all(isinstance(x, six.string_types) for x in p.file_patterns): raise ValueError( 'Expected all elements of p.file_patterns to be strings') file_patterns = p.file_patterns weights = p.weights for file_pattern in file_patterns: if ',' in file_pattern: raise ValueError( 'Can not use commas in file_pattern when within-batch ' 'mixing is used. file_pattern: %s' % (file_pattern, )) ret = py_utils.NestedMap() ret.data = data_source_from_file_pattern_fn( ','.join(file_patterns), input_source_weights=weights) ret.bprop_variable_filters = [''] * len(file_patterns) return ret
def FProp(self, theta, prepared_inputs, step_inputs, padding, state0): """Returns a A single inference step for this step graph. Args: theta: unused. prepared_inputs: Output from a call to PrepareExternalInputs. step_inputs: unused. padding: unused. state0: A NestedMap of state variables produced by either ZeroState or a previous invocation of this FProp step. Returns: (output, state1), both of which are NestedMaps. output is implementation-dependent and is defined by the output_signature parameter. state1 is a NestedMap where the keys are names of sub-steps and the values are state outputs from their FProp methods. """ del theta del step_inputs del padding def _Slice(tensor): """Return a slice of this tensor at time=state0.t.""" shape = py_utils.GetShape(tensor) # All zeros except for t in the time dimension. # e.g. if params.axis=1, begin is [0, t, 0, 0, 0, ...] begin = tf.one_hot(self.params.axis, tf.rank(tensor), on_value=state0.t) # Same as shape, but with a 1 in the time dimension. # e.g. if params.axis=1, shape is [shape[0], 1, shape[2], shape[3], ...] size = tf.concat([ shape[0:self.params.axis], tf.constant([1], dtype=tf.int32), shape[self.params.axis + 1:] ], axis=0) # Make a slice where the time dimension is fixed at state0.t. time_slice = tf.slice(tensor, begin, size) # Remove the time dimension. return tf.squeeze(time_slice, axis=self.params.axis) output = prepared_inputs.Transform(_Slice) state1 = py_utils.NestedMap(t=state0.t + 1) return output, state1
def PrepareExternalInputs(self, theta, external_inputs): """Delegates external inputs preparation to sub-layers. Args: theta: A `.NestedMap` object containing weights' values of this layer and its children layers. external_inputs: A `.NestedMap` object. The structure of the internal fields is defined by the sub-steps. Returns: A `.NestedMap` containing a pre-processed version of the external_inputs, one per sub-step. """ packed = py_utils.NestedMap(sub=[]) for i in range(len(self.sub)): packed.sub.append(self.sub[i].PrepareExternalInputs( theta.sub[i], external_inputs)) return packed
def ZeroState(self, theta, prepared_inputs, batch_size): """Computes a zero state for each sub-step. Args: theta: A `.NestedMap` object containing weights' values of this layer and its children layers. prepared_inputs: An output from PrepareExternalInputs. batch_size: The number of items in the batch that FProp will process. Returns: A `.NestedMap` containing a state0 object for each sub-step. """ state = py_utils.NestedMap(sub=[]) for i in range(len(self.sub)): state.sub.append(self.sub[i].ZeroState(theta.sub[i], prepared_inputs, batch_size)) return state
def _CellFn(unused_theta, state0, theta_i): """Recurrent cell function wrapper of body.FProp.""" # Retrieves fprop arguments from state and sets shapes. frop_inputs = _StateToArgs(state0) # Sets shapes for theta_i as well. for dst, src in zip(theta_i.Flatten(), theta_stack.Flatten()): if src is not None: dst.set_shape(tf.TensorShape(src.shape.as_list()[1:])) # Runs the actual body.FProp frop_outputs = self.body.FProp(theta_i, *frop_inputs) frop_outputs = _ToTuple(frop_outputs) assert len(frop_outputs) == len(frop_inputs) # Passes fprop outputs to the next layer through state. state1 = _ArgsToState(frop_outputs) return state1, py_utils.NestedMap()
def FPropMeta(cls, p, inputs, *args): # TODO(ankurbpn): return accurate estimate of flops. py_utils.CheckShapes((inputs, )) flops_per_element = 2 # Is this correct? vocab = p.token_emb.vocab_size dim = p.token_emb.embedding_dim src_dim_0, src_dim_1 = inputs flops = flops_per_element * src_dim_0 * src_dim_1 * dim * vocab args = args if isinstance(args, tuple) else (args, ) new_inputs = tshape.Shape([src_dim_0, src_dim_1, dim]) new_args = list(args) if p.add_tgt_embedding_layer: tgt_dim_0, tgt_dim_1 = args[1] new_args[1] = tshape.Shape([tgt_dim_0, tgt_dim_1, dim]) if p.ret_task_ids: new_args = new_args[:5] + [None, None] + new_args[7:] else: new_args = new_args[:5] + [None, None] new_args = tuple(new_args) return py_utils.NestedMap(flops=flops, out_shapes=(new_inputs, ) + new_args)
def ZeroState(self, theta, prepared_inputs, batch_size): """Produce a zero state for this step. Args: theta: A `.NestedMap` object containing weights' values of this layer and its children layers. prepared_inputs: A set of inputs pre-processed by using PrepareExternalInputs. batch_size: Number of elements in the batched input. Returns: state0, a state parameter to pass to FProp on its first invocation. """ query_state0 = self.query_generator.ZeroState( theta.query_generator, prepared_inputs.query_generator, batch_size) atten_state0 = self.attention.ZeroState(theta.attention, prepared_inputs.attention, batch_size) state0 = py_utils.NestedMap( query_state=query_state0, atten_state=atten_state0) return state0
def FProp(self, theta, *args): p = self.params graph_tensors = self._fprop = GraphTensors() with tf.name_scope(p.name): if len(p.input_endpoints) != len(args): raise ValueError( 'Wrong number of inputs for {}: required={}, provided={}'.format( p.name, len(p.input_endpoints), len(args))) for n, t in zip(p.input_endpoints, args): if isinstance(t, py_utils.NestedMap): assert all(isinstance(x, tf.Tensor) for x in t.Flatten()), t else: assert isinstance(t, tf.Tensor) graph_tensors.StoreTensor(n, t) ch_out = None for i, (name, sig, ch) in enumerate(self._seq): th = theta[name] template = py_utils.NestedMap(inputs=sig.inputs) packed = template.Transform(graph_tensors.GetTensor) input_args = packed.inputs tf.logging.vlog(1, 'signature: %s', p.sub[i][0]) tf.logging.vlog(1, 'GraphLayer: call %s %s %d %s', ch.params.name, ch, len(input_args), str(input_args)) ch_out = ch.FProp(th, *input_args) if len(sig.outputs) == 1: ch_out = (ch_out,) assert len(sig.outputs) == len(ch_out) for n, t in zip(sig.outputs, ch_out): graph_tensors.StoreTensor(n, t) layer_out = tuple(graph_tensors.GetTensor(x) for x in p.output_endpoints) if len(layer_out) == 1: layer_out = layer_out[0] return layer_out
def ZeroState(self, theta, prepared_inputs, batch_size): """Produce a zero state for this step. Args: theta: A `.NestedMap` object containing weights' values of this layer and its children layers. prepared_inputs: A set of inputs pre-processed by using PrepareExternalInputs. batch_size: Number of elements in the batched input. Returns: state0, a state parameter to pass to FProp on its first invocation. """ max_seq_length = py_utils.GetShape(prepared_inputs.src, 3)[0] atten_state = self.atten.ZeroAttentionState(max_seq_length, batch_size) (new_atten_context, _, new_atten_states) = self.atten.ComputeContextVectorWithSource( theta.atten, prepared_inputs.packed_src, tf.zeros([batch_size, self.params.atten.query_dim], dtype=py_utils.FPropDtype(self.params)), attention_state=atten_state) return py_utils.NestedMap( atten_context=new_atten_context, atten_state=new_atten_states)
def FProp(self, theta, prepared_inputs, step_inputs, padding, state0): """Perform inference on a stateless layer. Args: theta: A `.NestedMap` object containing weights' values of this layer and its children layers. prepared_inputs: unused. step_inputs: A NestedMap containing 'inputs', which are passed directly to the layer. padding: A 0/1 float tensor of shape [batch_size]; 1.0 means that this batch element is empty in this step. state0: unused. Returns: (output, state1), where output is the output of the layer, and state1 is an empty NestedMap. """ del state0 del prepared_inputs args = {} if padding is not None: args['padding'] = padding output = self.layer.FProp(theta.layer, step_inputs.inputs, **args) return output, py_utils.NestedMap()
def FPropMeta(cls, p, inputs, *args): dim1, dim2 = args[1][:2] if p.inputs_from_decoder else inputs[:2] logits = tshape.Shape([dim1, dim2, p.num_classes]) return py_utils.NestedMap(flops=100, out_shapes=(logits, ))
def FProp(self, theta, x, x_paddings=None, eos_id=1, force_sample_last_token=True): """Applies SymbolInsertionLayer. We take in a `x`, which represents the groundtruth sequence (i.e., English sequence). We return a sampled rollin (observed) canvas (i.e., random subset of the English sequence), as well as the target (indices) for an insertion-based model (i.e., the targets given the random observed subset). Args: theta: Ignored, this can be None. x: The symbol ids of shape `[batch_size, time_dim]`. x_paddings: The paddings (1 or 0) of shape `[batch_size, time_dim]` where 0 is valid and 1 is invalid. eos_id: The <eos> token id to represent end-of-slot. force_sample_last_token: Set True to force sample the last token of `x`. Returns: A `NestedMap`. - canvas: The canvas (based off of the `rollin_policy`) of shape [batch_size, c_dim]. Note that, `c_dim` <= `time_dim` but need not be equal. - canvas_indices: The canvas indices (into `x`). - canvas_paddings: The paddings of `canvas_indices`. - target_indices: The target indices of shape [num_targets, 3]. `num_targets` is the number of total targets in the entire batch. [:, 0] captures the batch, [:, 1] captures the slot, and [:, 2] captures the token. Each row [batch, slot, vocab] represents the indices of the target -- i.e., the batch, slot and vocab combination of the target. Typical usage of these indices is to tf.gather_nd the log-probs (from the softmax layer). - target_weights: The target weights. Raises: ValueError: If invalid params. """ p = self.params batch_size = py_utils.GetShape(x)[0] time_dim = py_utils.GetShape(x)[1] if x_paddings is None: x_paddings = tf.zeros([batch_size, time_dim], tf.float32) oracle_policy = p.oracle_policy rollin_policy = (oracle_policy if p.rollin_policy == 'oracle' else p.rollin_policy) if rollin_policy != 'uniform': raise ValueError('Unknown or unsupported rollin policy: %s' % rollin_policy) if oracle_policy != 'uniform': raise ValueError('Unknown or unsupported oracle policy: %s' % oracle_policy) x_len = tf.cast(tf.round(tf.reduce_sum(1 - x_paddings, 1)), tf.int32) # Compute the desired length per example in the batch. ratio = tf.random.uniform([batch_size], 0.0, 1.0, seed=p.random_seed) if force_sample_last_token: c_len = tf.minimum( tf.cast(ratio * tf.cast(x_len, tf.float32), tf.int32), x_len - 1) + 1 else: c_len = tf.minimum( tf.cast(ratio * tf.cast(x_len + 1, tf.float32), tf.int32), x_len) # Compute the maximum length across the batch. c_len_max = tf.reduce_max(c_len) # Grab subset of random valid indices per example. z_logits = tf.cast( tf.expand_dims(tf.range(time_dim), 0) >= tf.expand_dims(x_len, 1), tf.float32) * -1e9 if force_sample_last_token: # Force sample the last token -- i.e., as indexed by `x_len - 1`. We can # accomplish this by add +LARGE_NUMBER to the logits. z_logits += tf.cast( tf.equal(tf.expand_dims(tf.range(time_dim), 0), tf.expand_dims(x_len - 1, 1)), tf.float32) * 1e9 # Gumbel-max trick to sample (we only sample valid positions per sample in # the batch). z = -tf.math.log(-tf.math.log( tf.random.uniform([batch_size, time_dim], seed=p.random_seed))) unused_c_values, c_indices = tf.nn.top_k(z_logits + z, time_dim) # Trim everything > c_len_max. c_indices = c_indices[:, :c_len_max] # Invalidate any indices >= c_len, we use the last index as the default # invalid index. c_indices = tf.where( tf.expand_dims(tf.range(c_len_max), 0) < tf.expand_dims(c_len, 1), c_indices, tf.fill(py_utils.GetShape(c_indices), time_dim - 1)) # Materialize the canvas. c_indices = tf.sort(c_indices) c = tf.gather_nd( x, tf.stack([ tf.reshape( tf.tile(tf.expand_dims(tf.range(batch_size), 1), [1, c_len_max]), [-1]), tf.reshape(c_indices, [-1]) ], 1)) c = tf.reshape(c, [batch_size, c_len_max]) # Compute the paddings. c_paddings = 1 - tf.sequence_mask( c_len, c_len_max, dtype=x_paddings.dtype) c *= tf.cast(1 - c_paddings, tf.int32) indices = tf.concat([ tf.reshape( tf.tile(tf.expand_dims(tf.range(batch_size), 1), [1, c_len_max]), [batch_size * c_len_max, 1]), tf.reshape(c_indices, [batch_size * c_len_max, 1]) ], 1) x_token_is_observed = tf.scatter_nd( indices, tf.ones([batch_size * c_len_max], tf.int32), py_utils.GetShape(x)) # `x_segments` captures which slot each `x` belongs to (both observed and # tokens that need to be observed). x_segments = tf.cumsum(x_token_is_observed, 1, exclusive=True) x_token_is_observed = tf.cast(x_token_is_observed, tf.bool) prev_x_token_is_observed = tf.pad(x_token_is_observed[:, :-1], [[0, 0], [1, 0]], constant_values=True) x_token_is_observed = tf.reshape(x_token_is_observed, [-1]) prev_x_token_is_observed = tf.reshape(prev_x_token_is_observed, [-1]) x_is_valid = tf.cast(1 - x_paddings, tf.bool) x_is_valid = tf.reshape(x_is_valid, [-1]) # Remap all the observed to <eos>, note some of these need a zero weight # (or else there would be <eos> and valid token in the same slot). target_indices = tf.cast(tf.reshape(x, [-1, 1]), tf.int32) target_indices = tf.where( x_token_is_observed, tf.fill(py_utils.GetShape(target_indices), eos_id), target_indices) # TODO(williamchan): We give uniform 1.0 weight, however, math suggests # we may want to weigh this term by the original sequence length. target_weights = tf.ones_like(target_indices, tf.float32) # We need to set all the weights for <eos> which actually have valid tokens # in the slot to zero. target_weights = tf.where( x_token_is_observed & ~prev_x_token_is_observed, tf.zeros_like(target_weights), target_weights) # TODO(williamchan): Consider dropping the entries w/ weight zero. # Add the batch and slot indices. target_indices = tf.concat([ tf.reshape( tf.tile(tf.expand_dims(tf.range(batch_size), 1), [1, time_dim]), [batch_size * time_dim, 1]), tf.reshape(x_segments, [-1, 1]), target_indices ], 1) # Select only the valid indices. The selected valid ones include slots w/ # <eos>. target_indices = target_indices[x_is_valid] target_weights = target_weights[x_is_valid] return py_utils.NestedMap(canvas=c, canvas_indices=c_indices, canvas_paddings=c_paddings, target_indices=target_indices, target_weights=target_weights)
def __init__(self): self._named_tensors = py_utils.NestedMap()
def FPropMeta(cls, p, *args): py_utils.CheckShapes(args) meta = p.body.cls.FPropMeta(p.body, *args) py_utils.CheckShapes(meta.out_shapes) total = meta.flops * p.repeat return py_utils.NestedMap(flops=total, out_shapes=args)
def FPropMeta(cls, p, *args): py_utils.CheckShapes(args) return py_utils.NestedMap(flops=0, out_shapes=args)