コード例 #1
0
 def next(self):
     if self.iter_next():
         self.batch_counter += 1
         batch, label = self._data_buffer.get(block=True)
         return io.DataBatch(data=[batch], label=[label], pad=0)
     else:
         raise StopIteration
コード例 #2
0
    def next(self):
        """Returns the next batch of data."""
        batch_size = self.batch_size
        c, h, w = self.data_shape
        batch_data = nd.empty((batch_size, c, h, w))
        batch_label = nd.empty(self.provide_label[0][1])
        i = 0
        try:
            while i < batch_size:
                #pdb.set_trace()
                label, s = self.next_sample()
                data = self.imdecode(s)
                filepath = os.path.join(
                    self.path_root, self.imglist[self.seq[self.cur - 1]][1])
                data = cv2.imread(filepath, 1)
                data, label = self.aug_position(data, label.asnumpy())
                data = cv2.cvtColor(data, cv2.COLOR_BGR2RGB)
                data = mx.nd.array(data).as_in_context(mx.cpu())
                label = nd.array(label).as_in_context(mx.cpu())
                try:
                    self.check_valid_image(data)
                except RuntimeError as e:
                    logging.debug('Invalid image, skipping:  %s', str(e))
                    continue
                data = self.augmentation_transform(data)
                assert i < batch_size, 'Batch size must be multiples of augmenter output length'
                batch_data[i] = self.postprocess_data(data)
                batch_label[i] = label
                i += 1
        except StopIteration:
            if not i:
                raise StopIteration

        return io.DataBatch([batch_data], [batch_label], batch_size - i)
コード例 #3
0
 def next(self):
     if not self.is_init:
         self.reset()
         self.is_init = True
     self.nbatch += 1
     batch_size = self.batch_size
     c, h, w = self.data_shape
     batch_data = nd.empty((batch_size, c, h, w))
     gender_label = nd.empty(self.provide_label[0][1])
     age_label = nd.empty(self.provide_label[1][1])
     i = 0
     try:
         while i < batch_size:
             label, s, _, _ = self.next_sample()
             _data = mx.image.imdecode(s)
             if self.rand_mirror:
                 _rd = random.randint(0, 1)
                 if _rd == 1:
                     _data = mx.ndarray.flip(data=_data, axis=1)
             if _data.shape[0] == 0:
                 logging.debug('Invalid image,skipping')
                 continue
             batch_data[i][:] = self.postprocess_data(_data)
             gender_label[i][:] = label[0]
             age_label[i][:] = label[1]
             i += 1
     except StopIteration:
         if i < batch_size:
             raise StopIteration
     return io.DataBatch([batch_data], [gender_label, age_label],
                         batch_size - i)
コード例 #4
0
    def next(self):
        if not self.is_init:
          self.reset()
          self.is_init = True
        """Returns the next batch of data."""
        #print('in next', self.cur, self.labelcur)
        self.nbatch+=1
        batch_size = self.batch_size
        c, h, w = self.data_shape
        batch_data = nd.empty((batch_size * 2, c, h, w))
        batch_margin = nd.empty((batch_size * 2,))
        if self.provide_label is not None:
          batch_label = nd.empty(self.provide_label[0][1])
        i = 0
        try:
            dataset_idx = 0
            while i < batch_size:
                batch_margin[i] = self.margin(self.iteration)
                _label, s, bbox, landmark = self.next_sample(dataset_idx)
                if(len(self.seq) > 1):
                  label = np.ones([self.rec_num,]) * (-1)
                  label[dataset_idx] = _label
                else:
                  label = _label
                dataset_idx = (dataset_idx + 1) % self.rec_num
                _data = self.imdecode(s)

                #_data = self.augs.apply(_data)

                data = [_data]
                try:
                    self.check_valid_image(data)
                except RuntimeError as e:
                    logging.debug('Invalid image, skipping:  %s', str(e))
                    continue

                for datum in data:
                    assert i < batch_size, 'Batch size must be multiples of augmenter output length'
                    #print(datum.shape)
                    batch_data[i][:] = self.postprocess_data(datum)
                    batch_label[i][:] = label
                    i += 1
        except StopIteration:
            if i<batch_size:
                raise StopIteration

        for i in range(batch_size):
            batch_data[i + batch_size] = batch_data[i].copy()
            batch_data[i + batch_size][:, 56:, :] = 0

        db = mx.io.DataBatch(data=(batch_data,))
        self.mx_model.forward(db, is_train=False)
        net_out = self.mx_model.get_outputs()

        net_out[1] = net_out[1] / mx.nd.sqrt((net_out[1] * net_out[1]).sum(1).reshape([1, 1]))

        self.iteration += 1
        return io.DataBatch([net_out[0][:batch_size, :, :, :], net_out[0][batch_size:, :, :, :], 
                             net_out[1][:batch_size, :], net_out[1][batch_size:, :]], 
                             [batch_label], batch_size - i)
コード例 #5
0
    def forward(self, data_batch, is_train=None):
        #  g,x = self.get_params()
        #  #print('{fc7_weight[0][0]}', self._iter, g['fc7_0_weight'].asnumpy()[0][0])
        #  #print('{pre_fc1_weight[0][0]}', self._iter, g['pre_fc1_weight'].asnumpy()[0][0])
        #  print('{fc7_weight[0][0]}', self._iter, np.max(g['fc7_0_weight'].asnumpy()))
        #  print('{pre_fc1_weight[0][0]}', self._iter, np.max(g['pre_fc1_weight'].asnumpy()))
        #  print('{fc7_weight min', self._iter, np.min(g['fc7_0_weight'].asnumpy()))
        #  print('{pre_fc1_weight min', self._iter, np.min(g['pre_fc1_weight'].asnumpy()))
        #  print('stn params locw', self._iter, np.max(g['locw'].asnumpy()))
        #  print('stn params locb', self._iter, np.max(g['locb'].asnumpy()))


        assert self.binded and self.params_initialized
        self._curr_module.forward(data_batch, is_train=is_train)
        if is_train:
          self._iter+=1
          fc1, label = self._curr_module.get_outputs(merge_multi_context=True)
          global_fc1 = fc1
          self.global_label = label.as_in_context(self._ctx_cpu)


          #  print('for debug, fc1 max', mx.nd.max(global_fc1))
          #  print('for debug, fc1 min', mx.nd.min(global_fc1))
          for i, _module in enumerate(self._arcface_modules):
            _label = self.global_label - self._ctx_class_start[i]
            db_global_fc1 = io.DataBatch([global_fc1], [_label])
            _module.forward(db_global_fc1) #fc7 with margin
コード例 #6
0
    def next(self):
        if not self.is_init:
            self.reset()
            self.is_init = True
        """Returns the next batch of data."""
        #print('in next', self.cur, self.labelcur)
        self.nbatch += 1
        batch_size = self.batch_size
        c, h, w = self.data_shape
        batch_data = nd.empty((batch_size, c, h, w))
        if self.provide_label is not None:
            batch_label = nd.empty(self.provide_label[0][1])
        i = 0
        try:
            while i < batch_size:
                label, s, bbox, landmark = self.next_sample()
                _data = self.imdecode(s)
                #print("sample img shape ",np.shape(_data))
                _data = self.resize_img(_data)
                if self.rand_mirror:
                    _rd = random.randint(0, 1)
                    if _rd == 1:
                        _data = mx.ndarray.flip(data=_data, axis=1)
                if self.nd_mean is not None:
                    _data = _data.astype('float32')
                    _data -= self.nd_mean
                    _data *= 0.0078125
                if self.cutoff > 0:
                    centerh = random.randint(0, _data.shape[0] - 1)
                    centerw = random.randint(0, _data.shape[1] - 1)
                    half = self.cutoff // 2
                    starth = max(0, centerh - half)
                    endh = min(_data.shape[0], centerh + half)
                    startw = max(0, centerw - half)
                    endw = min(_data.shape[1], centerw + half)
                    _data = _data.astype('float32')
                    #print(starth, endh, startw, endw, _data.shape)
                    _data[starth:endh, startw:endw, :] = 127.5
                data = [_data]
                #print("next data shpe ",np.shape(_data))
                try:
                    self.check_valid_image(data)
                except RuntimeError as e:
                    logging.debug('Invalid image, skipping:  %s', str(e))
                    continue
                #print('aa',data[0].shape)
                #data = self.augmentation_transform(data)
                #print('bb',data[0].shape)
                for datum in data:
                    assert i < batch_size, 'Batch size must be multiples of augmenter output length'
                    #print(datum.shape)
                    batch_data[i][:] = self.postprocess_data(datum)
                    batch_label[i][:] = label
                    i += 1
        except StopIteration:
            if i < batch_size:
                raise StopIteration

        return io.DataBatch([batch_data], [batch_label], batch_size - i)
コード例 #7
0
    def next(self):
        if not self.is_init:
          self.reset()
          self.is_init = True
        """Returns the next batch of data."""
        #print('in next', self.cur, self.labelcur)
        self.nbatch+=1
        batch_size = self.batch_size
        i = 0
        try:
            while i < batch_size:
                label, s, bbox, landmark = self.next_sample()
                try:
                    _data = self.imdecode(s)
                except Exception as e:
                    logging.debug('Invalid decoding, skipping:  %s', str(e))
                    continue
                if _data.shape[0]!=self.data_shape[1]:
                  _data = mx.image.resize_short(_data, self.data_shape[1])
                _data = _data.asnumpy().astype(np.float32)
                if self.rand_mirror and np.random.rand()<0.5:
                    _data = _data[:,::-1,:]
                if self.color_jittering>0:
                  if self.color_jittering>1:
                    _rd = random.randint(0,1)
                    if _rd==1:
                      _data = self.compress_aug(_data)
                  _data = self.color_aug(_data, 0.125)
                #if self.nd_mean is not None:
                #  _data -= self.nd_mean
                #  _data *= 0.0078125
                if self.cutoff>0:
                  _rd = random.randint(0,1)
                  if _rd==1:
                    #print('do cutoff aug', self.cutoff)
                    centerh = random.randint(0, _data.shape[0]-1)
                    centerw = random.randint(0, _data.shape[1]-1)
                    half = self.cutoff//2
                    starth = max(0, centerh-half)
                    endh = min(_data.shape[0], centerh+half)
                    startw = max(0, centerw-half)
                    endw = min(_data.shape[1], centerw+half)
                    #print(starth, endh, startw, endw, _data.shape)
                    _data[starth:endh, startw:endw, :] = 128
                #_data -= 127.5
                #_data /= 128.0
                #self.batch_data[i] = self.postprocess_data(_data)
                _data = _data.transpose( (2,0,1) )
                self.batch_data[i] = _data
                self.batch_label[i] = label
                i += 1
        except StopIteration:
            if i<batch_size:
                raise StopIteration

        batch_data = nd.array(self.batch_data)
        batch_label = nd.array(self.batch_label)
        return io.DataBatch([batch_data], [batch_label], batch_size - i)
コード例 #8
0
 def __call__(self, data):
     """Collate train data into batch."""
     img_data = nd.stack(*[item[0] for item in data])
     gt_bboxes = _pad_arrs_to_max_length([item[1] for item in data],
                                         self._max_gt_box_number,
                                         pad_axis=0, pad_val=-1)
     batch_data = io.DataBatch(data=[img_data],
                               label=[gt_bboxes])
     return batch_data
コード例 #9
0
ファイル: data.py プロジェクト: zzmcdc/insightocr
    def next(self):
        """Returns the next batch of data."""
        #print('in next', self.cur, self.labelcur)
        #self.nbatch+=1
        batch_size = self.batch_size
        #c, h, w = self.data_shape
        batch_data = nd.empty(self.provide_data[0][1])
        batch_label = nd.empty(self.provide_label[0][1])
        i = 0
        try:
            while i < batch_size:
                item = self.next_sample()
                with open(item['image_path'], 'rb') as fin:
                    img = fin.read()
                try:
                    #if config.to_gray:
                    #  _data = mx.image.imdecode(img, flag=0) #to gray
                    #else:
                    #  _data = mx.image.imdecode(img)
                    #self.check_valid_image(_data)
                    img = np.fromstring(img, np.uint8)
                    if config.to_gray:
                      _data = cv2.imdecode(img, cv2.IMREAD_GRAYSCALE)
                    else:
                      _data = cv2.imdecode(img, cv2.IMREAD_COLOR)
                      _data = cv2.cvtColor(_data, cv2.COLOR_BGR2RGB)
                    if _data.shape[0]!=config.img_height or _data.shape[1]!=config.img_width:
                      _data = cv2.resize(_data, (config.img_width, config.img_height) )
                except RuntimeError as e:
                    logging.debug('Invalid image, skipping:  %s', str(e))
                    continue
                _data = mx.nd.array(_data)
                #print(_data.shape)
                #if _data.shape[0]!=config.img_height or _data.shape[1]!=config.img_width:
                #  _data = self.resize_aug(_data)
                #print(_data.shape)
                _data = _data.astype('float32')
                _data -= 127.5
                _data *= 0.0078125
                data = [_data]
                label = item['label']
                for datum in data:
                    assert i < batch_size, 'Batch size must be multiples of augmenter output length'
                    #print(datum.shape)
                    batch_data[i][:] = self.postprocess_data(datum)
                    batch_label[i][:] = label
                    i += 1
        except StopIteration:
            if i<batch_size:
                raise StopIteration

        data_all = [batch_data]
        if config.use_lstm:
          data_all += self.init_state_arrays
        return io.DataBatch(data_all, [batch_label], batch_size - i)
コード例 #10
0
    def forward(self, data_batch, is_train=None):
        assert self.binded and self.params_initialized
        self._backbone_module.forward(data_batch, is_train=is_train)
        if is_train:
            self._iter += 1
            fc1, label = self._backbone_module.get_outputs(
                merge_multi_context=True)
            global_fc1 = fc1
            self.global_label = label.as_in_context(self._ctx_single_gpu)

            for i, _module in enumerate(self._arcface_modules):
                _label = self.global_label - self._ctx_class_start[i]
                db_global_fc1 = io.DataBatch([global_fc1], [_label])
                _module.forward(db_global_fc1)  #fc7 with margin
コード例 #11
0
    def next(self):
        if not self.is_init:
            self.reset()
            self.is_init = True
        """Returns the next batch of data."""
        #print('in next', self.cur, self.labelcur)
        self.nbatch += 1
        batch_size = self.batch_size
        c, h, w = self.data_shape
        batch_data = nd.empty((batch_size, c, h, w))
        if self.provide_label is not None:
            batch_label = nd.empty(self.provide_label[0][1])
        i = 0
        try:
            while i < batch_size:
                label, s, bbox, landmark = self.next_sample()
                _data = self.imdecode(s)
                _data = _data.astype('float32')
                _data = image.RandomGrayAug(.2)(_data)
                if random.random() < 0.2:
                    _data = image.ColorJitterAug(0.2, 0.2, 0.2)(_data)
                if self.rand_mirror:
                    _rd = random.randint(0, 1)
                    if _rd == 1:
                        _data = mx.ndarray.flip(data=_data, axis=1)
                if self.nd_mean is not None:
                    _data = _data.astype('float32')
                    _data -= self.nd_mean
                    _data *= 0.0078125
                data = [_data]
                try:
                    self.check_valid_image(data)
                except RuntimeError as e:
                    logging.debug('Invalid image, skipping:  %s', str(e))
                    continue
                #print('aa',data[0].shape)
                #data = self.augmentation_transform(data)
                #print('bb',data[0].shape)
                for datum in data:
                    assert i < batch_size, 'Batch size must be multiples of augmenter output length'
                    #print(datum.shape)
                    batch_data[i][:] = self.postprocess_data(datum)
                    batch_label[i][:] = label
                    i += 1
        except StopIteration:
            if i < batch_size:
                raise StopIteration

        return io.DataBatch([batch_data], [batch_label], batch_size - i)
コード例 #12
0
 def next(self):
     data = self.iter.next()
     img = data['img']
     # print(img)
     img = img.transpose(0, 3, 1, 2)
     batch_data = nd.array(img, dtype=np.uint8).astype(np.float32)
     # batch_data = nd.array(img)
     # print(batch_data.shape)
     # batch_data = nd.transpose(batch_data, axes=(0, 3, 1, 2))
     batch_label = nd.array(data['label'])  # .astype(np.int32)
     batch_index = None
     if 'index' in data:
         batch_index = nd.array(data['index'])
     # print('batch dtype {}'.format(batch_label.dtype))
     return io.DataBatch([batch_data], [batch_label], index=batch_index)
コード例 #13
0
 def __call__(self, data):
     """Collate train data into batch."""
     img_data = nd.stack(*[item[0] for item in data])
     center_targets = nd.stack(*[item[1] for item in data])
     scale_targets = nd.stack(*[item[2] for item in data])
     weights = nd.stack(*[item[3] for item in data])
     objectness = nd.stack(*[item[4] for item in data])
     class_targets = nd.stack(*[item[5] for item in data])
     gt_bboxes = _pad_arrs_to_max_length([item[6] for item in data],
                                         self._max_gt_box_number,
                                         pad_axis=0, pad_val=-1)
     batch_data = io.DataBatch(data=[img_data],
                               label=[gt_bboxes, objectness, center_targets,
                                      scale_targets, weights, class_targets])
     return batch_data
コード例 #14
0
ファイル: utils.py プロジェクト: ms-krajesh/squeezeDetMX
 def next(self):
     """Yield the next datum for MXNet to run."""
     batch_images = nd.empty((self.batch_size, *self.img_shape))
     batch_labels = []
     for i in range(self.batch_size):
         batch_images[i][:] = self.image_to_mx(self.read_image())
         batch_labels.append(self.read_label())
         if self.record:
             self.bytedata = self.record.read()
     batch_label_box, batch_label_class, batch_label_score = \
         self.batch_label_to_mx(batch_labels)
     return io.DataBatch(
         [batch_images],
         [batch_label_box, batch_label_score, batch_label_class],
         self.batch_size - 1 - i)
コード例 #15
0
    def next(self):
        if not self.is_init:
            self.reset()
            self.is_init=True
        self.nbatch+=1
        batch_size=self.batch_size
        c,h,w=self.data_shape
        batch_data=nd.empty((batch_size,c,h,w))
        # gender_label=nd.empty((batch_size,self.provide_label[0][1]))
        age_label=nd.empty(self.provide_label[0][1])
        i=0

        try:
            while i<batch_size:
                label, s, _, _=self.next_sample()
                _data=mx.image.imdecode(s).asnumpy() #imdecode输入为opencv格式,输出会将bgr=>rgb
                # if True:
                #     cv2.imshow('_data',_data[:,:,::-1])
                #     key = cv2.waitKey(1000)

                if np.random.random()>0.5:
                    _data=_data[:,::-1]
                if np.random.random()>0.75:
                    _data=self.random_rotation(x=_data,rg=20,row_axis=0,col_axis=1,channel_axis=2)
                if np.random.random()>0.75:
                    _data=self.random_shear(x=_data,intensity=0.2,row_axis=0,col_axis=1,channel_axis=2)
                if np.random.random()>0.75:
                    _data=self.random_shift(x=_data,wrg=0.2,hrg=0.2,row_axis=0,col_axis=1,channel_axis=2)
                if np.random.random()>0.75:
                    _data=self.random_zoom(x=_data,zoom_range=[0.8,1.2],row_axis=0,col_axis=1,channel_axis=2)

                batch_data[i][:]=self.postprocess_data(mx.nd.array(_data))
                # gender_label[i][:]=label[1]
                age_label[i][:]=label[0]
                i+=1
                # if True:
                #     cv2.imshow('aur_data',_data[:,:,::-1])
                #     cv2.waitKey(1000)
        except StopIteration:
            if i<batch_size:
                raise StopIteration
        return io.DataBatch([batch_data], [age_label], batch_size - i)
コード例 #16
0
def get_input_shape(sym, proto_obj):
    """Helper function to obtain the shape of an array"""
    arg_params = proto_obj.arg_dict
    aux_params = proto_obj.aux_dict

    model_input_shape = [
        data[1] for data in proto_obj.model_metadata.get('input_tensor_data')
    ]
    data_names = [
        data[0] for data in proto_obj.model_metadata.get('input_tensor_data')
    ]

    # creating dummy inputs
    inputs = []
    for in_shape in model_input_shape:
        inputs.append(nd.ones(shape=in_shape))

    data_shapes = []
    for idx, input_name in enumerate(data_names):
        data_shapes.append((input_name, inputs[idx].shape))

    ctx = context.cpu()
    # create a module
    mod = module.Module(symbol=sym,
                        data_names=data_names,
                        context=ctx,
                        label_names=None)
    mod.bind(for_training=False, data_shapes=data_shapes, label_shapes=None)
    mod.set_params(arg_params=arg_params, aux_params=aux_params)

    data_forward = []
    for idx, input_name in enumerate(data_names):
        val = inputs[idx]
        data_forward.append(val)

    mod.forward(io.DataBatch(data_forward))
    result = mod.get_outputs()[0].asnumpy()

    return result.shape
コード例 #17
0
    def forward(self, data_batch, is_train=None):
        #g,x = self.get_params()
        #print('{fc7_weight[0][0]}', self._iter, g['fc7_0_weight'].asnumpy()[0][0])
        #print('{pre_fc1_weight[0][0]}', self._iter, g['pre_fc1_weight'].asnumpy()[0][0])

        assert self.binded and self.params_initialized
        self._curr_module.forward(data_batch, is_train=is_train)
        if is_train:
            self._iter += 1
            fc1, label = self._curr_module.get_outputs(
                merge_multi_context=True)
            global_fc1 = fc1
            self.global_label = label.as_in_context(self._ctx_cpu).reshape([
                -1,
            ])

            for i, _module in enumerate(self._arcface_modules):
                _label = self.global_label - self._ctx_class_start[i]
                _label = _label.reshape([
                    -1,
                ])
                db_global_fc1 = io.DataBatch([global_fc1], [_label])
                _module.forward(db_global_fc1)  #fc7 with margin
コード例 #18
0
    def next(self):
        if self.need_init:
            self.reset()
            self.need_init = False
        """Returns the next batch of data."""
        #print('in next', self.cur, self.labelcur)
        self.nbatch += 1
        batch_size = self.batch_size
        c, h, w = self.data_shape
        batch_data = nd.empty((batch_size, c, h, w))
        if self.provide_label is not None:
            batch_label = nd.empty(self.provide_label[0][1])

        i = 0
        try:
            while i < batch_size:
                _label, s, pos_flag, score = self.next_sample()
                label = pos_flag
                _data = self.imdecode(s)
                _data = self.augs.apply(_data)

                data = [_data]
                try:
                    self.check_valid_image(data)
                except RuntimeError as e:
                    logging.debug('Invalid image, skipping:  %s', str(e))
                    continue
                for datum in data:
                    assert i < batch_size, 'Batch size must be multiples of augmenter output length'
                    batch_data[i][:] = self.postprocess_data(datum)
                    batch_label[i][:] = label
                    i += 1
        except StopIteration:
            if i < batch_size:
                raise StopIteration
        self.iteration += 1
        return io.DataBatch([batch_data], [batch_label], batch_size - i)
コード例 #19
0
    def next(self):
        if not self.is_init:
            self.reset()
            self.is_init = True
        """Returns the next batch of data."""
        #print('in next', self.cur, self.labelcur)
        self.nbatch += 1
        batch_size = self.batch_size
        c, h, w = self.data_shape
        batch_data = nd.empty((batch_size, c, h, w))
        if self.provide_label is not None:
            batch_label = nd.empty(self.provide_label[0][1])
        i = 0
        try:
            while i < batch_size:
                #print('XXXX', i)
                label, s, bbox, landmark = self.next_sample()
                gender = int(label[0])
                age = int(label[1])
                assert age >= 0
                #assert gender==0 or gender==1
                plabel = np.zeros(shape=(101, ), dtype=np.float32)
                plabel[0] = gender
                if age == 0:
                    age = 1
                if age > 100:
                    age = 100
                plabel[1:age + 1] = 1
                label = plabel
                _data = self.imdecode(s)
                if _data.shape[0] != self.data_shape[1]:
                    _data = mx.image.resize_short(_data, self.data_shape[1])
                if self.rand_mirror:
                    _rd = random.randint(0, 1)
                    if _rd == 1:
                        _data = mx.ndarray.flip(data=_data, axis=1)
                if self.color_jittering > 0:
                    if self.color_jittering > 1:
                        _rd = random.randint(0, 1)
                        if _rd == 1:
                            _data = self.compress_aug(_data)
                    #print('do color aug')
                    _data = _data.astype('float32', copy=False)
                    #print(_data.__class__)
                    _data = self.color_aug(_data, 0.125)
                if self.nd_mean is not None:
                    _data = _data.astype('float32', copy=False)
                    _data -= self.nd_mean
                    _data *= 0.0078125
                if self.cutoff > 0:
                    _rd = random.randint(0, 1)
                    if _rd == 1:
                        #print('do cutoff aug', self.cutoff)
                        centerh = random.randint(0, _data.shape[0] - 1)
                        centerw = random.randint(0, _data.shape[1] - 1)
                        half = self.cutoff // 2
                        starth = max(0, centerh - half)
                        endh = min(_data.shape[0], centerh + half)
                        startw = max(0, centerw - half)
                        endw = min(_data.shape[1], centerw + half)
                        #print(starth, endh, startw, endw, _data.shape)
                        _data[starth:endh, startw:endw, :] = 128
                data = [_data]
                for datum in data:
                    assert i < batch_size, 'Batch size must be multiples of augmenter output length'
                    #print(datum.shape)
                    batch_data[i][:] = self.postprocess_data(datum)
                    batch_label[i][:] = label
                    i += 1
        except StopIteration:
            if i < batch_size:
                raise StopIteration

        return io.DataBatch([batch_data], [batch_label], batch_size - i)
コード例 #20
0
    def next(self):
        if not self.is_init:
            self.reset()
            self.is_init = True
        """Returns the next batch of data."""
        self.nbatch += 1
        c, h, w = self.data_shape
        batch_data_srctar = nd.empty((2 * self.batch_size_src, c, h, w))
        batch_data_t = nd.empty((2 * self.batch_size_src, c, h, w))
        batch_label_srctar = nd.empty(self.provide_label_srctar[0][1])
        batch_label_t = nd.empty(self.provide_label_srctar[0][1])

        batch_data_tar = nd.empty((self.batch_size_src, c, h, w))
        batch_data_adv = nd.empty((self.batch_size_src, c, h, w))
        batch_data = nd.empty((3 * self.batch_size_src, c, h, w))
        batch_label = nd.empty(self.provide_label[0][1])

        arg_t, aux_t = self.model.get_params()
        self.model_adv.set_params(arg_t, aux_t)
        #print("update model_adv params")
        #time_now2 = datetime.datetime.now()
        #print("update params time", time_now2-time_now1)

        i = 0
        try:
            while i < self.batch_size_src:
                label, img = self.next_sample1()

                batch_data_srctar[i][:] = img
                batch_label_srctar[i][:] = label
                i += 1
        except StopIteration:
            if i < self.batch_size_src:
                raise StopIteration
        try:
            while i < 2 * self.batch_size_src:
                label, img = self.next_sample2()
                #print(img)
                #img_show = np.squeeze(img)
                #img_show = img_show.astype(np.uint8)
                #print("img.shape:", img_show.shape)
                #plt.imshow(img_show)
                #plt.show()
                batch_data_srctar[i][:] = img
                batch_label_srctar[i][:] = label
                i += 1
        except StopIteration:
            if i < 2 * self.batch_size_src:
                raise StopIteration
        #print("batch_label_srctar:", batch_label_srctar)
        margin = self.batch_size_src // self.ctx_num
        #print("margin: ",margin)
        for i in xrange(self.ctx_num):
            batch_data_t[2 * i * margin:(2 * i + 1) *
                         margin][:] = batch_data_srctar[i * margin:(i + 1) *
                                                        margin][:]
            batch_data_t[
                (2 * i + 1) * margin:2 * (i + 1) *
                margin][:] = batch_data_srctar[self.batch_size_src +
                                               i * margin:self.batch_size_src +
                                               (i + 1) * margin][:]
        for i in xrange(self.ctx_num):
            batch_label_t[2 * i * margin:(2 * i + 1) *
                          margin][:] = batch_label_srctar[i * margin:(i + 1) *
                                                          margin][:]
            batch_label_t[
                (2 * i + 1) * margin:2 * (i + 1) *
                margin][:] = batch_label_srctar[self.batch_size_src + i *
                                                margin:self.batch_size_src +
                                                (i + 1) * margin][:]

        #print("batch_data_t:", batch_data_t[0][:],batch_data_t)
        batch_data_t_o = batch_data_t
        db = mx.io.DataBatch([batch_data_t])
        self.model_adv.forward(db, is_train=True)
        ori_out = self.model_adv.get_outputs()[-1].asnumpy()
        #print("ori_dis: ", ori_out)
        self.model_adv.backward()
        grad = self.model_adv.get_input_grads()[0]
        #print("grad: ", grad)
        grad = mx.nd.array(grad)
        #print("batch_data_t: ", batch_data_t.asnumpy().shape)

        for i in xrange(self.ctx_num):
            batch_data_t[2 * i * margin:(2 * i + 1) *
                         margin][:] -= 1 / 255 * mx.nd.sign(
                             grad[2 * i * margin:(2 * i + 1) * margin][:])
            batch_data_t[2 * i * margin:(2 * i + 1) *
                         margin][:] = mx.nd.maximum(
                             mx.nd.maximum(
                                 batch_data_t[2 * i * margin:(2 * i + 1) *
                                              margin][:],
                                 batch_data_t_o[2 * i * margin:(2 * i + 1) *
                                                margin][:] - self.sigma),
                             mx.nd.zeros_like(
                                 batch_data_t[2 * i * margin:(2 * i + 1) *
                                              margin][:]))
            batch_data_t[2 * i * margin:(2 * i + 1) *
                         margin][:] = mx.nd.minimum(
                             batch_data_t[2 * i * margin:(2 * i + 1) *
                                          margin][:],
                             mx.nd.minimum(
                                 batch_data_t_o[2 * i * margin:(2 * i + 1) *
                                                margin][:] + self.sigma,
                                 mx.nd.ones_like(
                                     batch_data_t[2 * i * margin:(2 * i + 1) *
                                                  margin][:])))
        #print("first")
        for i in range(0, self.round - 1):
            db = mx.io.DataBatch([batch_data_t])
            self.model_adv.forward(db, is_train=True)
            adv_out = self.model_adv.get_outputs()[-1].asnumpy()
            #print("adv_dis: ", i, adv_out, np.max(adv_out))
            if np.max(adv_out) > self.thd:
                self.model_adv.backward()
                grad = self.model_adv.get_input_grads()[0]
                grad = mx.nd.array(grad)
                for i in xrange(self.ctx_num):
                    batch_data_t[2 * i * margin:(2 * i + 1) *
                                 margin][:] -= 1 / 255 * mx.nd.sign(
                                     grad[2 * i * margin:(2 * i + 1) *
                                          margin][:])
                    batch_data_t[2 * i * margin:(
                        2 * i + 1) * margin][:] = mx.nd.maximum(
                            mx.nd.maximum(
                                batch_data_t[2 * i * margin:(2 * i + 1) *
                                             margin][:],
                                batch_data_t_o[2 * i * margin:(2 * i + 1) *
                                               margin][:] - self.sigma),
                            mx.nd.zeros_like(
                                batch_data_t[2 * i * margin:(2 * i + 1) *
                                             margin][:]))
                    batch_data_t[2 * i * margin:(2 * i + 1) *
                                 margin][:] = mx.nd.minimum(
                                     batch_data_t[2 * i * margin:(2 * i + 1) *
                                                  margin][:],
                                     mx.nd.minimum(
                                         batch_data_t_o[2 * i *
                                                        margin:(2 * i + 1) *
                                                        margin][:] +
                                         self.sigma,
                                         mx.nd.ones_like(
                                             batch_data_t[2 * i *
                                                          margin:(2 * i + 1) *
                                                          margin][:])))
            else:
                #print("adv_dis: ", i)
                break
        db = mx.io.DataBatch([batch_data_t])
        self.model_adv.forward(db, is_train=True)
        adv_out = self.model_adv.get_outputs()[-1].asnumpy()
        #print("adv_dis: ", adv_out)
        '''
        for i in xrange(5):
            imgadv_show = np.squeeze(batch_data_t[i][0][:].asnumpy())
            imgadv_show = imgadv_show.astype(np.uint8)
            print("imgadv_show.type: ", imgadv_show.astype)
            #imgadv_show = np.transpose(imgadv_show, (1, 2, 0))
            plt.imshow(imgadv_show)
            plt.show()
        '''
        for i in xrange(self.ctx_num):
            batch_data_adv[i * margin:(i + 1) *
                           margin][:] = batch_data_t[2 * i *
                                                     margin:(2 * i + 1) *
                                                     margin][:]

        batch_data_src = batch_data_srctar[0:self.batch_size_src][:]
        batch_data_tar = batch_data_srctar[self.batch_size_src:2 *
                                           self.batch_size_src][:]

        #for i in xrange(self.ctx_num):
        #    batch_data_tar[i * margin: (i + 1) * margin][:] = batch_data_t[(2 * i + 1) * margin:2 * (i + 1) * margin][:]

        batch_label_src = batch_label_srctar[0:self.batch_size_src][:]
        batch_label_tar = batch_label_srctar[self.batch_size_src:2 *
                                             self.batch_size_src][:]
        #print("labels: " , batch_label_src , batch_label_tar)

        margin = self.batch_size_src // self.main_ctx_num  # 30
        for i in xrange(self.main_ctx_num):  # 0 1 2 3
            batch_data[margin * 3 * i:margin * 3 * i +
                       margin][:] = batch_data_src[margin * i:margin * i +
                                                   margin][:]
            batch_data[margin * 3 * i + margin:margin * 3 * i +
                       2 * margin][:] = batch_data_tar[margin * i:margin * i +
                                                       margin][:]
            batch_data[margin * 3 * i + 2 * margin:margin * 3 * i +
                       3 * margin][:] = batch_data_adv[margin * i:margin * i +
                                                       margin][:]

        for i in xrange(self.main_ctx_num):
            batch_label[margin * 3 * i:margin * 3 * i +
                        margin][:] = batch_label_src[margin * i:margin * i +
                                                     margin][:]
            batch_label[margin * 3 * i + margin:margin * 3 * i +
                        2 * margin][:] = batch_label_tar[margin *
                                                         i:margin * i +
                                                         margin][:]
            batch_label[margin * 3 * i + 2 * margin:margin * 3 * i +
                        3 * margin][:] = batch_label_src[margin *
                                                         i:margin * i +
                                                         margin][:]

        #print('batch_label: ', batch_label)
        '''
        for i in xrange(2):
            imgadv_show = np.squeeze(batch_data[i][0][:].asnumpy())
            print(imgadv_show)
            #imgadv_show = imgadv_show.astype(np.uint8)
            print(imgadv_show)
            print("imgadv_show.type: ", imgadv_show.astype)
            #imgadv_show = np.transpose(imgadv_show, (1, 2, 0))
            plt.imshow(imgadv_show)
            plt.show()
        '''
        return io.DataBatch([batch_data], [batch_label])
コード例 #21
0
    def next(self):
        if not self.is_init:
          self.reset()
          self.is_init = True
        """Returns the next batch of data."""
        #print('in next', self.cur, self.labelcur)
        self.nbatch+=1
        batch_size = self.batch_size
        c, h, w = self.data_shape
        batch_data = nd.empty((batch_size, c, h, w))
        if self.provide_label is not None:
          batch_label = nd.empty(self.provide_label[0][1])
        i = 0
        try:
            while i < batch_size:
                label, s, bbox, landmark = self.next_sample()
                _data = self.imdecode(s)
                if self.rand_mirror:
                  _rd = random.randint(0,1)
                  if _rd==1:
                    _data = mx.ndarray.flip(data=_data, axis=1)
                if self.cutoff>0:
                  centerh = random.randint(0, _data.shape[0]-1)
                  centerw = random.randint(0, _data.shape[1]-1)
                  half = self.cutoff//2
                  starth = max(0, centerh-half)
                  endh = min(_data.shape[0], centerh+half)
                  startw = max(0, centerw-half)
                  endw = min(_data.shape[1], centerw+half)
                  _data = _data.astype('float32')
                  #print(starth, endh, startw, endw, _data.shape)
                  _data[starth:endh, startw:endw, :] = 127.5
                #_npdata = _data.asnumpy()
                #if landmark is not None:
                #  _npdata = face_preprocess.preprocess(_npdata, bbox = bbox, landmark=landmark, image_size=self.image_size)
                #if self.rand_mirror:
                #  _npdata = self.mirror_aug(_npdata)
                #if self.mean is not None:
                #  _npdata = _npdata.astype(np.float32)
                #  _npdata -= self.mean
                #  _npdata *= 0.0078125
                #nimg = np.zeros(_npdata.shape, dtype=np.float32)
                #nimg[self.patch[1]:self.patch[3],self.patch[0]:self.patch[2],:] = _npdata[self.patch[1]:self.patch[3], self.patch[0]:self.patch[2], :]
                #_data = mx.nd.array(nimg)
                data = [_data]
                try:
                    self.check_valid_image(data)
                except RuntimeError as e:
                    logging.debug('Invalid image, skipping:  %s', str(e))
                    continue
                #print('aa',data[0].shape)
                #data = self.augmentation_transform(data)
                #print('bb',data[0].shape)
                for datum in data:
                    assert i < batch_size, 'Batch size must be multiples of augmenter output length'
                    #print(datum.shape)
                    batch_data[i][:] = self.postprocess_data(datum)
                    if self.provide_label is not None:
                      batch_label[i][:] = label
                    i += 1
        except StopIteration:
            if i<batch_size:
                raise StopIteration

        #print('next end', batch_size, i)
        _label = None
        if self.provide_label is not None:
          _label = [batch_label]
        return io.DataBatch([batch_data], _label, batch_size - i)
コード例 #22
0
    def next(self):
        if not self.is_init:
            self.reset()
            self.is_init = True
        """Returns the next batch of data."""
        #print('in next', self.cur, self.labelcur)
        self.nbatch += 1
        batch_size = self.batch_size
        c, h, w = self.data_shape
        batch_data = nd.empty((batch_size, c, h, w))
        if self.provide_label is not None:
            batch_label = nd.empty(self.provide_label[0][1])
        i = 0
        try:
            while i < batch_size:
                label, s, bbox, landmark = self.next_sample()
                _data = self.imdecode(s)
                if _data.shape[0] != self.data_shape[1]:
                    _data = mx.image.resize_short(_data, self.data_shape[1])
                if self.rand_mirror:
                    _rd = random.randint(0, 1)
                    if _rd == 1:
                        _data = mx.ndarray.flip(data=_data, axis=1)
                if self.blur:
                    aug_blur = iaa.Sequential([
                        iaa.OneOf([
                            iaa.GaussianBlur(sigma=(0.5, 2.5)),
                            iaa.AverageBlur(k=(2, 5)),
                            iaa.MotionBlur(k=(5, 7)),
                            iaa.BilateralBlur(d=(3, 4),
                                              sigma_color=(10, 250),
                                              sigma_space=(10, 250)),
                            iaa.imgcorruptlike.DefocusBlur(severity=1),
                            iaa.imgcorruptlike.GlassBlur(severity=1),
                            iaa.imgcorruptlike.Pixelate(severity=(1, 3)),
                            iaa.Pepper(0.01),
                            iaa.AdditiveGaussianNoise(scale=(0, 0.1 * 255),
                                                      per_channel=True),
                            iaa.imgcorruptlike.SpeckleNoise(severity=1),
                            iaa.imgcorruptlike.JpegCompression(severity=(1,
                                                                         4)),
                        ])
                    ])
                    _rd = random.randint(0, 1)
                    if _rd == 1:
                        _data = aug_blur(images=_data)

                if self.maxpooling:
                    maxpool_aug = iaa.MaxPooling(2)
                    _rd = random.randint(0, 1)
                    if _rd == 1:
                        _data = maxpool_aug(images=_data)

                if self.color_jittering > 0:
                    if self.color_jittering > 1:
                        _rd = random.randint(0, 1)
                        if _rd == 1:
                            _data = self.compress_aug(_data)
                    #print('do color aug')
                    _data = _data.astype('float32', copy=False)
                    #print(_data.__class__)
                    _data = self.color_aug(_data, 0.125)
                if self.nd_mean is not None:
                    _data = _data.astype('float32', copy=False)
                    _data -= self.nd_mean
                    _data *= 0.0078125
                if self.cutoff > 0:
                    _rd = random.randint(0, 1)
                    if _rd == 1:
                        #print('do cutoff aug', self.cutoff)
                        centerh = random.randint(0, _data.shape[0] - 1)
                        centerw = random.randint(0, _data.shape[1] - 1)
                        half = self.cutoff // 2
                        starth = max(0, centerh - half)
                        endh = min(_data.shape[0], centerh + half)
                        startw = max(0, centerw - half)
                        endw = min(_data.shape[1], centerw + half)
                        #print(starth, endh, startw, endw, _data.shape)
                        _data[starth:endh, startw:endw, :] = 128
                data = [_data]
                try:
                    self.check_valid_image(data)
                except RuntimeError as e:
                    logging.debug('Invalid image, skipping:  %s', str(e))
                    continue
                #print('aa',data[0].shape)
                #data = self.augmentation_transform(data)
                #print('bb',data[0].shape)
                for datum in data:
                    assert i < batch_size, 'Batch size must be multiples of augmenter output length'
                    #print(datum.shape)
                    batch_data[i][:] = self.postprocess_data(datum)
                    batch_label[i][:] = label
                    i += 1
        except StopIteration:
            if i < batch_size:
                raise StopIteration

        return io.DataBatch([batch_data], [batch_label], batch_size - i)
コード例 #23
0
    def next(self):
        if not self.is_init:
            self.reset()
            self.is_init = True
        """Returns the next batch of data."""
        #print('in next', self.cur, self.labelcur)
        self.nbatch += 1
        batch_size = self.batch_size
        c, h, w = self.data_shape
        batch_data = nd.empty((batch_size, c, h, w))
        features_data = np.zeros((batch_size, 512))
        if self.provide_label is not None:
            batch_label = nd.empty(self.provide_label[0][1])
        i = 0
        try:
            while i < batch_size:
                label, s, bbox, landmark = self.next_sample()
                _data = self.imdecode(s)
                features_data[i, :] = self.train_features[self.seq[self.cur -
                                                                   1] - 1, :]
                if _data.shape[0] != self.data_shape[1]:
                    _data = mx.image.resize_short(_data, self.data_shape[1])
                if self.rand_mirror:
                    _rd = random.randint(0, 1)
                    if _rd == 1:
                        _data = mx.ndarray.flip(data=_data, axis=1)
                        features_data[i, :] = self.train_features_flip[
                            self.seq[self.cur - 1] - 1, :]
                if self.color_jittering > 0:
                    if self.color_jittering > 1:
                        _rd = random.randint(0, 1)
                        if _rd == 1:
                            _data = self.compress_aug(_data)
                    #print('do color aug')
                    _data = _data.astype('float32', copy=False)
                    #print(_data.__class__)
                    _data = self.color_aug(_data, 0.125)
                if self.nd_mean is not None:
                    _data = _data.astype('float32', copy=False)
                    _data -= self.nd_mean
                    _data *= 0.0078125
                if self.cutoff > 0:
                    _rd = random.randint(0, 1)
                    if _rd == 1:
                        #print('do cutoff aug', self.cutoff)
                        centerh = random.randint(0, _data.shape[0] - 1)
                        centerw = random.randint(0, _data.shape[1] - 1)
                        half = self.cutoff // 2
                        starth = max(0, centerh - half)
                        endh = min(_data.shape[0], centerh + half)
                        startw = max(0, centerw - half)
                        endw = min(_data.shape[1], centerw + half)
                        #print(starth, endh, startw, endw, _data.shape)
                        _data[starth:endh, startw:endw, :] = 128
                data = [_data]
                try:
                    self.check_valid_image(data)
                except RuntimeError as e:
                    logging.debug('Invalid image, skipping:  %s', str(e))
                    continue
                #print('aa',data[0].shape)
                #data = self.augmentation_transform(data)
                #print('bb',data[0].shape)
                for datum in data:
                    assert i < batch_size, 'Batch size must be multiples of augmenter output length'
                    #print(datum.shape)
                    batch_data[i][:] = self.postprocess_data(datum)
                    batch_label[i][:] = label
                    i += 1
        except StopIteration:
            if i < batch_size:
                raise StopIteration

        return io.DataBatch([batch_data], ), features_data
コード例 #24
0
    def next(self):
        if not self.is_init:
            self.reset()
            self.is_init = True
        """Returns the next batch of data."""
        #print('in next', self.cur, self.labelcur)
        self.nbatch += 1
        batch_size = self.batch_size
        c, h, w = self.data_shape
        batch_data = nd.empty((batch_size, c, h, w))  #例如10个 3*5*5,的shape
        if self.provide_label is not None:
            batch_label = nd.empty(self.provide_label[0][1])  #一维
            mix_lab = nd.empty(self.provide_label[0][1])
        i = 0
        try:
            while i < batch_size:
                label, s, bbox, landmark = self.next_sample()
                _data = self.imdecode(s)
                if _data.shape[0] != self.data_shape[1]:
                    _data = mx.image.resize_short(_data, self.data_shape[1])
                ### differ resize
                if self.rand_resize:
                    inputResize = mx.image.ForceResizeAug(size=(112, 112))

                    _r = random.randint(0, 6)  # random 0 1 2 3 4 5
                    #resolution is 30 40 50 60
                    if _r < 3:  # 20 26 32 38
                        resolutionResize = mx.image.ForceResizeAug(
                            size=(20 + 6 * _r, 20 + 6 * _r))
                        _data = resolutionResize(_data)
                        _data = inputResize(_data)

                    else:  # 24 34 44 54,
                        rr = _r
                        resolutionResize = mx.image.ForceResizeAug(
                            size=(24 + 10 * rr, 24 + 10 * rr))
                        _data = resolutionResize(_data)
                        _data = inputResize(_data)

                ###############################################
                ## all_ low resize
                # if self.rand_resize:
                # inputResize = mx.image.ForceResizeAug(size=(112,112))

                # _r=random.randint(0,6) # random 0 1 2 3 4 5
                # #resolution is 30 40 50 60
                # if _r<3:# 28 34 40
                # resolutionResize = mx.image.ForceResizeAug(size=(28+6*_r,28+6*_r))
                # _data = resolutionResize(_data)
                # _data = inputResize(_data)

                # else: # 30 40 50 60
                # rr=_r
                # resolutionResize = mx.image.ForceResizeAug(size=(30+10*(rr-3),30+10*(rr-3)))
                # _data = resolutionResize(_data)
                # _data = inputResize(_data)

                if self.rand_mirror:
                    _rd = random.randint(0, 1)
                    if _rd == 1:
                        _data = mx.ndarray.flip(data=_data, axis=1)

                if self.color_jittering > 0:
                    if self.color_jittering > 1:
                        _rd = random.randint(0, 1)
                        if _rd == 1:
                            _data = self.compress_aug(_data)
                    #print('do color aug')
                    _data = _data.astype('float32', copy=False)
                    #print(_data.__class__)
                    _data = self.color_aug(_data, 0.125)
                if self.nd_mean is not None:
                    _data = _data.astype('float32', copy=False)
                    _data -= self.nd_mean
                    _data *= 0.0078125
                if self.cutoff > 0:
                    _rd = random.randint(0, 1)
                    if _rd == 1:
                        #print('do cutoff aug', self.cutoff)
                        centerh = random.randint(0, _data.shape[0] - 1)
                        centerw = random.randint(0, _data.shape[1] - 1)
                        half = self.cutoff // 2
                        starth = max(0, centerh - half)
                        endh = min(_data.shape[0], centerh + half)
                        startw = max(0, centerw - half)
                        endw = min(_data.shape[1], centerw + half)
                        #print(starth, endh, startw, endw, _data.shape)
                        _data[starth:endh, startw:endw, :] = 128
                data = [_data]
                try:
                    self.check_valid_image(data)
                except RuntimeError as e:
                    logging.debug('Invalid image, skipping:  %s', str(e))
                    continue
                #print('aa',data[0].shape)
                #data = self.augmentation_transform(data)
                #print('bb',data[0].shape)
                for datum in data:
                    assert i < batch_size, 'Batch size must be multiples of augmenter output length'
                    #print(datum.shape)
                    batch_data[i][:] = self.postprocess_data(
                        datum)  #hwc,to chw一个一个填充到batch 大小的图片label
                    batch_label[
                        i][:] = label  #[2.0 108.0......]  #batch 个label
                    i += 1
        except StopIteration:
            if i < batch_size:
                raise StopIteration

        # add  SamplePairing  not mixup
        # if self.mix_up:  #True  alpha=0.2
        alpha = 0.4  #0.2增大到0.4
        lam = 0.4
        SP_mixup = False

        if SP_mixup and random.randint(0, 1):  # 随机一个batch mixup
            if alpha > 0.:  #lam  设置0.2中间值很少,接近0 或者1 的随机值,适合人脸混叠
                lam = np.random.beta(alpha, alpha,
                                     [batch_size, 1])  #  产生batch个随机值,
                # ...[9.98583138e-08]
                # [9.99157586e-01]
                # [2.03985548e-02]
                # [2.86226043e-01]]
            else:
                lam = np.ones((batch_size, 1))
            lam_nd = nd.array(np.tile(lam[..., None, None],
                                      [1, 3, 112, 112]))  #img size chw

            index = np.random.permutation(
                batch_size)  #打乱一个batch ,和原始batch 相当于两个batch
            mix_img = batch_data * lam_nd + batch_data[index] * (
                1 - lam_nd)  #(batch, 3, 112, 112)
            mix_img = (mix_img.astype('int', copy=False)).astype('float32',
                                                                 copy=False)
            ####label, batch_label    batch_label[]
            # batch_label[index]
            for ind in range(batch_size):
                if lam[ind] < 0.5:  #数据大于0.5  标签是a
                    mix_lab[ind][:] = batch_label[index][
                        ind]  #batch B(改变索引顺序的标签)
                else:  #>=0.5
                    if lam[ind] > 0.5:
                        mix_lab[ind][:] = batch_label[ind]  #batch   a 的标签
                    else:  #=0.5  随机标签
                        if random.randint(0, 1) == 0:
                            mix_lab[ind][:] = batch_label[ind]
                        else:
                            mix_lab[ind][:] = batch_label[index][ind]
            # print(mix_lab)
            # print("%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%>=0.5")
            # print (batch_label)
            # print("%****************************************")
            # print(mix_lab.shape)
            # print (batch_label.shape)

            return io.DataBatch([mix_img], [mix_lab], batch_size - i)

        else:
            return io.DataBatch([batch_data], [batch_label], batch_size - i)
    def next(self):
        if not self.is_init:
            self.reset()
            self.is_init = True
        self.nbatch += 1
        batch_size1 = self.batch_size1
        interclass_size = self.batchsize_id
        c, h, w = self.data_shape
        batch_data = nd.empty((batch_size1 + interclass_size, c, h, w))
        batch_data_t = nd.empty((batch_size1 + interclass_size, c, h, w))
        if self.provide_label is not None:
            batch_label = nd.empty(self.provide_label[0][1])
            batch_label_t = nd.empty(self.provide_label[0][1])
        i = 0
        try:
            while i < batch_size1:
                label, s, bbox, landmark = self.next_sample()
                _data = self.imdecode(s)
                _data = _data.astype('float32')
                _data = image.RandomGrayAug(.2)(_data)
                if random.random() < 0.2:
                    _data = image.ColorJitterAug(0.2, 0.2, 0.2)(_data)
                if self.rand_mirror:
                    _rd = random.randint(0, 1)
                    if _rd == 1:
                        _data = mx.ndarray.flip(data=_data, axis=1)
                data = [_data]
                try:
                    self.check_valid_image(data)
                except RuntimeError as e:
                    logging.debug('Invalid image, skipping:  %s', str(e))
                    continue
                for datum in data:
                    assert i < batch_size1, 'Batch size must be multiples of augmenter output length'
                    # print(datum.shape)
                    batch_data[i][:] = self.postprocess_data(datum)
                    batch_label[i][:] = label
                    i += 1
        except StopIteration:
            if i < batch_size1:
                raise StopIteration
        try:
            while i < interclass_size + batch_size1:
                label, s, bbox, landmark = self.next_sample2()
                _data = self.imdecode(s)
                _data = _data.astype('float32')
                _data = image.RandomGrayAug(.2)(_data)
                if random.random() < 0.2:
                    _data = image.ColorJitterAug(0.2, 0.2, 0.2)(_data)
                if self.rand_mirror:
                    _rd = random.randint(0, 1)
                    if _rd == 1:
                        _data = mx.ndarray.flip(data=_data, axis=1)
                data = [_data]
                try:
                    self.check_valid_image(data)
                except RuntimeError as e:
                    logging.debug('Invalid image, skipping:  %s', str(e))
                    continue
                for datum in data:
                    assert i < interclass_size + batch_size1, 'Batch size must be multiples of augmenter output length'
                    # print(datum.shape)
                    batch_data[i][:] = self.postprocess_data(datum)
                    batch_label[i][:] = label
                    i += 1
        except StopIteration:
            if i < interclass_size + batch_size1:
                raise StopIteration

        margin = batch_size1 // self.ctx_num
        for i in xrange(self.ctx_num):
            batch_data_t[2 * i * margin:(2 * i + 1) *
                         margin][:] = batch_data[i * margin:(i + 1) *
                                                 margin][:]
            batch_data_t[(2 * i + 1) * margin:2 * (i + 1) *
                         margin][:] = batch_data[batch_size1 +
                                                 i * margin:batch_size1 +
                                                 (i + 1) * margin][:]
        for i in xrange(self.ctx_num):
            batch_label_t[2 * i * margin:(2 * i + 1) *
                          margin][:] = batch_label[i * margin:(i + 1) *
                                                   margin][:]
            batch_label_t[(2 * i + 1) * margin:2 * (i + 1) *
                          margin][:] = batch_label[batch_size1 +
                                                   i * margin:batch_size1 +
                                                   (i + 1) * margin][:]
        return io.DataBatch([batch_data_t], [batch_label_t])
コード例 #26
0
    def next(self):
        if not self.is_init:
            self.reset()
            self.is_init = True
        """Returns the next batch of data."""
        self.nbatch += 1
        c, h, w = self.data_shape
        batch_data_srctar = nd.empty((2 * self.batch_size_src, c, h, w))
        batch_data_t = nd.empty((2 * self.batch_size_src, c, h, w))
        batch_label_srctar = nd.empty(self.provide_label_srctar[0][1])
        batch_label_t = nd.empty(self.provide_label_srctar[0][1])

        batch_data_tar = nd.empty((self.batch_size_src, c, h, w))
        batch_data_adv = nd.empty((self.batch_size_src, c, h, w))
        batch_data = nd.empty((3 * self.batch_size_src, c, h, w))
        batch_label = nd.empty(self.provide_label[0][1])

        #time_now1 = datetime.datetime.now()
        arg_t, aux_t = self.model.get_params()
        self.model_adv.set_params(arg_t, aux_t)
        #print("update model_adv params")
        #time_now2 = datetime.datetime.now()
        #print("update params time", time_now2-time_now1)

        i = 0
        try:
            while i < self.batch_size_src:
                label, s, bbox, landmark = self.next_sample1()
                _data = self.imdecode(s)
                if self.rand_mirror:
                    _rd = random.randint(0, 1)
                    if _rd == 1:
                        _data = mx.ndarray.flip(data=_data, axis=1)
                if self.nd_mean is not None:
                    _data = _data.astype('float32')
                    _data -= self.nd_mean
                    _data *= 0.0078125
                if self.cutoff > 0:
                    centerh = random.randint(0, _data.shape[0] - 1)
                    centerw = random.randint(0, _data.shape[1] - 1)
                    half = self.cutoff // 2
                    starth = max(0, centerh - half)
                    endh = min(_data.shape[0], centerh + half)
                    startw = max(0, centerw - half)
                    endw = min(_data.shape[1], centerw + half)
                    _data = _data.astype('float32')
                    _data[starth:endh, startw:endw, :] = 127.5
                data = [_data]
                try:
                    self.check_valid_image(data)
                except RuntimeError as e:
                    logging.debug('Invalid image, skipping:  %s', str(e))
                    continue

                for datum in data:
                    assert i < self.batch_size_src, 'Batch size must be multiples of augmenter output length'
                    batch_data_srctar[i][:] = self.postprocess_data(datum)
                    batch_label_srctar[i][:] = label
                    i += 1
        except StopIteration:
            if i < self.batch_size_src:
                raise StopIteration
        try:
            while i < 2 * self.batch_size_src:
                label, s, bbox, landmark = self.next_sample2()
                _data = self.imdecode(s)
                if self.rand_mirror:
                    _rd = random.randint(0, 1)
                    if _rd == 1:
                        _data = mx.ndarray.flip(data=_data, axis=1)
                if self.nd_mean is not None:
                    _data = _data.astype('float32')
                    _data -= self.nd_mean
                    _data *= 0.0078125
                if self.cutoff > 0:
                    centerh = random.randint(0, _data.shape[0] - 1)
                    centerw = random.randint(0, _data.shape[1] - 1)
                    half = self.cutoff // 2
                    starth = max(0, centerh - half)
                    endh = min(_data.shape[0], centerh + half)
                    startw = max(0, centerw - half)
                    endw = min(_data.shape[1], centerw + half)
                    _data = _data.astype('float32')
                    _data[starth:endh, startw:endw, :] = 127.5
                data = [_data]
                try:
                    self.check_valid_image(data)
                except RuntimeError as e:
                    logging.debug('Invalid image, skipping:  %s', str(e))
                    continue

                for datum in data:
                    assert i < 2 * self.batch_size_src, 'Batch size must be multiples of augmenter output length'
                    batch_data_srctar[i][:] = self.postprocess_data(datum)
                    batch_label_srctar[i][:] = label
                    i += 1
        except StopIteration:
            if i < 2 * self.batch_size_src:
                raise StopIteration

        #print("batch_label_srctar:", batch_label_srctar)
        margin = self.batch_size_src // self.ctx_num
        #print("margin: ",margin)
        for i in xrange(self.ctx_num):
            batch_data_t[2 * i * margin:(2 * i + 1) *
                         margin][:] = batch_data_srctar[i * margin:(i + 1) *
                                                        margin][:]
            batch_data_t[
                (2 * i + 1) * margin:2 * (i + 1) *
                margin][:] = batch_data_srctar[self.batch_size_src +
                                               i * margin:self.batch_size_src +
                                               (i + 1) * margin][:]
        for i in xrange(self.ctx_num):
            batch_label_t[2 * i * margin:(2 * i + 1) *
                          margin][:] = batch_label_srctar[i * margin:(i + 1) *
                                                          margin][:]
            batch_label_t[
                (2 * i + 1) * margin:2 * (i + 1) *
                margin][:] = batch_label_srctar[self.batch_size_src + i *
                                                margin:self.batch_size_src +
                                                (i + 1) * margin][:]

        #print("batch_label_t:", batch_label_t)

        db = mx.io.DataBatch([batch_data_t])
        self.model_adv.forward(db, is_train=True)
        ori_out = self.model_adv.get_outputs()[-1].asnumpy()
        #print("ori_dis: ", ori_out)
        self.model_adv.backward()
        grad = self.model_adv.get_input_grads()[0]
        #print("grad: ", grad)
        grad = mx.nd.array(grad)
        #print("batch_data_t: ", batch_data_t.asnumpy().shape)

        for i in xrange(self.ctx_num):
            batch_data_t[2 * i * margin:(2 * i + 1) *
                         margin][:] -= self.sigma * mx.nd.sign(
                             grad[2 * i * margin:(2 * i + 1) * margin][:])
            batch_data_t[2 * i * margin:(2 * i + 1) *
                         margin][:] = mx.nd.maximum(
                             batch_data_t[2 * i * margin:(2 * i + 1) *
                                          margin][:],
                             mx.nd.zeros_like(
                                 batch_data_t[2 * i * margin:(2 * i + 1) *
                                              margin][:]))
            batch_data_t[2 * i * margin:(2 * i + 1) *
                         margin][:] = mx.nd.minimum(
                             batch_data_t[2 * i * margin:(2 * i + 1) *
                                          margin][:],
                             255 * mx.nd.ones_like(
                                 batch_data_t[2 * i * margin:(2 * i + 1) *
                                              margin][:]))
        #print("first")
        for i in range(0, self.round - 1):
            db = mx.io.DataBatch([batch_data_t])
            self.model_adv.forward(db, is_train=True)
            adv_out = self.model_adv.get_outputs()[-1].asnumpy()
            #print("adv_dis: ", i, adv_out, np.max(adv_out))
            if np.max(adv_out) > self.thd:
                self.model_adv.backward()
                grad = self.model_adv.get_input_grads()[0]
                grad = mx.nd.array(grad)
                for i in xrange(self.ctx_num):
                    batch_data_t[2 * i * margin:(2 * i + 1) *
                                 margin][:] -= self.sigma * mx.nd.sign(
                                     grad[2 * i * margin:(2 * i + 1) *
                                          margin][:])
                    batch_data_t[2 * i * margin:(2 * i + 1) *
                                 margin][:] = mx.nd.maximum(
                                     batch_data_t[2 * i * margin:(2 * i + 1) *
                                                  margin][:],
                                     mx.nd.zeros_like(
                                         batch_data_t[2 * i *
                                                      margin:(2 * i + 1) *
                                                      margin][:]))
                    batch_data_t[2 * i * margin:(2 * i + 1) *
                                 margin][:] = mx.nd.minimum(
                                     batch_data_t[2 * i * margin:(2 * i + 1) *
                                                  margin][:],
                                     255 * mx.nd.ones_like(
                                         batch_data_t[2 * i *
                                                      margin:(2 * i + 1) *
                                                      margin][:]))
            else:
                #print("adv_dis: ", i)
                break
        db = mx.io.DataBatch([batch_data_t])
        self.model_adv.forward(db, is_train=True)
        adv_out = self.model_adv.get_outputs()[-1].asnumpy()
        #print("adv_dis: ", adv_out)

        #imgadv_show = np.squeeze(batch_data_t[0][:].asnumpy())
        #imgadv_show = imgadv_show.astype(np.uint8)
        # print("imgadv_show.type: ", imgadv_show.astype)
        #imgadv_show = np.transpose(imgadv_show, (1, 2, 0))
        #plt.imshow(imgadv_show)
        #plt.show()

        for i in xrange(self.ctx_num):
            batch_data_adv[i * margin:(i + 1) *
                           margin][:] = batch_data_t[2 * i *
                                                     margin:(2 * i + 1) *
                                                     margin][:]

        batch_data_src = batch_data_srctar[0:self.batch_size_src][:]
        batch_data_tar = batch_data_srctar[self.batch_size_src:2 *
                                           self.batch_size_src][:]

        #for i in xrange(self.ctx_num):
        #    batch_data_tar[i * margin: (i + 1) * margin][:] = batch_data_t[(2 * i + 1) * margin:2 * (i + 1) * margin][:]

        batch_label_src = batch_label_srctar[0:self.batch_size_src][:]
        batch_label_tar = batch_label_srctar[self.batch_size_src:2 *
                                             self.batch_size_src][:]
        #print("labels: " , batch_label_src , batch_label_tar)

        margin = self.batch_size_src // self.main_ctx_num  # 30
        for i in xrange(self.main_ctx_num):  # 0 1 2 3
            batch_data[margin * 3 * i:margin * 3 * i +
                       margin][:] = batch_data_src[margin * i:margin * i +
                                                   margin][:]
            batch_data[margin * 3 * i + margin:margin * 3 * i +
                       2 * margin][:] = batch_data_tar[margin * i:margin * i +
                                                       margin][:]
            batch_data[margin * 3 * i + 2 * margin:margin * 3 * i +
                       3 * margin][:] = batch_data_adv[margin * i:margin * i +
                                                       margin][:]

        for i in xrange(self.main_ctx_num):
            batch_label[margin * 3 * i:margin * 3 * i +
                        margin][:] = batch_label_src[margin * i:margin * i +
                                                     margin][:]
            batch_label[margin * 3 * i + margin:margin * 3 * i +
                        2 * margin][:] = batch_label_tar[margin *
                                                         i:margin * i +
                                                         margin][:]
            batch_label[margin * 3 * i + 2 * margin:margin * 3 * i +
                        3 * margin][:] = batch_label_src[margin *
                                                         i:margin * i +
                                                         margin][:]

        #print("batch labels: ", batch_label)
        return io.DataBatch([batch_data], [batch_label])
コード例 #27
0
    def next(self):
        if not self.is_init:
            self.reset()
            self.is_init = True
        """Returns the next batch of data."""
        #print('in next', self.cur, self.labelcur)
        self.nbatch += 1
        batch_size = self.batch_size
        c, h, w = self.data_shape
        batch_data = nd.empty((batch_size, c, h, w))
        if self.provide_label is not None:
            batch_label = nd.empty(self.provide_label[0][1])
        i = 0
        try:
            while i < batch_size:
                label, s, bbox, landmark = self.next_sample()
                _data = self.imdecode(s)
                if self.rand_mirror:
                    _rd = random.randint(0, 1)
                    if _rd == 1:
                        _data = mx.ndarray.flip(data=_data, axis=1)
                if self.nd_mean is not None:
                    _data = _data.astype('float32')
                    _data -= self.nd_mean
                    _data *= 0.0078125
                #_npdata = _data.asnumpy()
                #if landmark is not None:
                #  _npdata = face_preprocess.preprocess(_npdata, bbox = bbox, landmark=landmark, image_size=self.image_size)
                #if self.rand_mirror:
                #  _npdata = self.mirror_aug(_npdata)
                #if self.mean is not None:
                #  _npdata = _npdata.astype(np.float32)
                #  _npdata -= self.mean
                #  _npdata *= 0.0078125
                #nimg = np.zeros(_npdata.shape, dtype=np.float32)
                #nimg[self.patch[1]:self.patch[3],self.patch[0]:self.patch[2],:] = _npdata[self.patch[1]:self.patch[3], self.patch[0]:self.patch[2], :]
                #_data = mx.nd.array(nimg)
                data = [_data]
                try:
                    self.check_valid_image(data)
                except RuntimeError as e:
                    logging.debug('Invalid image, skipping:  %s', str(e))
                    continue
                #print('aa',data[0].shape)
                #data = self.augmentation_transform(data)
                #print('bb',data[0].shape)
                for datum in data:
                    assert i < batch_size, 'Batch size must be multiples of augmenter output length'
                    #print(datum.shape)
                    batch_data[i][:] = self.postprocess_data(datum)
                    if self.provide_label is not None:
                        if not self.coco_mode:
                            if len(batch_label.shape) == 1:
                                batch_label[i][:] = label
                            else:
                                for ll in xrange(batch_label.shape[1]):
                                    v = label[ll]
                                    if ll > 0:
                                        c2c = v
                                        #m = min(0.55, max(0.3,math.log(c2c+1)*4-1.85))
                                        #v = math.cos(m)
                                        #v = v*v
                                        #_param = [0.5, 0.3, 0.85, 0.7]
                                        _param = [0.5, 0.4, 0.85, 0.75]
                                        #_param = [0.55, 0.4, 0.9, 0.75]
                                        _a = (_param[1] - _param[0]) / (
                                            _param[3] - _param[2])
                                        m = _param[1] + _a * (c2c - _param[3])
                                        m = min(_param[0], max(_param[1], m))
                                        #m = 0.5
                                        #if c2c<0.77:
                                        #  m = 0.3
                                        #elif c2c<0.82:
                                        #  m = 0.4
                                        #elif c2c>0.88:
                                        #  m = 0.55
                                        v = math.cos(m)
                                        v = v * v
                                        #print('c2c', i,c2c,m,v)

                                    batch_label[i][ll] = v
                        else:
                            batch_label[i][:] = (i % self.per_batch_size
                                                 ) // self.images_per_identity
                    i += 1
        except StopIteration:
            if i < batch_size:
                raise StopIteration

        #print('next end', batch_size, i)
        _label = None
        if self.provide_label is not None:
            _label = [batch_label]
        if self.data_extra is not None:
            return io.DataBatch([batch_data, self.data_extra], _label,
                                batch_size - i)
        else:
            return io.DataBatch([batch_data], _label, batch_size - i)
コード例 #28
0
    def next(self):
        if not self.is_init:
            self.reset()
            self.is_init = True
        """Returns the next batch of data."""
        # print('in next', self.cur, self.labelcur)
        self.nbatch += 1
        batch_size = self.batch_size
        c, h, w = self.data_shape
        batch_data = nd.empty((batch_size, c, h, w))
        if self.provide_label is not None:
            batch_label = nd.empty(self.provide_label[0][1])
        i = 0
        try:
            while i < batch_size:
                label, s, bbox, landmark = self.next_sample()
                _data = self.imdecode(s)
                if _data.shape[0] != self.data_shape[1]:
                    _data = mx.image.resize_short(_data, self.data_shape[1])
                if self.rand_mirror:
                    _rd = random.randint(0, 5)  # change sai
                    if _rd == 1:
                        _data = mx.ndarray.flip(data=_data, axis=1)
                if self.color_jittering > 0:
                    if self.color_jittering > 1:
                        _rd = random.randint(0, 1)
                        if _rd == 1:
                            _data = self.compress_aug(_data)
                    _rd = random.randint(0, 5)  # change sai
                    if _rd == 1:
                        _data = _data.astype('float32', copy=False)
                        _data = self.color_aug(_data, 0.125)
                if self.nd_mean is not None:
                    _data = _data.astype('float32', copy=False)
                    _data -= self.nd_mean
                    _data *= 0.0078125
                if self.cutoff > 0:
                    _rd = random.randint(0, 10)  # change sai
                    if _rd == 1:
                        rate = random.randint(1, 5)
                        xmin = rate
                        ymin = rate
                        xmax = int(self.data_shape[1]) - rate
                        ymax = int(self.data_shape[2]) - rate
                        _data = _data[ymin:ymax, xmin:xmax, :]
                        _data = mx.image.resize_short(_data,
                                                      self.data_shape[1])
                if self.shelter:
                    _rd = random.randint(0, 10)  # change sai
                    if _rd == 1:
                        # change sai
                        xmin = random.randint(15, 100)
                        ymin = random.randint(15, 100)
                        xmax = xmin + random.randint(5, 10)
                        ymax = ymin + random.randint(5, 10)
                        _data = _data.astype('float32')
                        _data[ymin:ymax,
                              xmin:xmax, :] = (random.randint(0, 255),
                                               random.randint(0, 255),
                                               random.randint(0, 255))
                data = [_data]
                try:
                    self.check_valid_image(data)
                except RuntimeError as e:
                    logging.debug('Invalid image, skipping:  %s', str(e))
                    continue
                # print('aa',data[0].shape)
                # data = self.augmentation_transform(data)
                # print('bb',data[0].shape)
                for datum in data:
                    assert i < batch_size, 'Batch size must be multiples of augmenter output length'
                    # print(datum.shape)
                    batch_data[i][:] = self.postprocess_data(datum)
                    batch_label[i][:] = label
                    i += 1
        except StopIteration:
            if i < batch_size:
                raise StopIteration

        return io.DataBatch([batch_data], [batch_label], batch_size - i)
コード例 #29
0
    def next(self):
        if not self.is_init:
            self.reset()
            self.is_init = True
            pass

        self.n_batch += 1

        batch_size = self.batch_size
        c, h, w = self.data_shape
        batch_data = nd.empty((batch_size, c, h, w))
        batch_label = nd.empty(self.provide_label[0][1])

        i = 0

        try:
            while i < batch_size:
                label, image_str = self.next_sample()
                # image_str ---> NDArray 112x112x3
                # 可以 image_arr.asnumpy() ---> numpy array type
                # plt.imshow() 来展示
                image_arr = mx.image.imdecode(image_str)

                if image_arr.shape[0] != image_arr.shape[1]:
                    image_arr = mx.image.resize_short(image_arr,
                                                      self.data_shape[1])
                    pass

                # 镜像翻转
                if self.rand_mirror:
                    rand_int = np.random.randint(0, 2)
                    if rand_int == 1:
                        image_arr = mx.ndarray.flip(data=image_arr, axis=1)
                        pass
                    pass

                if self.color_jitter > 0:
                    if self.color_jitter > 1:
                        rand_int = np.random.randint(0, 2)
                        if rand_int == 1:
                            # 精简增强
                            image_arr = self.compress_aug(image_arr)
                            pass
                        pass

                    # 将 像素转为 float32
                    image_arr = image_arr.astype("float32", copy=False)
                    # 颜色增强
                    image_arr = self.color_jitter_aug(image_arr)
                    pass

                if self.nd_mean is not None:
                    image_arr = image_arr.astype('float32', copy=False)
                    image_arr -= self.nd_mean
                    image_arr *= 0.0078125
                    pass

                # 随机裁剪
                if self.cutoff > 0:
                    rand_int = np.random.randint(0, 2)
                    if rand_int == 1:
                        center_h = np.random.randint(0, image_arr.shape[0])
                        center_w = np.random.randint(0, image_arr.shape[1])
                        half = self.cutoff // 2
                        start_h = max(0, center_h - half)
                        end_h = min(image_arr.shape[0], center_h + half)
                        start_w = max(0, center_w - half)
                        end_w = min(image_arr.shape[1], center_w + half)
                        image_arr[start_h:end_h, start_w:end_w, :] = 128
                        pass
                    pass

                image_data = [image_arr]

                try:
                    # 检测图像数据
                    self.check_valid_image(image_data)
                    pass
                except RuntimeError as e:
                    print("Invalid image, skipping: {}".format(e))
                    continue
                    pass

                for image_info in image_data:
                    assert i < batch_size, 'Batch size must be multiples of augmenter output length'

                    # [height, width, channel] ---> [channel, height, width]
                    batch_data[i][:] = self.post_process_data(image_info)
                    batch_label[i][:] = label
                    i += 1
                    pass
                pass
            pass
        except StopIteration:
            if i < batch_size:
                raise StopIteration
            pass

        return io.DataBatch([batch_data], [batch_label], batch_size - i)
        pass
コード例 #30
0
    def next(self):
        if not self.is_init:
            self.reset()
            self.is_init = True
        """Returns the next batch of data."""
        #print('in next', self.cur, self.labelcur)
        self.nbatch += 1
        batch_size = self.batch_size
        c, h, w = self.data_shape
        batch_data = nd.empty((batch_size, c, h, w))
        if self.provide_label is not None:
            batch_label = nd.empty(self.provide_label[0][1])
        i = 0
        try:
            while i < batch_size:
                label, s, bbox, landmark = self.next_sample()
                _data = self.imdecode(s)
                #  if _data.shape[0]!=self.data_shape[1]:
                #  _data = mx.image.resize_short(_data, self.data_shape[1])
                # add by yan, 1127
                _data, _ = mx.image.random_crop(_data, (w, h))
                if self.rand_mirror:
                    _rd = random.randint(0, 1)
                    if _rd == 1:
                        _data = mx.ndarray.flip(data=_data, axis=1)
                if self.color_jittering > 0:
                    if self.color_jittering > 1:
                        _rd = random.randint(0, 1)
                        if _rd == 1:
                            _data = self.compress_aug(_data)
                    #print('do color aug')
                    _data = _data.astype('float32', copy=False)
                    #print(_data.__class__)
                    _data = self.color_aug(_data, 0.125)
                if self.nd_mean is not None:
                    #  print('mean is not none')
                    _data = _data.astype('float32', copy=False)
                    _data -= self.nd_mean
                    _data *= 0.0078125
                if self.cutoff > 0:
                    _rd = random.randint(0, 1)
                    if _rd == 1:
                        #print('do cutoff aug', self.cutoff)
                        centerh = random.randint(0, _data.shape[0] - 1)
                        centerw = random.randint(0, _data.shape[1] - 1)
                        half = self.cutoff // 2
                        starth = max(0, centerh - half)
                        endh = min(_data.shape[0], centerh + half)
                        startw = max(0, centerw - half)
                        endw = min(_data.shape[1], centerw + half)
                        #print(starth, endh, startw, endw, _data.shape)
                        _data[starth:endh, startw:endw, :] = 128
                data = [_data]
                #  print ('data is :', data)
                try:
                    self.check_valid_image(data)
                except RuntimeError as e:
                    logging.debug('Invalid image, skipping:  %s', str(e))
                    continue
                #print('aa',data[0].shape)
                #data = self.augmentation_transform(data)
                #print('bb',data[0].shape)
                for datum in data:
                    assert i < batch_size, 'Batch size must be multiples of augmenter output length'
                    #print(datum.shape)
                    batch_data[i][:] = self.postprocess_data(datum)
                    batch_label[i][:] = label[0]
                    i += 1
        except StopIteration:
            if i < batch_size:
                raise StopIteration

        #  for i in range(len(batch_label)):
        #      #  print('for debug , batch label is :{}, shape is {}'.format(batch_label, batch_label.shape))
        #      if 0 < batch_label[i][0] and batch_label[i][0] < 20:
        #          print('in the middle')
        #          #return nd.transpose(datum, axes=(2, 0, 1))
        #          print('image tmp: ', batch_data[i].shape)
        #          tmp_data = nd.transpose(batch_data[i], axes=(2, 1, 0))
        #          print('image data shape is: ', tmp_data.shape)
        #          if not os.path.exists('./debug-rec/{}'.format(str(i))):
        #              os.mkdir('./debug-rec/{}'.format(str(i)))
        #          #  from numpy import random
        #          img_name = random.randint(0, 10000)
        #          tmp_img = cv2.imwrite('./debug-rec/{}/{}.jpg'.format(str(i), img_name), tmp_data.asnumpy())
        return io.DataBatch([batch_data], [batch_label], batch_size - i)