def find_var(self, var_or_local_name: Union[TfExpression, str]) -> TfExpression: """Find variable by local or global name.""" assert tfutil.is_tf_expression(var_or_local_name) or isinstance( var_or_local_name, str) return self.vars[var_or_local_name] if isinstance( var_or_local_name, str) else var_or_local_name
def undo_loss_scaling(self, value: TfExpression) -> TfExpression: """Undo the effect of dynamic loss scaling for the given expression.""" assert tfutil.is_tf_expression(value) if not self.use_loss_scaling: return value return value * tfutil.exp2(-self.get_loss_scaling_var(value.device)) # pylint: disable=invalid-unary-operand-type
def get_var_local_name( self, var_or_global_name: Union[TfExpression, str]) -> str: """Get the local name of a given variable, without any surrounding name scopes.""" assert tfutil.is_tf_expression(var_or_global_name) or isinstance( var_or_global_name, str) global_name = var_or_global_name if isinstance( var_or_global_name, str) else var_or_global_name.name return self.var_global_to_local[global_name]
def apply_loss_scaling(self, value: TfExpression) -> TfExpression: """Apply dynamic loss scaling for the given expression.""" assert tfutil.is_tf_expression(value) if not self.use_loss_scaling: return value return value * tfutil.exp2(self.get_loss_scaling_var(value.device))
def get_output_for( self, *in_expr: TfExpression, return_as_list: bool = False, **dynamic_kwargs) -> Union[TfExpression, List[TfExpression]]: """Construct TensorFlow expression(s) for the output(s) of this network, given the input expression(s).""" assert len(in_expr) == self.num_inputs assert not all(expr is None for expr in in_expr) # Finalize build func kwargs. build_kwargs = dict(self.static_kwargs) build_kwargs.update(dynamic_kwargs) build_kwargs["is_template_graph"] = False build_kwargs["components"] = self.components # Build TensorFlow graph to evaluate the network. with tfutil.absolute_variable_scope( self.scope, reuse=True), tf.name_scope(self.name): assert tf.get_variable_scope().name == self.scope valid_inputs = [expr for expr in in_expr if expr is not None] final_inputs = [] for expr, name, shape in zip(in_expr, self.input_names, self.input_shapes): if expr is not None: expr = tf.identity(expr, name=name) else: expr = tf.zeros([tf.shape(valid_inputs[0])[0]] + shape[1:], name=name) final_inputs.append(expr) out_expr = self._build_func(*final_inputs, **build_kwargs) # Propagate input shapes back to the user-specified expressions. for expr, final in zip(in_expr, final_inputs): if isinstance(expr, tf.Tensor): expr.set_shape(final.shape) # Express outputs in the desired format. assert tfutil.is_tf_expression(out_expr) or isinstance(out_expr, tuple) if return_as_list: out_expr = [ out_expr ] if tfutil.is_tf_expression(out_expr) else list(out_expr) return out_expr
def register_gradients(self, loss: TfExpression, trainable_vars: Union[List, dict]) -> None: """Register the gradients of the given loss function with respect to the given variables. Intended to be called once per GPU.""" assert not self._updates_applied # Validate arguments. if isinstance(trainable_vars, dict): trainable_vars = list(trainable_vars.values( )) # allow passing in Network.trainables as vars assert isinstance(trainable_vars, list) and len(trainable_vars) >= 1 assert all( tfutil.is_tf_expression(expr) for expr in trainable_vars + [loss]) if self._grad_shapes is None: self._grad_shapes = [ tfutil.shape_to_list(var.shape) for var in trainable_vars ] assert len(trainable_vars) == len(self._grad_shapes) assert all( tfutil.shape_to_list(var.shape) == var_shape for var, var_shape in zip(trainable_vars, self._grad_shapes)) dev = loss.device assert all(var.device == dev for var in trainable_vars) # Register device and compute gradients. with tf.name_scope(self.id + "_grad"), tf.device(dev): if dev not in self._dev_opt: opt_name = self.scope.replace( "/", "_") + "_opt%d" % len(self._dev_opt) assert callable(self.optimizer_class) self._dev_opt[dev] = self.optimizer_class( name=opt_name, learning_rate=self.learning_rate, **self.optimizer_kwargs) self._dev_grads[dev] = [] loss = self.apply_loss_scaling(tf.cast(loss, tf.float32)) grads = self._dev_opt[dev].compute_gradients( loss, trainable_vars, gate_gradients=tf.train.Optimizer.GATE_NONE ) # disable gating to reduce memory usage grads = [(g, v) if g is not None else (tf.zeros_like(v), v) for g, v in grads ] # replace disconnected gradients with zeros self._dev_grads[dev].append(grads)
def autosummary(name: str, value: TfExpressionEx, passthru: TfExpressionEx = None) -> TfExpressionEx: """Create a new autosummary. Args: name: Name to use in TensorBoard value: TensorFlow expression or python value to track passthru: Optionally return this TF node without modifications but tack an autosummary update side-effect to this node. Example use of the passthru mechanism: n = autosummary('l2loss', loss, passthru=n) This is a shorthand for the following code: with tf.control_dependencies([autosummary('l2loss', loss)]): n = tf.identity(n) """ tfutil.assert_tf_initialized() name_id = name.replace("/", "_") if tfutil.is_tf_expression(value): with tf.name_scope("summary_" + name_id), tf.device(value.device): update_op = _create_var(name, value) with tf.control_dependencies([update_op]): return tf.identity(value if passthru is None else passthru) else: # python scalar or numpy array if name not in _immediate: with tfutil.absolute_name_scope( "Autosummary/" + name_id), tf.device(None), tf.control_dependencies(None): update_value = tf.placeholder(_dtype) update_op = _create_var(name, update_value) _immediate[name] = update_op, update_value update_op, update_value = _immediate[name] tfutil.run(update_op, {update_value: value}) return value if passthru is None else passthru
def G_style( latents_in, # First input: Latent vectors (Z) [minibatch, latent_size]. labels_in, # Second input: Conditioning labels [minibatch, label_size]. truncation_psi = 0.7, # Style strength multiplier for the truncation trick. None = disable. truncation_cutoff = 8, # Number of layers for which to apply the truncation trick. None = disable. truncation_psi_val = None, # Value for truncation_psi to use during validation. truncation_cutoff_val = None, # Value for truncation_cutoff to use during validation. dlatent_avg_beta = 0.995, # Decay for tracking the moving average of W during training. None = disable. style_mixing_prob = 0.9, # Probability of mixing styles during training. None = disable. is_training = False, # Network is under training? Enables and disables specific features. is_validation = False, # Network is under validation? Chooses which value to use for truncation_psi. is_template_graph = False, # True = template graph constructed by the Network class, False = actual evaluation. components = dnnlib.EasyDict(), # Container for sub-networks. Retained between calls. **kwargs): # Arguments for sub-networks (G_mapping and G_synthesis). # Validate arguments. assert not is_training or not is_validation assert isinstance(components, dnnlib.EasyDict) if is_validation: truncation_psi = truncation_psi_val truncation_cutoff = truncation_cutoff_val if is_training or (truncation_psi is not None and not tflib.is_tf_expression(truncation_psi) and truncation_psi == 1): truncation_psi = None if is_training or (truncation_cutoff is not None and not tflib.is_tf_expression(truncation_cutoff) and truncation_cutoff <= 0): truncation_cutoff = None if not is_training or (dlatent_avg_beta is not None and not tflib.is_tf_expression(dlatent_avg_beta) and dlatent_avg_beta == 1): dlatent_avg_beta = None if not is_training or (style_mixing_prob is not None and not tflib.is_tf_expression(style_mixing_prob) and style_mixing_prob <= 0): style_mixing_prob = None # Setup components. if 'synthesis' not in components: components.synthesis = tflib.Network('G_synthesis', func_name=G_synthesis, **kwargs) num_layers = components.synthesis.input_shape[1] dlatent_size = components.synthesis.input_shape[2] if 'mapping' not in components: components.mapping = tflib.Network('G_mapping', func_name=G_mapping, dlatent_broadcast=num_layers, **kwargs) # Setup variables. lod_in = tf.get_variable('lod', initializer=np.float32(0), trainable=False) dlatent_avg = tf.get_variable('dlatent_avg', shape=[dlatent_size], initializer=tf.initializers.zeros(), trainable=False) # Evaluate mapping network. dlatents = components.mapping.get_output_for(latents_in, labels_in, **kwargs) # Update moving average of W. if dlatent_avg_beta is not None: with tf.variable_scope('DlatentAvg'): batch_avg = tf.reduce_mean(dlatents[:, 0], axis=0) update_op = tf.assign(dlatent_avg, tflib.lerp(batch_avg, dlatent_avg, dlatent_avg_beta)) with tf.control_dependencies([update_op]): dlatents = tf.identity(dlatents) # Perform style mixing regularization. if style_mixing_prob is not None: with tf.name_scope('StyleMix'): latents2 = tf.random_normal(tf.shape(latents_in)) dlatents2 = components.mapping.get_output_for(latents2, labels_in, **kwargs) layer_idx = np.arange(num_layers)[np.newaxis, :, np.newaxis] cur_layers = num_layers - tf.cast(lod_in, tf.int32) * 2 mixing_cutoff = tf.cond( tf.random_uniform([], 0.0, 1.0) < style_mixing_prob, lambda: tf.random_uniform([], 1, cur_layers, dtype=tf.int32), lambda: cur_layers) dlatents = tf.where(tf.broadcast_to(layer_idx < mixing_cutoff, tf.shape(dlatents)), dlatents, dlatents2) # Apply truncation trick. if truncation_psi is not None and truncation_cutoff is not None: with tf.variable_scope('Truncation'): layer_idx = np.arange(num_layers)[np.newaxis, :, np.newaxis] ones = np.ones(layer_idx.shape, dtype=np.float32) coefs = tf.where(layer_idx < truncation_cutoff, truncation_psi * ones, ones) dlatents = tflib.lerp(dlatent_avg, dlatents, coefs) # Evaluate synthesis network. with tf.control_dependencies([tf.assign(components.synthesis.find_var('lod'), lod_in)]): images_out = components.synthesis.get_output_for(dlatents, force_clean_graph=is_template_graph, **kwargs) return tf.identity(images_out, name='images_out')
def run( self, *in_arrays: Tuple[Union[np.ndarray, None], ...], input_transform: dict = None, output_transform: dict = None, return_as_list: bool = False, print_progress: bool = False, minibatch_size: int = None, num_gpus: int = 1, assume_frozen: bool = False, **dynamic_kwargs ) -> Union[np.ndarray, Tuple[np.ndarray, ...], List[np.ndarray]]: """Run this network for the given NumPy array(s), and return the output(s) as NumPy array(s). Args: input_transform: A dict specifying a custom transformation to be applied to the input tensor(s) before evaluating the network. The dict must contain a 'func' field that points to a top-level function. The function is called with the input TensorFlow expression(s) as positional arguments. Any remaining fields of the dict will be passed in as kwargs. output_transform: A dict specifying a custom transformation to be applied to the output tensor(s) after evaluating the network. The dict must contain a 'func' field that points to a top-level function. The function is called with the output TensorFlow expression(s) as positional arguments. Any remaining fields of the dict will be passed in as kwargs. return_as_list: True = return a list of NumPy arrays, False = return a single NumPy array, or a tuple if there are multiple outputs. print_progress: Print progress to the console? Useful for very large input arrays. minibatch_size: Maximum minibatch size to use, None = disable batching. num_gpus: Number of GPUs to use. assume_frozen: Improve multi-GPU performance by assuming that the trainable parameters will remain changed between calls. dynamic_kwargs: Additional keyword arguments to be passed into the network build function. """ assert len(in_arrays) == self.num_inputs assert not all(arr is None for arr in in_arrays) assert input_transform is None or util.is_top_level_function( input_transform["func"]) assert output_transform is None or util.is_top_level_function( output_transform["func"]) output_transform, dynamic_kwargs = _handle_legacy_output_transforms( output_transform, dynamic_kwargs) num_items = in_arrays[0].shape[0] if minibatch_size is None: minibatch_size = num_items # Construct unique hash key from all arguments that affect the TensorFlow graph. key = dict(input_transform=input_transform, output_transform=output_transform, num_gpus=num_gpus, assume_frozen=assume_frozen, dynamic_kwargs=dynamic_kwargs) def unwind_key(obj): if isinstance(obj, dict): return [(key, unwind_key(value)) for key, value in sorted(obj.items())] if callable(obj): return util.get_top_level_function_name(obj) return obj key = repr(unwind_key(key)) # Build graph. if key not in self._run_cache: with tfutil.absolute_name_scope( self.scope + "/_Run"), tf.control_dependencies(None): with tf.device("/cpu:0"): in_expr = [ tf.placeholder(tf.float32, name=name) for name in self.input_names ] in_split = list( zip(*[tf.split(x, num_gpus) for x in in_expr])) out_split = [] for gpu in range(num_gpus): with tf.device("/gpu:%d" % gpu): net_gpu = self.clone() if assume_frozen else self in_gpu = in_split[gpu] if input_transform is not None: in_kwargs = dict(input_transform) in_gpu = in_kwargs.pop("func")(*in_gpu, **in_kwargs) in_gpu = [in_gpu] if tfutil.is_tf_expression( in_gpu) else list(in_gpu) assert len(in_gpu) == self.num_inputs out_gpu = net_gpu.get_output_for(*in_gpu, return_as_list=True, **dynamic_kwargs) if output_transform is not None: out_kwargs = dict(output_transform) out_gpu = out_kwargs.pop("func")(*out_gpu, **out_kwargs) out_gpu = [out_gpu] if tfutil.is_tf_expression( out_gpu) else list(out_gpu) assert len(out_gpu) == self.num_outputs out_split.append(out_gpu) with tf.device("/cpu:0"): out_expr = [ tf.concat(outputs, axis=0) for outputs in zip(*out_split) ] self._run_cache[key] = in_expr, out_expr # Run minibatches. in_expr, out_expr = self._run_cache[key] out_arrays = [ np.empty([num_items] + tfutil.shape_to_list(expr.shape)[1:], expr.dtype.name) for expr in out_expr ] for mb_begin in range(0, num_items, minibatch_size): if print_progress: print("\r%d / %d" % (mb_begin, num_items), end="") mb_end = min(mb_begin + minibatch_size, num_items) mb_num = mb_end - mb_begin mb_in = [ src[mb_begin:mb_end] if src is not None else np.zeros([mb_num] + shape[1:]) for src, shape in zip(in_arrays, self.input_shapes) ] mb_out = tf.get_default_session().run(out_expr, dict(zip(in_expr, mb_in))) for dst, src in zip(out_arrays, mb_out): dst[mb_begin:mb_end] = src # Done. if print_progress: print("\r%d / %d" % (num_items, num_items)) if not return_as_list: out_arrays = out_arrays[0] if len(out_arrays) == 1 else tuple( out_arrays) return out_arrays
def _init_graph(self) -> None: # Collect inputs. self.input_names = [] for param in inspect.signature(self._build_func).parameters.values(): if param.kind == param.POSITIONAL_OR_KEYWORD and param.default is param.empty: self.input_names.append(param.name) self.num_inputs = len(self.input_names) assert self.num_inputs >= 1 # Choose name and scope. if self.name is None: self.name = self._build_func_name assert re.match("^[A-Za-z0-9_.\\-]*$", self.name) with tf.name_scope(None): self.scope = tf.get_default_graph().unique_name(self.name, mark_as_used=True) # Finalize build func kwargs. build_kwargs = dict(self.static_kwargs) build_kwargs["is_template_graph"] = True build_kwargs["components"] = self.components # Build template graph. with tfutil.absolute_variable_scope( self.scope, reuse=tf.AUTO_REUSE), tfutil.absolute_name_scope( self.scope): # ignore surrounding scopes assert tf.get_variable_scope().name == self.scope assert tf.get_default_graph().get_name_scope() == self.scope with tf.control_dependencies( None): # ignore surrounding control dependencies self.input_templates = [ tf.placeholder(tf.float32, name=name) for name in self.input_names ] out_expr = self._build_func(*self.input_templates, **build_kwargs) # Collect outputs. assert tfutil.is_tf_expression(out_expr) or isinstance(out_expr, tuple) self.output_templates = [ out_expr ] if tfutil.is_tf_expression(out_expr) else list(out_expr) self.num_outputs = len(self.output_templates) assert self.num_outputs >= 1 assert all(tfutil.is_tf_expression(t) for t in self.output_templates) # Perform sanity checks. if any(t.shape.ndims is None for t in self.input_templates): raise ValueError( "Network input shapes not defined. Please call x.set_shape() for each input." ) if any(t.shape.ndims is None for t in self.output_templates): raise ValueError( "Network output shapes not defined. Please call x.set_shape() where applicable." ) if any(not isinstance(comp, Network) for comp in self.components.values()): raise ValueError( "Components of a Network must be Networks themselves.") if len(self.components) != len( set(comp.name for comp in self.components.values())): raise ValueError("Components of a Network must have unique names.") # List inputs and outputs. self.input_shapes = [ tfutil.shape_to_list(t.shape) for t in self.input_templates ] self.output_shapes = [ tfutil.shape_to_list(t.shape) for t in self.output_templates ] self.input_shape = self.input_shapes[0] self.output_shape = self.output_shapes[0] self.output_names = [ t.name.split("/")[-1].split(":")[0] for t in self.output_templates ] # List variables. self.own_vars = OrderedDict( (var.name[len(self.scope) + 1:].split(":")[0], var) for var in tf.global_variables(self.scope + "/")) self.vars = OrderedDict(self.own_vars) self.vars.update((comp.name + "/" + name, var) for comp in self.components.values() for name, var in comp.vars.items()) self.trainables = OrderedDict( (name, var) for name, var in self.vars.items() if var.trainable) self.var_global_to_local = OrderedDict( (var.name.split(":")[0], name) for name, var in self.vars.items())