Ejemplo n.º 1
0
 def reset_optimizer_state(self) -> None:
     """Reset internal state of the underlying optimizer."""
     tfutil.assert_tf_initialized()
     tfutil.run([
         var.initializer for device in self._devices.values()
         for var in device.optimizer.variables()
     ])
Ejemplo n.º 2
0
 def copy_own_vars_from(self, src_net: "Network") -> None:
     """Copy the values of all variables from the given network, excluding sub-networks."""
     names = [
         name for name in self.own_vars.keys() if name in src_net.own_vars
     ]
     tfutil.set_vars(
         tfutil.run({self.vars[name]: src_net.vars[name]
                     for name in names}))
Ejemplo n.º 3
0
def autosummary(name: str, value: TfExpressionEx, passthru: TfExpressionEx = None, condition: TfExpressionEx = True) -> TfExpressionEx:
    """Create a new autosummary.

    Args:
        name:     Name to use in TensorBoard
        value:    TensorFlow expression or python value to track
        passthru: Optionally return this TF node without modifications but tack an autosummary update side-effect to this node.

    Example use of the passthru mechanism:

    n = autosummary('l2loss', loss, passthru=n)

    This is a shorthand for the following code:

    with tf.control_dependencies([autosummary('l2loss', loss)]):
        n = tf.identity(n)
    """
    tfutil.assert_tf_initialized()
    name_id = name.replace("/", "_")

    if tfutil.is_tf_expression(value):
        with tf.name_scope("summary_" + name_id), tf.device(value.device):
            condition = tf.convert_to_tensor(condition, name='condition')
            update_op = tf.cond(condition, lambda: tf.group(_create_var(name, value)), tf.no_op)
            with tf.control_dependencies([update_op]):
                return tf.identity(value if passthru is None else passthru)

    else:  # python scalar or numpy array
        assert not tfutil.is_tf_expression(passthru)
        assert not tfutil.is_tf_expression(condition)
        if condition:
            if name not in _immediate:
                with tfutil.absolute_name_scope("Autosummary/" + name_id), tf.device(None), tf.control_dependencies(None):
                    update_value = tf.placeholder(_dtype)
                    update_op = _create_var(name, update_value)
                    _immediate[name] = update_op, update_value
            update_op, update_value = _immediate[name]
            tfutil.run(update_op, {update_value: value})
        return value if passthru is None else passthru
Ejemplo n.º 4
0
 def __getstate__(self) -> dict:
     """Pickle export."""
     state = dict()
     state["version"] = 4
     state["name"] = self.name
     state["static_kwargs"] = dict(self.static_kwargs)
     state["components"] = dict(self.components)
     state["build_module_src"] = self._build_module_src
     state["build_func_name"] = self._build_func_name
     state["variables"] = list(
         zip(self.own_vars.keys(),
             tfutil.run(list(self.own_vars.values()))))
     return state
Ejemplo n.º 5
0
    def apply_updates(self, allow_no_op: bool = False) -> tf.Operation:
        """Construct training op to update the registered variables based on their gradients."""
        tfutil.assert_tf_initialized()
        assert not self._updates_applied
        self._updates_applied = True
        all_ops = []

        # Check for no-op.
        if allow_no_op and len(self._devices) == 0:
            with tfutil.absolute_name_scope(self.scope):
                return tf.no_op(name='TrainingOp')

        # Clean up gradients.
        for device_idx, device in enumerate(self._devices.values()):
            with tfutil.absolute_name_scope(self.scope + "/Clean%d" %
                                            device_idx), tf.device(
                                                device.name):
                for var, grad in device.grad_raw.items():

                    # Filter out disconnected gradients and convert to float32.
                    grad = [g for g in grad if g is not None]
                    grad = [tf.cast(g, tf.float32) for g in grad]

                    # Sum within the device.
                    if len(grad) == 0:
                        grad = tf.zeros(var.shape)  # No gradients => zero.
                    elif len(grad) == 1:
                        grad = grad[0]  # Single gradient => use as is.
                    else:
                        grad = tf.add_n(grad)  # Multiple gradients => sum.

                    # Scale as needed.
                    scale = 1.0 / len(device.grad_raw[var]) / len(
                        self._devices)
                    scale = tf.constant(scale, dtype=tf.float32, name="scale")
                    if self.minibatch_multiplier is not None:
                        scale /= tf.cast(self.minibatch_multiplier, tf.float32)
                    scale = self.undo_loss_scaling(scale)
                    device.grad_clean[var] = grad * scale

        # Sum gradients across devices.
        if len(self._devices) > 1:
            with tfutil.absolute_name_scope(self.scope +
                                            "/Broadcast"), tf.device(None):
                for all_vars in zip(*[
                        device.grad_clean.keys()
                        for device in self._devices.values()
                ]):
                    if len(all_vars) > 0 and all(
                            dim > 0 for dim in all_vars[0].shape.as_list()
                    ):  # NCCL does not support zero-sized tensors.
                        all_grads = [
                            device.grad_clean[var] for device, var in zip(
                                self._devices.values(), all_vars)
                        ]
                        all_grads = nccl_ops.all_sum(all_grads)
                        for device, var, grad in zip(self._devices.values(),
                                                     all_vars, all_grads):
                            device.grad_clean[var] = grad

        # Apply updates separately on each device.
        for device_idx, device in enumerate(self._devices.values()):
            with tfutil.absolute_name_scope(self.scope + "/Apply%d" %
                                            device_idx), tf.device(
                                                device.name):
                # pylint: disable=cell-var-from-loop

                # Accumulate gradients over time.
                if self.minibatch_multiplier is None:
                    acc_ok = tf.constant(True, name='acc_ok')
                    device.grad_acc = OrderedDict(device.grad_clean)
                else:
                    # Create variables.
                    with tf.control_dependencies(None):
                        for var in device.grad_clean.keys():
                            device.grad_acc_vars[var] = tf.Variable(
                                tf.zeros(var.shape),
                                trainable=False,
                                name="grad_acc_var")
                        device.grad_acc_count = tf.Variable(
                            tf.zeros([]),
                            trainable=False,
                            name="grad_acc_count")

                    # Track counter.
                    count_cur = device.grad_acc_count + 1.0
                    count_inc_op = lambda: tf.assign(device.grad_acc_count,
                                                     count_cur)
                    count_reset_op = lambda: tf.assign(device.grad_acc_count,
                                                       tf.zeros([]))
                    acc_ok = (count_cur >= tf.cast(self.minibatch_multiplier,
                                                   tf.float32))
                    all_ops.append(
                        tf.cond(acc_ok, count_reset_op, count_inc_op))

                    # Track gradients.
                    for var, grad in device.grad_clean.items():
                        acc_var = device.grad_acc_vars[var]
                        acc_cur = acc_var + grad
                        device.grad_acc[var] = acc_cur
                        with tf.control_dependencies([acc_cur]):
                            acc_inc_op = lambda: tf.assign(acc_var, acc_cur)
                            acc_reset_op = lambda: tf.assign(
                                acc_var, tf.zeros(var.shape))
                            all_ops.append(
                                tf.cond(acc_ok, acc_reset_op, acc_inc_op))

                # No overflow => apply gradients.
                all_ok = tf.reduce_all(
                    tf.stack([acc_ok] + [
                        tf.reduce_all(tf.is_finite(g))
                        for g in device.grad_acc.values()
                    ]))
                apply_op = lambda: device.optimizer.apply_gradients(
                    [(tf.cast(grad, var.dtype), var)
                     for var, grad in device.grad_acc.items()])
                all_ops.append(tf.cond(all_ok, apply_op, tf.no_op))

                # Adjust loss scaling.
                if self.use_loss_scaling:
                    ls_inc_op = lambda: tf.assign_add(device.loss_scaling_var,
                                                      self.loss_scaling_inc)
                    ls_dec_op = lambda: tf.assign_sub(device.loss_scaling_var,
                                                      self.loss_scaling_dec)
                    ls_update_op = lambda: tf.group(
                        tf.cond(all_ok, ls_inc_op, ls_dec_op))
                    all_ops.append(tf.cond(acc_ok, ls_update_op, tf.no_op))

                # Last device => report statistics.
                if device_idx == len(self._devices) - 1:
                    all_ops.append(
                        autosummary.autosummary(self.id + "/learning_rate",
                                                self.learning_rate))
                    all_ops.append(
                        autosummary.autosummary(self.id +
                                                "/overflow_frequency",
                                                tf.where(all_ok, 0, 1),
                                                condition=acc_ok))
                    if self.use_loss_scaling:
                        all_ops.append(
                            autosummary.autosummary(
                                self.id + "/loss_scaling_log2",
                                device.loss_scaling_var))

        # Initialize variables.
        self.reset_optimizer_state()
        if self.use_loss_scaling:
            tfutil.init_uninitialized_vars(
                [device.loss_scaling_var for device in self._devices.values()])
        if self.minibatch_multiplier is not None:
            tfutil.run([
                var.initializer for device in self._devices.values()
                for var in list(device.grad_acc_vars.values()) +
                [device.grad_acc_count]
            ])

        # Group everything into a single op.
        with tfutil.absolute_name_scope(self.scope):
            return tf.group(*all_ops, name="TrainingOp")
Ejemplo n.º 6
0
 def reset_trainables(self) -> None:
     """Re-initialize all trainable variables of this network, including sub-networks."""
     tfutil.run([var.initializer for var in self.trainables.values()])
Ejemplo n.º 7
0
 def reset_own_vars(self) -> None:
     """Re-initialize all variables of this network, excluding sub-networks."""
     tfutil.run([var.initializer for var in self.own_vars.values()])