def _get_dL_dS_stat_dict(self, dlds_nested, dsds_nested): dlds_flat, _ = ravel_nested_stuff(dlds_nested, with_indices=True) dsds_flat, indices = ravel_nested_stuff(dsds_nested, with_indices=True) od = OrderedDict() # Keys are dL(hub.error_injection_step)/dS' for dlds, dsds, index in zip(dlds_flat, dsds_flat, indices): assert isinstance(index, list) assert isinstance(dlds, tf.Tensor) and isinstance(dsds, tf.Tensor) assert len(dlds.shape) == 2 and len(dsds.shape) == 3 # Generate key for dL/dSi if len(dlds_flat) == 1: grad_name = 'S' else: grad_name = 'S{}'.format('-'.join([str(i + 1) for i in index])) assert hub.error_injection_step < 0 grad_name = 'dL[{}]/d{}'.format(hub.error_injection_step, grad_name) # dLtdS.shape = [Ts, state_size] dLtdS = self._sandwich_bottom(dlds, dsds) # Batch dimension should be kept (important) dLtdS = tf.stack([dLtdS]) # Calculate norm # od[grad_name] = tf.stack([dLtdS]) norm = tf.norm(dLtdS, ord=np.inf, axis=2) norm = norm / norm[0, -1] od['||{}||'.format(grad_name)] = norm return od
def _get_dL_dS_dict(self, dlds_nested, dsds_nested): dlds_flat, _ = ravel_nested_stuff(dlds_nested, with_indices=True) dsds_flat, indices = ravel_nested_stuff(dsds_nested, with_indices=True) od = OrderedDict() # Keys are '(dL/dSi)j' for dlds, dsds, index in zip(dlds_flat, dsds_flat, indices): assert isinstance(index, list) assert isinstance(dlds, tf.Tensor) and isinstance(dsds, tf.Tensor) assert len(dlds.shape) == 2 and len(dsds.shape) == 3 # Generate key for dL/dSi if len(dlds_flat) == 1: grad_name = 'S' else: grad_name = 'S{}'.format('-'.join([str(i + 1) for i in index])) grad_name = 'dL/d{}'.format(grad_name) block_dict = OrderedDict() od[grad_name] = block_dict # Say T = num_steps, (dL/dSi)j is a T by T lower triangular matrix triangle = self._form_triangle(dlds, dsds) assert isinstance(triangle, tf.Tensor) for i, t in enumerate(tf.split(triangle, triangle.shape.as_list()[0])): if hub.max_states_per_block > 0 and hub.max_states_per_block == i: break block_dict['{}[{}]'.format(grad_name, i + 1)] = t # TODO: we do not care about [*] for now # block_dict[grad_name + '[*]'] = tf.reduce_sum( # tf.abs(triangle), axis=0, keepdims=True) return od
def _calc_dS_dS_prev(self, states, pre_states): # Ravel states and pre_states assert isinstance(states, (tuple, list)) assert isinstance(pre_states, (tuple, list)) states = ravel_nested_stuff(states) pre_states = ravel_nested_stuff(pre_states) # Split states for calculating Jacobian later split_states = [] for s in states: split_states += tf.split(s, s.shape[1], axis=-1) return (tf.stack( [tf.concat(tf.gradients(s, pre_states), axis=-1) for s in split_states], axis=-1),)
def _calc_dL_dS_prev(self, loss, pre_states): """dS in dL/dS must be an integral whole""" assert isinstance(loss, tf.Tensor) dL_dS = tf.gradients(loss, ravel_nested_stuff(pre_states)) assert isinstance(dL_dS, (tuple, list)) if len(dL_dS) > 1: dL_dS = [tf.concat(dL_dS, axis=-1)] return tuple(dL_dS)
def _register_memories(pre_states): """Register memory tensors as a dict into tfr.collections""" if not hub.use_default_s_in_dy_ds and not hub.export_states: return assert isinstance(pre_states, (list, tuple)) tensors, indices = ravel_nested_stuff(pre_states, with_indices=True) for tensor, index in zip(tensors, indices): assert isinstance(index, list) if len(tensors) == 1: key = 'S' else: key = 'S{}'.format('-'.join([str(i + 1) for i in index])) if hub.export_states: context.add_tensor_to_export(key, tensor) if hub.use_default_s_in_dy_ds: context.add_to_dict_collection(context.S_IN_DYDS, key, tensor)
def _get_jacobian_stat(self, dsds_nested): dsds_flat, indices = ravel_nested_stuff(dsds_nested, with_indices=True) od = OrderedDict() # Keys are ||dS/dS|| for dsds, index in zip(dsds_flat, indices): assert isinstance(index, list) and isinstance(dsds, tf.Tensor) assert len(dsds.shape) == 3 # Generate key for dSi/dSi if len(dsds_flat) == 1: grad_index = '' else: grad_index = '({})'.format('-'.join([str(i + 1) for i in index])) grad_name = '||dS/dS{}||'.format(grad_index) # Pretend that dsds has batch size 1 dsds = tf.stack([dsds]) od[grad_name] = tf.norm(dsds, axis=[-2, -1]) return od