예제 #1
0
 def find_var(self, var_or_local_name: Union[TfExpression,
                                             str]) -> TfExpression:
     """Find variable by local or global name."""
     assert tfutil.is_tf_expression(var_or_local_name) or isinstance(
         var_or_local_name, str)
     return self.vars[var_or_local_name] if isinstance(
         var_or_local_name, str) else var_or_local_name
예제 #2
0
    def undo_loss_scaling(self, value: TfExpression) -> TfExpression:
        """Undo the effect of dynamic loss scaling for the given expression."""
        assert tfutil.is_tf_expression(value)

        if not self.use_loss_scaling:
            return value

        return value * tfutil.exp2(-self.get_loss_scaling_var(value.device))  # pylint: disable=invalid-unary-operand-type
예제 #3
0
 def get_var_local_name(
         self, var_or_global_name: Union[TfExpression, str]) -> str:
     """Get the local name of a given variable, without any surrounding name scopes."""
     assert tfutil.is_tf_expression(var_or_global_name) or isinstance(
         var_or_global_name, str)
     global_name = var_or_global_name if isinstance(
         var_or_global_name, str) else var_or_global_name.name
     return self.var_global_to_local[global_name]
예제 #4
0
    def apply_loss_scaling(self, value: TfExpression) -> TfExpression:
        """Apply dynamic loss scaling for the given expression."""
        assert tfutil.is_tf_expression(value)

        if not self.use_loss_scaling:
            return value

        return value * tfutil.exp2(self.get_loss_scaling_var(value.device))
예제 #5
0
    def get_output_for(
            self,
            *in_expr: TfExpression,
            return_as_list: bool = False,
            **dynamic_kwargs) -> Union[TfExpression, List[TfExpression]]:
        """Construct TensorFlow expression(s) for the output(s) of this network, given the input expression(s)."""
        assert len(in_expr) == self.num_inputs
        assert not all(expr is None for expr in in_expr)

        # Finalize build func kwargs.
        build_kwargs = dict(self.static_kwargs)
        build_kwargs.update(dynamic_kwargs)
        build_kwargs["is_template_graph"] = False
        build_kwargs["components"] = self.components

        # Build TensorFlow graph to evaluate the network.
        with tfutil.absolute_variable_scope(
                self.scope, reuse=True), tf.name_scope(self.name):
            assert tf.get_variable_scope().name == self.scope
            valid_inputs = [expr for expr in in_expr if expr is not None]
            final_inputs = []
            for expr, name, shape in zip(in_expr, self.input_names,
                                         self.input_shapes):
                if expr is not None:
                    expr = tf.identity(expr, name=name)
                else:
                    expr = tf.zeros([tf.shape(valid_inputs[0])[0]] + shape[1:],
                                    name=name)
                final_inputs.append(expr)
            out_expr = self._build_func(*final_inputs, **build_kwargs)

        # Propagate input shapes back to the user-specified expressions.
        for expr, final in zip(in_expr, final_inputs):
            if isinstance(expr, tf.Tensor):
                expr.set_shape(final.shape)

        # Express outputs in the desired format.
        assert tfutil.is_tf_expression(out_expr) or isinstance(out_expr, tuple)
        if return_as_list:
            out_expr = [
                out_expr
            ] if tfutil.is_tf_expression(out_expr) else list(out_expr)
        return out_expr
예제 #6
0
    def register_gradients(self, loss: TfExpression,
                           trainable_vars: Union[List, dict]) -> None:
        """Register the gradients of the given loss function with respect to the given variables.
        Intended to be called once per GPU."""
        assert not self._updates_applied

        # Validate arguments.
        if isinstance(trainable_vars, dict):
            trainable_vars = list(trainable_vars.values(
            ))  # allow passing in Network.trainables as vars

        assert isinstance(trainable_vars, list) and len(trainable_vars) >= 1
        assert all(
            tfutil.is_tf_expression(expr) for expr in trainable_vars + [loss])

        if self._grad_shapes is None:
            self._grad_shapes = [
                tfutil.shape_to_list(var.shape) for var in trainable_vars
            ]

        assert len(trainable_vars) == len(self._grad_shapes)
        assert all(
            tfutil.shape_to_list(var.shape) == var_shape
            for var, var_shape in zip(trainable_vars, self._grad_shapes))

        dev = loss.device

        assert all(var.device == dev for var in trainable_vars)

        # Register device and compute gradients.
        with tf.name_scope(self.id + "_grad"), tf.device(dev):
            if dev not in self._dev_opt:
                opt_name = self.scope.replace(
                    "/", "_") + "_opt%d" % len(self._dev_opt)
                assert callable(self.optimizer_class)
                self._dev_opt[dev] = self.optimizer_class(
                    name=opt_name,
                    learning_rate=self.learning_rate,
                    **self.optimizer_kwargs)
                self._dev_grads[dev] = []

            loss = self.apply_loss_scaling(tf.cast(loss, tf.float32))
            grads = self._dev_opt[dev].compute_gradients(
                loss,
                trainable_vars,
                gate_gradients=tf.train.Optimizer.GATE_NONE
            )  # disable gating to reduce memory usage
            grads = [(g, v) if g is not None else (tf.zeros_like(v), v)
                     for g, v in grads
                     ]  # replace disconnected gradients with zeros
            self._dev_grads[dev].append(grads)
예제 #7
0
def autosummary(name: str,
                value: TfExpressionEx,
                passthru: TfExpressionEx = None) -> TfExpressionEx:
    """Create a new autosummary.

    Args:
        name:     Name to use in TensorBoard
        value:    TensorFlow expression or python value to track
        passthru: Optionally return this TF node without modifications but tack an autosummary update side-effect to this node.

    Example use of the passthru mechanism:

    n = autosummary('l2loss', loss, passthru=n)

    This is a shorthand for the following code:

    with tf.control_dependencies([autosummary('l2loss', loss)]):
        n = tf.identity(n)
    """
    tfutil.assert_tf_initialized()
    name_id = name.replace("/", "_")

    if tfutil.is_tf_expression(value):
        with tf.name_scope("summary_" + name_id), tf.device(value.device):
            update_op = _create_var(name, value)
            with tf.control_dependencies([update_op]):
                return tf.identity(value if passthru is None else passthru)

    else:  # python scalar or numpy array
        if name not in _immediate:
            with tfutil.absolute_name_scope(
                    "Autosummary/" +
                    name_id), tf.device(None), tf.control_dependencies(None):
                update_value = tf.placeholder(_dtype)
                update_op = _create_var(name, update_value)
                _immediate[name] = update_op, update_value

        update_op, update_value = _immediate[name]
        tfutil.run(update_op, {update_value: value})
        return value if passthru is None else passthru
예제 #8
0
def G_style(
    latents_in,                                     # First input: Latent vectors (Z) [minibatch, latent_size].
    labels_in,                                      # Second input: Conditioning labels [minibatch, label_size].
    truncation_psi          = 0.7,                  # Style strength multiplier for the truncation trick. None = disable.
    truncation_cutoff       = 8,                    # Number of layers for which to apply the truncation trick. None = disable.
    truncation_psi_val      = None,                 # Value for truncation_psi to use during validation.
    truncation_cutoff_val   = None,                 # Value for truncation_cutoff to use during validation.
    dlatent_avg_beta        = 0.995,                # Decay for tracking the moving average of W during training. None = disable.
    style_mixing_prob       = 0.9,                  # Probability of mixing styles during training. None = disable.
    is_training             = False,                # Network is under training? Enables and disables specific features.
    is_validation           = False,                # Network is under validation? Chooses which value to use for truncation_psi.
    is_template_graph       = False,                # True = template graph constructed by the Network class, False = actual evaluation.
    components              = dnnlib.EasyDict(),    # Container for sub-networks. Retained between calls.
    **kwargs):                                      # Arguments for sub-networks (G_mapping and G_synthesis).

    # Validate arguments.
    assert not is_training or not is_validation
    assert isinstance(components, dnnlib.EasyDict)
    if is_validation:
        truncation_psi = truncation_psi_val
        truncation_cutoff = truncation_cutoff_val
    if is_training or (truncation_psi is not None and not tflib.is_tf_expression(truncation_psi) and truncation_psi == 1):
        truncation_psi = None
    if is_training or (truncation_cutoff is not None and not tflib.is_tf_expression(truncation_cutoff) and truncation_cutoff <= 0):
        truncation_cutoff = None
    if not is_training or (dlatent_avg_beta is not None and not tflib.is_tf_expression(dlatent_avg_beta) and dlatent_avg_beta == 1):
        dlatent_avg_beta = None
    if not is_training or (style_mixing_prob is not None and not tflib.is_tf_expression(style_mixing_prob) and style_mixing_prob <= 0):
        style_mixing_prob = None

    # Setup components.
    if 'synthesis' not in components:
        components.synthesis = tflib.Network('G_synthesis', func_name=G_synthesis, **kwargs)
    num_layers = components.synthesis.input_shape[1]
    dlatent_size = components.synthesis.input_shape[2]
    if 'mapping' not in components:
        components.mapping = tflib.Network('G_mapping', func_name=G_mapping, dlatent_broadcast=num_layers, **kwargs)

    # Setup variables.
    lod_in = tf.get_variable('lod', initializer=np.float32(0), trainable=False)
    dlatent_avg = tf.get_variable('dlatent_avg', shape=[dlatent_size], initializer=tf.initializers.zeros(), trainable=False)

    # Evaluate mapping network.
    dlatents = components.mapping.get_output_for(latents_in, labels_in, **kwargs)

    # Update moving average of W.
    if dlatent_avg_beta is not None:
        with tf.variable_scope('DlatentAvg'):
            batch_avg = tf.reduce_mean(dlatents[:, 0], axis=0)
            update_op = tf.assign(dlatent_avg, tflib.lerp(batch_avg, dlatent_avg, dlatent_avg_beta))
            with tf.control_dependencies([update_op]):
                dlatents = tf.identity(dlatents)

    # Perform style mixing regularization.
    if style_mixing_prob is not None:
        with tf.name_scope('StyleMix'):
            latents2 = tf.random_normal(tf.shape(latents_in))
            dlatents2 = components.mapping.get_output_for(latents2, labels_in, **kwargs)
            layer_idx = np.arange(num_layers)[np.newaxis, :, np.newaxis]
            cur_layers = num_layers - tf.cast(lod_in, tf.int32) * 2
            mixing_cutoff = tf.cond(
                tf.random_uniform([], 0.0, 1.0) < style_mixing_prob,
                lambda: tf.random_uniform([], 1, cur_layers, dtype=tf.int32),
                lambda: cur_layers)
            dlatents = tf.where(tf.broadcast_to(layer_idx < mixing_cutoff, tf.shape(dlatents)), dlatents, dlatents2)

    # Apply truncation trick.
    if truncation_psi is not None and truncation_cutoff is not None:
        with tf.variable_scope('Truncation'):
            layer_idx = np.arange(num_layers)[np.newaxis, :, np.newaxis]
            ones = np.ones(layer_idx.shape, dtype=np.float32)
            coefs = tf.where(layer_idx < truncation_cutoff, truncation_psi * ones, ones)
            dlatents = tflib.lerp(dlatent_avg, dlatents, coefs)

    # Evaluate synthesis network.
    with tf.control_dependencies([tf.assign(components.synthesis.find_var('lod'), lod_in)]):
        images_out = components.synthesis.get_output_for(dlatents, force_clean_graph=is_template_graph, **kwargs)
    return tf.identity(images_out, name='images_out')
예제 #9
0
    def run(
        self,
        *in_arrays: Tuple[Union[np.ndarray, None], ...],
        input_transform: dict = None,
        output_transform: dict = None,
        return_as_list: bool = False,
        print_progress: bool = False,
        minibatch_size: int = None,
        num_gpus: int = 1,
        assume_frozen: bool = False,
        **dynamic_kwargs
    ) -> Union[np.ndarray, Tuple[np.ndarray, ...], List[np.ndarray]]:
        """Run this network for the given NumPy array(s), and return the output(s) as NumPy array(s).

        Args:
            input_transform:    A dict specifying a custom transformation to be applied to the input tensor(s) before evaluating the network.
                                The dict must contain a 'func' field that points to a top-level function. The function is called with the input
                                TensorFlow expression(s) as positional arguments. Any remaining fields of the dict will be passed in as kwargs.
            output_transform:   A dict specifying a custom transformation to be applied to the output tensor(s) after evaluating the network.
                                The dict must contain a 'func' field that points to a top-level function. The function is called with the output
                                TensorFlow expression(s) as positional arguments. Any remaining fields of the dict will be passed in as kwargs.
            return_as_list:     True = return a list of NumPy arrays, False = return a single NumPy array, or a tuple if there are multiple outputs.
            print_progress:     Print progress to the console? Useful for very large input arrays.
            minibatch_size:     Maximum minibatch size to use, None = disable batching.
            num_gpus:           Number of GPUs to use.
            assume_frozen:      Improve multi-GPU performance by assuming that the trainable parameters will remain changed between calls.
            dynamic_kwargs:     Additional keyword arguments to be passed into the network build function.
        """
        assert len(in_arrays) == self.num_inputs
        assert not all(arr is None for arr in in_arrays)
        assert input_transform is None or util.is_top_level_function(
            input_transform["func"])
        assert output_transform is None or util.is_top_level_function(
            output_transform["func"])
        output_transform, dynamic_kwargs = _handle_legacy_output_transforms(
            output_transform, dynamic_kwargs)
        num_items = in_arrays[0].shape[0]
        if minibatch_size is None:
            minibatch_size = num_items

        # Construct unique hash key from all arguments that affect the TensorFlow graph.
        key = dict(input_transform=input_transform,
                   output_transform=output_transform,
                   num_gpus=num_gpus,
                   assume_frozen=assume_frozen,
                   dynamic_kwargs=dynamic_kwargs)

        def unwind_key(obj):
            if isinstance(obj, dict):
                return [(key, unwind_key(value))
                        for key, value in sorted(obj.items())]
            if callable(obj):
                return util.get_top_level_function_name(obj)
            return obj

        key = repr(unwind_key(key))

        # Build graph.
        if key not in self._run_cache:
            with tfutil.absolute_name_scope(
                    self.scope + "/_Run"), tf.control_dependencies(None):
                with tf.device("/cpu:0"):
                    in_expr = [
                        tf.placeholder(tf.float32, name=name)
                        for name in self.input_names
                    ]
                    in_split = list(
                        zip(*[tf.split(x, num_gpus) for x in in_expr]))

                out_split = []
                for gpu in range(num_gpus):
                    with tf.device("/gpu:%d" % gpu):
                        net_gpu = self.clone() if assume_frozen else self
                        in_gpu = in_split[gpu]

                        if input_transform is not None:
                            in_kwargs = dict(input_transform)
                            in_gpu = in_kwargs.pop("func")(*in_gpu,
                                                           **in_kwargs)
                            in_gpu = [in_gpu] if tfutil.is_tf_expression(
                                in_gpu) else list(in_gpu)

                        assert len(in_gpu) == self.num_inputs
                        out_gpu = net_gpu.get_output_for(*in_gpu,
                                                         return_as_list=True,
                                                         **dynamic_kwargs)

                        if output_transform is not None:
                            out_kwargs = dict(output_transform)
                            out_gpu = out_kwargs.pop("func")(*out_gpu,
                                                             **out_kwargs)
                            out_gpu = [out_gpu] if tfutil.is_tf_expression(
                                out_gpu) else list(out_gpu)

                        assert len(out_gpu) == self.num_outputs
                        out_split.append(out_gpu)

                with tf.device("/cpu:0"):
                    out_expr = [
                        tf.concat(outputs, axis=0)
                        for outputs in zip(*out_split)
                    ]
                    self._run_cache[key] = in_expr, out_expr

        # Run minibatches.
        in_expr, out_expr = self._run_cache[key]
        out_arrays = [
            np.empty([num_items] + tfutil.shape_to_list(expr.shape)[1:],
                     expr.dtype.name) for expr in out_expr
        ]

        for mb_begin in range(0, num_items, minibatch_size):
            if print_progress:
                print("\r%d / %d" % (mb_begin, num_items), end="")

            mb_end = min(mb_begin + minibatch_size, num_items)
            mb_num = mb_end - mb_begin
            mb_in = [
                src[mb_begin:mb_end]
                if src is not None else np.zeros([mb_num] + shape[1:])
                for src, shape in zip(in_arrays, self.input_shapes)
            ]
            mb_out = tf.get_default_session().run(out_expr,
                                                  dict(zip(in_expr, mb_in)))

            for dst, src in zip(out_arrays, mb_out):
                dst[mb_begin:mb_end] = src

        # Done.
        if print_progress:
            print("\r%d / %d" % (num_items, num_items))

        if not return_as_list:
            out_arrays = out_arrays[0] if len(out_arrays) == 1 else tuple(
                out_arrays)
        return out_arrays
예제 #10
0
    def _init_graph(self) -> None:
        # Collect inputs.
        self.input_names = []

        for param in inspect.signature(self._build_func).parameters.values():
            if param.kind == param.POSITIONAL_OR_KEYWORD and param.default is param.empty:
                self.input_names.append(param.name)

        self.num_inputs = len(self.input_names)
        assert self.num_inputs >= 1

        # Choose name and scope.
        if self.name is None:
            self.name = self._build_func_name
        assert re.match("^[A-Za-z0-9_.\\-]*$", self.name)
        with tf.name_scope(None):
            self.scope = tf.get_default_graph().unique_name(self.name,
                                                            mark_as_used=True)

        # Finalize build func kwargs.
        build_kwargs = dict(self.static_kwargs)
        build_kwargs["is_template_graph"] = True
        build_kwargs["components"] = self.components

        # Build template graph.
        with tfutil.absolute_variable_scope(
                self.scope, reuse=tf.AUTO_REUSE), tfutil.absolute_name_scope(
                    self.scope):  # ignore surrounding scopes
            assert tf.get_variable_scope().name == self.scope
            assert tf.get_default_graph().get_name_scope() == self.scope
            with tf.control_dependencies(
                    None):  # ignore surrounding control dependencies
                self.input_templates = [
                    tf.placeholder(tf.float32, name=name)
                    for name in self.input_names
                ]
                out_expr = self._build_func(*self.input_templates,
                                            **build_kwargs)

        # Collect outputs.
        assert tfutil.is_tf_expression(out_expr) or isinstance(out_expr, tuple)
        self.output_templates = [
            out_expr
        ] if tfutil.is_tf_expression(out_expr) else list(out_expr)
        self.num_outputs = len(self.output_templates)
        assert self.num_outputs >= 1
        assert all(tfutil.is_tf_expression(t) for t in self.output_templates)

        # Perform sanity checks.
        if any(t.shape.ndims is None for t in self.input_templates):
            raise ValueError(
                "Network input shapes not defined. Please call x.set_shape() for each input."
            )
        if any(t.shape.ndims is None for t in self.output_templates):
            raise ValueError(
                "Network output shapes not defined. Please call x.set_shape() where applicable."
            )
        if any(not isinstance(comp, Network)
               for comp in self.components.values()):
            raise ValueError(
                "Components of a Network must be Networks themselves.")
        if len(self.components) != len(
                set(comp.name for comp in self.components.values())):
            raise ValueError("Components of a Network must have unique names.")

        # List inputs and outputs.
        self.input_shapes = [
            tfutil.shape_to_list(t.shape) for t in self.input_templates
        ]
        self.output_shapes = [
            tfutil.shape_to_list(t.shape) for t in self.output_templates
        ]
        self.input_shape = self.input_shapes[0]
        self.output_shape = self.output_shapes[0]
        self.output_names = [
            t.name.split("/")[-1].split(":")[0] for t in self.output_templates
        ]

        # List variables.
        self.own_vars = OrderedDict(
            (var.name[len(self.scope) + 1:].split(":")[0], var)
            for var in tf.global_variables(self.scope + "/"))
        self.vars = OrderedDict(self.own_vars)
        self.vars.update((comp.name + "/" + name, var)
                         for comp in self.components.values()
                         for name, var in comp.vars.items())
        self.trainables = OrderedDict(
            (name, var) for name, var in self.vars.items() if var.trainable)
        self.var_global_to_local = OrderedDict(
            (var.name.split(":")[0], name) for name, var in self.vars.items())