def find_var(self, var_or_local_name: Union[TfExpression, str]) -> TfExpression: """Find variable by local or global name.""" assert tfutil.is_tf_expression(var_or_local_name) or isinstance( var_or_local_name, str) return self.vars[var_or_local_name] if isinstance( var_or_local_name, str) else var_or_local_name
def get_var_local_name( self, var_or_global_name: Union[TfExpression, str]) -> str: """Get the local name of a given variable, without any surrounding name scopes.""" assert tfutil.is_tf_expression(var_or_global_name) or isinstance( var_or_global_name, str) global_name = var_or_global_name if isinstance( var_or_global_name, str) else var_or_global_name.name return self.var_global_to_local[global_name]
def get_output_for( self, *in_expr: TfExpression, return_as_list: bool = False, **dynamic_kwargs) -> Union[TfExpression, List[TfExpression]]: """Construct TensorFlow expression(s) for the output(s) of this network, given the input expression(s).""" assert len(in_expr) == self.num_inputs assert not all(expr is None for expr in in_expr) # Finalize build func kwargs. build_kwargs = dict(self.static_kwargs) build_kwargs.update(dynamic_kwargs) build_kwargs["is_template_graph"] = False build_kwargs["components"] = self.components # Build TensorFlow graph to evaluate the network. with tfutil.absolute_variable_scope( self.scope, reuse=True), tf.name_scope(self.name): assert tf.get_variable_scope().name == self.scope valid_inputs = [expr for expr in in_expr if expr is not None] final_inputs = [] for expr, name, shape in zip(in_expr, self.input_names, self.input_shapes): if expr is not None: expr = tf.identity(expr, name=name) else: expr = tf.zeros([tf.shape(valid_inputs[0])[0]] + shape[1:], name=name) final_inputs.append(expr) out_expr = self._build_func(*final_inputs, **build_kwargs) # Propagate input shapes back to the user-specified expressions. for expr, final in zip(in_expr, final_inputs): if isinstance(expr, tf.Tensor): expr.set_shape(final.shape) # Express outputs in the desired format. assert tfutil.is_tf_expression(out_expr) or isinstance(out_expr, tuple) if return_as_list: out_expr = [ out_expr ] if tfutil.is_tf_expression(out_expr) else list(out_expr) return out_expr
def register_gradients(self, loss: TfExpression, trainable_vars: Union[List, dict]) -> None: """Register the gradients of the given loss function with respect to the given variables. Intended to be called once per GPU.""" tfutil.assert_tf_initialized() assert not self._updates_applied device = self._get_device(loss.device) # Validate trainables. if isinstance(trainable_vars, dict): trainable_vars = list(trainable_vars.values( )) # allow passing in Network.trainables as vars assert isinstance(trainable_vars, list) and len(trainable_vars) >= 1 assert all( tfutil.is_tf_expression(expr) for expr in trainable_vars + [loss]) assert all(var.device == device.name for var in trainable_vars) # Validate shapes. if self._gradient_shapes is None: self._gradient_shapes = [ var.shape.as_list() for var in trainable_vars ] assert len(trainable_vars) == len(self._gradient_shapes) assert all( var.shape.as_list() == var_shape for var, var_shape in zip(trainable_vars, self._gradient_shapes)) # Report memory usage if requested. deps = [] if self._report_mem_usage: self._report_mem_usage = False try: with tf.name_scope(self.id + '_mem'), tf.device( device.name), tf.control_dependencies([loss]): deps.append( autosummary.autosummary( self.id + "/mem_usage_gb", tf.contrib.memory_stats.BytesInUse() / 2**30)) except tf.errors.NotFoundError: pass # Compute gradients. with tf.name_scope(self.id + "_grad"), tf.device( device.name), tf.control_dependencies(deps): loss = self.apply_loss_scaling(tf.cast(loss, tf.float32)) gate = tf.train.Optimizer.GATE_NONE # disable gating to reduce memory usage grad_list = device.optimizer.compute_gradients( loss=loss, var_list=trainable_vars, gate_gradients=gate) # Register gradients. for grad, var in grad_list: if var not in device.grad_raw: device.grad_raw[var] = [] device.grad_raw[var].append(grad)
def autosummary(name: str, value: TfExpressionEx, passthru: TfExpressionEx = None, condition: TfExpressionEx = True) -> TfExpressionEx: """Create a new autosummary. Args: name: Name to use in TensorBoard value: TensorFlow expression or python value to track passthru: Optionally return this TF node without modifications but tack an autosummary update side-effect to this node. Example use of the passthru mechanism: n = autosummary('l2loss', loss, passthru=n) This is a shorthand for the following code: with tf.control_dependencies([autosummary('l2loss', loss)]): n = tf.identity(n) """ tfutil.assert_tf_initialized() name_id = name.replace("/", "_") if tfutil.is_tf_expression(value): with tf.name_scope("summary_" + name_id), tf.device(value.device): condition = tf.convert_to_tensor(condition, name='condition') update_op = tf.cond(condition, lambda: tf.group(_create_var(name, value)), tf.no_op) with tf.control_dependencies([update_op]): return tf.identity(value if passthru is None else passthru) else: # python scalar or numpy array assert not tfutil.is_tf_expression(passthru) assert not tfutil.is_tf_expression(condition) if condition: if name not in _immediate: with tfutil.absolute_name_scope("Autosummary/" + name_id), tf.device(None), tf.control_dependencies(None): update_value = tf.placeholder(_dtype) update_op = _create_var(name, update_value) _immediate[name] = update_op, update_value update_op, update_value = _immediate[name] tfutil.run(update_op, {update_value: value}) return value if passthru is None else passthru
def undo_loss_scaling(self, value: TfExpression) -> TfExpression: """Undo the effect of dynamic loss scaling for the given expression.""" assert tfutil.is_tf_expression(value) if not self.use_loss_scaling: return value return value * tfutil.exp2(-self.get_loss_scaling_var(value.device)) # pylint: disable=invalid-unary-operand-type
def apply_loss_scaling(self, value: TfExpression) -> TfExpression: """Apply dynamic loss scaling for the given expression.""" assert tfutil.is_tf_expression(value) if not self.use_loss_scaling: return value return value * tfutil.exp2(self.get_loss_scaling_var(value.device))
def run( self, *in_arrays: Tuple[Union[np.ndarray, None], ...], input_transform: dict = None, output_transform: dict = None, return_as_list: bool = False, print_progress: bool = False, minibatch_size: int = None, num_gpus: int = 1, assume_frozen: bool = False, **dynamic_kwargs ) -> Union[np.ndarray, Tuple[np.ndarray, ...], List[np.ndarray]]: """Run this network for the given NumPy array(s), and return the output(s) as NumPy array(s). Args: input_transform: A dict specifying a custom transformation to be applied to the input tensor(s) before evaluating the network. The dict must contain a 'func' field that points to a top-level function. The function is called with the input TensorFlow expression(s) as positional arguments. Any remaining fields of the dict will be passed in as kwargs. output_transform: A dict specifying a custom transformation to be applied to the output tensor(s) after evaluating the network. The dict must contain a 'func' field that points to a top-level function. The function is called with the output TensorFlow expression(s) as positional arguments. Any remaining fields of the dict will be passed in as kwargs. return_as_list: True = return a list of NumPy arrays, False = return a single NumPy array, or a tuple if there are multiple outputs. print_progress: Print progress to the console? Useful for very large input arrays. minibatch_size: Maximum minibatch size to use, None = disable batching. num_gpus: Number of GPUs to use. assume_frozen: Improve multi-GPU performance by assuming that the trainable parameters will remain changed between calls. dynamic_kwargs: Additional keyword arguments to be passed into the network build function. """ assert len(in_arrays) == self.num_inputs assert not all(arr is None for arr in in_arrays) assert input_transform is None or util.is_top_level_function( input_transform["func"]) assert output_transform is None or util.is_top_level_function( output_transform["func"]) output_transform, dynamic_kwargs = _handle_legacy_output_transforms( output_transform, dynamic_kwargs) num_items = in_arrays[0].shape[0] if minibatch_size is None: minibatch_size = num_items # Construct unique hash key from all arguments that affect the TensorFlow graph. key = dict(input_transform=input_transform, output_transform=output_transform, num_gpus=num_gpus, assume_frozen=assume_frozen, dynamic_kwargs=dynamic_kwargs) def unwind_key(obj): if isinstance(obj, dict): return [(key, unwind_key(value)) for key, value in sorted(obj.items())] if callable(obj): return util.get_top_level_function_name(obj) return obj key = repr(unwind_key(key)) # Build graph. if key not in self._run_cache: with tfutil.absolute_name_scope( self.scope + "/_Run"), tf.control_dependencies(None): with tf.device("/cpu:0"): in_expr = [ tf.placeholder(tf.float32, name=name) for name in self.input_names ] in_split = list( zip(*[tf.split(x, num_gpus) for x in in_expr])) out_split = [] for gpu in range(num_gpus): with tf.device("/gpu:%d" % gpu): net_gpu = self.clone() if assume_frozen else self in_gpu = in_split[gpu] if input_transform is not None: in_kwargs = dict(input_transform) in_gpu = in_kwargs.pop("func")(*in_gpu, **in_kwargs) in_gpu = [in_gpu] if tfutil.is_tf_expression( in_gpu) else list(in_gpu) assert len(in_gpu) == self.num_inputs out_gpu = net_gpu.get_output_for(*in_gpu, return_as_list=True, **dynamic_kwargs) if output_transform is not None: out_kwargs = dict(output_transform) out_gpu = out_kwargs.pop("func")(*out_gpu, **out_kwargs) out_gpu = [out_gpu] if tfutil.is_tf_expression( out_gpu) else list(out_gpu) assert len(out_gpu) == self.num_outputs out_split.append(out_gpu) with tf.device("/cpu:0"): out_expr = [ tf.concat(outputs, axis=0) for outputs in zip(*out_split) ] self._run_cache[key] = in_expr, out_expr # Run minibatches. in_expr, out_expr = self._run_cache[key] out_arrays = [ np.empty([num_items] + expr.shape.as_list()[1:], expr.dtype.name) for expr in out_expr ] for mb_begin in range(0, num_items, minibatch_size): if print_progress: print("\r%d / %d" % (mb_begin, num_items), end="") mb_end = min(mb_begin + minibatch_size, num_items) mb_num = mb_end - mb_begin mb_in = [ src[mb_begin:mb_end] if src is not None else np.zeros([mb_num] + shape[1:]) for src, shape in zip(in_arrays, self.input_shapes) ] mb_out = tf.get_default_session().run(out_expr, dict(zip(in_expr, mb_in))) for dst, src in zip(out_arrays, mb_out): dst[mb_begin:mb_end] = src # Done. if print_progress: print("\r%d / %d" % (num_items, num_items)) if not return_as_list: out_arrays = out_arrays[0] if len(out_arrays) == 1 else tuple( out_arrays) return out_arrays
def _init_graph(self) -> None: # Collect inputs. self.input_names = [] for param in inspect.signature(self._build_func).parameters.values(): if param.kind == param.POSITIONAL_OR_KEYWORD and param.default is param.empty: self.input_names.append(param.name) self.num_inputs = len(self.input_names) assert self.num_inputs >= 1 # Choose name and scope. if self.name is None: self.name = self._build_func_name assert re.match("^[A-Za-z0-9_.\\-]*$", self.name) with tf.name_scope(None): self.scope = tf.get_default_graph().unique_name(self.name, mark_as_used=True) # Finalize build func kwargs. build_kwargs = dict(self.static_kwargs) build_kwargs["is_template_graph"] = True build_kwargs["components"] = self.components # Build template graph. with tfutil.absolute_variable_scope( self.scope, reuse=False), tfutil.absolute_name_scope( self.scope): # ignore surrounding scopes assert tf.get_variable_scope().name == self.scope assert tf.get_default_graph().get_name_scope() == self.scope with tf.control_dependencies( None): # ignore surrounding control dependencies self.input_templates = [ tf.placeholder(tf.float32, name=name) for name in self.input_names ] out_expr = self._build_func(*self.input_templates, **build_kwargs) # Collect outputs. assert tfutil.is_tf_expression(out_expr) or isinstance(out_expr, tuple) self.output_templates = [ out_expr ] if tfutil.is_tf_expression(out_expr) else list(out_expr) self.num_outputs = len(self.output_templates) assert self.num_outputs >= 1 assert all(tfutil.is_tf_expression(t) for t in self.output_templates) # Perform sanity checks. if any(t.shape.ndims is None for t in self.input_templates): raise ValueError( "Network input shapes not defined. Please call x.set_shape() for each input." ) if any(t.shape.ndims is None for t in self.output_templates): raise ValueError( "Network output shapes not defined. Please call x.set_shape() where applicable." ) if any(not isinstance(comp, Network) for comp in self.components.values()): raise ValueError( "Components of a Network must be Networks themselves.") if len(self.components) != len( set(comp.name for comp in self.components.values())): raise ValueError("Components of a Network must have unique names.") # List inputs and outputs. self.input_shapes = [t.shape.as_list() for t in self.input_templates] self.output_shapes = [t.shape.as_list() for t in self.output_templates] self.input_shape = self.input_shapes[0] self.output_shape = self.output_shapes[0] self.output_names = [ t.name.split("/")[-1].split(":")[0] for t in self.output_templates ] # List variables. self.own_vars = OrderedDict( (var.name[len(self.scope) + 1:].split(":")[0], var) for var in tf.global_variables(self.scope + "/")) self.vars = OrderedDict(self.own_vars) self.vars.update((comp.name + "/" + name, var) for comp in self.components.values() for name, var in comp.vars.items()) self.trainables = OrderedDict( (name, var) for name, var in self.vars.items() if var.trainable) self.var_global_to_local = OrderedDict( (var.name.split(":")[0], name) for name, var in self.vars.items())