Exemplo n.º 1
0
 def forward(self):
     if self.placement is None:
         # local apply
         outputs = _C.dispatch_coco_reader(
             self._op,
             session_id=current_scope().session_id,
             annotation_file=self.annotation_file,
             image_dir=self.image_dir,
             batch_size=self.batch_size,
             shuffle_after_epoch=self.shuffle,
             random_seed=self.random_seed,
             group_by_ratio=self.group_by_aspect_ratio,
             remove_images_without_annotations=self.
             remove_images_without_annotations,
             stride_partition=self.stride_partition,
             device=self.device,
         )
     else:
         # consistent apply
         outputs = _C.dispatch_coco_reader(
             self._op,
             session_id=current_scope().session_id,
             annotation_file=self.annotation_file,
             image_dir=self.image_dir,
             batch_size=self.batch_size,
             shuffle_after_epoch=self.shuffle,
             random_seed=self.random_seed,
             group_by_ratio=self.group_by_aspect_ratio,
             remove_images_without_annotations=self.
             remove_images_without_annotations,
             stride_partition=self.stride_partition,
             placement=self.placement,
             sbp=self.sbp,
         )
     return outputs
Exemplo n.º 2
0
 def forward(self, x):
     scope = scope_util.current_scope()
     scope_proto = graph_build_util.scope_to_proto(scope)
     test_case.assertEqual(
         scope_proto.parent_scope_symbol_id, self.prev_scope.symbol_id
     )
     ck_bool = scope_proto.attr_name2attr_value["checkpointing"]
     test_case.assertEqual(ck_bool.WhichOneof("value"), None)
     stage_int = scope_proto.attr_name2attr_value[
         "pipeline_stage_id_hint"
     ].at_int64
     test_case.assertEqual(stage_int, 1)
     name = self.name_prefix + self.name
     prefixes = []
     for prefix in scope_proto.scope_op_name_prefixes:
         prefixes.append(prefix)
     name_in_scope = ".".join(prefixes)
     test_case.assertEqual(name, name_in_scope)
     b = self.dummy_buff
     dummy_buff_scope_proto = graph_build_util.scope_to_proto(
         self._buffers["dummy_buff"].scope
     )
     test_case.assertEqual(
         dummy_buff_scope_proto.parent_scope_symbol_id, scope.symbol_id
     )
     x = self.fc1(x)
     return x + b
Exemplo n.º 3
0
 def forward(self, x):
     scope = scope_util.current_scope()
     scope_proto = graph_build_util.scope_to_proto(scope)
     ck_bool = scope_proto.attr_name2attr_value[
         "checkpointing"].at_bool
     test_case.assertEqual(ck_bool, True)
     out = self.model(x)
     return out
Exemplo n.º 4
0
 def forward(self, x):
     scope = scope_util.current_scope()
     scope_proto = graph_build_util.scope_to_proto(scope)
     ck_bool = scope_proto.attr_name2attr_value["checkpointing"].at_bool
     test_case.assertEqual(ck_bool, True)
     stage_int = scope_proto.attr_name2attr_value[
         "pipeline_stage_id_hint"
     ].at_int64
     test_case.assertEqual(stage_int, 0)
     out = self.conv1(x)
     weight = self.conv1.weight
     test_case.assertTrue(weight.is_lazy)
     return out
Exemplo n.º 5
0
 def compile(self, op_list):
     self._check_status(self.SessionStatus.OPEN)
     scope = scope_util.current_scope()
     device_tag = scope.device_parallel_desc_symbol.device_tag
     for op_conf in op_list:
         if _need_check_device_tag(
                 op_conf) and op_conf.device_tag != device_tag:
             print(
                 "WARNING: the device_tag of op {} is not equal to the device_tag of seesion's current scope ({} vs. {}), which may cause the op graph to be incompatible"
                 .format(op_conf.name, op_conf.device_tag, device_tag))
         compile_ctx.CurJobAddOp(op_conf)
     oneflow._oneflow_internal.CurJobBuildAndInferCtx_Complete()
     oneflow._oneflow_internal.CurJobBuildAndInferCtx_Rebuild()
Exemplo n.º 6
0
            def build(self, x):
                test_case.assertEqual(graph_build_util.lazy_mode.is_enabled(), True)
                import oneflow.framework.session_context as session_ctx
                from oneflow.framework.multi_client_session import MultiClientSession

                session = session_ctx.GetDefaultSession()
                test_case.assertEqual(type(session), MultiClientSession)
                import oneflow.framework.scope_util as scope_util

                scope = scope_util.current_scope()
                scope_proto = graph_build_util.scope_to_proto(scope)
                test_case.assertEqual(session.id, scope_proto.session_id)
                test_case.assertEqual(
                    oneflow._oneflow_internal.JobBuildAndInferCtx_GetCurrentJobName(),
                    self.name,
                )
                return x