Exemplo n.º 1
0
def get_iterators(data_root, data_shape, batch_size, TorV):
    print('Loading Data....')
    if TorV == 'train':
        batch_iter = image.ImageDetIter(
            batch_size=batch_size,
            data_shape=(3, data_shape[0], data_shape[1]),
            path_imgrec=os.path.join(data_root, TorV+'.rec'),
            path_imgidx=os.path.join(data_root, TorV+'.idx'),
            shuffle=True,
            brightness=0.5, 
            contrast=0.2, 
            saturation=0.5, 
            hue=1.0,
            )
    elif TorV == 'valid':
        batch_iter = image.ImageDetIter(
            batch_size=1,
            data_shape=(3, data_shape[0], data_shape[1]),
            path_imgrec=os.path.join(data_root, 'train.rec'),
            path_imgidx=os.path.join(data_root, 'train.idx'),
            shuffle=True,
            #rand_crop=0.2,
            #rand_pad=0.2,
            #area_range=(0.8, 1.2),
            brightness=0.2, 
            #contrast=0.2, 
            saturation=0.5, 
            #hue=1.0,
            )
    else:
        batch_iter = None
    return batch_iter
Exemplo n.º 2
0
def get_iterators(data_shape, batch_size, data_dir):
    """  获取训练集和验证集迭代器  """
    # lable名和label数
    class_names = ['pikachu']
    num_classes = len(class_names)

    # 训练集迭代器
    train_iter = image.ImageDetIter(
        batch_size=batch_size,
        data_shape=(3, data_shape, data_shape),
        path_imgrec=data_dir+'train.rec',
        path_imgidx=data_dir+'train.idx',
        shuffle=True,
        mean=True,
        rand_crop=1,
        min_object_covered=0.95,
        max_attempts=200
    )

    # 验证集迭代器
    val_iter = image.ImageDetIter(
        batch_size=batch_size,
        data_shape=(3, data_shape, data_shape),
        path_imgrec=data_dir+'val.rec',
        shuffle=False,
        mean=True,
    )
    return train_iter, val_iter, class_names, num_classes
Exemplo n.º 3
0
def load_data_pikachu(batch_size, edge_size=256):  # edge_size:输出图像的宽和高
    train_iter = image.ImageDetIter(
        path_imgrec=os.path.join(data_dir, 'train.rec'),
        path_imgidx=os.path.join(data_dir, 'train.idx'),
        batch_size=batch_size,
        data_shape=(3, edge_size, edge_size),  # 输出图像的形状
#         shuffle=False,  # 以随机顺序读取数据集
#         rand_crop=1,  # 随机裁剪的概率为1
        min_object_covered=0.95, max_attempts=200)
    val_iter = image.ImageDetIter(
        path_imgrec=os.path.join(data_dir, 'val.rec'), batch_size=batch_size,
        data_shape=(3, edge_size, edge_size), shuffle=False)
    return train_iter, val_iter
Exemplo n.º 4
0
def load_data_uav(data_dir = '../data/uav', batch_size=4, edge_size=256):
    # _download_pikachu(data_dir)
    train_iter = image.ImageDetIter(
        path_imgrec=os.path.join(data_dir, 'train.rec'),
        path_imgidx = os.path.join(data_dir, 'train.idx'),
        batch_size=batch_size,
        data_shape=(3, edge_size, edge_size),  # 输出图像的形状
        shuffle=True,  # 以随机顺序读取数据集
        rand_crop=1,  # 随机裁剪的概率为1
        min_object_covered=0.95, max_attempts=200)
    val_iter = image.ImageDetIter(
        path_imgrec=os.path.join(data_dir, 'val.rec'), batch_size=batch_size,
        data_shape=(3, edge_size, edge_size), shuffle=False)
    return train_iter, val_iter
def load_data_pikaku(batch_size, edg_size=256):
    data_dir = './data/pikaku'
    download_pikaku(data_dir)
    train_iter = image.ImageDetIter(path_imgrec=os.path.join(data_dir,'train.rec'),
                                    path_imgidx=os.path.join(data_dir,'train.idx'),
                                    batch_size=batch_size,
                                    data_shape=(3, edg_size, edg_size),
                                    shuffle=True,
                                    rand_crop=1,
                                    min_object_covered=0.95,
                                    max_attempts=200)
    val_iter = image.ImageDetIter(path_imgrec=os.path.join(data_dir,'val.rec'),batch_size=batch_size,
                                  data_shape=(3,edg_size,edg_size),shuffle=False)
    return train_iter, val_iter
Exemplo n.º 6
0
def load_data_pikachu(batch_size, edge_size=256):
    data_dir = '../data/pikachu'
    _download_pikachu(data_dir)
    train_iter = image.ImageDetIter(
        path_imgrec=os.path.join(data_dir, 'train.rec'),
        path_imgidx=os.path.join(data_dir, 'train.idx'),
        batch_size=batch_size,
        data_shape=(3, edge_size, edge_size),  # The shape of the output image
        shuffle=True,  # Read the data set in random order
        rand_crop=1,  # The probability of random cropping is 1
        min_object_covered=0.95, max_attempts=200)
    val_iter = image.ImageDetIter(
        path_imgrec=os.path.join(data_dir, 'val.rec'), batch_size=batch_size,
        data_shape=(3, edge_size, edge_size), shuffle=False)
    return train_iter, val_iter
Exemplo n.º 7
0
def load_data_pikachu(batch_size, edge_size=256):  # edge_size:输出图像的宽和高。
    data_dir = '~/.mxnet/datasets/pikachu'
    _maybe_download_pikachu(data_dir)
    train_iter = image.ImageDetIter(
        path_imgrec=os.path.join(data_dir, 'train.rec'),
        path_imgidx=os.path.join(data_dir, 'train.idx'),
        batch_size=batch_size,
        data_shape=(3, edge_size, edge_size),  # 输出图像的形状。
        shuffle=True,  # 以随机顺序读取数据集。
        rand_crop=1,  # 随机裁剪的概率为 1。
        min_object_covered=0.95, max_attempts=200)
    val_iter = image.ImageDetIter(
        path_imgrec=os.path.join(data_dir, 'val.rec'), batch_size=batch_size,
        data_shape=(3, edge_size, edge_size), shuffle=False)
    return train_iter, val_iter
Exemplo n.º 8
0
    def __init__(self):
        self.net = JanetRes(classes=self.classes, use_bn=True)
        self.net.initialize(ctx=self.ctx)

        self.trainIter = image.ImageDetIter(
            batch_size=self.BATCH_SIZE,
            data_shape=(3, 300, 300),
            path_imgrec='../DataX/annoTrainX.rec',
            path_imgidx='../DataX/annoTrainX.idx',
            path_imglist='../DataX/annoTrainX.lst',
            path_root='../DataX/',
            shuffle=True,
            mean=True,
            brightness=0.3,
            contrast=0.3,
            saturation=0.3,
            pca_noise=0.3,
            hue=0.3)

        with autograd.train_mode():
            _, _, anchors = self.net(
                mx.ndarray.zeros(shape=(self.BATCH_SIZE, 3, 300, 300),
                                 ctx=self.ctx))
        self.T = TargetGenV1(anchors=anchors.as_in_context(mx.cpu()),
                             height=300,
                             width=300)

        self.net.collect_params().reset_ctx(self.ctx)
        self.trainer = gluon.Trainer(self.net.collect_params(), 'sgd', {
            'learning_rate': 0.1,
            'wd': 5e-4
        })
def get_iterators(data_shape, batch_size):
    train_iter = image.ImageDetIter(batch_size=batch_size,
                                    data_shape=(3, data_shape, data_shape),
                                    path_imgrec=data_dir + 'train.rec',
                                    path_imgidx=data_dir + 'train.idx',
                                    shuffle=True,
                                    mean=True,
                                    rand_crop=1,
                                    min_object_covered=0.95,
                                    max_attempts=200)
    val_iter = image.ImageDetIter(batch_size=batch_size,
                                  data_shape=(3, data_shape, data_shape),
                                  path_imgrec=data_dir + 'val.rec',
                                  shuffle=False,
                                  mean=True)
    return train_iter, val_iter, class_names, num_class
Exemplo n.º 10
0
def load_data_pikachu(batch_size, edge_size=256):
    data_dir = '../data/pikachu/'
    _download_pikachu(data_dir)
    train_data = image.ImageDetIter(path_imgrec=data_dir + 'train.rec',
                                    path_imgidx=data_dir + 'train.idx',
                                    batch_size=batch_size,
                                    data_shape=(3, edge_size, edge_size),
                                    shuffle=True,
                                    rand_crop=1,
                                    min_object_covered=0.95,
                                    max_attempts=200)
    val_data = image.ImageDetIter(path_imgrec=data_dir + 'val.rec',
                                  batch_size=batch_size,
                                  data_shape=(3, edge_size, edge_size),
                                  shuffle=False)
    return (train_data, val_data)
Exemplo n.º 11
0
    def __init__(self, batchSize):
        self.net = Brunette(classes=self.classes)
        self.net.initialize(mx.init.Xavier(magnitude=2), ctx=self.ctx)
        self.batchSize = batchSize
        self.trainIter = image.ImageDetIter(batch_size=self.batchSize,
                                            data_shape=(3, 300, 300),
                                            path_imgrec='utils/TrainY.rec',
                                            path_imgidx='utils/TrainY.idx',
                                            shuffle=True,
                                            mean=True,
                                            brightness=0.3,
                                            contrast=0.3,
                                            saturation=0.3,
                                            pca_noise=0.3,
                                            hue=0.3)

        # with autograd.train_mode():
        #     _, _, anchors = self.net(mx.ndarray.zeros(shape=(self.batchSize, 3, 300, 300), ctx=self.ctx))

        # self.T = TargetGenV2(anchors=anchors.as_in_context(mx.cpu()), height=300, width=300)
        print("2")
        self.net.collect_params().reset_ctx(self.ctx)
        self.trainer = gluon.Trainer(self.net.collect_params(), 'sgd', {
            'learning_rate': 0.1,
            'wd': 3e-4
        })
Exemplo n.º 12
0
def load_data_test(batch_size, data_dir, fname):
    train_iter = image.ImageDetIter(
        path_imgrec=os.path.join(data_dir, '%s.rec' % fname),
        path_imgidx=os.path.join(data_dir, '%s.idx' % fname),
        batch_size=batch_size,
        shuffle=True,  # Read the dataset in random order
        data_shape=(config.channel, config.img_h, config.img_w))
    return train_iter
Exemplo n.º 13
0
    def __init__(self, batchSize):
        self.net = SSDv1()
        self.net.initialize(init=Xavier(), ctx=self.ctx)
        self.batchSize = batchSize
        self.trainIter = image.ImageDetIter(batch_size=self.batchSize, data_shape=(3, 300, 300),
                                            path_imgrec='utils/TrainY.rec', path_imgidx='utils/TrainY.idx',)

        self.trainer = gluon.Trainer(self.net.collect_params(), 'sgd', {'learning_rate': 0.1, 'wd': 5e-4})
Exemplo n.º 14
0
def load_data_pikachu(batch_size, edge_size):
    data_dir = "../data/pikachu"
    train_iter = image.ImageDetIter(
        path_imgrec=os.path.join(data_dir, "train.rec"),
        path_imgidx=os.path.join(data_dir, "train.idx"),
        batch_size=batch_size,
        data_shape=(3, edge_size, edge_size),
        shuffle=True,
        rand_crop=1,
        min_object_covered=0.95,
        max_attempts=200)
    val_iter = image.ImageDetIter(path_imgrec=os.path.join(
        data_dir, "val.rec"),
                                  batch_size=batch_size,
                                  data_shape=(3, edge_size, edge_size),
                                  shuffle=False)
    return train_iter, val_iter
Exemplo n.º 15
0
def load_data_pikachu(batch_size, edge_size=256):
    """Download the pikachu dataest and then load into memory."""
    data_dir = '../data/pikachu/'
    _download_pikachu(data_dir)
    train_iter = image.ImageDetIter(path_imgrec=data_dir + 'train.rec',
                                    path_imgidx=data_dir + 'train.idx',
                                    batch_size=batch_size,
                                    data_shape=(3, edge_size, edge_size),
                                    shuffle=True,
                                    rand_crop=1,
                                    min_object_covered=0.95,
                                    max_attempts=200)
    val_iter = image.ImageDetIter(path_imgrec=data_dir + 'val.rec',
                                  batch_size=batch_size,
                                  data_shape=(3, edge_size, edge_size),
                                  shuffle=False)
    return train_iter, val_iter
Exemplo n.º 16
0
def load_data_uav(data_dir='../data/uav', batch_size=4, edge_size=256):
    # _download_pikachu(data_dir)
    train_iter = image.ImageDetIter(
        path_imgrec=os.path.join(data_dir, 'train.rec'),
        path_imgidx=os.path.join(data_dir, 'train.idx'),
        batch_size=batch_size,
        data_shape=(3, edge_size, edge_size),  # shape of output image
        shuffle=True,  # read data in randomly
        rand_crop=1,  # prob of cropping randomly
        min_object_covered=0.95,
        max_attempts=200)
    val_iter = image.ImageDetIter(path_imgrec=os.path.join(
        data_dir, 'val.rec'),
                                  batch_size=batch_size,
                                  data_shape=(3, edge_size, edge_size),
                                  shuffle=False)
    return train_iter, val_iter
Exemplo n.º 17
0
def get_iterators(data_shape, batch_size):
    class_names = ['pikachu']
    num_class = len(class_names)
    train_iter = image.ImageDetIter(batch_size=batch_size,
                                    data_shape=(3, data_shape, data_shape),
                                    path_imgrec='./data/pikachu_train.rec',
                                    path_imgidx='./data/pikachu_train.idx',
                                    shuffle=True,
                                    mean=True,
                                    rand_crop=1,
                                    min_object_covered=0.95,
                                    max_attempts=200)
    val_iter = image.ImageDetIter(batch_size=batch_size,
                                  data_shape=(3, data_shape, data_shape),
                                  path_imgrec='./data/pikachu_val.rec',
                                  shuffle=False,
                                  mean=True)
    return train_iter, val_iter, class_names, num_class
Exemplo n.º 18
0
def get_iterators(data_shape, batch_size):
    data_dir = '/Users/liudiwen/.mxnet/datasets/data/pikachu/'
    class_names = ['pikachu']
    num_class = len(class_names)
    train_iter = image.ImageDetIter(batch_size=batch_size,
                                    data_shape=(3, data_shape, data_shape),
                                    path_imgrec=data_dir + 'train.rec',
                                    path_imgidx=data_dir + 'train.idx',
                                    shuffle=True,
                                    mean=True,
                                    rand_crop=1,
                                    min_object_covered=0.95,
                                    max_attempts=200)
    val_iter = image.ImageDetIter(batch_size=batch_size,
                                  data_shape=(3, data_shape, data_shape),
                                  path_imgrec=data_dir + 'val.rec',
                                  shuffle=False,
                                  mean=True)
    return train_iter, val_iter, class_names, num_class
Exemplo n.º 19
0
def get_iterators(data_shape, batch_size):
    #    class_names = ["RBC", "WBC", "Platelets"]
    class_names = ["HUMAN", "WBC", "Platelets"]
    num_class = len(class_names)
    train_iter = image.ImageDetIter(batch_size=batch_size,
                                    data_shape=(3, data_shape[0],
                                                data_shape[1]),
                                    path_imgrec=data_dir + 'train.rec',
                                    path_imgidx=data_dir + 'train.idx',
                                    shuffle=True,
                                    mean=True,
                                    rand_crop=1,
                                    min_object_covered=0.95,
                                    max_attempts=200)
    val_iter = image.ImageDetIter(batch_size=batch_size,
                                  data_shape=(3, data_shape[0], data_shape[1]),
                                  path_imgrec=data_dir + 'val.rec',
                                  shuffle=False,
                                  mean=True)
    return train_iter, val_iter, class_names, num_class
Exemplo n.º 20
0
def data_iterator(batch_size, data_shape, data_root):
    #os.chdir(data_root)
    train_iter = image.ImageDetIter(batch_size=batch_size,
                                    data_shape=(3, data_shape, data_shape),
                                    path_imgrec=data_root + '/train.rec',
                                    path_imgidx=data_root + '/train.idx',
                                    shuffle=True,
                                    mean=True,
                                    std=True,
                                    rand_crop=1,
                                    min_object_covered=0.9,
                                    max_attempts=200)
    val_iter = image.ImageDetIter(batch_size=batch_size,
                                  data_shape=(3, data_shape, data_shape),
                                  path_imgrec=data_root + '/val.rec',
                                  path_imgidx=data_root + '/val.idx',
                                  shuffle=False,
                                  mean=True,
                                  std=True)
    return train_iter, val_iter
Exemplo n.º 21
0
def load_data_pikachu(batch_size, edge_size=256):
    """Download the pikachu dataest and then load into memory."""
    data_dir = '../data/pikachu'
    _download_pikachu(data_dir)
    #aug = image.CreateDetAugmenter(data_shape  = (3,edge_size,edge_size),mean=True,std= True)
    train_iter = image.ImageDetIter(
        path_imgrec=os.path.join(data_dir, 'train.rec'),
        path_imgidx=os.path.join(data_dir, 'train.idx'),
        batch_size=batch_size,
        data_shape=(3, edge_size, edge_size),
        shuffle=True,
        rand_crop=1,
        min_object_covered=0.95,
        max_attempts=200,
        )
    val_iter = image.ImageDetIter(
        path_imgrec=os.path.join(data_dir, 'val.rec'),
        batch_size=batch_size,
        data_shape=(3, edge_size, edge_size),
        shuffle=False,
       )
    return train_iter, val_iter
Exemplo n.º 22
0
 def load_dataset(self,dataName):
     if dataName == "toy":
        self.trainIter = image.ImageDetIter(
            batch_size = self.batchSize,
            data_shape = (3, self.dataShape, self.dataShape),
            path_imgrec = os.path.join( os.path.join(self.dataRoot,dataName), 'train.rec'),
            path_imgidx = os.path.join( os.path.join(self.dataRoot,dataName), 'train.idx'),
            shuffle = True,
            mean = True,
            std = True,
            rand_crop = 1,
            min_object_covered = 0.95,
            max_attempts = 200)
        self.validIter = image.ImageDetIter(
            batch_size=self.batchSize,
            data_shape=(3,self.dataShape,self.dataShape),
            path_imgrec=os.path.join( os.path.join(self.dataRoot,dataName), 'val.rec'),
            shuffle=False, mean=True, std=True)
        self.classNames = 'pikachu,dummy'.split(',')
        self.numClass = len(self.classNames)
        return True
     return False
Exemplo n.º 23
0
def get_iterators(rec_prefix, data_shape, batch_size):
    class_names = ['papercup']
    num_class = len(class_names)
    train_iter = image.ImageDetIter(
        batch_size=batch_size,
        data_shape=data_shape,
        path_imgrec=rec_prefix + '_train.rec',
        path_imgidx=rec_prefix + '_train.idx',
        aug_list=None,
        shuffle=True,
        mean=True,
        std=True,
        rand_crop=1,
        rand_gray=0.2,
        rand_mirror=True,
        rand_pad=0.4,
        pad_val=(rgb_mean[0], rgb_mean[1], rgb_mean[2]),
        min_object_covered=0.95,
        max_attempts=200,
        brightness=0.2,
        contrast=0.2,
        saturation=0.2,
        hue=0.05,
        aspect_ratio_range=(0.8, 1.2),
        # pca_noise=0.01,
    )

    valid_iter = image.ImageDetIter(
        batch_size=batch_size,
        data_shape=data_shape,
        path_imgrec=rec_prefix + '_val.rec',
        shuffle=False,
        # mean=True,
        # std=True
    )

    return train_iter, valid_iter, class_names, num_class
Exemplo n.º 24
0
def load_ImageDetIter(path, batch_size, h, w):
    print('Loading ImageDetIter ' + path)
    batch_iter = image.ImageDetIter(batch_size, (3, h, w),
        path_imgrec=path+'.rec',
        path_imgidx=path+'.idx',
        shuffle=True,
        pca_noise=0.1, 
        brightness=0.5,
        saturation=0.5, 
        contrast=0.5, 
        hue=1.0
        #rand_crop=0.2,
        #rand_pad=0.2,
        #area_range=(0.8, 1.2),
        )
    return batch_iter
Exemplo n.º 25
0
def get_iterators(data_shape, batch_size):
    train_iter = image.ImageDetIter(batch_size=batch_size,
                                    data_shape=(3, data_shape, data_shape),
                                    path_imgrec='./REC_Data/voc2012.rec',
                                    path_imgidx='./REC_Data/voc2012.idx',
                                    shuffle=True,
                                    mean=True,
                                    rand_crop=0,
                                    min_object_covered=0.95,
                                    max_attempts=200)
    # val_iter = image.ImageDetIter(
    #     batch_size=batch_size,
    #     data_shape=(3, data_shape, data_shape),
    #     path_imgrec=data_dir+'val.rec',
    #     shuffle=False,
    #     mean=True)
    return train_iter
Exemplo n.º 26
0
            #计算每一项差的绝对值
            output = F.abs(pred-labels)
            #计算平滑损失
            smooth_output = F.smooth_l1(output,sigma)
            return F.mean(smooth_output,axis=self.axis)


if __name__ == '__main__':
    #获取人群数据
    batch_size = 64
    edge_size = 256
    param_filename = 'ssd_params'
    train_iter = image.ImageDetIter(path_imgrec='../data/mydataset.rec',
                                    path_imgidx='../data/mydataset.idx',
                                    batch_size=batch_size,
                                    data_shape=(3,edge_size,edge_size),
                                    shuffle= True,
                                    rand_crop=1, #随机裁剪的概率为1
                                    min_object_covered = 0.95,
                                    max_attempts=200)
    ctx = gb.try_gpu()
    ssd = TinySSD(num_classes=1)
    ssd.initialize(init = init.Xavier(),ctx=ctx)
    trainer = gluon.Trainer(ssd.collect_params(),'sgd',{'learning_rate':0.2,'wd':5e-4})
    focal_loss = focal_SoftMaxCrossEntropyLoss()
    smooth_l1 = smooth_L1Loss()

   
    while True:
        #由于不是DataLoader类,每次需要我们重置指针
        train_iter.reset()
        start = time.time()
Exemplo n.º 27
0
from SimpleYOLO.YOLO2.metric import LossRecorder
from projects.pikachu.config import Config

batch_size = 32
ctx = mx.gpu(0)

data_h, data_w = Config.size
class_names = Config.classes
num_class = len(Config.classes)
anchors = Config.anchors
train_data = image.ImageDetIter(
    data_shape=(3, data_h, data_w),
    std=np.array([255, 255, 255]),
    path_imgrec=Config.train_rec,
    path_imgidx=Config.train_idx,
    batch_size=batch_size,
    shuffle=True,
    rand_mirror=True,
    rand_crop=1,
    min_object_covered=0.95,
    max_attempts=100
)
test_data = image.ImageDetIter(
    data_shape=(3, data_h, data_w),
    std=np.array([255, 255, 255]),
    path_imgrec=Config.val_rec,
    batch_size=batch_size,
    shuffle=False
)

sce_loss = gluon.loss.SoftmaxCrossEntropyLoss(from_logits=False)
l1_loss = gluon.loss.L1Loss()
Exemplo n.º 28
0
mean_red, mean_blue, mean_green = 123, 117, 104
rgb_mean = nd.array([123, 117, 104])
rgb_std = nd.array([58.395, 57.12, 57.375])
maxEpoch = 200

data_dir = 'data\\'

#1--create data iter
#trainIter = iterator.DetRecordIter(trainRecPath,batchSize,(3, dataShape, dataShape),\
#                                  mean_pixels=(mean_red, mean_green, mean_blue), std_pixel = ( rgb_std[0], rgb_std[1], rgb_std[2]  ))

trainIter = image.ImageDetIter(batch_size=batchSize,
                               data_shape=(3, dataShape, dataShape),
                               path_imgrec=data_dir + 'train.rec',
                               path_imgidx=data_dir + 'train.idx',
                               shuffle=True,
                               mean=True,
                               std=True,
                               rand_crop=1,
                               min_object_covered=0.95,
                               max_attempts=200)

batch = trainIter.next()
label, data = batch.label, batch.data
print(label[0].shape,
      "batchSize, maxObjectNum,classId,xmin,ymin,xmax,ymax,difficult")
print(data[0].shape, "batchSize,C,H,W")

#2--load net symbol
#2.1-load pretrained net (feature part)
pretrained = vision.get_model('resnet18_v1', pretrained=True).features
net = nn.HybridSequential()
Exemplo n.º 29
0
            X, anchors[i], class_preds[i], bbox_preds[i] = blk_forward(
                X, getattr(self, 'blk_%d' % i), sizes[i], ratios[i],
                getattr(self, 'class_predictor_%d' % i),
                getattr(self, 'bbox_predictor_%d' % i))
        return (nd.concat(*anchors, dim=1),
                concat_preds(class_preds).reshape(0, -1, self.num_classes + 1),
                concat_preds(bbox_preds))


batch_size = 32
img_size = 256
train_iter = image.ImageDetIter(
    path_imgrec='D:\d2l-zh20200904\data\pikachu\\train.rec',
    path_imgidx='D:\d2l-zh20200904\data\pikachu\\train.idx',
    batch_size=batch_size,
    data_shape=(3, img_size, img_size),
    shuffle=True,
    rand_crop=1,
    min_object_covered=0.95,
    max_attempts=200)
# test_iter = image.ImageDetIter(path_imgrec='D:\d2l-zh20200904\data\pikachu\\val.rec', batch_size=batch_size,
#                                data_shape=(3, img_size, img_size), shuffle=False)

softmaxloss = gloss.SoftmaxCrossEntropyLoss()
l1loss = gloss.L1Loss()


def ssdLoss(class_preds, bbox_preds, class_labels, bbox_labels, bbox_mask):
    class_loss = softmaxloss(class_preds, class_labels)
    bbox_loss = l1loss(bbox_preds * bbox_mask, bbox_labels * bbox_mask)
    return class_loss + bbox_loss