def _remap_element(self, ele_type, name): """Remap element based on type.""" graph = self._graph_item.graph if ele_type is ResourceVariable: res = get_read_var_tensor(graph.get_tensor_by_name(name).op) else: # Default element mapper, including the RefVariable case res = graph.as_graph_element(name, allow_tensor=True, allow_operation=True) return res
def _get_accum_apply_and_agg_grad(var_op, grad, indices, dense_shape): if indices is None: tensor = variable_utils.get_read_var_tensor(var_op) grad_accum = data_flow_ops.ConditionalAccumulator( grad.dtype, shape=tensor.get_shape(), shared_name=var_op.name + "/grad_accum") # Get a copy of consumers list before creating accum_apply_op grad_consumers = list(grad.consumers()) accum_apply_op = grad_accum.apply_grad(grad, local_step=MAX_INT64, name=grad.op.name + '_accum_apply_grad') agg_grad = grad_accum.take_grad(num_accum_required, name=var_op.name + '_take_grad') update_consumers(grad_consumers, grad, agg_grad) update_control_consumers(get_control_consumers(grad.op), grad.op, agg_grad.op) else: grad_indexed_slices = ops.IndexedSlices( values=grad, indices=indices, dense_shape=dense_shape) grad_accum = data_flow_ops.SparseConditionalAccumulator( grad.dtype, shape=grad.shape, shared_name=var_op.name + "/grad_accum") # Get a copy of consumers list before creating accum_apply_op indices_consumers = list(indices.consumers()) grad_consumers = list(grad.consumers()) accum_apply_op = grad_accum.apply_indexed_slices_grad( grad_indexed_slices, local_step=MAX_INT64, name=grad.op.name + '_accum_apply_grad') agg_grad = grad_accum.take_indexed_slices_grad( num_accum_required, name=var_op.name + '_take_grad') agg_indices = agg_grad.indices if indices.dtype != agg_grad.indices.dtype: agg_indices = math_ops.cast(agg_grad.indices, indices.dtype) agg_grad = ops.IndexedSlices(values=agg_grad.values, indices=agg_indices, dense_shape=agg_grad.dense_shape) assert isinstance(agg_grad, ops.IndexedSlices) update_consumers(indices_consumers, indices, agg_grad.indices) update_consumers(grad_consumers, grad, agg_grad.values) update_control_consumers(get_control_consumers(indices.op), indices.op, agg_grad.indices.op) update_control_consumers(get_control_consumers(grad.op), grad.op, agg_grad.values.op) return accum_apply_op, agg_grad
def _mirror_read_var_ops(self, other): """ Mirror read var ops. Args: other (Operation): Other resource var op. """ assert other in self._proxy_vars for old_read_var_op in self._read_var_ops: if old_read_var_op == get_read_var_tensor(self._this_op).op: new_read_var_op = other._graph_element.op else: new_read_var_op = other.value().op self._read_var_ops_mappings[other][ old_read_var_op] = new_read_var_op
def _build_proxy_on(self, destination_device): """ Build a proxy of the original variable on `destination_device`. Args: destination_device (DeviceSpecV2): the destination device where the proxy is on. """ is_gpu = destination_device.device_type.upper( ) == 'GPU' if destination_device.device_type else False prefix = replica_prefix(destination_device.device_index ) if is_gpu else replica_prefix('CPU') with ops.device(destination_device): proxy_var = variable_scope.get_variable( ops.prepend_name_scope(self._this_op.name, prefix), dtype=self._dtype, initializer=self._initial_value, trainable=False) self._graph_item.info.update_variables( [proxy_var], replace=False) # Should we update graph_item.info? self._proxy_vars.append(proxy_var) self._proxy_var_init_ops.append( proxy_var.assign(get_read_var_tensor(self._this_op))) self._mirror_all_read_var_ops() self._update_all_consumers()