def _connect_ops(self, info): """Connect the previously copied ops.""" for op in info.sgv.ops: logging.debug("Finalizing op: %s", op.name) op_ = info.transformed_ops[op] # pylint: disable=protected-access if op_.inputs: raise ValueError("The newly transformed op should not have " "any inputs yet: {}".format(op_.name)) inputs_ = [self._transformed_t(info, t) for t in op.inputs] for t in inputs_: op_._add_input(t) # Finalize original op. if op._original_op: original_op = info.transform_original_op_handler(info, op._original_op) if original_op is None: logging.debug("Could not find original op for: %s", op_.name) else: op_._original_op = original_op # Finalize control inputs: control_inputs_ = [self.transform_control_input_handler(info, ci) for ci in op.control_inputs] control_inputs_ = [ci for ci in control_inputs_ if ci is not None] reroute.add_control_inputs(op_, control_inputs_)
def transform(self, feature_column): """Returns a Tensor which represents given feature_column. Args: feature_column: An instance of FeatureColumn. Returns: A Tensor which represents given feature_column. It may create a new Tensor or re-use an existing one. Raises: ValueError: if FeatureColumn cannot be handled by this Transformer. """ logging.debug('Transforming feature_column %s', feature_column) if feature_column in self._columns_to_tensors: # Feature_column is already transformed. return self._columns_to_tensors[feature_column] feature_column.insert_transformed_feature(self._columns_to_tensors) if feature_column not in self._columns_to_tensors: raise ValueError('Column {} is not supported.'.format( feature_column.name)) return self._columns_to_tensors[feature_column]
def _transformed_t(self, info, t, consumer_op): """Return tre transformed tensor of `t`.""" if t in info.transformed_ts: # If op is in the subgraph, just return its transformed counterpart. return info.transformed_ts[t] if t in info.sgv_inputs_set: # `t` is an input of the subgraph. return self.transform_external_input_handler(info, t) elif t.op in info.ops: # `t` is an internal tensor but is not transformed yet because it # belongs to a graph cycle. logging.debug("Cyclic tensor: t.name = %s", t.name) # Try to find an existing tensor we can use for now, # otherwise create one. We'll rewire this later. if consumer_op.type == "Merge": first_input = consumer_op.inputs[0] tmp_t_ = self._transformed_t(info, first_input, consumer_op) elif t.op.type == "Enter": enter_input = t.op.inputs[0] tmp_t_ = self._transformed_t(info, enter_input, consumer_op) else: with info.graph_.as_default(): tmp_t_ = util.make_placeholder_from_tensor(t, scope=info.scope_, prefix="geph_tmp") logging.debug("Created temporary placeholder: %s.", tmp_t_.name) # Register as temporary and return. info.tmp_cyclic_ts.append((t, tmp_t_, consumer_op)) return tmp_t_ else: # `t` is a hidden input of the subgraph. return self.transform_external_hidden_input_handler(info, t)
def evaluate_and_export(self): """Evaluate and (maybe) export the current model. Returns: Evaluation results. Returns `None` if current round of evaluation is skipped. Raises: RuntimeError: for any unexpected internal error. TypeError: if evaluation result has wrong type. """ latest_ckpt_path = self._estimator.latest_checkpoint() if not latest_ckpt_path: self._log_err_msg('Estimator is not trained yet. Will start an ' 'evaluation when a checkpoint is ready.') return None if latest_ckpt_path == self._previous_ckpt_path: self._log_err_msg( 'No new checkpoint ready for evaluation. Skip the current ' 'evaluation pass as evaluation results are expected to be same ' 'for the same checkpoint.') return None eval_result = self._estimator.evaluate( input_fn=self._eval_spec.input_fn, steps=self._eval_spec.steps, name=self._eval_spec.name, checkpoint_path=latest_ckpt_path, hooks=self._eval_spec.hooks) if not eval_result: raise RuntimeError( 'Internal error: `Estimator.evaluate` should never return empty ' 'result.') if not isinstance(eval_result, dict): raise TypeError( '`Estimator.evaluate` should return dict. Given {}.'.format( type(eval_result))) if ops.GraphKeys.GLOBAL_STEP not in eval_result: raise RuntimeError( 'Internal error: `Estimator.evaluate` result should have ' '`global_step` in result. Given {}'.format(eval_result)) is_the_final_export = (eval_result[ops.GraphKeys.GLOBAL_STEP] >= self._max_training_steps if self._max_training_steps else False) self._export_eval_result(eval_result, latest_ckpt_path, is_the_final_export) if is_the_final_export: logging.debug('Calling exporter with the `is_the_final_export=True`.') self._is_final_export_triggered = True self._last_warning_time = 0 self._previous_ckpt_path = latest_ckpt_path return eval_result
def _SetPath(self, path): old_path = self._path if old_path and not gcs.IsGCSPath(old_path): # We're done with the path, so store its size. size = io_wrapper.Size(old_path) logging.debug('Setting latest size of %s to %d', old_path, size) self._finalized_sizes[old_path] = size self._path = path self._loader = self._loader_factory(path)
def __init__(self, file_path): if file_path is None: raise ValueError('A file path is required') file_path = resource_loader.readahead_file_path(file_path) logging.debug('Opening a record reader pointing at %s', file_path) self._reader = pywrap_tensorflow.PyRecordReader_New( compat.as_bytes(file_path), 0) # Store it for logging purposes. self._file_path = file_path if not self._reader: raise IOError('Failed to open a record reader pointing to %s' % file_path)
def run(self): # Don't fetch logs or adjust timing: just ping the watchdog. # # If we hit an exception, reset our session as it is likely broken. while self._running: try: self._worker_manager.ping(request=None) time.sleep(self.ping_interval) except errors.OpError as e: # Catch any TF errors that occur so we don't stop sending heartbeats logging.debug('Caught error while sending heartbeat: %s', e) self._reset_manager()
def _input_thread_fn_for_loading(self, session, enqueue_ops, iterations): count = 0 while True: signal = self._signal_queue.get() if signal == _SIGNAL.STOP: logging.info('Stop Infeed input thread.') return for i in range(iterations): logging.debug('InfeedEnqueue data for iteration (%d, %d)', count, i) session.run(enqueue_ops) count += 1
def Load(self): # Create a temp file to hold the contents that we haven't seen yet. with tempfile.NamedTemporaryFile(prefix='tf-gcs-') as temp_file: name = temp_file.name logging.debug('Temp file created at %s', name) gcs.CopyContents(self._gcs_path, self._gcs_offset, temp_file) reader = pywrap_tensorflow.PyRecordReader_New(compat.as_bytes(name), 0) while reader.GetNext(): event = event_pb2.Event() event.ParseFromString(reader.record()) yield event logging.debug('No more events in %s', name) self._gcs_offset += reader.offset()
def every_n_step_end(self, step, outputs): super(ValidationMonitor, self).every_n_step_end(step, outputs) # TODO(mdan): The use of step below is probably misleading. # The code should probably use the step from the checkpoint, because # that's what is being evaluated. if self._estimator is None: raise ValueError("Missing call to set_estimator.") # Check that we are not running evaluation on the same checkpoint. latest_path = saver_lib.latest_checkpoint(self._estimator.model_dir) if latest_path is None: logging.debug("Skipping evaluation since model has not been saved yet " "at step %d.", step) return False if latest_path is not None and latest_path == self._latest_path: logging.debug("Skipping evaluation due to same checkpoint %s for step %d " "as for step %d.", latest_path, step, self._latest_path_step) return False self._latest_path = latest_path self._latest_path_step = step # Run evaluation and log it. validation_outputs = self._estimator.evaluate( x=self.x, y=self.y, input_fn=self.input_fn, batch_size=self.batch_size, steps=self.eval_steps, metrics=self.metrics, hooks=self.hooks, name=self.name) stats = [] for name in validation_outputs: stats.append("%s = %s" % (name, str(validation_outputs[name]))) logging.info("Validation (step %d): %s", step, ", ".join(stats)) # Early stopping logic. if self.early_stopping_rounds is not None: if self.early_stopping_metric not in validation_outputs: raise ValueError("Metric %s missing from outputs %s." % ( self.early_stopping_metric, set(validation_outputs.keys()))) current_value = validation_outputs[self.early_stopping_metric] if (self._best_value is None or (self.early_stopping_metric_minimize and (current_value < self._best_value)) or (not self.early_stopping_metric_minimize and (current_value > self._best_value))): self._best_value = current_value self._best_value_step = step stop_now = (step - self._best_value_step >= self.early_stopping_rounds) if stop_now: logging.info("Stopping. Best step: {} with {} = {}." .format(self._best_value_step, self.early_stopping_metric, self._best_value)) self._early_stopped = True return True return False
def Load(self): """Loads all new values from disk. Calling Load multiple times in a row will not 'drop' events as long as the return value is not iterated over. Yields: All values that were written to disk that have not been yielded yet. """ while self._reader.GetNext(): event = event_pb2.Event() event.ParseFromString(self._reader.record()) yield event logging.debug('No more events in %s', self._file_path)
def copy_op_handler(info, op, copy_shape=True): """Copy a `tf.Operation`. Args: info: Transform._TmpInfo instance. op: the `tf.Operation` to be copied. copy_shape: also copy the shape of the tensor Returns: A `(op, op_outputs)` tuple containing the transformed op and its outputs. """ # pylint: disable=protected-access # Clone the node def: node_def_ = deepcopy(op._node_def) # Transform name: name_ = info.new_name(op.name) name_ = info.graph_.unique_name(name_) node_def_.name = name_ # Copy the other inputs needed for initialization output_types_ = op._output_types[:] input_types_ = op._input_types[:] # Make a copy of the op_def too. # Its unique to every _type_ of Operation. op_def_ = deepcopy(op._op_def) # Initialize a new Operation instance op_ = tf_ops.Operation(node_def_, info.graph_, [], output_types_, [], input_types_, None, op_def_) # copy the shape over if copy_shape: for t, t_ in zip(op.outputs, op_.outputs): t_.set_shape(t.get_shape()) # Finalize original op. if op._original_op: original_op = info.transform_original_op_handler(info, op._original_op) if original_op is None: logging.debug("Could not find original op of: %s", op_.name) else: op_._original_op = original_op # Add op to the graph info.graph_._add_op(op_) return op_, op_.outputs
def evaluate_and_export(self): """Evaluate and (maybe) export the current model. Returns: A tuple of `EvalResult` instance and the export results. Raises: RuntimeError: for any unexpected internal error. TypeError: if evaluation result has wrong type. """ latest_ckpt_path = self._estimator.latest_checkpoint() if not latest_ckpt_path: self._log_err_msg('Estimator is not trained yet. Will start an ' 'evaluation when a checkpoint is ready.') return _EvalResult(status=_EvalStatus.MISSING_CHECKPOINT), [] if latest_ckpt_path == self._previous_ckpt_path: self._log_err_msg( 'No new checkpoint ready for evaluation. Skip the current ' 'evaluation pass as evaluation results are expected to be same ' 'for the same checkpoint.') return _EvalResult(status=_EvalStatus.NO_NEW_CHECKPOINT), [] metrics = self._estimator.evaluate( input_fn=self._eval_spec.input_fn, steps=self._eval_spec.steps, name=self._eval_spec.name, checkpoint_path=latest_ckpt_path, hooks=self._eval_spec.hooks) # _EvalResult validates the metrics. eval_result = _EvalResult( status=_EvalStatus.EVALUATED, metrics=metrics, checkpoint_path=latest_ckpt_path) is_the_final_export = ( eval_result.metrics[ops.GraphKeys.GLOBAL_STEP] >= self._max_training_steps if self._max_training_steps else False) export_results = self._export_eval_result(eval_result, is_the_final_export) if is_the_final_export: logging.debug('Calling exporter with the `is_the_final_export=True`.') self._is_final_export_triggered = True self._last_warning_time = 0 self._previous_ckpt_path = latest_ckpt_path return eval_result, export_results
def ping(self, request=None, timeout_in_ms=5000): """Ping all workers, returning the parsed status results.""" if request is None: request = event_pb2.WorkerHeartbeatRequest() options = config_pb2.RunOptions(timeout_in_ms=timeout_in_ms) results = self._session.run( self._ops, feed_dict={self._request_placeholder: request.SerializeToString()}, options=options) parsed_results = [ event_pb2.WorkerHeartbeatResponse.FromString(res_pb) for res_pb in results ] logging.debug('Ping results: %s', parsed_results) return parsed_results
def _copy_ops(self, info): """Copy ops without connecting them.""" for op in info.sgv.ops: logging.debug("Copying op: %s", op.name) # TODO(fkp): return a subgraph? op_, op_outputs_ = self.transform_op_handler(info, op) if op is op_: raise ValueError("In-place tranformation not allowed.") # Process op. info.transformed_ops[op] = op_ self.assign_collections_handler(info, op, op_) # Process output tensors. for op_output, op_output_ in zip(op.outputs, op_outputs_): info.transformed_ts[op_output] = op_output_ self.assign_collections_handler(info, op_output, op_output_)
def _Create(baseclass, subclass_name, *args, **kwargs): """Creates an instance of a named subclass. Args: baseclass: The expected base class. subclass_name: The fully-qualified type name of the subclass to create. *args: Passed to the subclass constructor. **kwargs: Passed to the subclass constructor. Returns: An instance of the named subclass, or None on error. """ subclass = _GetClass(subclass_name) if subclass is None: return None # _GetClass() already logged an error if not issubclass(subclass, baseclass): logging.debug('Class "%s" is not a subclass of "%s"', subclass_name, baseclass.__name__) return None return subclass(*args, **kwargs)
def _SetPath(self, path): """Sets the current path to watch for new events. This also records the size of the old path, if any. If the size can't be found, an error is logged. Args: path: The full path of the file to watch. """ old_path = self._path if old_path and not gcs.IsGCSPath(old_path): try: # We're done with the path, so store its size. size = io_wrapper.Size(old_path) logging.debug('Setting latest size of %s to %d', old_path, size) self._finalized_sizes[old_path] = size except (IOError, OSError) as e: logging.error('Unable to get size of %s: %s', old_path, e) self._path = path self._loader = self._loader_factory(path)
def _subscribe(tensor, side_effects, control_cache): """Helper method that subscribes a single tensor to a list of side_effects. This method will check if the given tensor has already been subscribed or if it's a tensor returned by a previous call to `subscribe()` and, if so, will reuse the existing identity op, appending the given side effects to the list of existing ones. Args: tensor: The `tf.Tensor` to be subscribed. side_effects: List of side_effect functions, see subscribe for details. control_cache: `_ControlOutputCache` helper to get control_outputs faster. Returns: The modified replacement to the passed in tensor which triggers the side effects or the given tensor, if it was already been subscribed. """ # Check if the given tensor has a numpy compatible type (see dtypes.py). # If not, we cannot subscribe it, so we just return the original tensor. if not tensor.dtype.is_numpy_compatible: logging.debug(('Tensor {} has an un-supported {} type and cannot be ' 'subscribed.').format(tensor.name, tensor.dtype)) return tensor if _is_subscribed_identity(tensor): return _subscribe_extend(tensor, side_effects) # Check if the given tensor has already been subscribed by inspecting its # outputs. name_scope = tensor.op.name + '/subscription/Identity' consumers = tensor.consumers() matching_ops = [op for op in consumers if op.name.startswith(name_scope)] assert len(matching_ops) <= 1, ('Op {} must only have one subscription ' 'op connected to it').format(tensor.op.name) if len(matching_ops) == 1: candidate_tensor = matching_ops[0].outputs[0] if _is_subscribed_identity(candidate_tensor): return _subscribe_extend(candidate_tensor, side_effects) return _subscribe_new(tensor, side_effects, control_cache)
def _connect_control_inputs(self, info): """Connect the previously copied ops.""" for op in info.sgv.ops: logging.debug("Connecting control inputs of op: %s", op.name) op_ = info.transformed_ops[op] # Finalize original op. # TODO(fkp): Stop worrying about _original_op and remove this code? # pylint: disable=protected-access if op._original_op: original_op = self.transform_original_op_handler(info, op._original_op) if original_op is None: logging.debug("Could not find original op for: %s", op_.name) else: op_._original_op = original_op # pylint: enable=protected-access # Finalize control inputs: control_inputs_ = [self.transform_control_input_handler(info, ci) for ci in op.control_inputs] control_inputs_ = [ci for ci in control_inputs_ if ci is not None] reroute.add_control_inputs(op_, control_inputs_)
def Load(self): """Loads all new values from disk. Calling Load multiple times in a row will not 'drop' events as long as the return value is not iterated over. Yields: All values that were written to disk that have not been yielded yet. """ while True: try: with errors.raise_exception_on_not_ok_status() as status: self._reader.GetNext(status) except (errors.DataLossError, errors.OutOfRangeError): # We ignore partial read exceptions, because a record may be truncated. # PyRecordReader holds the offset prior to the failed read, so retrying # will succeed. break event = event_pb2.Event() event.ParseFromString(self._reader.record()) yield event logging.debug('No more events in %s', self._file_path)
def _transform_t(self, t): """Transform a tf.Tensor. Args: t: the tensor to be transformed. Returns: The transformed tensor. """ logging.debug("Transforming tensor: %s", t.name) if t in self._info.transformed_ts: return self._info.transformed_ts[t] # Mark as None to detect cycle. self._info.transformed_ts[t] = None op, op_index = t.op, t.value_index # If op is not in the subgraph: if op not in self._info.ops: # t_ is an input of the subgraph if t in self._info.sgv_inputs_set: t_ = self.transform_external_input_handler(self._info, t) # t_ is a hidden input of the subgraph else: t_ = self.transform_external_hidden_input_handler(self._info, t) # If op is in the subgraph, just transform it: else: op_ = self._transform_op(op) t_ = op_.outputs[op_index] # assign to collection if t is not t_: self.assign_collections_handler(self._info, t, t_) self._info.transformed_ts[t] = t_ return t_
def get(self, key): """Returns a `Tensor` for the given key. A `str` key is used to access a base feature (not-transformed). When a `_FeatureColumn` is passed, the transformed feature is returned if it already exists, otherwise the given `_FeatureColumn` is asked to provide its transformed output, which is then cached. Args: key: a `str` or a `_FeatureColumn`. Returns: The transformed `Tensor` corresponding to the `key`. Raises: ValueError: if key is not found or a transformed `Tensor` cannot be computed. """ if key in self._columns_to_tensors: # Feature_column is already transformed or it's a raw feature. return self._columns_to_tensors[key] if not isinstance(key, (str, _FeatureColumn)): raise TypeError('"key" must be either a "str" or "_FeatureColumn". ' 'Provided: {}'.format(key)) if not isinstance(key, _FeatureColumn): raise ValueError('Feature {} is not in features dictionary.'.format(key)) column = key logging.debug('Transforming feature_column %s.', column) transformed = column._transform_feature(self) # pylint: disable=protected-access if transformed is None: raise ValueError('Column {} is not supported.'.format(column.name)) self._columns_to_tensors[column] = transformed return self._columns_to_tensors[column]
def _check_inputs(self, features, targets): if self._features_info is not None: logging.debug('Given features: %s, required signatures: %s.', str(features), str(self._features_info)) if not tensor_signature.tensors_compatible(features, self._features_info): raise ValueError('Features are incompatible with given information. ' 'Given features: %s, required signatures: %s.' % (str(features), str(self._features_info))) else: self._features_info = tensor_signature.create_signatures(features) logging.debug('Setting feature info to %s.', str(self._features_info)) if targets is not None: if self._targets_info is not None: logging.debug('Given targets: %s, required signatures: %s.', str(targets), str(self._targets_info)) if not tensor_signature.tensors_compatible(targets, self._targets_info): raise ValueError('Targets are incompatible with given information. ' 'Given targets: %s, required signatures: %s.' % (str(targets), str(self._targets_info))) else: self._targets_info = tensor_signature.create_signatures(targets) logging.debug('Setting targets info to %s', str(self._targets_info))
def _GetClass(name): """Looks up a class by name. Args: name: The fully-qualified type name of the class to return. Returns: The class associated with the |name|, or None on error. """ elements = name.split('.') # Need at least "module.Class". if len(elements) < 2: logging.debug('Malformed type: "%s"', name) return None module_path = '.'.join(elements[:-1]) class_name = elements[-1] # Import the module. try: __import__(module_path) except ImportError as e: logging.debug('Unable to find module "%s": "%s"', module_path, e) return None module = sys.modules[module_path] # Look up the class. if not hasattr(module, class_name): logging.debug('Name "%s" not found in module: "%s"', class_name, module_path) return None class_obj = getattr(module, class_name) # Check that it is actually a class. if not inspect.isclass(class_obj): logging.debug('Name does not refer to a class: "%s"', name) return None return class_obj
def _check_inputs(self, features, labels, mode): if mode in self._features_info: logging.debug('Given features for mode %s: %s, required signatures: %s.', mode, str(features), str(self._features_info[mode])) if not tensor_signature.tensors_compatible(features, self._features_info[mode]): raise ValueError('Features for mode %s are incompatible with given information. ' 'Given features: %s, required signatures: %s.' % (mode, str(features), str(self._features_info[mode]))) else: self._features_info[mode] = tensor_signature.create_signatures(features) logging.debug('Setting feature info for mode %s to %s.', mode, str(self._features_info[mode])) if labels is not None: if mode in self._labels_info: logging.debug('Given labels: %s, required signatures: %s.', str(labels), str(self._labels_info)) if not tensor_signature.tensors_compatible(labels, self._labels_info[mode]): raise ValueError('Labels for mode %s are incompatible with given information. ' 'Given labels: %s, required signatures: %s.' % (mode, str(labels), str(self._labels_info[mode]))) else: self._labels_info[mode] = tensor_signature.create_signatures(labels) logging.debug('Setting labels info for mode %s to %s', mode, str(self._labels_info[mode]))
def train(self): # 获取数据集 logging.debug('creating datasets') self._merged_dataset = self._get_merged_dataset() self._x, self._y = self._merged_dataset.next_batch logging.debug('successfully get dataset') # 建立模型,得到结果 logits, self._end_points = self._get_model() logging.debug('successfully getting logits') # 获取损失函数 total_loss = self._get_loss(logits) logging.debug('successfully getting total loss') # 构建优化器 optimizer = self._get_optimizer() logging.debug('successfully getting optimizer') # 性能指标相关操作 self._metrics_summary_ops, self._metrics_update_ops, \ self._metrics_update_after_reset_ops = self._get_metrics(logits, total_loss) logging_tensors = self._metrics_update_after_reset_ops logging.debug('successfully getting metrics') # 构建train_op, hooks, scaffold train_op = self._get_train_op(total_loss, optimizer) logging.debug('successfully getting train_op') hooks = self._get_hooks() logging.debug('successfully getting hooks') scaffold = self._get_scaffold() logging.debug('successfully getting scaffold') logging.debug('training start') train( train_op, self._base_logs_dir, scaffold=scaffold, hooks=hooks, max_steps=self._max_steps, feed_fn=self._get_train_feed_fn, logging_tensors=logging_tensors, logging_every_n_steps=self._logging_every_n_steps, summary_writer=tf.summary.FileWriter(self._base_logs_dir, graph=tf.get_default_graph()), summary_every_n_steps=self._summary_every_n_steps, save_every_n_steps=self._save_every_n_steps, )
def _get_accumulation_ops(graph_item, gradient, target, num_accum_required): def _get_accum_apply_and_agg_grad(var_op, grad, indices, dense_shape): if indices is None: tensor = variable_utils.get_read_var_tensor(var_op) grad_accum = data_flow_ops.ConditionalAccumulator( grad.dtype, shape=tensor.get_shape(), shared_name=var_op.name + "/grad_accum") # Get a copy of consumers list before creating accum_apply_op grad_consumers = list(grad.consumers()) accum_apply_op = grad_accum.apply_grad(grad, local_step=MAX_INT64, name=grad.op.name + '_accum_apply_grad') agg_grad = grad_accum.take_grad(num_accum_required, name=var_op.name + '_take_grad') update_consumers(grad_consumers, grad, agg_grad) update_control_consumers(get_control_consumers(grad.op), grad.op, agg_grad.op) else: grad_indexed_slices = ops.IndexedSlices( values=grad, indices=indices, dense_shape=dense_shape) grad_accum = data_flow_ops.SparseConditionalAccumulator( grad.dtype, shape=grad.shape, shared_name=var_op.name + "/grad_accum") # Get a copy of consumers list before creating accum_apply_op indices_consumers = list(indices.consumers()) grad_consumers = list(grad.consumers()) accum_apply_op = grad_accum.apply_indexed_slices_grad( grad_indexed_slices, local_step=MAX_INT64, name=grad.op.name + '_accum_apply_grad') agg_grad = grad_accum.take_indexed_slices_grad( num_accum_required, name=var_op.name + '_take_grad') agg_indices = agg_grad.indices if indices.dtype != agg_grad.indices.dtype: agg_indices = math_ops.cast(agg_grad.indices, indices.dtype) agg_grad = ops.IndexedSlices(values=agg_grad.values, indices=agg_indices, dense_shape=agg_grad.dense_shape) assert isinstance(agg_grad, ops.IndexedSlices) update_consumers(indices_consumers, indices, agg_grad.indices) update_consumers(grad_consumers, grad, agg_grad.values) update_control_consumers(get_control_consumers(indices.op), indices.op, agg_grad.indices.op) update_control_consumers(get_control_consumers(grad.op), grad.op, agg_grad.values.op) return accum_apply_op, agg_grad # Aggregate gradients from different workers using ConditionalAccumulator. # var_op_to_agg_grad and var_op_to_accum_apply_op are updated. var_op_to_agg_grad = {} var_op_to_accum_apply_op = {} if target.op not in graph_item.trainable_var_op_to_var: logging.debug( "Gradient for non-trainable variable %s is created, " "do not insert accumulator for aggregating this gradient" % target.op.name) return {}, {} var_op = target.op if isinstance(gradient, ops.Tensor): grad = gradient indices = None dense_shape = None else: grad = gradient.values indices = gradient.indices dense_shape = gradient.dense_shape with ops.device(var_op.device), ops.name_scope(""): accum_apply_op, agg_grad = _get_accum_apply_and_agg_grad( var_op, grad, indices, dense_shape) if indices is None: var_op_to_agg_grad[var_op] = (None, agg_grad) else: var_op_to_agg_grad[var_op] = (agg_grad.indices, agg_grad.values) var_op_to_accum_apply_op[var_op] = accum_apply_op return var_op_to_agg_grad, var_op_to_accum_apply_op
def automatic_sharding(num_shards, input_ts, output_ts, edge_filter=None, frozen_inference=False): """Automatically set shards for all connected nodes in graph. Args: num_shards(int): number of shards to split graph over. input_ts(tf.Tensor): tensor closest to the data-feed in graph. output_ts(tf.Tensor): tensor closest to the output/loss in graph. edge_filter: a callable predicate, with the signature fn(edge), where edge is a tuple containing the name of the source op and the name of the destination op. If the predicate returns True then the graph will not be split at that edge. Only used if frozen_inference is False. frozen_inference: Flag set to True if running inference on a frozen graph. Raises: ValueError if no ops are set to run on IPU device. """ output_op = output_ts.op input_op = input_ts.op ipu_ops = list( filter(lambda o: 'IPU' in o.device, output_op.graph.get_operations())) if len(ipu_ops) == 0: raise ValueError("No ops placed on IPU device to shard.") fwd_ops = [] marked_collection = output_op.graph.get_collection(sharding._IPU_AUTOSHARD) if len(marked_collection) > 0: fwd_ops = marked_collection else: for op in ipu_ops: if not any( [s in op.name.lower() for s in ['gradients/', '/update_']]): fwd_ops.append(op) bwd_ops = [o for o in ipu_ops if o not in fwd_ops] fwd_ops = [o for o in fwd_ops if o.type not in prohibited_ops] if input_op not in fwd_ops: input_op = [op for op in input_ts.consumers() if op in fwd_ops][0] if frozen_inference: graph = convert_inference_ops_to_nx(fwd_ops) else: graph = convert_ops_to_nx(fwd_ops, bwd_ops) # Check graph is a single weakly connected component # if not find the component with the output op in and use that weakly_connected_components = list(nx.weakly_connected_components(graph)) graph_fwd = None for g in weakly_connected_components: if output_op.name in g: graph_fwd = graph.subgraph(g) break fwd_ops = [op for op in fwd_ops if op.name in graph_fwd.nodes] if nx.number_weakly_connected_components(graph_fwd) != 1: raise RuntimeError( "Error: number of disconnected subgraphs in auto-sharder is {}". format(nx.number_weakly_connected_components(graph))) splitting_edges = [] if frozen_inference: # Find all graph ops that when split at their output can create two sub-graphs # where the input and output are not in the same sub-graph for node in graph_fwd.nodes: if is_splitting_node(graph_fwd, node, input_op.name, output_op.name): splitting_edges.append([(node, v) for v in graph_fwd.successors(node)]) else: # Find all graph edges that split the graph into two subgraphs where the input # and output are not in the same subgraph for edge in graph_fwd.edges: if is_splitting_edge(graph_fwd, edge, input_op.name, output_op.name): splitting_edges.append([edge]) if edge_filter and callable(edge_filter): splitting_edges = list( filter(lambda e: not edge_filter(e[0]), splitting_edges)) logging.debug('Possible splitting edges ' + str(splitting_edges)) # Verify that we have enough subgraphs to fill all of the available shards if len(splitting_edges) + 1 < num_shards: raise Exception( "There are fewer subgraphs (%s) than available shards (%s). Reduce the " "number of shards." % (len(splitting_edges) + 1, num_shards)) # Given the splitting edges found find all of the subgraphs created and order them sub_graphs = find_all_subgraphs(graph_fwd, splitting_edges, input_op.name, output_op.name, [op.name for op in fwd_ops]) sub_graph_mem = [calculate_memory(g) for g in sub_graphs] logging_helper(sub_graphs) best_ind = group_subgraphs(sub_graph_mem, num_shards) logging.debug('Splitting edges ' + str(list(map(lambda x: str(splitting_edges[x]), best_ind)))) ind_pad = [0] + [i + 1 for i in best_ind] + [len(sub_graph_mem)] per_shard_sub_graphs = [] for i in range(num_shards): per_shard_sub_graphs.append( graph_fwd.subgraph([ nodes for g in sub_graphs[ind_pad[i]:ind_pad[i + 1]] for nodes in g.nodes ])) logging_helper(per_shard_sub_graphs) assign_shard(fwd_ops, ipu_ops, per_shard_sub_graphs)
def compute_gradients(self, loss, var_list=None, gate_gradients=optimizer.Optimizer.GATE_OP, aggregation_method=None, colocate_gradients_with_ops=False, grad_loss=None): """Compute gradients. See base class `tf.compat.v1.train.Optimizer`.""" if self._is_loss_scale: loss_scale = self._loss_scale_manager.get_loss_scale() if context.executing_eagerly(): def scaled_loss(): loss_val = loss() return loss_val * math_ops.cast(loss_scale, loss_val.dtype.base_dtype) else: if callable(loss): loss_val = loss() else: loss_val = loss scaled_loss = loss_val * math_ops.cast( loss_scale, loss_val.dtype.base_dtype) self._float_status = gen_npu_ops.npu_alloc_float_status() else: scaled_loss = loss logging.debug("compute_gradients...") gradients = self._opt.compute_gradients( scaled_loss, var_list=var_list, gate_gradients=gate_gradients, aggregation_method=aggregation_method, colocate_gradients_with_ops=colocate_gradients_with_ops, grad_loss=grad_loss) if not self._is_distributed: if self._is_loss_scale: return self._down_scale(gradients, loss_scale) else: return gradients averaged_gradients = [] grads = [] with tf.name_scope(self._name + "_Allreduce"): for grad, var in gradients: grads.append(grad) if self._is_loss_scale and ( len(grads) == len(gradients)) and self._is_tailing_optimization: self._reduce_all(grads) with tf.get_default_graph().control_dependencies( [self._is_overall_finite]): avg_grad = allreduce( grad, True) if grad is not None else None averaged_gradients.append((avg_grad, var)) else: avg_grad = allreduce(grad, True) if grad is not None else None averaged_gradients.append((avg_grad, var)) if self._is_loss_scale: return self._down_scale(averaged_gradients, loss_scale) else: return averaged_gradients
def init_from_checkpoint(ckpt_dir_or_file, assignment_map): """Initializes current variables with tensors loaded from given checkpoint. Note: This overrides default initialization ops of specified variables and redefines dtype. Assignment map supports following syntax: * `'checkpoint_scope_name/': 'scope_name/'` - will load all variables in current `scope_name` from `checkpoint_scope_name` with matching tensor names. * `'checkpoint_scope_name/some_other_variable': 'scope_name/variable_name'` - will initialize `scope_name/variable_name` variable from `checkpoint_scope_name/some_other_variable`. * `'scope_variable_name': variable` - will initialize given `tf.Variable` object with tensor 'scope_variable_name' from the checkpoint. * `'scope_variable_name': list(variable)` - will initialize list of partitioned variables with tensor 'scope_variable_name' from the checkpoint. * `'/': 'scope_name/'` - will load all variables in current `scope_name` from checkpoint's root (e.g. no scope). Supports loading into partitioned variables, which are represented as `'<variable>/part_<part #>'`. Example: ```python # Say, '/tmp/model.ckpt' has the following tensors: # -- name='old_scope_1/var1', shape=[20, 2] # -- name='old_scope_1/var2', shape=[50, 4] # -- name='old_scope_2/var3', shape=[100, 100] # Create new model's variables with tf.variable_scope('new_scope_1'): var1 = tf.get_variable('var1', shape=[20, 2], initializer=tf.zeros_initializer()) with tf.variable_scope('new_scope_2'): var2 = tf.get_variable('var2', shape=[50, 4], initializer=tf.zeros_initializer()) # Partition into 5 variables along the first axis. var3 = tf.get_variable(name='var3', shape=[100, 100], initializer=tf.zeros_initializer(), partitioner=lambda shape, dtype: [5, 1]) # Initialize all variables in `new_scope_1` from `old_scope_1`. init_from_checkpoint('/tmp/model.ckpt', {'old_scope_1/', 'new_scope_1'}) # Use names to specify which variables to initialize from checkpoint. init_from_checkpoint('/tmp/model.ckpt', {'old_scope_1/var1': 'new_scope_1/var1', 'old_scope_1/var2': 'new_scope_2/var2'}) # Or use tf.Variable objects to identify what to initialize. init_from_checkpoint('/tmp/model.ckpt', {'old_scope_1/var1': var1, 'old_scope_1/var2': var2}) # Initialize partitioned variables using variable's name init_from_checkpoint('/tmp/model.ckpt', {'old_scope_2/var3': 'new_scope_2/var3'}) # Or specify the list of tf.Variable objects. init_from_checkpoint('/tmp/model.ckpt', {'old_scope_2/var3': var3._get_variable_list()}) ``` Args: ckpt_dir_or_file: Directory with checkpoints file or path to checkpoint. assignment_map: Dict, where keys are names of the variables in the checkpoint and values are current variables or names of current variables (in default graph). Raises: tf.errors.OpError: If missing checkpoints or tensors in checkpoints. ValueError: If missing variables in current graph. """ ckpt_file = _get_checkpoint_filename(ckpt_dir_or_file) reader = load_checkpoint(ckpt_dir_or_file) variable_map = reader.get_variable_to_shape_map() for tensor_name_in_ckpt, current_var_or_name in sorted( six.iteritems(assignment_map)): var = None # Check if this is Variable object or list of Variable objects (in case of # partitioned variables). is_var = lambda x: isinstance(x, variables.Variable) if is_var(current_var_or_name) or ( isinstance(current_var_or_name, list) and all(is_var(v) for v in current_var_or_name)): var = current_var_or_name else: store_vars = vs._get_default_variable_store()._vars # pylint:disable=protected-access # Check if this variable is in var_store. var = store_vars.get(current_var_or_name, None) # Also check if variable is partitioned as list. if var is None: var = _collect_partitioned_variable(current_var_or_name, store_vars) if var is not None: # If 1 to 1 mapping was provided, find variable in the checkpoint. if tensor_name_in_ckpt not in variable_map: raise ValueError("Tensor %s is not found in %s checkpoint %s" % ( tensor_name_in_ckpt, ckpt_dir_or_file, variable_map )) if is_var(var): # Additional at-call-time checks. if not var.get_shape().is_compatible_with( variable_map[tensor_name_in_ckpt]): raise ValueError( "Shape of variable %s (%s) doesn't match with shape of " "tensor %s (%s) from checkpoint reader." % ( var.name, str(var.get_shape()), tensor_name_in_ckpt, str(variable_map[tensor_name_in_ckpt]) )) var_name = var.name else: var_name = ",".join([v.name for v in var]) _set_variable_or_list_initializer(var, ckpt_file, tensor_name_in_ckpt) logging.debug("Initialize variable %s from checkpoint %s with %s", var_name, ckpt_dir_or_file, tensor_name_in_ckpt) else: scopes = "" # TODO(vihanjain): Support list of 'current_var_or_name' here. if "/" in current_var_or_name: scopes = current_var_or_name[:current_var_or_name.rindex("/")] if not tensor_name_in_ckpt.endswith("/"): raise ValueError( "Assignment map with scope only name {} should map to scope only " "{}. Should be 'scope/': 'other_scope/'.".format( scopes, tensor_name_in_ckpt)) # If scope to scope mapping was provided, find all variables in the scope # and create variable to variable mapping. scope_variables = set() for var_name in store_vars: if not scopes or var_name.startswith(scopes + "/"): # Consume /part_ if partitioned variable. if "/part_" in var_name: var_name = var_name[:var_name.index("/part_")] scope_variables.add(var_name) for var_name in sorted(scope_variables): # Lookup name with specified prefix and suffix from current variable. # If tensor_name given is '/' (root), don't use it for full name. full_tensor_name = var_name[len(scopes):] if current_var_or_name != "/": full_tensor_name = full_tensor_name[1:] if tensor_name_in_ckpt != "/": full_tensor_name = tensor_name_in_ckpt + full_tensor_name # Remove trailing '/', if any, in the full_tensor_name if full_tensor_name.endswith("/"): full_tensor_name = full_tensor_name[:-1] if full_tensor_name not in variable_map: raise ValueError( "Tensor %s (%s in %s) is not found in %s checkpoint" % ( full_tensor_name, var_name[len(scopes) + 1:], tensor_name_in_ckpt, ckpt_dir_or_file )) var = store_vars.get(var_name, None) if var is None: var = _collect_partitioned_variable(var_name, store_vars) _set_variable_or_list_initializer(var, ckpt_file, full_tensor_name) logging.debug("Initialize variable %s from checkpoint %s with %s", var_name, ckpt_dir_or_file, full_tensor_name)
def _init_from_checkpoint(_, ckpt_dir_or_file, assignment_map): """See `init_from_checkpoint` for documentation.""" ckpt_file = _get_checkpoint_filename(ckpt_dir_or_file) reader = load_checkpoint(ckpt_dir_or_file) variable_map = reader.get_variable_to_shape_map() for tensor_name_in_ckpt, current_var_or_name in sorted( six.iteritems(assignment_map)): var = None # Check if this is Variable object or list of Variable objects (in case of # partitioned variables). if _is_variable(current_var_or_name) or ( isinstance(current_var_or_name, list) and all(_is_variable(v) for v in current_var_or_name)): var = current_var_or_name else: store_vars = vs._get_default_variable_store()._vars # pylint:disable=protected-access # Check if this variable is in var_store. var = store_vars.get(current_var_or_name, None) # Also check if variable is partitioned as list. if var is None: var = _collect_partitioned_variable(current_var_or_name, store_vars) if var is not None: # If 1 to 1 mapping was provided, find variable in the checkpoint. if tensor_name_in_ckpt not in variable_map: raise ValueError("Tensor %s is not found in %s checkpoint %s" % ( tensor_name_in_ckpt, ckpt_dir_or_file, variable_map )) if _is_variable(var): # Additional at-call-time checks. if not var.get_shape().is_compatible_with( variable_map[tensor_name_in_ckpt]): raise ValueError( "Shape of variable %s (%s) doesn't match with shape of " "tensor %s (%s) from checkpoint reader." % ( var.name, str(var.get_shape()), tensor_name_in_ckpt, str(variable_map[tensor_name_in_ckpt]) )) var_name = var.name else: var_name = ",".join([v.name for v in var]) _set_variable_or_list_initializer(var, ckpt_file, tensor_name_in_ckpt) logging.debug("Initialize variable %s from checkpoint %s with %s", var_name, ckpt_dir_or_file, tensor_name_in_ckpt) else: scopes = "" # TODO(vihanjain): Support list of 'current_var_or_name' here. if "/" in current_var_or_name: scopes = current_var_or_name[:current_var_or_name.rindex("/")] if not tensor_name_in_ckpt.endswith("/"): raise ValueError( "Assignment map with scope only name {} should map to scope only " "{}. Should be 'scope/': 'other_scope/'.".format( scopes, tensor_name_in_ckpt)) # If scope to scope mapping was provided, find all variables in the scope # and create variable to variable mapping. scope_variables = set() for var_name in store_vars: if not scopes or var_name.startswith(scopes + "/"): # Consume /part_ if partitioned variable. if "/part_" in var_name: var_name = var_name[:var_name.index("/part_")] scope_variables.add(var_name) for var_name in sorted(scope_variables): # Lookup name with specified prefix and suffix from current variable. # If tensor_name given is '/' (root), don't use it for full name. full_tensor_name = var_name[len(scopes):] if current_var_or_name != "/": full_tensor_name = full_tensor_name[1:] if tensor_name_in_ckpt != "/": full_tensor_name = tensor_name_in_ckpt + full_tensor_name # Remove trailing '/', if any, in the full_tensor_name if full_tensor_name.endswith("/"): full_tensor_name = full_tensor_name[:-1] if full_tensor_name not in variable_map: raise ValueError( "Tensor %s (%s in %s) is not found in %s checkpoint" % ( full_tensor_name, var_name[len(scopes) + 1:], tensor_name_in_ckpt, ckpt_dir_or_file )) var = store_vars.get(var_name, None) if var is None: var = _collect_partitioned_variable(var_name, store_vars) _set_variable_or_list_initializer(var, ckpt_file, full_tensor_name) logging.debug("Initialize variable %s from checkpoint %s with %s", var_name, ckpt_dir_or_file, full_tensor_name)
def __init__(self, units, activation='tanh', recurrent_activation='sigmoid', use_bias=True, kernel_initializer='glorot_uniform', recurrent_initializer='orthogonal', bias_initializer='zeros', kernel_regularizer=None, recurrent_regularizer=None, bias_regularizer=None, activity_regularizer=None, kernel_constraint=None, recurrent_constraint=None, bias_constraint=None, dropout=0., recurrent_dropout=0., return_sequences=False, return_state=False, go_backwards=False, stateful=False, unroll=False, time_major=False, reset_after=True, **kwargs): # return_runtime is a flag for testing, which shows the real backend # implementation chosen by grappler in graph mode. self._return_runtime = kwargs.pop('return_runtime', False) super(GRU, self).__init__(units, activation=activation, recurrent_activation=recurrent_activation, use_bias=use_bias, kernel_initializer=kernel_initializer, recurrent_initializer=recurrent_initializer, bias_initializer=bias_initializer, kernel_regularizer=kernel_regularizer, recurrent_regularizer=recurrent_regularizer, bias_regularizer=bias_regularizer, activity_regularizer=activity_regularizer, kernel_constraint=kernel_constraint, recurrent_constraint=recurrent_constraint, bias_constraint=bias_constraint, dropout=dropout, recurrent_dropout=recurrent_dropout, implementation=kwargs.pop('implementation', 2), return_sequences=return_sequences, return_state=return_state, go_backwards=go_backwards, stateful=stateful, unroll=unroll, time_major=time_major, reset_after=reset_after, **kwargs) # GPU kernel uses following setting by default and not configurable. self._could_use_gpu_kernel = ( self.activation in (activations.tanh, tf.tanh) and self.recurrent_activation in (activations.sigmoid, tf.sigmoid) and recurrent_dropout == 0 and not unroll and use_bias and reset_after and tf.compat.v1.executing_eagerly_outside_functions()) if tf.config.list_logical_devices('GPU'): # Only show the message when there is GPU available, user will not care # about the cuDNN if there isn't any GPU. if self._could_use_gpu_kernel: logging.debug(gru_lstm_utils.CUDNN_AVAILABLE_MSG % self.name) else: logging.warning(gru_lstm_utils.CUDNN_NOT_AVAILABLE_MSG % self.name) if gru_lstm_utils.use_new_gru_lstm_impl(): self._defun_wrapper = gru_lstm_utils.DefunWrapper( time_major, go_backwards, 'gru')
import collections import re from tensorflow.python.distribute.cluster_resolver import cluster_resolver from tensorflow.python.framework import errors from tensorflow.python.platform import tf_logging as logging from tensorflow.python.tpu import tpu_system_metadata as tpu_system_metadata_lib from tensorflow.python.training import server_lib from tensorflow.python.util import compat from tensorflow.python.util.tf_export import tf_export try: from cloud_tpu_client import client # pylint: disable=g-import-not-at-top except ImportError: logging.debug( 'Falling back to TensorFlow client; we recommended you install the Cloud ' 'TPU client directly with pip install cloud-tpu-client.') from tensorflow.python.tpu.client import client def is_running_in_gce(): return True _TPU_DEVICE_REGEX = re.compile( r'.*task:(?P<host_id>\d+)/.*device:TPU:(?P<core_id>\d+)$') _TPU_CONN_RETRIES = 120 DeviceDetails = collections.namedtuple( 'DeviceDetails', ['device_map', 'total_cores']) @tf_export('distribute.cluster_resolver.TPUClusterResolver')
def init_fn(scaffold, session): merged_dataset.init(session) if pre_trained_init_fn: pre_trained_init_fn(session) logging.debug('init_fn successfully processed...')
def warm_start(ckpt_to_initialize_from, vars_to_warm_start=".*", var_name_to_vocab_info=None, var_name_to_prev_var_name=None): """Warm-starts a model using the given settings. If you are using a tf.estimator.Estimator, this will automatically be called during training. Args: ckpt_to_initialize_from: [Required] A string specifying the directory with checkpoint file(s) or path to checkpoint from which to warm-start the model parameters. vars_to_warm_start: [Optional] One of the following: - A regular expression (string) that captures which variables to warm-start (see tf.compat.v1.get_collection). This expression will only consider variables in the TRAINABLE_VARIABLES collection -- if you need to warm-start non_TRAINABLE vars (such as optimizer accumulators or batch norm statistics), please use the below option. - A list of strings, each a regex scope provided to tf.compat.v1.get_collection with GLOBAL_VARIABLES (please see tf.compat.v1.get_collection). For backwards compatibility reasons, this is separate from the single-string argument type. - A list of Variables to warm-start. If you do not have access to the `Variable` objects at the call site, please use the above option. - `None`, in which case only TRAINABLE variables specified in `var_name_to_vocab_info` will be warm-started. Defaults to `'.*'`, which warm-starts all variables in the TRAINABLE_VARIABLES collection. Note that this excludes variables such as accumulators and moving statistics from batch norm. var_name_to_vocab_info: [Optional] Dict of variable names (strings) to `tf.estimator.VocabInfo`. The variable names should be "full" variables, not the names of the partitions. If not explicitly provided, the variable is assumed to have no (changes to) vocabulary. var_name_to_prev_var_name: [Optional] Dict of variable names (strings) to name of the previously-trained variable in `ckpt_to_initialize_from`. If not explicitly provided, the name of the variable is assumed to be same between previous checkpoint and current model. Note that this has no effect on the set of variables that is warm-started, and only controls name mapping (use `vars_to_warm_start` for controlling what variables to warm-start). Raises: ValueError: If the WarmStartSettings contains prev_var_name or VocabInfo configuration for variable names that are not used. This is to ensure a stronger check for variable configuration than relying on users to examine the logs. """ if var_name_to_vocab_info is None: var_name_to_vocab_info = {} if var_name_to_prev_var_name is None: var_name_to_prev_var_name = {} logging.info("Warm-starting from: %s", (ckpt_to_initialize_from, )) grouped_variables = _get_grouped_variables(vars_to_warm_start) warmstarted_count = 0 # Keep track of which var_names in var_name_to_prev_var_name and # var_name_to_vocab_info have been used. Err on the safer side by throwing an # exception if any are unused by the end of the loop. It is easy to misname # a variable during this configuration, in which case without this check, we # would fail to warm-start silently. prev_var_name_used = set() vocab_info_used = set() # Group the vocabless vars into one call to init_from_checkpoint. vocabless_vars = {} for var_name, variable in six.iteritems(grouped_variables): prev_var_name = var_name_to_prev_var_name.get(var_name) if prev_var_name: prev_var_name_used.add(var_name) vocab_info = var_name_to_vocab_info.get(var_name) if vocab_info: vocab_info_used.add(var_name) warmstarted_count += 1 logging.debug( "Warm-starting variable: {}; current_vocab: {} current_vocab_size: {}" " prev_vocab: {} prev_vocab_size: {} current_oov: {} prev_tensor: {}" " initializer: {}".format( var_name, vocab_info.new_vocab, vocab_info.new_vocab_size, vocab_info.old_vocab, (vocab_info.old_vocab_size if vocab_info.old_vocab_size > 0 else "All"), vocab_info.num_oov_buckets, prev_var_name or "Unchanged", vocab_info.backup_initializer or "zero-initialized")) _warm_start_var_with_vocab( variable, current_vocab_path=vocab_info.new_vocab, current_vocab_size=vocab_info.new_vocab_size, prev_ckpt=ckpt_to_initialize_from, prev_vocab_path=vocab_info.old_vocab, previous_vocab_size=vocab_info.old_vocab_size, current_oov_buckets=vocab_info.num_oov_buckets, prev_tensor_name=prev_var_name, initializer=vocab_info.backup_initializer, axis=vocab_info.axis) else: # For the special value of vars_to_warm_start = None, # we only warm-start variables with explicitly specified vocabularies. if vars_to_warm_start: warmstarted_count += 1 logging.debug( "Warm-starting variable: {}; prev_var_name: {}".format( var_name, prev_var_name or "Unchanged")) # Because we use a default empty list in grouped_variables, single # unpartitioned variables will be lists here, which we rectify in order # for init_from_checkpoint logic to work correctly. if len(variable) == 1: variable = variable[0] prev_tensor_name, var = _get_var_info(variable, prev_var_name) vocabless_vars[prev_tensor_name] = var checkpoint_utils.init_from_checkpoint(ckpt_to_initialize_from, vocabless_vars) prev_var_name_not_used = set( var_name_to_prev_var_name.keys()) - prev_var_name_used vocab_info_not_used = set(var_name_to_vocab_info.keys()) - vocab_info_used logging.info("Warm-started %d variables.", warmstarted_count) if prev_var_name_not_used: raise ValueError( "You provided the following variables in " "var_name_to_prev_var_name that were not used: " "{0}. Perhaps you misspelled them? Here is the list of viable " "variable names: {1}".format(prev_var_name_not_used, grouped_variables.keys())) if vocab_info_not_used: raise ValueError( "You provided the following variables in " "var_name_to_vocab_info that were not used: {0}. " " Perhaps you misspelled them? Here is the list of viable variable " "names: {1}".format(vocab_info_not_used, grouped_variables.keys()))
def __init__(self, units, activation='tanh', recurrent_activation='sigmoid', use_bias=True, kernel_initializer='glorot_uniform', recurrent_initializer='orthogonal', bias_initializer='zeros', unit_forget_bias=True, kernel_regularizer=None, recurrent_regularizer=None, bias_regularizer=None, activity_regularizer=None, kernel_constraint=None, recurrent_constraint=None, bias_constraint=None, dropout=0., recurrent_dropout=0., return_sequences=False, return_state=False, go_backwards=False, stateful=False, time_major=False, unroll=False, **kwargs): # return_runtime is a flag for testing, which shows the real backend # implementation chosen by grappler in graph mode. self.return_runtime = kwargs.pop('return_runtime', False) implementation = kwargs.pop('implementation', 2) if implementation == 0: logging.warning('`implementation=0` has been deprecated, ' 'and now defaults to `implementation=1`.' 'Please update your layer call.') if 'enable_caching_device' in kwargs: cell_kwargs = {'enable_caching_device': kwargs.pop('enable_caching_device')} else: cell_kwargs = {} cell = LSTMCell( units, activation=activation, recurrent_activation=recurrent_activation, use_bias=use_bias, kernel_initializer=kernel_initializer, recurrent_initializer=recurrent_initializer, unit_forget_bias=unit_forget_bias, bias_initializer=bias_initializer, kernel_regularizer=kernel_regularizer, recurrent_regularizer=recurrent_regularizer, bias_regularizer=bias_regularizer, kernel_constraint=kernel_constraint, recurrent_constraint=recurrent_constraint, bias_constraint=bias_constraint, dropout=dropout, recurrent_dropout=recurrent_dropout, implementation=implementation, dtype=kwargs.get('dtype'), trainable=kwargs.get('trainable', True), **cell_kwargs) super().__init__( cell, return_sequences=return_sequences, return_state=return_state, go_backwards=go_backwards, stateful=stateful, time_major=time_major, unroll=unroll, **kwargs) self.activity_regularizer = regularizers.get(activity_regularizer) self.input_spec = [InputSpec(ndim=3)] self.state_spec = [ InputSpec(shape=(None, dim)) for dim in (self.units, self.units) ] self._could_use_gpu_kernel = ( self.activation in (activations.tanh, tf.tanh) and self.recurrent_activation in (activations.sigmoid, tf.sigmoid) and recurrent_dropout == 0 and not unroll and use_bias and tf.compat.v1.executing_eagerly_outside_functions()) if tf.config.list_logical_devices('GPU'): # Only show the message when there is GPU available, user will not care # about the cuDNN if there isn't any GPU. if self._could_use_gpu_kernel: logging.debug(gru_lstm_utils.CUDNN_AVAILABLE_MSG % self.name) else: logging.warning(gru_lstm_utils.CUDNN_NOT_AVAILABLE_MSG % self.name) if gru_lstm_utils.use_new_gru_lstm_impl(): self._defun_wrapper = gru_lstm_utils.DefunWrapper( time_major, go_backwards, 'lstm')
def _create_variable(self, next_creator, **kwargs): """Implements StrategyExtendedV2._create_variable. Creates a `Variable` or a `ShardedVariable`. A `ShardedVariable` will be created if satisfying all the following criteria: 1. `self._variable_partitioner` results in more than one partition on the first axis. 2. variable's rank is greater than 0. 3. variable is not colocated with another variable. Otherwise a `Variable` will be created. Args: next_creator: See `variable_scope.variable_creator_scope`; the next creator in the chain. **kwargs: Passed through to the next creator. Returns: A `Variable` or `ShardedVariable`. """ var_creator = self._create_var_creator(next_creator, **kwargs) if "colocate_with" in kwargs: # Never partition colocated_with variables. colocate_with = kwargs["colocate_with"] # Clear the variable scope to avoid possible conflicts between device # scope and colocation scope. with ops.device(None): with ops.colocate_with(colocate_with): var = var_creator(**kwargs) logging.debug( "Creating variable (name:%s, shape:%r) that colocates with %s", var.name, var.shape, kwargs["colocate_with"].name) return var if self._variable_partitioner is None: return self._create_variable_round_robin(var_creator, **kwargs) name = kwargs.get("name", None) initial_value = kwargs.get("initial_value", None) if initial_value is None: raise ValueError( "It looks like you are using `ParameterServerStrategy` with a " "`variable_partitioner`, and trying to create a variable without " "specifying `initial_value`. This is not allowed. Please specify the " "`initial_value`. This can also happen if you are trying to load a " "saved_model within a `ParameterServerStrategy` scope. Loading a " "saved_model with `variable_partitioner` is not supported.") # Two cases where initial_value can be a callable: # 1. initial_value is passed as a callable, e.g, an `initializer` class. # 2. restoring from checkpoint, initial_value is a # "CheckpointInitialValueCallable". init_from_fn = callable(initial_value) dtype = kwargs.get("dtype", None) shape = kwargs.get("shape", None) if init_from_fn and (shape is None or dtype is None): init_from_fn = False initial_value = initial_value() if not init_from_fn: # The initial_value is created on coordinator, it will need to be sent to # ps for variable initialization, which can be inefficient and can # potentially hit the 2GB limit on protobuf serialization. initial_value = ops.convert_to_tensor(initial_value, dtype=dtype) dtype = initial_value.dtype shape = initial_value.shape else: shape = tensor_shape.as_shape(shape) if shape.rank == 0: # Skip partitioning rank-0 variable. return self._create_variable_round_robin(var_creator, **kwargs) num_partitions = self._variable_partitioner(shape=shape, dtype=dtype) if not num_partitions or num_partitions[0] == 0 or any( v != 1 for v in num_partitions[1:]): raise ValueError( "variable_partitioner must return a list/tuple whose elements are 1" " besides the first element (non-zero), got: %r" % num_partitions) if num_partitions[0] == 1: # no partition return self._create_variable_round_robin(var_creator, **kwargs) # Use "div" partition strategy to partition the variable. num_partitions = min(num_partitions[0], shape[0]) base = shape[0] // num_partitions extra = shape[0] % num_partitions # An example: num_partitions=4, shape[0]=10, partitions: [3, 3, 2, 2] # offsets: [0, 3, 6, 8, 10] offsets = [] for i in range(num_partitions): if i == 0: offsets.append(0) else: prev_shard_size = base + (1 if i - 1 < extra else 0) offsets.append(offsets[i - 1] + prev_shard_size) offsets.append(shape[0]) def init_shard_fn(shard_index): if not init_from_fn: logging.log_if( logging.WARN, _INEFFICIENT_INIT_WARNING % name, shard_index == 0 and shape.num_elements() > _LARGE_VARIABLE_NUM_ELEMENTS) return initial_value[offsets[shard_index]:offsets[shard_index + 1]] partition_shape = (offsets[shard_index + 1] - offsets[shard_index], ) + shape[1:] partition_offset = ( offsets[shard_index], ) + (0, ) * len(shape[1:]) arg_spec = tf_inspect.getfullargspec(initial_value) if ("shard_info" not in arg_spec.args and "shard_info" not in arg_spec.kwonlyargs): try: value = initial_value(partition_shape=partition_shape, partition_offset=partition_offset) except (TypeError, ValueError): # TypeError: Initializer doesn't accept kwargs # ValueError: Initializer doesn't accept partition kwargs # In both cases we go ahead creating the full value and then slice. value = initial_value() if value.shape == partition_shape: # Initializer supports partition: value is the partition value. return value else: # Initializer doesn't support partition: value is the full value # and needs to be sliced to get the partition value. logging.log_if( logging.WARN, _INEFFICIENT_INIT_WARNING % name, shard_index == 0 and shape.num_elements() > _LARGE_VARIABLE_NUM_ELEMENTS) return value[offsets[shard_index]:offsets[shard_index + 1]] else: # For compatibility with `CheckpointInitialValueCallable`. return initial_value(shard_info=trackable.ShardInfo( shape=tensor_shape.as_shape(partition_shape), offset=partition_offset)) var_list = [] for i in range(num_partitions): kwargs["shape"] = (offsets[i + 1] - offsets[i], ) + shape[1:] kwargs["initial_value"] = lambda: init_shard_fn(i) if name is not None: kwargs["name"] = "{}/part_{}".format(name, i) var_list.append( self._create_variable_round_robin(var_creator, **kwargs)) result = sharded_variable.ShardedVariable(var_list) return result
def _init_from_checkpoint(_, ckpt_dir_or_file, assignment_map): """See `init_from_checkpoint` for documentation.""" ckpt_file = _get_checkpoint_filename(ckpt_dir_or_file) reader = load_checkpoint(ckpt_dir_or_file) variable_map = reader.get_variable_to_shape_map() for tensor_name_in_ckpt, current_var_or_name in sorted( six.iteritems(assignment_map)): var = None # Check if this is Variable object or list of Variable objects (in case of # partitioned variables). if _is_variable(current_var_or_name) or ( isinstance(current_var_or_name, list) and all(_is_variable(v) for v in current_var_or_name)): var = current_var_or_name else: store_vars = vs._get_default_variable_store()._vars # pylint:disable=protected-access # Check if this variable is in var_store. var = store_vars.get(current_var_or_name, None) # Also check if variable is partitioned as list. if var is None: var = _collect_partitioned_variable(current_var_or_name, store_vars) if var is not None: # If 1 to 1 mapping was provided, find variable in the checkpoint. if tensor_name_in_ckpt not in variable_map: raise ValueError( "Tensor %s is not found in %s checkpoint %s" % (tensor_name_in_ckpt, ckpt_dir_or_file, variable_map)) if _is_variable(var): # Additional at-call-time checks. if not var.get_shape().is_compatible_with( variable_map[tensor_name_in_ckpt]): raise ValueError( "Shape of variable %s (%s) doesn't match with shape of " "tensor %s (%s) from checkpoint reader." % (var.name, str(var.get_shape()), tensor_name_in_ckpt, str(variable_map[tensor_name_in_ckpt]))) var_name = var.name else: var_name = ",".join([v.name for v in var]) _set_variable_or_list_initializer(var, ckpt_file, tensor_name_in_ckpt) logging.debug("Initialize variable %s from checkpoint %s with %s", var_name, ckpt_dir_or_file, tensor_name_in_ckpt) else: scopes = "" # TODO(vihanjain): Support list of 'current_var_or_name' here. if "/" in current_var_or_name: scopes = current_var_or_name[:current_var_or_name.rindex("/")] if not tensor_name_in_ckpt.endswith("/"): raise ValueError( "Assignment map with scope only name {} should map to scope only " "{}. Should be 'scope/': 'other_scope/'.".format( scopes, tensor_name_in_ckpt)) # If scope to scope mapping was provided, find all variables in the scope # and create variable to variable mapping. scope_variables = set() for var_name in store_vars: if not scopes or var_name.startswith(scopes + "/"): # Consume /part_ if partitioned variable. if "/part_" in var_name: var_name = var_name[:var_name.index("/part_")] scope_variables.add(var_name) for var_name in sorted(scope_variables): # Lookup name with specified prefix and suffix from current variable. # If tensor_name given is '/' (root), don't use it for full name. full_tensor_name = var_name[len(scopes):] if current_var_or_name != "/": full_tensor_name = full_tensor_name[1:] if tensor_name_in_ckpt != "/": full_tensor_name = tensor_name_in_ckpt + full_tensor_name # Remove trailing '/', if any, in the full_tensor_name if full_tensor_name.endswith("/"): full_tensor_name = full_tensor_name[:-1] if full_tensor_name not in variable_map: raise ValueError( "Tensor %s (%s in %s) is not found in %s checkpoint" % (full_tensor_name, var_name[len(scopes) + 1:], tensor_name_in_ckpt, ckpt_dir_or_file)) var = store_vars.get(var_name, None) if var is None: var = _collect_partitioned_variable(var_name, store_vars) _set_variable_or_list_initializer(var, ckpt_file, full_tensor_name) logging.debug( "Initialize variable %s from checkpoint %s with %s", var_name, ckpt_dir_or_file, full_tensor_name)
def __call__(self, sgv, dst_graph, dst_scope, src_scope="", reuse_dst_scope=False): """Execute the transformation. Args: sgv: the source subgraph-view. dst_graph: the destination graph. dst_scope: the destination scope. src_scope: the source scope, which specify the path from which the relative path of the transformed nodes are computed. For instance, if src_scope is a/ and dst_scoped is b/, then the node a/x/y will have a relative path of x/y and will be transformed into b/x/y. reuse_dst_scope: if True the dst_scope is re-used if it already exists. Otherwise, the scope is given a unique name based on the one given by appending an underscore followed by a digit (default). Returns: A tuple `(sgv, info)` where: `sgv` is the transformed subgraph view; `info` is an instance of Transformer.ResultInfo containing information about the transform, including mapping between original and transformed tensors and operations. Raises: ValueError: if the arguments are invalid. """ sgv = subgraph.make_view(sgv) if not isinstance(dst_graph, tf_ops.Graph): raise TypeError("Expected a tf.Graph, got: {}".format(type(dst_graph))) src_scope = util.scope_finalize(src_scope) dst_scope = util.scope_finalize(dst_scope) # Potentially create new scope if reuse_dst_scope is False if dst_scope and not reuse_dst_scope: dst_scope = util.scope_finalize(dst_graph.unique_name(dst_scope[:-1])) # Create temporary info used during this transform call self._info = Transformer._Info(self, sgv, dst_graph, dst_scope, src_scope) # Transform the graph starting from the output tensors. for output_t in self._info.sgv.outputs: self._transform_t(output_t) # Some ops might have been missed by the previous walk, namely, the roots # without any outputs. So the walk is now finalized from those roots. remaining_ops = [op for op in self._info.sgv.ops if op not in self._info.transformed_ops] remaining_roots = [op for op in remaining_ops if not op.outputs] for op in remaining_roots: self._transform_op(op) # Finalize cyclic ops: for op in self._info.cyclic_ops: logging.debug("Finalizing cyclic op: %s", op.name) op_ = self._info.transformed_ops[op] inputs_ = [self._info.transformed_ts[t] for t in op.inputs] if None in inputs_: raise ValueError("Could not find all the inputs of cyclic op: {}" .format(op_.name)) for input_id, t_ in enumerate(inputs_): op_._update_input(input_id, t_) # pylint: disable=protected-access sgv_ = self._transform_sgv(sgv) res_info = Transformer.ResultInfo(self._info) self._info = None return sgv_, res_info
def every_n_step_end(self, step, outputs): super(ValidationMonitor, self).every_n_step_end(step, outputs) # TODO(mdan): The use of step below is probably misleading. # The code should probably use the step from the checkpoint, because # that's what is being evaluated. if self._estimator is None: raise ValueError("Missing call to set_estimator.") current_time = time.time() if (self._check_interval_secs is not None and self._last_checkpoint_check_time is not None and current_time - self._last_checkpoint_check_time <= self._check_interval_secs): logging.debug( "Skipping evaluation since less than %d seconds have passed since " "last check for a new checkpoint.", self._check_interval_secs) return False self._last_checkpoint_check_time = current_time # Check that we are not running evaluation on the same checkpoint. latest_path = checkpoint_management.latest_checkpoint( self._estimator.model_dir) if latest_path is None: logging.debug("Skipping evaluation since model has not been saved yet " "at step %d.", step) return False if latest_path is not None and latest_path == self._latest_path: logging.debug("Skipping evaluation due to same checkpoint %s for step %d " "as for step %d.", latest_path, step, self._latest_path_step) return False self._latest_path = latest_path self._latest_path_step = step # Run evaluation and log it. validation_outputs = self._evaluate_estimator() stats = [] for name in validation_outputs: stats.append("%s = %s" % (name, str(validation_outputs[name]))) logging.info("Validation (step %d): %s", step, ", ".join(stats)) # Early stopping logic. if self.early_stopping_rounds is not None: if self.early_stopping_metric not in validation_outputs: raise ValueError("Metric %s missing from outputs %s." % (self.early_stopping_metric, set(validation_outputs.keys()))) current_value = validation_outputs[self.early_stopping_metric] if (self._best_value is None or (self.early_stopping_metric_minimize and (current_value < self._best_value)) or (not self.early_stopping_metric_minimize and (current_value > self._best_value))): self._best_value = current_value self._best_metrics = copy.deepcopy(validation_outputs) self._best_value_step = step stop_now = (step - self._best_value_step >= self.early_stopping_rounds) if stop_now: logging.info("Stopping. Best step: {} with {} = {}.".format( self._best_value_step, self.early_stopping_metric, self._best_value)) self._early_stopped = True return True return False
def quantize( saved_model_path: str, signature_keys: Optional[List[str]] = None, tags: Optional[Iterable[str]] = None, output_directory: Optional[str] = None, quantization_options: Optional[quant_opts_pb2.QuantizationOptions] = None, representative_dataset: Optional[_RepresentativeDataset] = None) ->...: """Quantizes the given SavedModel. Args: saved_model_path: Path to the saved model. When representative_dataset is not provided, this should be a model trained with QAT. signature_keys: List of keys identifying SignatureDef containing inputs and outputs. If None, ["serving_default"] is used. tags: Set of tags identifying the MetaGraphDef within the SavedModel to analyze. If None, {"serve"} is used. output_directory: The path to save the output SavedModel (must be an empty directory). quantization_options: A set of options for quantization. representative_dataset: a generator that returns a dictionary in {input_name: input_tensor} format or a tuple with signature key and a dictionary in {input_name: input_tensor} format that feeds calibration data for quantizing model. This should be provided when the model is a PTQ model. Returns: A SavedModel object with TF quantization applied, or None if no quantization is performed. Raises: ValueError: When 1) representative_dataset is not provided for non QAT model for enabling static range quantization, or 2) invalid value is provided as a quantization method. NotImplementedError: When the specified quantization method is not yet implemented. """ if tags is None: tags = {tag_constants.SERVING} if signature_keys is None: signature_keys = [signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY] if quantization_options is None: quantization_options = quant_opts_pb2.QuantizationOptions() method: quant_opts_pb2.QuantizationMethod = quantization_options.quantization_method if method.HasField('method'): raise ValueError(f'Invalid value for QuantizationMethod: {method.method}.') elif method.HasField('experimental_method'): if method.experimental_method == _ExperimentalMethod.STATIC_RANGE: return _static_range_quantize(saved_model_path, signature_keys, tags, output_directory, representative_dataset) elif method.experimental_method == _ExperimentalMethod.DYNAMIC_RANGE: return _dynamic_range_quantize(saved_model_path, signature_keys, tags, output_directory) else: raise NotImplementedError( 'Experimental quantization method {method.experimental_method}' ' is not implemented.') else: logging.debug( 'Neither "method" nor "experimental_method" for QuantizationMethod ' 'is specified. Static range quantization is used by default.') return _static_range_quantize(saved_model_path, signature_keys, tags, output_directory, representative_dataset)
def run_episode(self, env, sess, features, labels, no_run_hooks, global_step, update_episode_op, update_timestep_op, estimator_spec): """We need to differentiate between the `global_timestep` and `global_step`. The `global_step` gets updated directly by the `train_op` and has an effect on the training learning rate, especially if it gets decayed. The `global_timestep` on the other hand is related to the episode and how many times our agent acted. It has an effect on the exploration rate and how it's annealed. Args: env: `Environment` instance. sess: `MonitoredTrainingSession` instance. estimator_spec: `EstimatorSpec` instance. Returns: statistics about episode. """ env_spec = env.reset() last_in_memory = self.memory.get_by_index(-1) if last_in_memory is not None: dist_values = last_in_memory['dist_values'] else: dist_values = [0] * 2 * env.num_actions if env.is_continuous else [0] * env.num_actions stats = Stats() loss = None while not env_spec.done: data = env_spec.to_dict() data['dist_values'] = dist_values _, step, timestep, action, next_dist_values = sess.run( [no_run_hooks, global_step, update_timestep_op, estimator_spec.predictions['results'], estimator_spec.predictions['dist_values']], feed_dict=self._prepare_feed_dict('act', features, labels, data)) env_spec = env.step(action, env_spec.next_state) self.memory.step(dist_values=dist_values, **env_spec.to_dict()) dist_values = next_dist_values[0] stats.rewards.append(env_spec.reward) if self.memory.can_sample(): logging.info('Updating model.') feed_dict = self._prepare_feed_dict( 'observe', features, labels, self.memory.sample(), stats, from_memory=True) _, policy_gradient = sess.run([no_run_hooks, estimator_spec.predictions['policy_gradient']], feed_dict=feed_dict) if np.allclose(policy_gradient, np.zeros_like(policy_gradient)): logging.debug('Gradient zero, skipping update') return # TODO: move logic to from tensorflow.contrib.solvers.python.ops.linear_equations import conjugate_gradient def fisher_vector_product(p): feed_dict[labels['tangents']] = p _, fvp = sess.run([no_run_hooks, estimator_spec.predictions['fisher_vector_product']], feed_dict) return fvp + self.optimizer_params['cg_damping'] * p direction = conjugate_gradient(fisher_vector_product, -policy_gradient, self.optimizer_params['cg_iterations']) shs = 0.5 * direction.dot(fisher_vector_product(direction)) # theta lagrange_multiplier = np.sqrt(shs / self.optimizer_params['max_kl_divergence']) update_step = direction / (lagrange_multiplier + EPSILON) negative_gradient_direction = -policy_gradient.dot(direction) _, previous_theta = sess.run([no_run_hooks, estimator_spec.extra_ops['get_theta']]) def compute_loss(theta): sess.run([no_run_hooks, estimator_spec.extra_ops['set_theta']], feed_dict={labels['theta']: theta}) return sess.run([no_run_hooks, estimator_spec.loss], feed_dict=feed_dict)[1] improved, theta = line_search( compute_loss, previous_theta, update_step, negative_gradient_direction / (lagrange_multiplier + EPSILON), self.optimizer_params['line_search_iterations']) if improved: logging.info('Updating with line search.') sess.run([no_run_hooks, estimator_spec.extra_ops['set_theta']], feed_dict={labels['theta']: theta}) elif self.optimizer_params['override_line_search']: logging.info('Updating with full step.') sess.run([no_run_hooks, estimator_spec.extra_ops['set_theta']], feed_dict={labels['theta']: previous_theta + update_step}) else: logging.debug('No update.') loss = sess.run([estimator_spec.loss], feed_dict=feed_dict)[0] if env_spec.done: last_in_memory = self.memory.get_by_index(-1) sess.run([update_episode_op], feed_dict=self._prepare_feed_dict( 'observe', features, labels, last_in_memory, stats)) return loss
def warm_start(ckpt_to_initialize_from, vars_to_warm_start=".*", var_name_to_vocab_info=None, var_name_to_prev_var_name=None): """Warm-starts a model using the given settings. If you are using a tf.estimator.Estimator, this will automatically be called during training. Args: ckpt_to_initialize_from: [Required] A string specifying the directory with checkpoint file(s) or path to checkpoint from which to warm-start the model parameters. vars_to_warm_start: [Optional] One of the following: - A regular expression (string) that captures which variables to warm-start (see tf.compat.v1.get_collection). This expression will only consider variables in the TRAINABLE_VARIABLES collection -- if you need to warm-start non_TRAINABLE vars (such as optimizer accumulators or batch norm statistics), please use the below option. - A list of strings, each a regex scope provided to tf.compat.v1.get_collection with GLOBAL_VARIABLES (please see tf.compat.v1.get_collection). For backwards compatibility reasons, this is separate from the single-string argument type. - A list of Variables to warm-start. If you do not have access to the `Variable` objects at the call site, please use the above option. - `None`, in which case only TRAINABLE variables specified in `var_name_to_vocab_info` will be warm-started. Defaults to `'.*'`, which warm-starts all variables in the TRAINABLE_VARIABLES collection. Note that this excludes variables such as accumulators and moving statistics from batch norm. var_name_to_vocab_info: [Optional] Dict of variable names (strings) to `tf.estimator.VocabInfo`. The variable names should be "full" variables, not the names of the partitions. If not explicitly provided, the variable is assumed to have no (changes to) vocabulary. var_name_to_prev_var_name: [Optional] Dict of variable names (strings) to name of the previously-trained variable in `ckpt_to_initialize_from`. If not explicitly provided, the name of the variable is assumed to be same between previous checkpoint and current model. Note that this has no effect on the set of variables that is warm-started, and only controls name mapping (use `vars_to_warm_start` for controlling what variables to warm-start). Raises: ValueError: If the WarmStartSettings contains prev_var_name or VocabInfo configuration for variable names that are not used. This is to ensure a stronger check for variable configuration than relying on users to examine the logs. """ if var_name_to_vocab_info is None: var_name_to_vocab_info = {} if var_name_to_prev_var_name is None: var_name_to_prev_var_name = {} logging.info("Warm-starting from: %s", (ckpt_to_initialize_from,)) grouped_variables = _get_grouped_variables(vars_to_warm_start) warmstarted_count = 0 # Keep track of which var_names in var_name_to_prev_var_name and # var_name_to_vocab_info have been used. Err on the safer side by throwing an # exception if any are unused by the end of the loop. It is easy to misname # a variable during this configuration, in which case without this check, we # would fail to warm-start silently. prev_var_name_used = set() vocab_info_used = set() # Group the vocabless vars into one call to init_from_checkpoint. vocabless_vars = {} for var_name, variable in six.iteritems(grouped_variables): prev_var_name = var_name_to_prev_var_name.get(var_name) if prev_var_name: prev_var_name_used.add(var_name) vocab_info = var_name_to_vocab_info.get(var_name) if vocab_info: vocab_info_used.add(var_name) warmstarted_count += 1 logging.debug( "Warm-starting variable: {}; current_vocab: {} current_vocab_size: {}" " prev_vocab: {} prev_vocab_size: {} current_oov: {} prev_tensor: {}" " initializer: {}".format( var_name, vocab_info.new_vocab, vocab_info.new_vocab_size, vocab_info.old_vocab, (vocab_info.old_vocab_size if vocab_info.old_vocab_size > 0 else "All"), vocab_info.num_oov_buckets, prev_var_name or "Unchanged", vocab_info.backup_initializer or "zero-initialized")) _warm_start_var_with_vocab( variable, current_vocab_path=vocab_info.new_vocab, current_vocab_size=vocab_info.new_vocab_size, prev_ckpt=ckpt_to_initialize_from, prev_vocab_path=vocab_info.old_vocab, previous_vocab_size=vocab_info.old_vocab_size, current_oov_buckets=vocab_info.num_oov_buckets, prev_tensor_name=prev_var_name, initializer=vocab_info.backup_initializer, axis=vocab_info.axis) else: # For the special value of vars_to_warm_start = None, # we only warm-start variables with explicitly specified vocabularies. if vars_to_warm_start: warmstarted_count += 1 logging.debug("Warm-starting variable: {}; prev_var_name: {}".format( var_name, prev_var_name or "Unchanged")) # Because we use a default empty list in grouped_variables, single # unpartitioned variables will be lists here, which we rectify in order # for init_from_checkpoint logic to work correctly. if len(variable) == 1: variable = variable[0] prev_tensor_name, var = _get_var_info(variable, prev_var_name) vocabless_vars[prev_tensor_name] = var checkpoint_utils.init_from_checkpoint(ckpt_to_initialize_from, vocabless_vars) prev_var_name_not_used = set( var_name_to_prev_var_name.keys()) - prev_var_name_used vocab_info_not_used = set(var_name_to_vocab_info.keys()) - vocab_info_used logging.info("Warm-started %d variables.", warmstarted_count) if prev_var_name_not_used: raise ValueError( "You provided the following variables in " "var_name_to_prev_var_name that were not used: " "{0}. Perhaps you misspelled them? Here is the list of viable " "variable names: {1}".format(prev_var_name_not_used, grouped_variables.keys())) if vocab_info_not_used: raise ValueError( "You provided the following variables in " "var_name_to_vocab_info that were not used: {0}. " " Perhaps you misspelled them? Here is the list of viable variable " "names: {1}".format(vocab_info_not_used, grouped_variables.keys()))