Esempio n. 1
0
    def next_minibatch(self, num_samples, number_of_workers=1, worker_rank=0, device=None):        
        if self._total_num_samples >= self._max_samples:
            return {}
        # determine how many samples, starting from self._cursor, will fit into the requested minibatch size of num_samples
        begin = self._cursor
        end = self._cursor
        assert begin < self._num_samples
        actual_num_samples = { name: 0 for name in self._data.keys() }
        while end < self._num_samples: 
            new_num_samples = { name: actual_num_samples[name] + (MinibatchSourceFromData._get_len(value[end]) if self._is_sequence[name] else 1)
                                for name, value in self._data.items() }
            # return up to requested number of samples. but at least one even if longer
            # also stop if we hit the maximum requested number of samples
            max_num_samples = max(new_num_samples.values())
            if actual_num_samples and (max_num_samples > num_samples or self._total_num_samples + max_num_samples > self._max_samples):
                break
            actual_num_samples = new_num_samples
            end += 1

        self._total_num_samples += max(actual_num_samples.values())

        # the minibatch data to return
        result = {}  # [stream_info] -> MinibatchData
        at_end = (end == self._num_samples)
        for si in self.streams.values():
            arg = self._data[si.name]
            if isinstance(arg, Value):  # if entire corpus is one big Value, then slice NDArrayView directly
                data = arg.data
                sub_shape = data.shape[1:]
                extent = (end - begin,) + sub_shape
                start_offset = (begin,) + tuple(0 for _ in sub_shape)
                if number_of_workers != 1: # slice_view presently does not support strides
                    raise ValueError('distributed reading from Value objects is not supported')
                mb_data = data.slice_view(start_offset, extent, data.is_read_only)
            else:
                # in case of distributed reading, we sub-slice the minibatch
                #print('rank/worker', worker_rank, number_of_workers, 'reading', slice(begin+worker_rank, end+worker_rank, number_of_workers))
                mb_data = arg[begin+worker_rank:end+worker_rank:number_of_workers]
                if number_of_workers != 1:
                    mb_data = mb_data.copy() # un-stride it, to avoid performance warning
            if isinstance(mb_data, list): # create a Value object
                if si.name not in self._vars: # this case is more complex, we need a CNTK Variable
                    from cntk import input_variable, device
                    self._vars[si.name] = input_variable(**self._types[si.name])
                value = Value.create(self._vars[si.name], mb_data)
            else:
                value = Value(mb_data)
            result[si] = MinibatchData(value, num_sequences=end - begin, num_samples=actual_num_samples[si.name],
                                       sweep_end=at_end or (self._total_num_samples >= self._max_samples))

        # wrap around the cursor
        self._cursor = 0 if at_end else end

        return result
Esempio n. 2
0
def test_Value_raises():
    from cntk import NDArrayView, Value
    with pytest.raises(ValueError):
        nd = NDArrayView.from_dense(np.asarray([[[4, 5]]], dtype=np.float32))
        val = Value(nd)