Example #1
0
 def reset_optimizer_state(self) -> None:
     """Reset internal state of the underlying optimizer."""
     tfutil.assert_tf_initialized()
     tfutil.run([
         var.initializer for device in self._devices.values()
         for var in device.optimizer.variables()
     ])
Example #2
0
    def __init__(self,
                 name: str = None,
                 func_name: Any = None,
                 **static_kwargs):
        tfutil.assert_tf_initialized()
        assert isinstance(name, str) or name is None
        assert func_name is not None
        assert isinstance(func_name,
                          str) or util.is_top_level_function(func_name)
        assert util.is_pickleable(static_kwargs)

        self._init_fields()
        self.name = name
        self.static_kwargs = util.EasyDict(static_kwargs)

        # Locate the user-specified network build function.
        if util.is_top_level_function(func_name):
            func_name = util.get_top_level_function_name(func_name)
        module, self._build_func_name = util.get_module_from_obj_name(
            func_name)
        self._build_func = util.get_obj_from_module(module,
                                                    self._build_func_name)
        assert callable(self._build_func)

        # Dig up source code for the module containing the build function.
        self._build_module_src = _import_module_src.get(module, None)
        if self._build_module_src is None:
            self._build_module_src = inspect.getsource(module)

        # Init TensorFlow graph.
        self._init_graph()
        self.reset_own_vars()
Example #3
0
def finalize_autosummaries() -> None:
    """Create the necessary ops to include autosummaries in TensorBoard report.
    Note: This should be done only once per graph.
    """
    global _finalized
    tfutil.assert_tf_initialized()

    if _finalized:
        return None

    _finalized = True
    tfutil.init_uninitialized_vars([var for vars_list in _vars.values() for var in vars_list])

    # Create summary ops.
    with tf.device(None), tf.control_dependencies(None):
        for name, vars_list in _vars.items():
            name_id = name.replace("/", "_")
            with tfutil.absolute_name_scope("Autosummary/" + name_id):
                moments = tf.add_n(vars_list)
                moments /= moments[0]
                with tf.control_dependencies([moments]):  # read before resetting
                    reset_ops = [tf.assign(var, tf.zeros(3, dtype=_dtype)) for var in vars_list]
                    with tf.name_scope(None), tf.control_dependencies(reset_ops):  # reset before reporting
                        mean = moments[1]
                        std = tf.sqrt(moments[2] - tf.square(moments[1]))
                        tf.summary.scalar(name, mean)
                        if enable_custom_scalars:
                            tf.summary.scalar("xCustomScalars/" + name + "/margin_lo", mean - std)
                            tf.summary.scalar("xCustomScalars/" + name + "/margin_hi", mean + std)

    # Setup layout for custom scalars.
    layout = None
    if enable_custom_scalars:
        cat_dict = OrderedDict()
        for series_name in sorted(_vars.keys()):
            p = series_name.split("/")
            cat = p[0] if len(p) >= 2 else ""
            chart = "/".join(p[1:-1]) if len(p) >= 3 else p[-1]
            if cat not in cat_dict:
                cat_dict[cat] = OrderedDict()
            if chart not in cat_dict[cat]:
                cat_dict[cat][chart] = []
            cat_dict[cat][chart].append(series_name)
        categories = []
        for cat_name, chart_dict in cat_dict.items():
            charts = []
            for chart_name, series_names in chart_dict.items():
                series = []
                for series_name in series_names:
                    series.append(layout_pb2.MarginChartContent.Series(
                        value=series_name,
                        lower="xCustomScalars/" + series_name + "/margin_lo",
                        upper="xCustomScalars/" + series_name + "/margin_hi"))
                margin = layout_pb2.MarginChartContent(series=series)
                charts.append(layout_pb2.Chart(title=chart_name, margin=margin))
            categories.append(layout_pb2.Category(title=cat_name, chart=charts))
        layout = summary_lib.custom_scalar_pb(layout_pb2.Layout(category=categories))
    return layout
Example #4
0
    def register_gradients(self, loss: TfExpression,
                           trainable_vars: Union[List, dict]) -> None:
        """Register the gradients of the given loss function with respect to the given variables.
        Intended to be called once per GPU."""
        tfutil.assert_tf_initialized()
        assert not self._updates_applied
        device = self._get_device(loss.device)

        # Validate trainables.
        if isinstance(trainable_vars, dict):
            trainable_vars = list(trainable_vars.values(
            ))  # allow passing in Network.trainables as vars
        assert isinstance(trainable_vars, list) and len(trainable_vars) >= 1
        assert all(
            tfutil.is_tf_expression(expr) for expr in trainable_vars + [loss])
        assert all(var.device == device.name for var in trainable_vars)

        # Validate shapes.
        if self._gradient_shapes is None:
            self._gradient_shapes = [
                var.shape.as_list() for var in trainable_vars
            ]
        assert len(trainable_vars) == len(self._gradient_shapes)
        assert all(
            var.shape.as_list() == var_shape
            for var, var_shape in zip(trainable_vars, self._gradient_shapes))

        # Report memory usage if requested.
        deps = []
        if self._report_mem_usage:
            self._report_mem_usage = False
            try:
                with tf.name_scope(self.id + '_mem'), tf.device(
                        device.name), tf.control_dependencies([loss]):
                    deps.append(
                        autosummary.autosummary(
                            self.id + "/mem_usage_gb",
                            tf.contrib.memory_stats.BytesInUse() / 2**30))
            except tf.errors.NotFoundError:
                pass

        # Compute gradients.
        with tf.name_scope(self.id + "_grad"), tf.device(
                device.name), tf.control_dependencies(deps):
            loss = self.apply_loss_scaling(tf.cast(loss, tf.float32))
            gate = tf.train.Optimizer.GATE_NONE  # disable gating to reduce memory usage
            grad_list = device.optimizer.compute_gradients(
                loss=loss, var_list=trainable_vars, gate_gradients=gate)

        # Register gradients.
        for grad, var in grad_list:
            if var not in device.grad_raw:
                device.grad_raw[var] = []
            device.grad_raw[var].append(grad)
Example #5
0
def save_summaries(file_writer, global_step=None):
    """Call FileWriter.add_summary() with all summaries in the default graph,
    automatically finalizing and merging them on the first call.
    """
    global _merge_op
    tfutil.assert_tf_initialized()

    if _merge_op is None:
        layout = finalize_autosummaries()
        if layout is not None:
            file_writer.add_summary(layout)
        with tf.device(None), tf.control_dependencies(None):
            _merge_op = tf.summary.merge_all()

    file_writer.add_summary(_merge_op.eval(), global_step)
Example #6
0
    def _get_device(self, device_name: str):
        """Get internal state for the given TensorFlow device."""
        tfutil.assert_tf_initialized()
        if device_name in self._devices:
            return self._devices[device_name]

        # Initialize fields.
        device = util.EasyDict()
        device.name = device_name
        device.optimizer = None  # Underlying optimizer:     optimizer_class
        device.loss_scaling_var = None  # Log2 of loss scaling:     tf.Variable
        device.grad_raw = OrderedDict(
        )  # Raw gradients:            var => [grad, ...]
        device.grad_clean = OrderedDict(
        )  # Clean gradients:          var => grad
        device.grad_acc_vars = OrderedDict(
        )  # Accumulation sums:        var => tf.Variable
        device.grad_acc_count = None  # Accumulation counter:     tf.Variable
        device.grad_acc = OrderedDict(
        )  # Accumulated gradients:    var => grad

        # Setup TensorFlow objects.
        with tfutil.absolute_name_scope(self.scope + "/Devices"), tf.device(
                device_name), tf.control_dependencies(None):
            if device_name not in self._shared_optimizers:
                optimizer_name = self.scope.replace(
                    "/", "_") + "_opt%d" % len(self._shared_optimizers)
                self._shared_optimizers[device_name] = self.optimizer_class(
                    name=optimizer_name,
                    learning_rate=self.learning_rate,
                    **self.optimizer_kwargs)
            device.optimizer = self._shared_optimizers[device_name]
            if self.use_loss_scaling:
                device.loss_scaling_var = tf.Variable(np.float32(
                    self.loss_scaling_init),
                                                      trainable=False,
                                                      name="loss_scaling_var")

        # Register device.
        self._devices[device_name] = device
        return device
Example #7
0
def autosummary(name: str, value: TfExpressionEx, passthru: TfExpressionEx = None, condition: TfExpressionEx = True) -> TfExpressionEx:
    """Create a new autosummary.

    Args:
        name:     Name to use in TensorBoard
        value:    TensorFlow expression or python value to track
        passthru: Optionally return this TF node without modifications but tack an autosummary update side-effect to this node.

    Example use of the passthru mechanism:

    n = autosummary('l2loss', loss, passthru=n)

    This is a shorthand for the following code:

    with tf.control_dependencies([autosummary('l2loss', loss)]):
        n = tf.identity(n)
    """
    tfutil.assert_tf_initialized()
    name_id = name.replace("/", "_")

    if tfutil.is_tf_expression(value):
        with tf.name_scope("summary_" + name_id), tf.device(value.device):
            condition = tf.convert_to_tensor(condition, name='condition')
            update_op = tf.cond(condition, lambda: tf.group(_create_var(name, value)), tf.no_op)
            with tf.control_dependencies([update_op]):
                return tf.identity(value if passthru is None else passthru)

    else:  # python scalar or numpy array
        assert not tfutil.is_tf_expression(passthru)
        assert not tfutil.is_tf_expression(condition)
        if condition:
            if name not in _immediate:
                with tfutil.absolute_name_scope("Autosummary/" + name_id), tf.device(None), tf.control_dependencies(None):
                    update_value = tf.placeholder(_dtype)
                    update_op = _create_var(name, update_value)
                    _immediate[name] = update_op, update_value
            update_op, update_value = _immediate[name]
            tfutil.run(update_op, {update_value: value})
        return value if passthru is None else passthru
Example #8
0
    def __setstate__(self, state: dict) -> None:
        """Pickle import."""
        # pylint: disable=attribute-defined-outside-init
        tfutil.assert_tf_initialized()
        self._init_fields()

        # Execute custom import handlers.
        for handler in _import_handlers:
            state = handler(state)

        # Set basic fields.
        assert state["version"] in [2, 3, 4]
        self.name = state["name"]
        self.static_kwargs = util.EasyDict(state["static_kwargs"])
        self.components = util.EasyDict(state.get("components", {}))
        self._build_module_src = state["build_module_src"]
        self._build_func_name = state["build_func_name"]

        # Create temporary module from the imported source code.
        module_name = "_tflib_network_import_" + uuid.uuid4().hex
        module = types.ModuleType(module_name)
        sys.modules[module_name] = module
        _import_module_src[module] = self._build_module_src
        exec(self._build_module_src, module.__dict__)  # pylint: disable=exec-used

        # Locate network build function in the temporary module.
        self._build_func = util.get_obj_from_module(module,
                                                    self._build_func_name)
        assert callable(self._build_func)

        # Init TensorFlow graph.
        self._init_graph()
        self.reset_own_vars()
        tfutil.set_vars(
            {self.find_var(name): value
             for name, value in state["variables"]})
Example #9
0
    def apply_updates(self, allow_no_op: bool = False) -> tf.Operation:
        """Construct training op to update the registered variables based on their gradients."""
        tfutil.assert_tf_initialized()
        assert not self._updates_applied
        self._updates_applied = True
        all_ops = []

        # Check for no-op.
        if allow_no_op and len(self._devices) == 0:
            with tfutil.absolute_name_scope(self.scope):
                return tf.no_op(name='TrainingOp')

        # Clean up gradients.
        for device_idx, device in enumerate(self._devices.values()):
            with tfutil.absolute_name_scope(self.scope + "/Clean%d" %
                                            device_idx), tf.device(
                                                device.name):
                for var, grad in device.grad_raw.items():

                    # Filter out disconnected gradients and convert to float32.
                    grad = [g for g in grad if g is not None]
                    grad = [tf.cast(g, tf.float32) for g in grad]

                    # Sum within the device.
                    if len(grad) == 0:
                        grad = tf.zeros(var.shape)  # No gradients => zero.
                    elif len(grad) == 1:
                        grad = grad[0]  # Single gradient => use as is.
                    else:
                        grad = tf.add_n(grad)  # Multiple gradients => sum.

                    # Scale as needed.
                    scale = 1.0 / len(device.grad_raw[var]) / len(
                        self._devices)
                    scale = tf.constant(scale, dtype=tf.float32, name="scale")
                    if self.minibatch_multiplier is not None:
                        scale /= tf.cast(self.minibatch_multiplier, tf.float32)
                    scale = self.undo_loss_scaling(scale)
                    device.grad_clean[var] = grad * scale

        # Sum gradients across devices.
        if len(self._devices) > 1:
            with tfutil.absolute_name_scope(self.scope +
                                            "/Broadcast"), tf.device(None):
                for all_vars in zip(*[
                        device.grad_clean.keys()
                        for device in self._devices.values()
                ]):
                    if len(all_vars) > 0 and all(
                            dim > 0 for dim in all_vars[0].shape.as_list()
                    ):  # NCCL does not support zero-sized tensors.
                        all_grads = [
                            device.grad_clean[var] for device, var in zip(
                                self._devices.values(), all_vars)
                        ]
                        all_grads = nccl_ops.all_sum(all_grads)
                        for device, var, grad in zip(self._devices.values(),
                                                     all_vars, all_grads):
                            device.grad_clean[var] = grad

        # Apply updates separately on each device.
        for device_idx, device in enumerate(self._devices.values()):
            with tfutil.absolute_name_scope(self.scope + "/Apply%d" %
                                            device_idx), tf.device(
                                                device.name):
                # pylint: disable=cell-var-from-loop

                # Accumulate gradients over time.
                if self.minibatch_multiplier is None:
                    acc_ok = tf.constant(True, name='acc_ok')
                    device.grad_acc = OrderedDict(device.grad_clean)
                else:
                    # Create variables.
                    with tf.control_dependencies(None):
                        for var in device.grad_clean.keys():
                            device.grad_acc_vars[var] = tf.Variable(
                                tf.zeros(var.shape),
                                trainable=False,
                                name="grad_acc_var")
                        device.grad_acc_count = tf.Variable(
                            tf.zeros([]),
                            trainable=False,
                            name="grad_acc_count")

                    # Track counter.
                    count_cur = device.grad_acc_count + 1.0
                    count_inc_op = lambda: tf.assign(device.grad_acc_count,
                                                     count_cur)
                    count_reset_op = lambda: tf.assign(device.grad_acc_count,
                                                       tf.zeros([]))
                    acc_ok = (count_cur >= tf.cast(self.minibatch_multiplier,
                                                   tf.float32))
                    all_ops.append(
                        tf.cond(acc_ok, count_reset_op, count_inc_op))

                    # Track gradients.
                    for var, grad in device.grad_clean.items():
                        acc_var = device.grad_acc_vars[var]
                        acc_cur = acc_var + grad
                        device.grad_acc[var] = acc_cur
                        with tf.control_dependencies([acc_cur]):
                            acc_inc_op = lambda: tf.assign(acc_var, acc_cur)
                            acc_reset_op = lambda: tf.assign(
                                acc_var, tf.zeros(var.shape))
                            all_ops.append(
                                tf.cond(acc_ok, acc_reset_op, acc_inc_op))

                # No overflow => apply gradients.
                all_ok = tf.reduce_all(
                    tf.stack([acc_ok] + [
                        tf.reduce_all(tf.is_finite(g))
                        for g in device.grad_acc.values()
                    ]))
                apply_op = lambda: device.optimizer.apply_gradients(
                    [(tf.cast(grad, var.dtype), var)
                     for var, grad in device.grad_acc.items()])
                all_ops.append(tf.cond(all_ok, apply_op, tf.no_op))

                # Adjust loss scaling.
                if self.use_loss_scaling:
                    ls_inc_op = lambda: tf.assign_add(device.loss_scaling_var,
                                                      self.loss_scaling_inc)
                    ls_dec_op = lambda: tf.assign_sub(device.loss_scaling_var,
                                                      self.loss_scaling_dec)
                    ls_update_op = lambda: tf.group(
                        tf.cond(all_ok, ls_inc_op, ls_dec_op))
                    all_ops.append(tf.cond(acc_ok, ls_update_op, tf.no_op))

                # Last device => report statistics.
                if device_idx == len(self._devices) - 1:
                    all_ops.append(
                        autosummary.autosummary(self.id + "/learning_rate",
                                                self.learning_rate))
                    all_ops.append(
                        autosummary.autosummary(self.id +
                                                "/overflow_frequency",
                                                tf.where(all_ok, 0, 1),
                                                condition=acc_ok))
                    if self.use_loss_scaling:
                        all_ops.append(
                            autosummary.autosummary(
                                self.id + "/loss_scaling_log2",
                                device.loss_scaling_var))

        # Initialize variables.
        self.reset_optimizer_state()
        if self.use_loss_scaling:
            tfutil.init_uninitialized_vars(
                [device.loss_scaling_var for device in self._devices.values()])
        if self.minibatch_multiplier is not None:
            tfutil.run([
                var.initializer for device in self._devices.values()
                for var in list(device.grad_acc_vars.values()) +
                [device.grad_acc_count]
            ])

        # Group everything into a single op.
        with tfutil.absolute_name_scope(self.scope):
            return tf.group(*all_ops, name="TrainingOp")