def decide_slices(self, data_shapes):
        """Decide the slices for each context according to the workload.

        Parameters
        ----------
        data_shapes : list
            list of (name, shape) specifying the shapes for the input data or label.
        """
        assert len(data_shapes) > 0
        major_axis = [DataDesc.get_batch_axis(x.layout) for x in data_shapes]

        for (name, shape), axis in zip(data_shapes, major_axis):
            if axis == -1:
                continue

            batch_size = shape[axis]
            if self.batch_size is not None:
                assert batch_size == self.batch_size, ("all data must have the same batch size: "
                                                       + ("batch_size = %d, but " % self.batch_size)
                                                       + ("%s has shape %s" % (name, shape)))
            else:
                self.batch_size = batch_size
                self.slices = _split_input_slice(self.batch_size, self.workload)

        return major_axis
示例#2
0
    def decide_slices(self, data_shapes):
        """Decide the slices for each context according to the workload.

        Parameters
        ----------
        data_shapes : list
            list of (name, shape) specifying the shapes for the input data or label.
        """
        assert len(data_shapes) > 0
        major_axis = [DataDesc.get_batch_axis(x.layout) for x in data_shapes]

        for (name, shape), axis in zip(data_shapes, major_axis):
            if axis == -1:
                continue

            batch_size = shape[axis]
            if self.batch_size is not None:
                assert batch_size == self.batch_size, (
                    "all data must have the same batch size: " +
                    ("batch_size = %d, but " % self.batch_size) +
                    ("%s has shape %s" % (name, shape)))
            else:
                self.batch_size = batch_size
                self.slices = _split_input_slice(self.batch_size,
                                                 self.workload)

        return major_axis
    def __init__(self, symbol, contexts, workload, data_shapes, label_shapes, param_names,
                 for_training, inputs_need_grad, shared_group=None, logger=logging,
                 fixed_param_names=None, grad_req='write', state_names=None):
        self.param_names = param_names
        self.arg_names = symbol.list_arguments()
        self.aux_names = symbol.list_auxiliary_states()

        self.symbol = symbol
        self.contexts = contexts
        self.workload = workload

        self.for_training = for_training
        self.inputs_need_grad = inputs_need_grad

        self.logger = logger
        #In the future we should have a better way to profile memory per device (haibin)
        # self._total_exec_bytes = 0
        self.fixed_param_names = fixed_param_names
        if self.fixed_param_names is None:
            self.fixed_param_names = []

        self.state_names = state_names
        if self.state_names is None:
            self.state_names = []

        if not for_training:
            grad_req = 'null'

        # data_shapes = [x if isinstance(x, DataDesc) else DataDesc(*x) for x in data_shapes]
        # if label_shapes is not None:
        #     label_shapes = [x if isinstance(x, DataDesc) else DataDesc(*x) for x in label_shapes]

        data_names = [x.name for x in data_shapes[0]]

        if isinstance(grad_req, str):
            self.grad_req = {}
            for k in self.arg_names:
                if k in self.param_names:
                    self.grad_req[k] = 'null' if k in self.fixed_param_names else grad_req
                elif k in data_names:
                    self.grad_req[k] = grad_req if self.inputs_need_grad else 'null'
                else:
                    self.grad_req[k] = 'null'
        elif isinstance(grad_req, (list, tuple)):
            assert len(grad_req) == len(self.arg_names)
            self.grad_req = dict(zip(self.arg_names, grad_req))
        elif isinstance(grad_req, dict):
            self.grad_req = {}
            for k in self.arg_names:
                if k in self.param_names:
                    self.grad_req[k] = 'null' if k in self.fixed_param_names else 'write'
                elif k in data_names:
                    self.grad_req[k] = 'write' if self.inputs_need_grad else 'null'
                else:
                    self.grad_req[k] = 'null'
            self.grad_req.update(grad_req)
        else:
            raise ValueError("grad_req must be one of str, list, tuple, or dict.")

        if shared_group is not None:
            self.shared_data_arrays = shared_group.shared_data_arrays
        else:
            self.shared_data_arrays = [{} for _ in contexts]

        # initialize some instance variables
        self.batch_size = len(data_shapes)
        self.slices = None
        self.execs = []
        self._default_execs = None
        self.data_arrays = None
        self.label_arrays = None
        self.param_arrays = None
        self.state_arrays = None
        self.grad_arrays = None
        self.aux_arrays = None
        self.input_grad_arrays = None

        self.data_shapes = None
        self.label_shapes = None
        self.data_layouts = None
        self.label_layouts = None
        self.output_layouts = [DataDesc.get_batch_axis(self.symbol[name].attr('__layout__'))
                               for name in self.symbol.list_outputs()]
        self.bind_exec(data_shapes, label_shapes, shared_group)
示例#4
0
    def __init__(self,
                 symbol,
                 contexts,
                 workload,
                 data_shapes,
                 label_shapes,
                 param_names,
                 for_training,
                 inputs_need_grad,
                 shared_group=None,
                 logger=logging,
                 fixed_param_names=None,
                 grad_req='write',
                 state_names=None):
        self.param_names = param_names
        self.arg_names = symbol.list_arguments()
        self.aux_names = symbol.list_auxiliary_states()

        self.symbol = symbol
        self.contexts = contexts
        self.workload = workload

        self.for_training = for_training
        self.inputs_need_grad = inputs_need_grad

        self.logger = logger
        #In the future we should have a better way to profile memory per device (haibin)
        # self._total_exec_bytes = 0
        self.fixed_param_names = fixed_param_names
        if self.fixed_param_names is None:
            self.fixed_param_names = []

        self.state_names = state_names
        if self.state_names is None:
            self.state_names = []

        if not for_training:
            grad_req = 'null'

        # data_shapes = [x if isinstance(x, DataDesc) else DataDesc(*x) for x in data_shapes]
        # if label_shapes is not None:
        #     label_shapes = [x if isinstance(x, DataDesc) else DataDesc(*x) for x in label_shapes]

        data_names = [x.name for x in data_shapes[0]]

        if isinstance(grad_req, str):
            self.grad_req = {}
            for k in self.arg_names:
                if k in self.param_names:
                    self.grad_req[
                        k] = 'null' if k in self.fixed_param_names else grad_req
                elif k in data_names:
                    self.grad_req[
                        k] = grad_req if self.inputs_need_grad else 'null'
                else:
                    self.grad_req[k] = 'null'
        elif isinstance(grad_req, (list, tuple)):
            assert len(grad_req) == len(self.arg_names)
            self.grad_req = dict(list(zip(self.arg_names, grad_req)))
        elif isinstance(grad_req, dict):
            self.grad_req = {}
            for k in self.arg_names:
                if k in self.param_names:
                    self.grad_req[
                        k] = 'null' if k in self.fixed_param_names else 'write'
                elif k in data_names:
                    self.grad_req[
                        k] = 'write' if self.inputs_need_grad else 'null'
                else:
                    self.grad_req[k] = 'null'
            self.grad_req.update(grad_req)
        else:
            raise ValueError(
                "grad_req must be one of str, list, tuple, or dict.")

        if shared_group is not None:
            self.shared_data_arrays = shared_group.shared_data_arrays
        else:
            self.shared_data_arrays = [{} for _ in contexts]

        # initialize some instance variables
        self.batch_size = len(data_shapes)
        self.slices = None
        self.execs = []
        self._default_execs = None
        self.data_arrays = None
        self.label_arrays = None
        self.param_arrays = None
        self.state_arrays = None
        self.grad_arrays = None
        self.aux_arrays = None
        self.input_grad_arrays = None

        self.data_shapes = None
        self.label_shapes = None
        self.data_layouts = None
        self.label_layouts = None
        self.output_layouts = [
            DataDesc.get_batch_axis(self.symbol[name].attr('__layout__'))
            for name in self.symbol.list_outputs()
        ]
        self.bind_exec(data_shapes, label_shapes, shared_group)