Exemplo n.º 1
0
 def reset_optimizer_state(self) -> None:
     """Reset internal state of the underlying optimizer."""
     tfutil.assert_tf_initialized()
     tfutil.run([
         var.initializer for opt in self._dev_opt.values()
         for var in opt.variables()
     ])
Exemplo n.º 2
0
    def __init__(self,
                 name: str = None,
                 func_name: Any = None,
                 **static_kwargs):
        tfutil.assert_tf_initialized()
        assert isinstance(name, str) or name is None
        assert func_name is not None
        assert isinstance(func_name,
                          str) or util.is_top_level_function(func_name)
        assert util.is_pickleable(static_kwargs)

        self._init_fields()
        self.name = name
        self.static_kwargs = util.EasyDict(static_kwargs)

        # Locate the user-specified network build function.
        if util.is_top_level_function(func_name):
            func_name = util.get_top_level_function_name(func_name)
        module, self._build_func_name = util.get_module_from_obj_name(
            func_name)
        self._build_func = util.get_obj_from_module(module,
                                                    self._build_func_name)
        assert callable(self._build_func)

        # Dig up source code for the module containing the build function.
        self._build_module_src = _import_module_src.get(module, None)
        if self._build_module_src is None:
            self._build_module_src = inspect.getsource(module)

        # Init TensorFlow graph.
        self._init_graph()
        self.reset_own_vars()
Exemplo n.º 3
0
def save_summaries(file_writer, global_step=None):
    """Call FileWriter.add_summary() with all summaries in the default graph,
    automatically finalizing and merging them on the first call.
    """
    global _merge_op
    tfutil.assert_tf_initialized()

    if _merge_op is None:
        layout = finalize_autosummaries()
        if layout is not None:
            file_writer.add_summary(layout)
        with tf.device(None), tf.control_dependencies(None):
            _merge_op = tf.summary.merge_all()

    file_writer.add_summary(_merge_op.eval(), global_step)
Exemplo n.º 4
0
def autosummary(name: str,
                value: TfExpressionEx,
                passthru: TfExpressionEx = None) -> TfExpressionEx:
    """Create a new autosummary.

    Args:
        name:     Name to use in TensorBoard
        value:    TensorFlow expression or python value to track
        passthru: Optionally return this TF node without modifications but tack an autosummary update side-effect to this node.

    Example use of the passthru mechanism:

    n = autosummary('l2loss', loss, passthru=n)

    This is a shorthand for the following code:

    with tf.control_dependencies([autosummary('l2loss', loss)]):
        n = tf.identity(n)
    """
    tfutil.assert_tf_initialized()
    name_id = name.replace("/", "_")

    if tfutil.is_tf_expression(value):
        with tf.name_scope("summary_" + name_id), tf.device(value.device):
            update_op = _create_var(name, value)
            with tf.control_dependencies([update_op]):
                return tf.identity(value if passthru is None else passthru)

    else:  # python scalar or numpy array
        if name not in _immediate:
            with tfutil.absolute_name_scope(
                    "Autosummary/" +
                    name_id), tf.device(None), tf.control_dependencies(None):
                update_value = tf.placeholder(_dtype)
                update_op = _create_var(name, update_value)
                _immediate[name] = update_op, update_value

        update_op, update_value = _immediate[name]
        tfutil.run(update_op, {update_value: value})
        return value if passthru is None else passthru
Exemplo n.º 5
0
    def __setstate__(self, state: dict) -> None:
        """Pickle import."""
        # pylint: disable=attribute-defined-outside-init
        tfutil.assert_tf_initialized()
        self._init_fields()

        # Execute custom import handlers.
        for handler in _import_handlers:
            state = handler(state)

        # Set basic fields.
        assert state["version"] in [2, 3]
        self.name = state["name"]
        self.static_kwargs = util.EasyDict(state["static_kwargs"])
        self.components = util.EasyDict(state.get("components", {}))
        self._build_module_src = state["build_module_src"]
        self._build_func_name = state["build_func_name"]

        # Create temporary module from the imported source code.
        module_name = "_tflib_network_import_" + uuid.uuid4().hex
        module = types.ModuleType(module_name)
        sys.modules[module_name] = module
        _import_module_src[module] = self._build_module_src
        exec(self._build_module_src, module.__dict__)  # pylint: disable=exec-used

        # Locate network build function in the temporary module.
        self._build_func = util.get_obj_from_module(module,
                                                    self._build_func_name)
        assert callable(self._build_func)

        # Init TensorFlow graph.
        self._init_graph()
        self.reset_own_vars()
        tfutil.set_vars(
            {self.find_var(name): value
             for name, value in state["variables"]})
Exemplo n.º 6
0
def finalize_autosummaries() -> None:
    """Create the necessary ops to include autosummaries in TensorBoard report.
    Note: This should be done only once per graph.
    """
    global _finalized
    tfutil.assert_tf_initialized()

    if _finalized:
        return None

    _finalized = True
    tfutil.init_uninitialized_vars(
        [var for vars_list in _vars.values() for var in vars_list])

    # Create summary ops.
    with tf.device(None), tf.control_dependencies(None):
        for name, vars_list in _vars.items():
            name_id = name.replace("/", "_")
            with tfutil.absolute_name_scope("Autosummary/" + name_id):
                moments = tf.add_n(vars_list)
                moments /= moments[0]
                with tf.control_dependencies([moments
                                              ]):  # read before resetting
                    reset_ops = [
                        tf.assign(var, tf.zeros(3, dtype=_dtype))
                        for var in vars_list
                    ]
                    with tf.name_scope(None), tf.control_dependencies(
                            reset_ops):  # reset before reporting
                        mean = moments[1]
                        std = tf.sqrt(moments[2] - tf.square(moments[1]))
                        tf.summary.scalar(name, mean)
                        tf.summary.scalar(
                            "xCustomScalars/" + name + "/margin_lo",
                            mean - std)
                        tf.summary.scalar(
                            "xCustomScalars/" + name + "/margin_hi",
                            mean + std)

    # Group by category and chart name.
    cat_dict = OrderedDict()
    for series_name in sorted(_vars.keys()):
        p = series_name.split("/")
        cat = p[0] if len(p) >= 2 else ""
        chart = "/".join(p[1:-1]) if len(p) >= 3 else p[-1]
        if cat not in cat_dict:
            cat_dict[cat] = OrderedDict()
        if chart not in cat_dict[cat]:
            cat_dict[cat][chart] = []
        cat_dict[cat][chart].append(series_name)

    # Setup custom_scalar layout.
    categories = []
    for cat_name, chart_dict in cat_dict.items():
        charts = []
        for chart_name, series_names in chart_dict.items():
            series = []
            for series_name in series_names:
                series.append(
                    layout_pb2.MarginChartContent.Series(
                        value=series_name,
                        lower="xCustomScalars/" + series_name + "/margin_lo",
                        upper="xCustomScalars/" + series_name + "/margin_hi"))
            margin = layout_pb2.MarginChartContent(series=series)
            charts.append(layout_pb2.Chart(title=chart_name, margin=margin))
        categories.append(layout_pb2.Category(title=cat_name, chart=charts))
    layout = summary_lib.custom_scalar_pb(
        layout_pb2.Layout(category=categories))
    return layout
Exemplo n.º 7
0
    def apply_updates(self) -> tf.Operation:
        """Construct training op to update the registered variables based on their gradients."""
        tfutil.assert_tf_initialized()
        assert not self._updates_applied
        self._updates_applied = True
        devices = list(self._dev_grads.keys())
        total_grads = sum(len(grads) for grads in self._dev_grads.values())
        assert len(devices) >= 1 and total_grads >= 1
        ops = []

        with tfutil.absolute_name_scope(self.scope):
            # Cast gradients to FP32 and calculate partial sum within each device.
            dev_grads = OrderedDict()  # device => [(grad, var), ...]

            for dev_idx, dev in enumerate(devices):
                with tf.name_scope("ProcessGrads%d" % dev_idx), tf.device(dev):
                    sums = []

                    for gv in zip(*self._dev_grads[dev]):
                        assert all(v is gv[0][1] for g, v in gv)
                        g = [tf.cast(g, tf.float32) for g, v in gv]
                        g = g[0] if len(g) == 1 else tf.add_n(g)
                        sums.append((g, gv[0][1]))

                    dev_grads[dev] = sums

            # Sum gradients across devices.
            if len(devices) > 1:
                with tf.name_scope("SumAcrossGPUs"), tf.device(None):
                    for var_idx, grad_shape in enumerate(self._grad_shapes):
                        g = [dev_grads[dev][var_idx][0] for dev in devices]

                        if np.prod(
                                grad_shape
                        ):  # nccl does not support zero-sized tensors
                            g = nccl_ops.all_sum(g)

                        for dev, gg in zip(devices, g):
                            dev_grads[dev][var_idx] = (
                                gg, dev_grads[dev][var_idx][1])

            # Apply updates separately on each device.
            for dev_idx, (dev, grads) in enumerate(dev_grads.items()):
                with tf.name_scope("ApplyGrads%d" % dev_idx), tf.device(dev):
                    # Scale gradients as needed.
                    if self.use_loss_scaling or total_grads > 1:
                        with tf.name_scope("Scale"):
                            coef = tf.constant(np.float32(1.0 / total_grads),
                                               name="coef")
                            coef = self.undo_loss_scaling(coef)
                            grads = [(g * coef, v) for g, v in grads]

                    # Check for overflows.
                    with tf.name_scope("CheckOverflow"):
                        grad_ok = tf.reduce_all(
                            tf.stack([
                                tf.reduce_all(tf.is_finite(g))
                                for g, v in grads
                            ]))

                    # Update weights and adjust loss scaling.
                    with tf.name_scope("UpdateWeights"):
                        # pylint: disable=cell-var-from-loop
                        opt = self._dev_opt[dev]
                        ls_var = self.get_loss_scaling_var(dev)

                        if not self.use_loss_scaling:
                            ops.append(
                                tf.cond(grad_ok,
                                        lambda: opt.apply_gradients(grads),
                                        tf.no_op))
                        else:
                            ops.append(
                                tf.cond(
                                    grad_ok, lambda: tf.group(
                                        tf.assign_add(ls_var, self.
                                                      loss_scaling_inc),
                                        opt.apply_gradients(grads)),
                                    lambda: tf.group(
                                        tf.assign_sub(ls_var, self.
                                                      loss_scaling_dec))))

                    # Report statistics on the last device.
                    if dev == devices[-1]:
                        with tf.name_scope("Statistics"):
                            ops.append(
                                autosummary.autosummary(
                                    self.id + "/learning_rate",
                                    self.learning_rate))
                            ops.append(
                                autosummary.autosummary(
                                    self.id + "/overflow_frequency",
                                    tf.where(grad_ok, 0, 1)))

                            if self.use_loss_scaling:
                                ops.append(
                                    autosummary.autosummary(
                                        self.id + "/loss_scaling_log2",
                                        ls_var))

            # Initialize variables and group everything into a single op.
            self.reset_optimizer_state()
            tfutil.init_uninitialized_vars(list(self._dev_ls_var.values()))

            return tf.group(*ops, name="TrainingOp")