def next(self): """Returns the next batch of data.""" if self.curr_idx == len(self.idx): raise StopIteration i, j = self.idx[self.curr_idx] self.curr_idx += 1 if self.major_axis == 1: data = self.nddata[i][j:j + self.batch_size].T label = self.ndlabel[i][j:j + self.batch_size].T else: data = self.nddata[i][j:j + self.batch_size] label = self.ndlabel[i][j:j + self.batch_size] return DataBatch([data], [label], pad=0, bucket_key=self.buckets[i], provide_data=[ DataDesc(name=self.data_name, shape=data.shape, layout=self.layout) ], provide_label=[ DataDesc(name=self.label_name, shape=label.shape, layout=self.layout) ])
def next(self): """Returns the next batch of data.""" if self.curr_idx == len(self.idx): raise StopIteration #i = batches index, j = starting record i, j = self.idx[self.curr_idx] self.curr_idx += 1 indices = self.ndindex[i][j:j + self.batch_size] sentences = self.ndsent[i][j:j + self.batch_size] characters = self.ndchar[i][j:j + self.batch_size] label = self.ndlabel[i][j:j + self.batch_size] return DataBatch([sentences, characters], [label], pad=0, index=indices, bucket_key=self.buckets[i], provide_data=[ DataDesc(name=self.data_names[0], shape=sentences.shape, layout=self.layout), DataDesc(name=self.data_names[1], shape=characters.shape, layout=self.layout) ], provide_label=[ DataDesc(name=self.label_name, shape=label.shape, layout=self.layout) ])
def __init__(self, idx_file=None, rec_file=None, record_len=None, feature_len=None, seq_len_range=3360, batch_size=None, data_name='data', label_name='softmax_label'): super(BucketSentenceIter, self).__init__() self.record_len = record_len self.read_num = 0 self.idx_range = [i for i in range(self.record_len)] self.batch_size = batch_size self.record = mx.recordio.MXIndexedRecordIO(idx_file, rec_file, 'r') self.data_name = data_name self.label_name = label_name self.dtype = 'float32' self.layout = 'NT' self.seq_len_range = seq_len_range self.default_bucket_key = seq_len_range self.provide_data = [ DataDesc(name=self.data_name, shape=(batch_size, self.default_bucket_key, feature_len), layout=self.layout) ] self.provide_label = [ DataDesc(name=self.label_name, shape=(batch_size, self.default_bucket_key), layout=self.layout) ] self.cache = {} self.reset()
def provide_label(self): if self.rename_label is None: return sum([i.provide_label for i in self.iters], []) else: return sum([[ DataDesc(r[x.name], x.shape, x.dtype) if isinstance( x, DataDesc) else DataDesc(*x) for x in i.provide_label ] for r, i in zip(self.rename_label, self.iters)], [])
def provide_data(self): """The name and shape of data provided by this iterator""" if self.rename_data is None: return sum([i.provide_data for i in self.iters], []) else: return sum([[ DataDesc(r[x.name], x.shape, x.dtype) if isinstance( x, DataDesc) else DataDesc(*x) for x in i.provide_data ] for r, i in zip(self.rename_data, self.iters)], [])
def init_misc(self): self.num_source = len(self.data_list) self.provide_data = [ DataDesc(k, tuple([self.batch_size] + list(v.shape[1:])), v.dtype) for k, v in self.data ] self.provide_label = [ DataDesc(k, tuple([self.batch_size] + list(v.shape[1:])), v.dtype) for k, v in self.label ]
def forward(self, data_batch, is_train=None): """Forward computation. It supports data batches with different shapes, such as different batch sizes or different image sizes. If reshaping of data batch relates to modification of symbol or module, such as changing image layout ordering or switching from training to predicting, module rebinding is required. See Also ---------- :meth:`BaseModule.forward`. Parameters ---------- data_batch : DataBatch Could be anything with similar API implemented. is_train : bool Default is ``None``, which means ``is_train`` takes the value of ``self.for_training``. """ assert self.binded and self.params_initialized curr_data_shapes = tuple(i.shape for i in self._data_shapes) if isinstance(data_batch, list): assert data_batch is not None, "Encountered empty data batch" new_data_shapes = [] for i in range(len(data_batch[0].data)): shape = data_batch[0].data[i].shape for db in data_batch: assert shape == db.data[i].shape, \ "All data batches in a list need to have the same shape" new_batch_size = len(data_batch) * shape[0] new_data_shapes.append((new_batch_size, ) + shape[1:]) new_data_shapes = tuple(new_data_shapes) else: new_data_shapes = tuple(i.shape for i in data_batch.data) if curr_data_shapes != new_data_shapes: if hasattr(data_batch, "provide_data") and data_batch.provide_data: new_dshape = data_batch.provide_data else: new_dshape = [DataDesc(i.name, shape, i.dtype, i.layout) \ for i, shape in zip(self._data_shapes, new_data_shapes)] if hasattr(data_batch, "provide_label") and data_batch.provide_label: new_lshape = data_batch.provide_label elif hasattr(data_batch, "label") and data_batch.label: new_lshape = [DataDesc(i.name, j.shape, i.dtype, i.layout) \ for i, j in zip(self._label_shapes, data_batch.label)] else: new_lshape = None self.reshape(new_dshape, new_lshape) self._exec_group.forward(data_batch, is_train)
def decide_slices(self, data_shapes): """Decide the slices for each context according to the workload. Parameters ---------- data_shapes : list list of (name, shape) specifying the shapes for the input data or label. """ assert len(data_shapes) > 0 major_axis = [DataDesc.get_batch_axis(x.layout) for x in data_shapes] for (name, shape), axis in zip(data_shapes, major_axis): if axis == -1: continue batch_size = shape[axis] if self.batch_size is not None: assert batch_size == self.batch_size, ("all data must have the same batch size: " + ("batch_size = %d, but " % self.batch_size) + ("%s has shape %s" % (name, shape))) else: self.batch_size = batch_size self.slices = _split_input_slice(self.batch_size, self.workload) return major_axis
def decide_slices(self, data_shapes): """Decide the slices for each context according to the workload. Parameters ---------- data_shapes : list list of (name, shape) specifying the shapes for the input data or label. """ assert len(data_shapes) > 0 major_axis = [DataDesc.get_batch_axis(x.layout) for x in data_shapes] for (name, shape), axis in zip(data_shapes, major_axis): if axis == -1: continue batch_size = shape[axis] if self.batch_size is not None: assert batch_size == self.batch_size, ( "all data must have the same batch size: " + ("batch_size = %d, but " % self.batch_size) + ("%s has shape %s" % (name, shape))) else: self.batch_size = batch_size self.slices = _split_input_slice(self.batch_size, self.workload) return major_axis
def next(self): data, label = self._read_batch() data = mx.nd.array(data) label = mx.nd.array(label) return DataBatch([data], [label], pad=0, bucket_key=self.seq_len_range, provide_data=[ DataDesc(name=self.data_name, shape=data.shape, layout=self.layout) ], provide_label=[ DataDesc(name=self.label_name, shape=label.shape, layout=self.layout) ])
def create_batch(): length = input_shape[0] * input_shape[1] * input_shape[2] * input_shape[3] frame = np.zeros(length, dtype=np.float32) frame[:] = 1.0 batch_frame = [mx.nd.array(frame.reshape(input_shape))] batch_shape = [DataDesc('data', batch_frame[0].shape)] batch = DataBatch(data=batch_frame, provide_data=batch_shape) return batch
def provide_data(self): """The name and shape of data provided by this iterator.""" real_batch_size = (self._batch_size // self._batch_k) * self._batch_k return [ DataDesc('data', (real_batch_size, 3, self._data_shape, self._data_shape), np.float32) ]
def model_fn(path_to_model_files): from mxnet.io import DataDesc loaded_symbol = mx.symbol.load(os.path.join(path_to_model_files, "symbol")) created_module = mx.mod.Module(symbol=loaded_symbol) created_module.bind([DataDesc("data", (1, 1, 28, 28))]) created_module.load_params(os.path.join(path_to_model_files, "params")) return created_module
def create_batch(self, frame): """ :param frame: an (w,h,channels) numpy array (image) :return: DataBatch of (1,channels,data_shape,data_shape) """ frame_resize = mx.nd.array(cv2.resize(frame, (self.data_shape[0], self.data_shape[1]))) #frame_resize = mx.img.imresize(frame, self.data_shape[0], self.data_shape[1], cv2.INTER_LINEAR) # Change dimensions from (w,h,channels) to (channels, w, h) frame_t = mx.nd.transpose(frame_resize, axes=(2,0,1)) frame_norm = frame_t - self.mean_pixels_nd # Add dimension for batch, results in (1,channels,w,h) batch_frame = [mx.nd.expand_dims(frame_norm, axis=0)] batch_shape = [DataDesc('data', batch_frame[0].shape)] batch = DataBatch(data=batch_frame, provide_data=batch_shape) return batch
def _sliced_shape(self, shapes, i, major_axis): """Get the sliced shapes for the i-th executor. Parameters ---------- shapes : list of (str, tuple) The original (name, shape) pairs. i : int Which executor we are dealing with. """ sliced_shapes = [] for desc, axis in zip(shapes, major_axis): shape = list(desc.shape) if axis >= 0: shape[axis] = self.slices[i].stop - self.slices[i].start sliced_shapes.append(DataDesc(desc.name, tuple(shape), desc.dtype, desc.layout)) return sliced_shapes
def __init__(self, utterances, intents, batch_size, buckets, data_pad=-1, label_pad=-1, data_name='utterance', label_name='intent', dtype='float32'): """ :param utterances: list of list of int :param intents: list of int """ super(BucketUtteranceIter, self).__init__() buckets.sort() nslice = 0 # Keep track of how many utterances are sliced self.utterances = [[] for _ in buckets] self.intents = [[] for _ in buckets] self.indices = [[] for _ in buckets] for i, utt in enumerate(utterances): # Find the index of the smallest bucket that is larger than the sentence length buck_idx = bisect.bisect_left(buckets, len(utt)) # Slice utterances that are too long to the largest bucket size if buck_idx == len(buckets): buck_idx = buck_idx - 1 nslice += 1 utt = utt[:buckets[buck_idx]] # Pad utterances that are too short for their bucket buff = np.full((buckets[buck_idx]), data_pad, dtype=dtype) buff[:len(utt)] = utt # Add data/label to bucket self.utterances[buck_idx].append(buff) self.intents[buck_idx].append(intents[i]) self.indices[buck_idx].append(i) # Convert to list of array self.utterances = [np.asarray(i, dtype=dtype) for i in self.utterances] self.intents = [np.asarray(i, dtype=dtype) for i in self.intents] self.indices = [np.asarray(i, dtype=dtype) for i in self.indices] print("\nWarning, {0} utterances sliced to largest bucket size.". format(nslice)) if nslice > 0 else None print("Utterances per bucket: {}\nBucket sizes: {}".format( [arr.shape[0] for arr in self.utterances], buckets)) self.data_name = data_name self.label_name = label_name self.batch_size = batch_size self.buckets = buckets self.dtype = dtype self.data_pad = data_pad self.label_pad = label_pad self.default_bucket_key = max(buckets) self.layout = 'NT' self.provide_data = [ DataDesc(name=self.data_name, shape=(self.batch_size, self.default_bucket_key), layout=self.layout) ] self.provide_label = [ DataDesc(name=self.label_name, shape=(self.batch_size, ), layout=self.layout) ] # create empty list to store batch index values self.idx = [] for i, buck in enumerate(self.utterances): self.idx.extend([ (i, j) for j in range(0, len(buck) - batch_size + 1, batch_size) ]) self.curr_idx = 0 self.reset()
def __init__(self, symbol, contexts, workload, data_shapes, label_shapes, param_names, for_training, inputs_need_grad, shared_group=None, logger=logging, fixed_param_names=None, grad_req='write', state_names=None): self.param_names = param_names self.arg_names = symbol.list_arguments() self.aux_names = symbol.list_auxiliary_states() self.symbol = symbol self.contexts = contexts self.workload = workload self.for_training = for_training self.inputs_need_grad = inputs_need_grad self.logger = logger #In the future we should have a better way to profile memory per device (haibin) # self._total_exec_bytes = 0 self.fixed_param_names = fixed_param_names if self.fixed_param_names is None: self.fixed_param_names = [] self.state_names = state_names if self.state_names is None: self.state_names = [] if not for_training: grad_req = 'null' # data_shapes = [x if isinstance(x, DataDesc) else DataDesc(*x) for x in data_shapes] # if label_shapes is not None: # label_shapes = [x if isinstance(x, DataDesc) else DataDesc(*x) for x in label_shapes] data_names = [x.name for x in data_shapes[0]] if isinstance(grad_req, str): self.grad_req = {} for k in self.arg_names: if k in self.param_names: self.grad_req[ k] = 'null' if k in self.fixed_param_names else grad_req elif k in data_names: self.grad_req[ k] = grad_req if self.inputs_need_grad else 'null' else: self.grad_req[k] = 'null' elif isinstance(grad_req, (list, tuple)): assert len(grad_req) == len(self.arg_names) self.grad_req = dict(list(zip(self.arg_names, grad_req))) elif isinstance(grad_req, dict): self.grad_req = {} for k in self.arg_names: if k in self.param_names: self.grad_req[ k] = 'null' if k in self.fixed_param_names else 'write' elif k in data_names: self.grad_req[ k] = 'write' if self.inputs_need_grad else 'null' else: self.grad_req[k] = 'null' self.grad_req.update(grad_req) else: raise ValueError( "grad_req must be one of str, list, tuple, or dict.") if shared_group is not None: self.shared_data_arrays = shared_group.shared_data_arrays else: self.shared_data_arrays = [{} for _ in contexts] # initialize some instance variables self.batch_size = len(data_shapes) self.slices = None self.execs = [] self._default_execs = None self.data_arrays = None self.label_arrays = None self.param_arrays = None self.state_arrays = None self.grad_arrays = None self.aux_arrays = None self.input_grad_arrays = None self.data_shapes = None self.label_shapes = None self.data_layouts = None self.label_layouts = None self.output_layouts = [ DataDesc.get_batch_axis(self.symbol[name].attr('__layout__')) for name in self.symbol.list_outputs() ] self.bind_exec(data_shapes, label_shapes, shared_group)
def __init__(self, sentences, labels, batch_size, buckets=None, invalid_label=0, data_name='data', label_name='softmax_label', dtype='float32', layout='NT'): super(BucketSentenceIter, self).__init__() if not buckets: buckets = [ i for i, j in enumerate(np.bincount([len(s) for s in sentences])) if j >= batch_size ] buckets.sort() ndiscard = 0 self.data = [[] for _ in buckets] self.labels = [[] for _ in buckets] for i, sent in enumerate(sentences): buck = bisect.bisect_left(buckets, len(sent)) if buck == len(buckets): ndiscard += 1 continue buff = np.full((buckets[buck], ), invalid_label, dtype=dtype) buff[:len(sent)] = sent self.data[buck].append(buff) self.labels[buck].append(labels[i]) self.data = [np.asarray(i, dtype=dtype) for i in self.data] self.labels = [np.asarray(i, dtype=dtype) for i in self.labels] print( "WARNING: discarded %d sentences longer than the largest bucket." % ndiscard) self.batch_size = batch_size self.buckets = buckets self.data_name = data_name self.label_name = label_name self.dtype = dtype self.invalid_label = invalid_label self.nddata = [] self.ndlabel = [] self.major_axis = layout.find('N') self.layout = layout self.default_bucket_key = max(buckets) if self.major_axis == 0: self.provide_data = [ DataDesc(name=self.data_name, shape=(batch_size, self.default_bucket_key), layout=self.layout) ] self.provide_label = [ DataDesc(name=self.label_name, shape=(batch_size, ), layout=self.layout) ] elif self.major_axis == 1: self.provide_data = [ DataDesc(name=self.data_name, shape=(self.default_bucket_key, batch_size), layout=self.layout) ] self.provide_label = [ DataDesc(name=self.label_name, shape=(self.default_bucket_key, batch_size), layout=self.layout) ] else: raise ValueError( "Invalid layout %s: Must by NT (batch major) or TN (time major)" ) self.idx = [] for i, buck in enumerate(self.data): self.idx.extend([ (i, j) for j in range(0, len(buck) - batch_size + 1, batch_size) ]) self.curr_idx = 0 self.reset()
def __init__(self, sentences, characters, label, max_token_chars, batch_size, buckets=None, data_pad=-1, label_pad=-1, data_names=['sentences', 'characters'], label_name='seq_label', dtype='float32'): super(BucketNerIter, self).__init__() # Create a bucket for every seq length where there are more examples than the batch size if not buckets: seq_counts = np.bincount([len(s) for s in sentences]) buckets = [i for i, j in enumerate(seq_counts) if j >= batch_size] buckets.sort() print("\nBuckets created: ", buckets) assert (len(buckets) > 0), "Not enough utterances to create any buckets." ########### # Sentences ########### nslice = 0 # Create empty nested lists for storing data that falls into each bucket self.sentences = [[] for _ in buckets] for i, sent in enumerate(sentences): # Find the index of the smallest bucket that is larger than the sentence length buck_idx = bisect.bisect_left(buckets, len(sent)) if buck_idx == len( buckets ): # If the sentence is larger than the largest bucket buck_idx = buck_idx - 1 nslice += 1 sent = sent[:buckets[ buck_idx]] #Slice sentence to largest bucket size buff = np.full( (buckets[buck_idx]), data_pad, dtype=dtype) # Create an array filled with 'data_pad' buff[:len(sent)] = sent # Fill with actual values self.sentences[buck_idx].append( buff) # Append array to index = bucket index self.sentences = [np.asarray(i, dtype=dtype) for i in self.sentences] # Convert to list of array print("Warning, {0} sentences sliced to largest bucket size.".format( nslice)) if nslice > 0 else None ############ # Characters ############ # Create empty nested lists for storing data that falls into each bucket self.characters = [[] for _ in buckets] for i, charsent in enumerate(characters): # Find the index of the smallest bucket that is larger than the sentence length buck_idx = bisect.bisect_left(buckets, len(charsent)) if buck_idx == len( buckets ): # If the sentence is larger than the largest bucket buck_idx = buck_idx - 1 charsent = charsent[:buckets[ buck_idx]] #Slice sentence to largest bucket size charsent = [word[:max_token_chars] for word in charsent] # Slice to max length charsent = [ word + [data_pad] * (max_token_chars - len(word)) for word in charsent ] # Pad to max length charsent = np.array(charsent) buff = np.full((buckets[buck_idx], max_token_chars), data_pad, dtype=dtype) buff[:charsent.shape[0], :] = charsent # Fill with actual values self.characters[buck_idx].append( buff) # Append array to index = bucket index self.characters = [ np.asarray(i, dtype=dtype) for i in self.characters ] # Convert to list of array ########## # Entities ########## # Create empty nested lists for storing data that falls into each bucket self.label = [[] for _ in buckets] self.indices = [[] for _ in buckets] for i, entities in enumerate(label): # Find the index of the smallest bucket that is larger than the sentence length buck_idx = bisect.bisect_left(buckets, len(entities)) if buck_idx == len( buckets ): # If the sentence is larger than the largest bucket buck_idx = buck_idx - 1 entities = entities[:buckets[ buck_idx]] # Slice sentence to largest bucket size buff = np.full( (buckets[buck_idx]), label_pad, dtype=dtype) # Create an array filled with 'data_pad' buff[:len(entities)] = entities # Fill with actual values self.label[buck_idx].append( buff) # Append array to index = bucket index self.indices[buck_idx].append(i) self.label = [np.asarray(i, dtype=dtype) for i in self.label] # Convert to list of array self.indices = [np.asarray(i, dtype=dtype) for i in self.indices] # Convert to list of array self.data_names = data_names self.label_name = label_name self.batch_size = batch_size self.max_token_chars = max_token_chars self.buckets = buckets self.dtype = dtype self.data_pad = data_pad self.label_pad = label_pad self.default_bucket_key = max(buckets) self.layout = 'NT' self.provide_data = [ DataDesc(name=self.data_names[0], shape=(self.batch_size, self.default_bucket_key), layout=self.layout), DataDesc(name=self.data_names[1], shape=(self.batch_size, self.default_bucket_key, self.max_token_chars), layout=self.layout) ] self.provide_label = [ DataDesc(name=self.label_name, shape=(self.batch_size, self.default_bucket_key), layout=self.layout) ] #create empty list to store batch index values self.idx = [] #for each bucketarray for i, buck in enumerate(self.sentences): #extend the list eg output with batch size 5 and 20 training examples in bucket. [(0,0), (0,5), (0,10), (0,15), (1,0), (1,5), (1,10), (1,15)] self.idx.extend([ (i, j) for j in range(0, len(buck) - batch_size + 1, batch_size) ]) self.curr_idx = 0 self.reset()
def __init__(self, seqs, labels, batch_size, buckets=None, min_bucket_key=None, max_bucket_key=None, invalid_label=-1, data_name='data', label_name='softmax_label', data_dtype='float32', label_dtype='int'): super(BucketSeqLabelIter, self).__init__() if not buckets: # only consider bucket whose len >= batch_size buckets = [i for i, j in enumerate(np.bincount([len(s) for s in seqs])) if j >= batch_size] buckets.sort() if min_bucket_key: buckets = [k for k in buckets if k >= min_bucket_key] if max_bucket_key: buckets = [k for k in buckets if k <= max_bucket_key] ndiscard = 0 # distribute sequences in defined buckets self.data = [[] for _ in buckets] self.label = [[] for _ in buckets] for i, seq in enumerate(seqs): buck_id = bisect.bisect_left(buckets, len(seq)) if (buck_id == 0 and len(seq) < buckets[0]) or buck_id == len(buckets): # sequence longer or shorter than the biggest or smallest bucket is disgarded ndiscard += 1 continue buff = np.full((buckets[buck_id],), invalid_label, dtype=data_dtype) buff[:len(seq)] = seq self.data[buck_id].append(buff) self.label[buck_id].append(labels[i]) # buckets of sequences with padding self.data = [np.asarray(i) for i in self.data] print("WARNING: discarded %d sentences longer than the largest bucket." % ndiscard) self.batch_size = batch_size self.buckets = buckets self.data_name = data_name self.label_name = label_name self.data_dtype = data_dtype self.label_dtype = label_dtype self.invalid_label = invalid_label self.nddata = [] self.ndlabel = [] # self.major_axis = layout.find('N') # self.layout = layout self.default_bucket_key = max(buckets) # if self.major_axis == 0: self.provide_data = [DataDesc(name=self.data_name, shape=( batch_size, self.default_bucket_key), layout='NT')] self.provide_label = [DataDesc(name=self.label_name, shape=(batch_size,), layout='NT')] # elif self.major_axis == 1: # self.provide_data = [DataDesc( # name=self.data_name, shape=( # self.default_bucket_key, batch_size), # layout=self.layout)] # self.provide_label = [DataDesc( # name=self.label_name, shape=( # self.default_bucket_key, batch_size), # layout=self.layout)] # else: # raise ValueError( # "Invalid layout %s: Must by NT (batch major) or TN (time major)") self.idx = [] for i, buck in enumerate(self.data): self.idx.extend([(i, j) for j in range( 0, len(buck) - batch_size + 1, batch_size)]) self.curr_idx = 0 self.reset()
def provide_label(self): return [DataDesc(self.label_name, tuple([self.total_size()]), helper.DTYPE, "N")]
def provide_label(self): real_batch_size = (self._batch_size // self._batch_k) * self._batch_k return [DataDesc('label', (real_batch_size, ), np.int64)]
def __init__(self, symbol, contexts, workload, data_shapes, label_shapes, param_names, for_training, inputs_need_grad, shared_group=None, logger=logging, fixed_param_names=None, grad_req='write', state_names=None): self.param_names = param_names self.arg_names = symbol.list_arguments() self.aux_names = symbol.list_auxiliary_states() self.symbol = symbol self.contexts = contexts self.workload = workload self.for_training = for_training self.inputs_need_grad = inputs_need_grad self.logger = logger #In the future we should have a better way to profile memory per device (haibin) # self._total_exec_bytes = 0 self.fixed_param_names = fixed_param_names if self.fixed_param_names is None: self.fixed_param_names = [] self.state_names = state_names if self.state_names is None: self.state_names = [] if not for_training: grad_req = 'null' # data_shapes = [x if isinstance(x, DataDesc) else DataDesc(*x) for x in data_shapes] # if label_shapes is not None: # label_shapes = [x if isinstance(x, DataDesc) else DataDesc(*x) for x in label_shapes] data_names = [x.name for x in data_shapes[0]] if isinstance(grad_req, str): self.grad_req = {} for k in self.arg_names: if k in self.param_names: self.grad_req[k] = 'null' if k in self.fixed_param_names else grad_req elif k in data_names: self.grad_req[k] = grad_req if self.inputs_need_grad else 'null' else: self.grad_req[k] = 'null' elif isinstance(grad_req, (list, tuple)): assert len(grad_req) == len(self.arg_names) self.grad_req = dict(zip(self.arg_names, grad_req)) elif isinstance(grad_req, dict): self.grad_req = {} for k in self.arg_names: if k in self.param_names: self.grad_req[k] = 'null' if k in self.fixed_param_names else 'write' elif k in data_names: self.grad_req[k] = 'write' if self.inputs_need_grad else 'null' else: self.grad_req[k] = 'null' self.grad_req.update(grad_req) else: raise ValueError("grad_req must be one of str, list, tuple, or dict.") if shared_group is not None: self.shared_data_arrays = shared_group.shared_data_arrays else: self.shared_data_arrays = [{} for _ in contexts] # initialize some instance variables self.batch_size = len(data_shapes) self.slices = None self.execs = [] self._default_execs = None self.data_arrays = None self.label_arrays = None self.param_arrays = None self.state_arrays = None self.grad_arrays = None self.aux_arrays = None self.input_grad_arrays = None self.data_shapes = None self.label_shapes = None self.data_layouts = None self.label_layouts = None self.output_layouts = [DataDesc.get_batch_axis(self.symbol[name].attr('__layout__')) for name in self.symbol.list_outputs()] self.bind_exec(data_shapes, label_shapes, shared_group)
def provide_data(self): return [DataDesc(self.data_name, tuple([self.total_size()] + self.info["sample_shape"]), helper.DTYPE, "NCDHW")]
def provide_label(self): """The name and shape of label provided by this iterator.""" return [ DataDesc(k, tuple([self.batch_size] + list(v.shape[1:])), v.dtype) for k, v in self.label ]