def reset_optimizer_state(self) -> None: """Reset internal state of the underlying optimizer.""" tfutil.assert_tf_initialized() tfutil.run([ var.initializer for device in self._devices.values() for var in device.optimizer.variables() ])
def copy_own_vars_from(self, src_net: "Network") -> None: """Copy the values of all variables from the given network, excluding sub-networks.""" names = [ name for name in self.own_vars.keys() if name in src_net.own_vars ] tfutil.set_vars( tfutil.run({self.vars[name]: src_net.vars[name] for name in names}))
def autosummary(name: str, value: TfExpressionEx, passthru: TfExpressionEx = None, condition: TfExpressionEx = True) -> TfExpressionEx: """Create a new autosummary. Args: name: Name to use in TensorBoard value: TensorFlow expression or python value to track passthru: Optionally return this TF node without modifications but tack an autosummary update side-effect to this node. Example use of the passthru mechanism: n = autosummary('l2loss', loss, passthru=n) This is a shorthand for the following code: with tf.control_dependencies([autosummary('l2loss', loss)]): n = tf.identity(n) """ tfutil.assert_tf_initialized() name_id = name.replace("/", "_") if tfutil.is_tf_expression(value): with tf.name_scope("summary_" + name_id), tf.device(value.device): condition = tf.convert_to_tensor(condition, name='condition') update_op = tf.cond(condition, lambda: tf.group(_create_var(name, value)), tf.no_op) with tf.control_dependencies([update_op]): return tf.identity(value if passthru is None else passthru) else: # python scalar or numpy array assert not tfutil.is_tf_expression(passthru) assert not tfutil.is_tf_expression(condition) if condition: if name not in _immediate: with tfutil.absolute_name_scope("Autosummary/" + name_id), tf.device(None), tf.control_dependencies(None): update_value = tf.placeholder(_dtype) update_op = _create_var(name, update_value) _immediate[name] = update_op, update_value update_op, update_value = _immediate[name] tfutil.run(update_op, {update_value: value}) return value if passthru is None else passthru
def __getstate__(self) -> dict: """Pickle export.""" state = dict() state["version"] = 4 state["name"] = self.name state["static_kwargs"] = dict(self.static_kwargs) state["components"] = dict(self.components) state["build_module_src"] = self._build_module_src state["build_func_name"] = self._build_func_name state["variables"] = list( zip(self.own_vars.keys(), tfutil.run(list(self.own_vars.values())))) return state
def apply_updates(self, allow_no_op: bool = False) -> tf.Operation: """Construct training op to update the registered variables based on their gradients.""" tfutil.assert_tf_initialized() assert not self._updates_applied self._updates_applied = True all_ops = [] # Check for no-op. if allow_no_op and len(self._devices) == 0: with tfutil.absolute_name_scope(self.scope): return tf.no_op(name='TrainingOp') # Clean up gradients. for device_idx, device in enumerate(self._devices.values()): with tfutil.absolute_name_scope(self.scope + "/Clean%d" % device_idx), tf.device( device.name): for var, grad in device.grad_raw.items(): # Filter out disconnected gradients and convert to float32. grad = [g for g in grad if g is not None] grad = [tf.cast(g, tf.float32) for g in grad] # Sum within the device. if len(grad) == 0: grad = tf.zeros(var.shape) # No gradients => zero. elif len(grad) == 1: grad = grad[0] # Single gradient => use as is. else: grad = tf.add_n(grad) # Multiple gradients => sum. # Scale as needed. scale = 1.0 / len(device.grad_raw[var]) / len( self._devices) scale = tf.constant(scale, dtype=tf.float32, name="scale") if self.minibatch_multiplier is not None: scale /= tf.cast(self.minibatch_multiplier, tf.float32) scale = self.undo_loss_scaling(scale) device.grad_clean[var] = grad * scale # Sum gradients across devices. if len(self._devices) > 1: with tfutil.absolute_name_scope(self.scope + "/Broadcast"), tf.device(None): for all_vars in zip(*[ device.grad_clean.keys() for device in self._devices.values() ]): if len(all_vars) > 0 and all( dim > 0 for dim in all_vars[0].shape.as_list() ): # NCCL does not support zero-sized tensors. all_grads = [ device.grad_clean[var] for device, var in zip( self._devices.values(), all_vars) ] all_grads = nccl_ops.all_sum(all_grads) for device, var, grad in zip(self._devices.values(), all_vars, all_grads): device.grad_clean[var] = grad # Apply updates separately on each device. for device_idx, device in enumerate(self._devices.values()): with tfutil.absolute_name_scope(self.scope + "/Apply%d" % device_idx), tf.device( device.name): # pylint: disable=cell-var-from-loop # Accumulate gradients over time. if self.minibatch_multiplier is None: acc_ok = tf.constant(True, name='acc_ok') device.grad_acc = OrderedDict(device.grad_clean) else: # Create variables. with tf.control_dependencies(None): for var in device.grad_clean.keys(): device.grad_acc_vars[var] = tf.Variable( tf.zeros(var.shape), trainable=False, name="grad_acc_var") device.grad_acc_count = tf.Variable( tf.zeros([]), trainable=False, name="grad_acc_count") # Track counter. count_cur = device.grad_acc_count + 1.0 count_inc_op = lambda: tf.assign(device.grad_acc_count, count_cur) count_reset_op = lambda: tf.assign(device.grad_acc_count, tf.zeros([])) acc_ok = (count_cur >= tf.cast(self.minibatch_multiplier, tf.float32)) all_ops.append( tf.cond(acc_ok, count_reset_op, count_inc_op)) # Track gradients. for var, grad in device.grad_clean.items(): acc_var = device.grad_acc_vars[var] acc_cur = acc_var + grad device.grad_acc[var] = acc_cur with tf.control_dependencies([acc_cur]): acc_inc_op = lambda: tf.assign(acc_var, acc_cur) acc_reset_op = lambda: tf.assign( acc_var, tf.zeros(var.shape)) all_ops.append( tf.cond(acc_ok, acc_reset_op, acc_inc_op)) # No overflow => apply gradients. all_ok = tf.reduce_all( tf.stack([acc_ok] + [ tf.reduce_all(tf.is_finite(g)) for g in device.grad_acc.values() ])) apply_op = lambda: device.optimizer.apply_gradients( [(tf.cast(grad, var.dtype), var) for var, grad in device.grad_acc.items()]) all_ops.append(tf.cond(all_ok, apply_op, tf.no_op)) # Adjust loss scaling. if self.use_loss_scaling: ls_inc_op = lambda: tf.assign_add(device.loss_scaling_var, self.loss_scaling_inc) ls_dec_op = lambda: tf.assign_sub(device.loss_scaling_var, self.loss_scaling_dec) ls_update_op = lambda: tf.group( tf.cond(all_ok, ls_inc_op, ls_dec_op)) all_ops.append(tf.cond(acc_ok, ls_update_op, tf.no_op)) # Last device => report statistics. if device_idx == len(self._devices) - 1: all_ops.append( autosummary.autosummary(self.id + "/learning_rate", self.learning_rate)) all_ops.append( autosummary.autosummary(self.id + "/overflow_frequency", tf.where(all_ok, 0, 1), condition=acc_ok)) if self.use_loss_scaling: all_ops.append( autosummary.autosummary( self.id + "/loss_scaling_log2", device.loss_scaling_var)) # Initialize variables. self.reset_optimizer_state() if self.use_loss_scaling: tfutil.init_uninitialized_vars( [device.loss_scaling_var for device in self._devices.values()]) if self.minibatch_multiplier is not None: tfutil.run([ var.initializer for device in self._devices.values() for var in list(device.grad_acc_vars.values()) + [device.grad_acc_count] ]) # Group everything into a single op. with tfutil.absolute_name_scope(self.scope): return tf.group(*all_ops, name="TrainingOp")
def reset_trainables(self) -> None: """Re-initialize all trainable variables of this network, including sub-networks.""" tfutil.run([var.initializer for var in self.trainables.values()])
def reset_own_vars(self) -> None: """Re-initialize all variables of this network, excluding sub-networks.""" tfutil.run([var.initializer for var in self.own_vars.values()])