def get_batch(self): # slice roidb cur_from = self.cur cur_to = min(cur_from + self.batch_size, self.size) roidb = [self.roidb[self.index[i]] for i in range(cur_from, cur_to)] # decide multi device slices work_load_list = self.work_load_list ctx = self.ctx if work_load_list is None: work_load_list = [1] * len(ctx) assert isinstance(work_load_list, list) and len(work_load_list) == len(ctx), \ "Invalid settings for work load. " slices = _split_input_slice(self.batch_size, work_load_list) im_array_list = [] levels_data_list = [] for islice in slices: iroidb = [roidb[i] for i in range(islice.start, islice.stop)] im_array, levels_data = get_fpn_maskrcnn_batch(iroidb) im_array_list.append(im_array) levels_data_list.append(levels_data) if not config.USE_AFP: all_data, all_label = self._make_data_and_labels(im_array_list, levels_data_list) else: all_data, all_label = self._make_data_and_labels_afp(im_array_list, levels_data_list) self.data = [mx.nd.array(all_data[name]) for name in self.data_name] self.label = [mx.nd.array(all_label[name]) for name in self.label_name]
def get_batch(self): # slice roidb cur_from = self.cur cur_to = min(cur_from + self.batch_size, self.size) roidb = [self.roidb[self.index[i]] for i in range(cur_from, cur_to)] # decide multi device slices work_load_list = self.work_load_list ctx = self.ctx if work_load_list is None: work_load_list = [1] * len(ctx) assert isinstance(work_load_list, list) and len(work_load_list) == len(ctx), \ "Invalid settings for work load. " slices = _split_input_slice(self.batch_size, work_load_list) im_array_list = [] levels_data_list = [] for islice in slices: iroidb = [roidb[i] for i in range(islice.start, islice.stop)] im_array, levels_data = get_fpn_maskrcnn_batch(iroidb) im_array_list.append(im_array) levels_data_list.append(levels_data) all_data, all_label = self._make_data_and_labels(im_array_list, levels_data_list) self.data = [mx.nd.array(all_data[name]) for name in self.data_name] self.label = [mx.nd.array(all_label[name]) for name in self.label_name]