def reset_optimizer_state(self) -> None: """Reset internal state of the underlying optimizer.""" tfutil.assert_tf_initialized() tfutil.run([ var.initializer for device in self._devices.values() for var in device.optimizer.variables() ])
def __init__(self, name: str = None, func_name: Any = None, **static_kwargs): tfutil.assert_tf_initialized() assert isinstance(name, str) or name is None assert func_name is not None assert isinstance(func_name, str) or util.is_top_level_function(func_name) assert util.is_pickleable(static_kwargs) self._init_fields() self.name = name self.static_kwargs = util.EasyDict(static_kwargs) # Locate the user-specified network build function. if util.is_top_level_function(func_name): func_name = util.get_top_level_function_name(func_name) module, self._build_func_name = util.get_module_from_obj_name( func_name) self._build_func = util.get_obj_from_module(module, self._build_func_name) assert callable(self._build_func) # Dig up source code for the module containing the build function. self._build_module_src = _import_module_src.get(module, None) if self._build_module_src is None: self._build_module_src = inspect.getsource(module) # Init TensorFlow graph. self._init_graph() self.reset_own_vars()
def finalize_autosummaries() -> None: """Create the necessary ops to include autosummaries in TensorBoard report. Note: This should be done only once per graph. """ global _finalized tfutil.assert_tf_initialized() if _finalized: return None _finalized = True tfutil.init_uninitialized_vars([var for vars_list in _vars.values() for var in vars_list]) # Create summary ops. with tf.device(None), tf.control_dependencies(None): for name, vars_list in _vars.items(): name_id = name.replace("/", "_") with tfutil.absolute_name_scope("Autosummary/" + name_id): moments = tf.add_n(vars_list) moments /= moments[0] with tf.control_dependencies([moments]): # read before resetting reset_ops = [tf.assign(var, tf.zeros(3, dtype=_dtype)) for var in vars_list] with tf.name_scope(None), tf.control_dependencies(reset_ops): # reset before reporting mean = moments[1] std = tf.sqrt(moments[2] - tf.square(moments[1])) tf.summary.scalar(name, mean) if enable_custom_scalars: tf.summary.scalar("xCustomScalars/" + name + "/margin_lo", mean - std) tf.summary.scalar("xCustomScalars/" + name + "/margin_hi", mean + std) # Setup layout for custom scalars. layout = None if enable_custom_scalars: cat_dict = OrderedDict() for series_name in sorted(_vars.keys()): p = series_name.split("/") cat = p[0] if len(p) >= 2 else "" chart = "/".join(p[1:-1]) if len(p) >= 3 else p[-1] if cat not in cat_dict: cat_dict[cat] = OrderedDict() if chart not in cat_dict[cat]: cat_dict[cat][chart] = [] cat_dict[cat][chart].append(series_name) categories = [] for cat_name, chart_dict in cat_dict.items(): charts = [] for chart_name, series_names in chart_dict.items(): series = [] for series_name in series_names: series.append(layout_pb2.MarginChartContent.Series( value=series_name, lower="xCustomScalars/" + series_name + "/margin_lo", upper="xCustomScalars/" + series_name + "/margin_hi")) margin = layout_pb2.MarginChartContent(series=series) charts.append(layout_pb2.Chart(title=chart_name, margin=margin)) categories.append(layout_pb2.Category(title=cat_name, chart=charts)) layout = summary_lib.custom_scalar_pb(layout_pb2.Layout(category=categories)) return layout
def register_gradients(self, loss: TfExpression, trainable_vars: Union[List, dict]) -> None: """Register the gradients of the given loss function with respect to the given variables. Intended to be called once per GPU.""" tfutil.assert_tf_initialized() assert not self._updates_applied device = self._get_device(loss.device) # Validate trainables. if isinstance(trainable_vars, dict): trainable_vars = list(trainable_vars.values( )) # allow passing in Network.trainables as vars assert isinstance(trainable_vars, list) and len(trainable_vars) >= 1 assert all( tfutil.is_tf_expression(expr) for expr in trainable_vars + [loss]) assert all(var.device == device.name for var in trainable_vars) # Validate shapes. if self._gradient_shapes is None: self._gradient_shapes = [ var.shape.as_list() for var in trainable_vars ] assert len(trainable_vars) == len(self._gradient_shapes) assert all( var.shape.as_list() == var_shape for var, var_shape in zip(trainable_vars, self._gradient_shapes)) # Report memory usage if requested. deps = [] if self._report_mem_usage: self._report_mem_usage = False try: with tf.name_scope(self.id + '_mem'), tf.device( device.name), tf.control_dependencies([loss]): deps.append( autosummary.autosummary( self.id + "/mem_usage_gb", tf.contrib.memory_stats.BytesInUse() / 2**30)) except tf.errors.NotFoundError: pass # Compute gradients. with tf.name_scope(self.id + "_grad"), tf.device( device.name), tf.control_dependencies(deps): loss = self.apply_loss_scaling(tf.cast(loss, tf.float32)) gate = tf.train.Optimizer.GATE_NONE # disable gating to reduce memory usage grad_list = device.optimizer.compute_gradients( loss=loss, var_list=trainable_vars, gate_gradients=gate) # Register gradients. for grad, var in grad_list: if var not in device.grad_raw: device.grad_raw[var] = [] device.grad_raw[var].append(grad)
def save_summaries(file_writer, global_step=None): """Call FileWriter.add_summary() with all summaries in the default graph, automatically finalizing and merging them on the first call. """ global _merge_op tfutil.assert_tf_initialized() if _merge_op is None: layout = finalize_autosummaries() if layout is not None: file_writer.add_summary(layout) with tf.device(None), tf.control_dependencies(None): _merge_op = tf.summary.merge_all() file_writer.add_summary(_merge_op.eval(), global_step)
def _get_device(self, device_name: str): """Get internal state for the given TensorFlow device.""" tfutil.assert_tf_initialized() if device_name in self._devices: return self._devices[device_name] # Initialize fields. device = util.EasyDict() device.name = device_name device.optimizer = None # Underlying optimizer: optimizer_class device.loss_scaling_var = None # Log2 of loss scaling: tf.Variable device.grad_raw = OrderedDict( ) # Raw gradients: var => [grad, ...] device.grad_clean = OrderedDict( ) # Clean gradients: var => grad device.grad_acc_vars = OrderedDict( ) # Accumulation sums: var => tf.Variable device.grad_acc_count = None # Accumulation counter: tf.Variable device.grad_acc = OrderedDict( ) # Accumulated gradients: var => grad # Setup TensorFlow objects. with tfutil.absolute_name_scope(self.scope + "/Devices"), tf.device( device_name), tf.control_dependencies(None): if device_name not in self._shared_optimizers: optimizer_name = self.scope.replace( "/", "_") + "_opt%d" % len(self._shared_optimizers) self._shared_optimizers[device_name] = self.optimizer_class( name=optimizer_name, learning_rate=self.learning_rate, **self.optimizer_kwargs) device.optimizer = self._shared_optimizers[device_name] if self.use_loss_scaling: device.loss_scaling_var = tf.Variable(np.float32( self.loss_scaling_init), trainable=False, name="loss_scaling_var") # Register device. self._devices[device_name] = device return device
def autosummary(name: str, value: TfExpressionEx, passthru: TfExpressionEx = None, condition: TfExpressionEx = True) -> TfExpressionEx: """Create a new autosummary. Args: name: Name to use in TensorBoard value: TensorFlow expression or python value to track passthru: Optionally return this TF node without modifications but tack an autosummary update side-effect to this node. Example use of the passthru mechanism: n = autosummary('l2loss', loss, passthru=n) This is a shorthand for the following code: with tf.control_dependencies([autosummary('l2loss', loss)]): n = tf.identity(n) """ tfutil.assert_tf_initialized() name_id = name.replace("/", "_") if tfutil.is_tf_expression(value): with tf.name_scope("summary_" + name_id), tf.device(value.device): condition = tf.convert_to_tensor(condition, name='condition') update_op = tf.cond(condition, lambda: tf.group(_create_var(name, value)), tf.no_op) with tf.control_dependencies([update_op]): return tf.identity(value if passthru is None else passthru) else: # python scalar or numpy array assert not tfutil.is_tf_expression(passthru) assert not tfutil.is_tf_expression(condition) if condition: if name not in _immediate: with tfutil.absolute_name_scope("Autosummary/" + name_id), tf.device(None), tf.control_dependencies(None): update_value = tf.placeholder(_dtype) update_op = _create_var(name, update_value) _immediate[name] = update_op, update_value update_op, update_value = _immediate[name] tfutil.run(update_op, {update_value: value}) return value if passthru is None else passthru
def __setstate__(self, state: dict) -> None: """Pickle import.""" # pylint: disable=attribute-defined-outside-init tfutil.assert_tf_initialized() self._init_fields() # Execute custom import handlers. for handler in _import_handlers: state = handler(state) # Set basic fields. assert state["version"] in [2, 3, 4] self.name = state["name"] self.static_kwargs = util.EasyDict(state["static_kwargs"]) self.components = util.EasyDict(state.get("components", {})) self._build_module_src = state["build_module_src"] self._build_func_name = state["build_func_name"] # Create temporary module from the imported source code. module_name = "_tflib_network_import_" + uuid.uuid4().hex module = types.ModuleType(module_name) sys.modules[module_name] = module _import_module_src[module] = self._build_module_src exec(self._build_module_src, module.__dict__) # pylint: disable=exec-used # Locate network build function in the temporary module. self._build_func = util.get_obj_from_module(module, self._build_func_name) assert callable(self._build_func) # Init TensorFlow graph. self._init_graph() self.reset_own_vars() tfutil.set_vars( {self.find_var(name): value for name, value in state["variables"]})
def apply_updates(self, allow_no_op: bool = False) -> tf.Operation: """Construct training op to update the registered variables based on their gradients.""" tfutil.assert_tf_initialized() assert not self._updates_applied self._updates_applied = True all_ops = [] # Check for no-op. if allow_no_op and len(self._devices) == 0: with tfutil.absolute_name_scope(self.scope): return tf.no_op(name='TrainingOp') # Clean up gradients. for device_idx, device in enumerate(self._devices.values()): with tfutil.absolute_name_scope(self.scope + "/Clean%d" % device_idx), tf.device( device.name): for var, grad in device.grad_raw.items(): # Filter out disconnected gradients and convert to float32. grad = [g for g in grad if g is not None] grad = [tf.cast(g, tf.float32) for g in grad] # Sum within the device. if len(grad) == 0: grad = tf.zeros(var.shape) # No gradients => zero. elif len(grad) == 1: grad = grad[0] # Single gradient => use as is. else: grad = tf.add_n(grad) # Multiple gradients => sum. # Scale as needed. scale = 1.0 / len(device.grad_raw[var]) / len( self._devices) scale = tf.constant(scale, dtype=tf.float32, name="scale") if self.minibatch_multiplier is not None: scale /= tf.cast(self.minibatch_multiplier, tf.float32) scale = self.undo_loss_scaling(scale) device.grad_clean[var] = grad * scale # Sum gradients across devices. if len(self._devices) > 1: with tfutil.absolute_name_scope(self.scope + "/Broadcast"), tf.device(None): for all_vars in zip(*[ device.grad_clean.keys() for device in self._devices.values() ]): if len(all_vars) > 0 and all( dim > 0 for dim in all_vars[0].shape.as_list() ): # NCCL does not support zero-sized tensors. all_grads = [ device.grad_clean[var] for device, var in zip( self._devices.values(), all_vars) ] all_grads = nccl_ops.all_sum(all_grads) for device, var, grad in zip(self._devices.values(), all_vars, all_grads): device.grad_clean[var] = grad # Apply updates separately on each device. for device_idx, device in enumerate(self._devices.values()): with tfutil.absolute_name_scope(self.scope + "/Apply%d" % device_idx), tf.device( device.name): # pylint: disable=cell-var-from-loop # Accumulate gradients over time. if self.minibatch_multiplier is None: acc_ok = tf.constant(True, name='acc_ok') device.grad_acc = OrderedDict(device.grad_clean) else: # Create variables. with tf.control_dependencies(None): for var in device.grad_clean.keys(): device.grad_acc_vars[var] = tf.Variable( tf.zeros(var.shape), trainable=False, name="grad_acc_var") device.grad_acc_count = tf.Variable( tf.zeros([]), trainable=False, name="grad_acc_count") # Track counter. count_cur = device.grad_acc_count + 1.0 count_inc_op = lambda: tf.assign(device.grad_acc_count, count_cur) count_reset_op = lambda: tf.assign(device.grad_acc_count, tf.zeros([])) acc_ok = (count_cur >= tf.cast(self.minibatch_multiplier, tf.float32)) all_ops.append( tf.cond(acc_ok, count_reset_op, count_inc_op)) # Track gradients. for var, grad in device.grad_clean.items(): acc_var = device.grad_acc_vars[var] acc_cur = acc_var + grad device.grad_acc[var] = acc_cur with tf.control_dependencies([acc_cur]): acc_inc_op = lambda: tf.assign(acc_var, acc_cur) acc_reset_op = lambda: tf.assign( acc_var, tf.zeros(var.shape)) all_ops.append( tf.cond(acc_ok, acc_reset_op, acc_inc_op)) # No overflow => apply gradients. all_ok = tf.reduce_all( tf.stack([acc_ok] + [ tf.reduce_all(tf.is_finite(g)) for g in device.grad_acc.values() ])) apply_op = lambda: device.optimizer.apply_gradients( [(tf.cast(grad, var.dtype), var) for var, grad in device.grad_acc.items()]) all_ops.append(tf.cond(all_ok, apply_op, tf.no_op)) # Adjust loss scaling. if self.use_loss_scaling: ls_inc_op = lambda: tf.assign_add(device.loss_scaling_var, self.loss_scaling_inc) ls_dec_op = lambda: tf.assign_sub(device.loss_scaling_var, self.loss_scaling_dec) ls_update_op = lambda: tf.group( tf.cond(all_ok, ls_inc_op, ls_dec_op)) all_ops.append(tf.cond(acc_ok, ls_update_op, tf.no_op)) # Last device => report statistics. if device_idx == len(self._devices) - 1: all_ops.append( autosummary.autosummary(self.id + "/learning_rate", self.learning_rate)) all_ops.append( autosummary.autosummary(self.id + "/overflow_frequency", tf.where(all_ok, 0, 1), condition=acc_ok)) if self.use_loss_scaling: all_ops.append( autosummary.autosummary( self.id + "/loss_scaling_log2", device.loss_scaling_var)) # Initialize variables. self.reset_optimizer_state() if self.use_loss_scaling: tfutil.init_uninitialized_vars( [device.loss_scaling_var for device in self._devices.values()]) if self.minibatch_multiplier is not None: tfutil.run([ var.initializer for device in self._devices.values() for var in list(device.grad_acc_vars.values()) + [device.grad_acc_count] ]) # Group everything into a single op. with tfutil.absolute_name_scope(self.scope): return tf.group(*all_ops, name="TrainingOp")