예제 #1
0
def _from_proto_fn(v, import_scope=None):
    if has_distribution_strategy():
        raise NotImplementedError(
            "Deserialization of variables is not yet supported when using"
            "distributed strategies.")
    else:
        resource_variable_ops._from_proto_fn(v, import_scope=import_scope)
예제 #2
0
    def trainable_var_op_to_var(self):
        """
        Mapping from trainable variable ops (e.g. VarHandleOps) to the Variables.

        Returns:
            Dict
        """
        with self.graph.as_default():
            return {
                self.graph.get_operation_by_name(
                    get_op_name(var_def.variable_name)):
                _from_proto_fn(var_def)
                for var_def in self.info.trainable_variables
            }
예제 #3
0
    def replicate(self, graph_item):
        """
        Replicate the entire graph as many times as num_replica.

        Args:
            graph_item: the original graph item

        Returns: The new graph item
        """
        item = GraphItem(graph=ops.Graph())
        fwd_ctx, bwd_ctx = self._collect_while_context(graph_item.graph)
        with item.graph.as_default():
            gdef = graph_item.graph.as_graph_def()
            for i in range(self._num_local_replicas):
                # Replicate ops
                with ops.device(self._replica_device_placer(replica_id=i)):
                    import_graph_def(gdef, name=replica_prefix(i))

                # Replicate while_loop context (control_flow) if needed.
                # The order matters -- We must replicate bwd context first, then forward context.
                # TODO(Zeya): To handle cases when there are nested while loops, in which we must replicate
                #  parent context first and then child context. See:
                #  https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/ops/control_flow_ops.py#L938
                if bwd_ctx:
                    for ctx in bwd_ctx:
                        _ = WhileContext(context_def=ctx.to_proto(), grad_state=ctx._grad_state,
                                         import_scope=replica_prefix(i))
                if fwd_ctx:
                    for ctx in fwd_ctx:
                        _ = WhileContext(context_def=ctx.to_proto(), grad_state=ctx._grad_state,
                                         import_scope=replica_prefix(i))

            # update saver
            master_replica = 0
            if graph_item.info.savers:
                item.info.update_savers(
                    [Saver.from_proto(proto, import_scope=replica_prefix(master_replica)).to_proto()
                        for proto in graph_item.info.savers],
                    replace=False
                )

            # update gradient info
            for i in range(self._num_local_replicas):
                for g_name, t_name in graph_item.grad_target_name_pairs.items():
                    if isinstance(g_name, tuple):
                        new_g_name = (
                            ops.prepend_name_scope(g_name[0], replica_prefix(i)),
                            ops.prepend_name_scope(g_name[1], replica_prefix(i)),
                            ops.prepend_name_scope(g_name[2], replica_prefix(i)))
                    else:
                        new_g_name = ops.prepend_name_scope(g_name, replica_prefix(i))
                    new_t_name = ops.prepend_name_scope(t_name, replica_prefix(i))
                    item.extend_gradient_info_by_names(
                        grads=[new_g_name],
                        targets=[new_t_name]
                    )
                item.info.update_variables(
                    [_from_proto_fn(proto, import_scope=replica_prefix(i)).to_proto()
                        for proto in graph_item.info.variables],
                    replace=False
                )
                item.info.update_table_initializers(
                    [ops.prepend_name_scope(tb_init, replica_prefix(i))
                        for tb_init in graph_item.info.table_initializers],
                    replace=False
                )
        return item
예제 #4
0
 def get_all_variables(self):
     """Get all variables in this graph item."""
     with self.graph.as_default():
         return [_from_proto_fn(var_def) for var_def in self.info.variables]