def forward(self, x): scope = scope_util.current_scope() scope_proto = graph_build_util.scope_to_proto(scope) test_case.assertEqual( scope_proto.parent_scope_symbol_id, self.prev_scope.symbol_id ) ck_bool = scope_proto.attr_name2attr_value["checkpointing"] test_case.assertEqual(ck_bool.WhichOneof("value"), None) stage_int = scope_proto.attr_name2attr_value[ "pipeline_stage_id_hint" ].at_int64 test_case.assertEqual(stage_int, 1) name = self.name_prefix + self.name prefixes = [] for prefix in scope_proto.scope_op_name_prefixes: prefixes.append(prefix) name_in_scope = ".".join(prefixes) test_case.assertEqual(name, name_in_scope) b = self.dummy_buff dummy_buff_scope_proto = graph_build_util.scope_to_proto( self._buffers["dummy_buff"].scope ) test_case.assertEqual( dummy_buff_scope_proto.parent_scope_symbol_id, scope.symbol_id ) x = self.fc1(x) return x + b
def forward(self, x): scope = scope_util.current_scope() scope_proto = graph_build_util.scope_to_proto(scope) ck_bool = scope_proto.attr_name2attr_value[ "checkpointing"].at_bool test_case.assertEqual(ck_bool, True) out = self.model(x) return out
def forward(self, x): if graph_build_util.lazy_mode.is_enabled(): scope = oneflow.current_scope() scope_proto = graph_build_util.scope_to_proto(scope) ck_bool = scope_proto.attr_name2attr_value[ "checkpointing"].at_bool test_case.assertEqual(ck_bool, True) out = self.linear(x) return out
def forward(self, x): scope = scope_util.current_scope() scope_proto = graph_build_util.scope_to_proto(scope) ck_bool = scope_proto.attr_name2attr_value["checkpointing"].at_bool test_case.assertEqual(ck_bool, True) stage_int = scope_proto.attr_name2attr_value[ "pipeline_stage_id_hint" ].at_int64 test_case.assertEqual(stage_int, 0) out = self.conv1(x) weight = self.conv1.weight test_case.assertTrue(weight.is_lazy) return out
def build(self, x): test_case.assertEqual(graph_build_util.lazy_mode.is_enabled(), True) import oneflow.framework.session_context as session_ctx from oneflow.framework.multi_client_session import MultiClientSession session = session_ctx.GetDefaultSession() test_case.assertEqual(type(session), MultiClientSession) import oneflow.framework.scope_util as scope_util scope = scope_util.current_scope() scope_proto = graph_build_util.scope_to_proto(scope) test_case.assertEqual(session.id, scope_proto.session_id) test_case.assertEqual( oneflow._oneflow_internal.JobBuildAndInferCtx_GetCurrentJobName(), self.name, ) return x