def next_minibatch(self, num_samples, number_of_workers=1, worker_rank=0, device=None): if self._total_num_samples >= self._max_samples: return {} # determine how many samples, starting from self._cursor, will fit into the requested minibatch size of num_samples begin = self._cursor end = self._cursor assert begin < self._num_samples actual_num_samples = { name: 0 for name in self._data.keys() } while end < self._num_samples: new_num_samples = { name: actual_num_samples[name] + (MinibatchSourceFromData._get_len(value[end]) if self._is_sequence[name] else 1) for name, value in self._data.items() } # return up to requested number of samples. but at least one even if longer # also stop if we hit the maximum requested number of samples max_num_samples = max(new_num_samples.values()) if actual_num_samples and (max_num_samples > num_samples or self._total_num_samples + max_num_samples > self._max_samples): break actual_num_samples = new_num_samples end += 1 self._total_num_samples += max(actual_num_samples.values()) # the minibatch data to return result = {} # [stream_info] -> MinibatchData at_end = (end == self._num_samples) for si in self.streams.values(): arg = self._data[si.name] if isinstance(arg, Value): # if entire corpus is one big Value, then slice NDArrayView directly data = arg.data sub_shape = data.shape[1:] extent = (end - begin,) + sub_shape start_offset = (begin,) + tuple(0 for _ in sub_shape) if number_of_workers != 1: # slice_view presently does not support strides raise ValueError('distributed reading from Value objects is not supported') mb_data = data.slice_view(start_offset, extent, data.is_read_only) else: # in case of distributed reading, we sub-slice the minibatch #print('rank/worker', worker_rank, number_of_workers, 'reading', slice(begin+worker_rank, end+worker_rank, number_of_workers)) mb_data = arg[begin+worker_rank:end+worker_rank:number_of_workers] if number_of_workers != 1: mb_data = mb_data.copy() # un-stride it, to avoid performance warning if isinstance(mb_data, list): # create a Value object if si.name not in self._vars: # this case is more complex, we need a CNTK Variable from cntk import input_variable, device self._vars[si.name] = input_variable(**self._types[si.name]) value = Value.create(self._vars[si.name], mb_data) else: value = Value(mb_data) result[si] = MinibatchData(value, num_sequences=end - begin, num_samples=actual_num_samples[si.name], sweep_end=at_end or (self._total_num_samples >= self._max_samples)) # wrap around the cursor self._cursor = 0 if at_end else end return result
def test_Value_raises(): from cntk import NDArrayView, Value with pytest.raises(ValueError): nd = NDArrayView.from_dense(np.asarray([[[4, 5]]], dtype=np.float32)) val = Value(nd)