Exemplo n.º 1
0
 def forward(self, x):
     scope = scope_util.current_scope()
     scope_proto = graph_build_util.scope_to_proto(scope)
     test_case.assertEqual(
         scope_proto.parent_scope_symbol_id, self.prev_scope.symbol_id
     )
     ck_bool = scope_proto.attr_name2attr_value["checkpointing"]
     test_case.assertEqual(ck_bool.WhichOneof("value"), None)
     stage_int = scope_proto.attr_name2attr_value[
         "pipeline_stage_id_hint"
     ].at_int64
     test_case.assertEqual(stage_int, 1)
     name = self.name_prefix + self.name
     prefixes = []
     for prefix in scope_proto.scope_op_name_prefixes:
         prefixes.append(prefix)
     name_in_scope = ".".join(prefixes)
     test_case.assertEqual(name, name_in_scope)
     b = self.dummy_buff
     dummy_buff_scope_proto = graph_build_util.scope_to_proto(
         self._buffers["dummy_buff"].scope
     )
     test_case.assertEqual(
         dummy_buff_scope_proto.parent_scope_symbol_id, scope.symbol_id
     )
     x = self.fc1(x)
     return x + b
Exemplo n.º 2
0
 def forward(self, x):
     scope = scope_util.current_scope()
     scope_proto = graph_build_util.scope_to_proto(scope)
     ck_bool = scope_proto.attr_name2attr_value[
         "checkpointing"].at_bool
     test_case.assertEqual(ck_bool, True)
     out = self.model(x)
     return out
Exemplo n.º 3
0
 def forward(self, x):
     if graph_build_util.lazy_mode.is_enabled():
         scope = oneflow.current_scope()
         scope_proto = graph_build_util.scope_to_proto(scope)
         ck_bool = scope_proto.attr_name2attr_value[
             "checkpointing"].at_bool
         test_case.assertEqual(ck_bool, True)
     out = self.linear(x)
     return out
Exemplo n.º 4
0
 def forward(self, x):
     scope = scope_util.current_scope()
     scope_proto = graph_build_util.scope_to_proto(scope)
     ck_bool = scope_proto.attr_name2attr_value["checkpointing"].at_bool
     test_case.assertEqual(ck_bool, True)
     stage_int = scope_proto.attr_name2attr_value[
         "pipeline_stage_id_hint"
     ].at_int64
     test_case.assertEqual(stage_int, 0)
     out = self.conv1(x)
     weight = self.conv1.weight
     test_case.assertTrue(weight.is_lazy)
     return out
Exemplo n.º 5
0
            def build(self, x):
                test_case.assertEqual(graph_build_util.lazy_mode.is_enabled(), True)
                import oneflow.framework.session_context as session_ctx
                from oneflow.framework.multi_client_session import MultiClientSession

                session = session_ctx.GetDefaultSession()
                test_case.assertEqual(type(session), MultiClientSession)
                import oneflow.framework.scope_util as scope_util

                scope = scope_util.current_scope()
                scope_proto = graph_build_util.scope_to_proto(scope)
                test_case.assertEqual(session.id, scope_proto.session_id)
                test_case.assertEqual(
                    oneflow._oneflow_internal.JobBuildAndInferCtx_GetCurrentJobName(),
                    self.name,
                )
                return x