コード例 #1
0
ファイル: RAdam.py プロジェクト: zgd716/CorefQA
 def _get_beta_accumulators(self):
     with ops.init_scope():
         graph = ops.get_default_graph()
         return (
             self._get_non_slot_variable("beta1_power", graph=graph),
             self._get_non_slot_variable("beta2_power", graph=graph),
         )
コード例 #2
0
 def _get_beta_accumulators(self):
     with ops.init_scope():
         if context.executing_eagerly():
             graph = None
         else:
             graph = ops.get_default_graph()
         return (self._get_non_slot_variable("beta1_power", graph=graph),
                 self._get_non_slot_variable("beta2_power", graph=graph))
コード例 #3
0
ファイル: test_graph_item.py プロジェクト: zeta1999/autodist
def test_graph_item_context_scope():
    g1 = ops.Graph()
    i1 = graph_item.GraphItem(graph=g1)
    assert graph_item._default_graph_item is None
    with i1.as_default() as item:
        assert graph_item._default_graph_item == i1
        assert item._graph == g1
        assert ops.get_default_graph() == g1
        setattr(item, 'new_attr', 'new_value')
    assert graph_item._default_graph_item is None
    assert getattr(i1, 'new_attr') == 'new_value'
コード例 #4
0
 def begin(self):
     if self.summary_writer is None and self.output_dir:
         self.summary_writer = SummaryWriterCache.get(self.output_dir)
     graph = ops.get_default_graph()
     self.fake_seq = graph.get_tensor_by_name("model/" + FAKE_PROTEINS +
                                              ":0")
     self.labels = graph.get_tensor_by_name("model/" + LABELS + ":0")
     self.d_score = graph.get_tensor_by_name("model/d_score:0")
     self.global_step_tensor = training_util._get_or_create_global_step_read(
     )
     if self.global_step_tensor is None:
         raise RuntimeError("Could not global step tensor")
     if self.fake_seq is None:
         raise RuntimeError("Could not get fake seq tensor")
コード例 #5
0
 def begin(self):
     graph = ops.get_default_graph()
     self.acid_embeddings = graph.get_tensor_by_name(
         "model/" + self.variable_name_to_restore + ":0")
コード例 #6
0
ファイル: RAdam.py プロジェクト: zgd716/CorefQA
 def _get_niter(self):
     with ops.init_scope():
         graph = ops.get_default_graph()
         return self._get_non_slot_variable("niter", graph=graph)