コード例 #1
0
 def _remap_element(self, ele_type, name):
     """Remap element based on type."""
     graph = self._graph_item.graph
     if ele_type is ResourceVariable:
         res = get_read_var_tensor(graph.get_tensor_by_name(name).op)
     else:  # Default element mapper, including the RefVariable case
         res = graph.as_graph_element(name,
                                      allow_tensor=True,
                                      allow_operation=True)
     return res
コード例 #2
0
 def _get_accum_apply_and_agg_grad(var_op, grad, indices, dense_shape):
     if indices is None:
         tensor = variable_utils.get_read_var_tensor(var_op)
         grad_accum = data_flow_ops.ConditionalAccumulator(
             grad.dtype,
             shape=tensor.get_shape(),
             shared_name=var_op.name + "/grad_accum")
         # Get a copy of consumers list before creating accum_apply_op
         grad_consumers = list(grad.consumers())
         accum_apply_op = grad_accum.apply_grad(grad,
                                                local_step=MAX_INT64,
                                                name=grad.op.name +
                                                '_accum_apply_grad')
         agg_grad = grad_accum.take_grad(num_accum_required,
                                         name=var_op.name +
                                         '_take_grad')
         update_consumers(grad_consumers, grad, agg_grad)
         update_control_consumers(get_control_consumers(grad.op),
                                  grad.op, agg_grad.op)
     else:
         grad_indexed_slices = ops.IndexedSlices(
             values=grad, indices=indices, dense_shape=dense_shape)
         grad_accum = data_flow_ops.SparseConditionalAccumulator(
             grad.dtype,
             shape=grad.shape,
             shared_name=var_op.name + "/grad_accum")
         # Get a copy of consumers list before creating accum_apply_op
         indices_consumers = list(indices.consumers())
         grad_consumers = list(grad.consumers())
         accum_apply_op = grad_accum.apply_indexed_slices_grad(
             grad_indexed_slices,
             local_step=MAX_INT64,
             name=grad.op.name + '_accum_apply_grad')
         agg_grad = grad_accum.take_indexed_slices_grad(
             num_accum_required, name=var_op.name + '_take_grad')
         agg_indices = agg_grad.indices
         if indices.dtype != agg_grad.indices.dtype:
             agg_indices = math_ops.cast(agg_grad.indices,
                                         indices.dtype)
         agg_grad = ops.IndexedSlices(values=agg_grad.values,
                                      indices=agg_indices,
                                      dense_shape=agg_grad.dense_shape)
         assert isinstance(agg_grad, ops.IndexedSlices)
         update_consumers(indices_consumers, indices, agg_grad.indices)
         update_consumers(grad_consumers, grad, agg_grad.values)
         update_control_consumers(get_control_consumers(indices.op),
                                  indices.op, agg_grad.indices.op)
         update_control_consumers(get_control_consumers(grad.op),
                                  grad.op, agg_grad.values.op)
     return accum_apply_op, agg_grad
コード例 #3
0
    def _mirror_read_var_ops(self, other):
        """
        Mirror read var ops.

        Args:
            other (Operation): Other resource var op.
        """
        assert other in self._proxy_vars
        for old_read_var_op in self._read_var_ops:
            if old_read_var_op == get_read_var_tensor(self._this_op).op:
                new_read_var_op = other._graph_element.op
            else:
                new_read_var_op = other.value().op
            self._read_var_ops_mappings[other][
                old_read_var_op] = new_read_var_op
コード例 #4
0
    def _build_proxy_on(self, destination_device):
        """
        Build a proxy of the original variable on `destination_device`.

        Args:
            destination_device (DeviceSpecV2): the destination device where the proxy is on.
        """
        is_gpu = destination_device.device_type.upper(
        ) == 'GPU' if destination_device.device_type else False
        prefix = replica_prefix(destination_device.device_index
                                ) if is_gpu else replica_prefix('CPU')
        with ops.device(destination_device):
            proxy_var = variable_scope.get_variable(
                ops.prepend_name_scope(self._this_op.name, prefix),
                dtype=self._dtype,
                initializer=self._initial_value,
                trainable=False)
        self._graph_item.info.update_variables(
            [proxy_var], replace=False)  # Should we update graph_item.info?
        self._proxy_vars.append(proxy_var)
        self._proxy_var_init_ops.append(
            proxy_var.assign(get_read_var_tensor(self._this_op)))
        self._mirror_all_read_var_ops()
        self._update_all_consumers()