예제 #1
0
    def __init__(self):
        self.batch_size = 16
        self.learning_rate = 1e-4
        self.epochs = 5
        self.num_classes = 2
        self.dropout = 0.7

        self.ctx = mx.gpu()

        self.net = gluon.nn.Sequential()

        if os.path.exists(os.path.abspath(os.path.join(os.path.dirname(__file__), '..', REC))):
            self.rec_path = os.path.abspath(os.path.join(os.path.dirname(__file__), '..', REC))
        else:
            return

        if not os.path.exists(os.path.abspath(os.path.join(os.path.dirname(__file__), '..', 'models', 'mxnet'))):
            os.makedirs(os.path.abspath(os.path.join(os.path.dirname(__file__), '..', 'models', 'mxnet')))

        self.save_path = os.path.abspath(os.path.join(os.path.dirname(__file__), '..', 'models', 'mxnet'))

        if self.rec_path != None:
            self.train_data = ImageRecordIter(
                path_imgrec = os.path.join(self.rec_path, REAL_REC),
                path_imgidx = os.path.join(self.rec_path, REAL_IDX),
                data_shape = (3, 384, 384),
                shuffle = True,
                batch_size = self.batch_size
            )
예제 #2
0
def load_imagenet_record(batch_size=None, path=None):
    import os
    from mxnet.io import ImageRecordIter
    if path is None: path = os.environ['IMAGENETPATH']
    training_record = '%s/training.record' % path
    validation_record = '%s/validation.record' % path

    r_mean = 123.680
    g_mean = 116.779
    b_mean = 103.939
    mean = int((r_mean + g_mean + b_mean) / 3)
    scale = 1 / 59.4415

    training_data = ImageRecordIter(
        batch_size=batch_size,
        data_name='data',
        data_shape=(3, 224, 224),
        fill_value=mean,
        label_name='softmax_label',
        label_width=1,
        mean_r=r_mean,
        mean_g=g_mean,
        mean_b=b_mean,
        pad=4,
        path_imgrec=training_record,
        preprocess_threads=16,
        rand_crop=True,
        rand_mirror=True,
        scale=scale,
        shuffle=True,
        verbose=False,
    )

    validation_data = ImageRecordIter(
        batch_size=batch_size,
        data_name='data',
        data_shape=(3, 224, 224),
        fill_value=mean,
        label_name='softmax_label',
        label_width=1,
        mean_r=r_mean,
        mean_g=g_mean,
        mean_b=b_mean,
        pad=4,
        path_imgrec=validation_record,
        preprocess_threads=16,
        rand_crop=True,
        rand_mirror=True,
        scale=scale,
        shuffle=True,
        verbose=False,
    )

    return training_data, validation_data
예제 #3
0
def load_cifar10_data_iter(batch_size=None, path=''):
    from mxnet.io import ImageRecordIter

    train_record = path + 'cifar10_train.rec'
    val_record = path + 'cifar10_val.rec'

    r_mean = 123.680
    g_mean = 116.779
    b_mean = 103.939
    mean = int(sum((r_mean, g_mean, b_mean)) / 3)
    scale = 1 / 59.4415

    train_data = ImageRecordIter(
        batch_size=batch_size,
        data_name='data',
        data_shape=(3, 32, 32),
        fill_value=mean,
        label_name='softmax_label',
        label_width=1,
        mean_r=r_mean,
        mean_g=g_mean,
        mean_b=b_mean,
        pad=4,
        path_imgrec=train_record,
        preprocess_threads=16,
        rand_crop=True,
        rand_mirror=True,
        scale=scale,
        shuffle=True,
        verbose=False,
    )

    val_data = ImageRecordIter(
        batch_size=batch_size,
        data_name='data',
        data_shape=(3, 32, 32),
        label_name='softmax_label',
        label_width=1,
        mean_r=r_mean,
        mean_g=g_mean,
        mean_b=b_mean,
        num_parts=2,
        part_index=0,
        path_imgrec=val_record,
        preprocess_threads=16,
        scale=scale,
        verbose=False,
    )

    return train_data, val_data
예제 #4
0
def get_iterators(batch_size, num_classes, data_shape):
    train = ImageRecordIter(path_imgrec='wider_records/training_list.rec',
                            path_imglist='wider_records/training_list.lst',
                            batch_size=batch_size,
                            data_shape=data_shape,
                            preprocess_threads=4,
                            mean_r=104,
                            mean_g=117,
                            mean_b=123,
                            resize=256,
                            max_crop_size=224,
                            min_crop_size=128,
                            label_width=num_classes,
                            shuffle=False,
                            round_batch=False,
                            rand_crop=True,
                            rand_mirror=True)
    val = ImageRecordIter(path_imgrec='wider_records/valid_list.rec',
                          path_imglist='wider_records/valid_list.lst',
                          shuffle=False,
                          mean_r=104,
                          mean_g=117,
                          mean_b=123,
                          round_batch=False,
                          label_width=num_classes,
                          preprocess_threads=4,
                          batch_size=batch_size,
                          data_shape=data_shape)
    test = ImageRecordIter(path_imgrec='wider_records/testing_list.rec',
                           path_imglist='wider_records/testing_list.lst',
                           shuffle=False,
                           round_batch=False,
                           mean_r=104,
                           mean_g=117,
                           mean_b=123,
                           label_width=num_classes,
                           preprocess_threads=4,
                           batch_size=batch_size,
                           data_shape=data_shape)
    return train, val, test
예제 #5
0
from PIL import Image
import cv2
import numpy as np

data_iter = ImageRecordIter(
    path_imgrec = "/data/deeplearning/dataset/arrow/train_0301.rec",
    path_imglist="/data/deeplearning/dataset/arrow/train_0301.lst",
    label_width=1,
    data_name='data',
    label_name='softmax_label',
    resize=256,
    data_shape=(3, 200, 200),
    batch_size=512,
    pad=0,
    fill_value=127,  # only used when pad is valid
    rand_crop=False,
    max_random_scale=1.0,  # 480 with imagnet and vggface, 384 with msface, 32 with cifar10
    min_random_scale=1.0,  # 256.0/480.0=0.533, 256.0/384.0=0.667
    max_aspect_ratio=0,
    random_h=0,  # 0.4*90
    random_s=0,  # 0.4*127
    random_l=0,  # 0.4*127
    max_rotate_angle=0,
    max_shear_ratio=0,
    rand_mirror=False,
    shuffle=False,
)

data_iter.reset()
next_batch = data_iter.iter_next()
while next_batch:
예제 #6
0
def load_cifar10_record(batch_size=None, path=None):
    import os
    from mxnet.io import ImageRecordIter
    if path is None: path = os.environ['CIFAR10PATH']
    training_record = '%s/training-record' % path
    validation_record = '%s/validation-record' % path

    r_mean = 123.680
    g_mean = 116.779
    b_mean = 103.939
    mean = int(sum((r_mean, g_mean, b_mean)) / 3)
    scale = 1 / 59.4415

    training_data = ImageRecordIter(
        batch_size=batch_size,
        data_name='data',
        data_shape=(3, 32, 32),
        fill_value=mean,
        label_name='softmax_label',
        label_width=1,
        mean_r=r_mean,
        mean_g=g_mean,
        mean_b=b_mean,
        pad=4,
        path_imgrec=training_record,
        preprocess_threads=16,
        rand_crop=True,
        rand_mirror=True,
        scale=scale,
        shuffle=True,
        verbose=False,
    )
    validation_data = ImageRecordIter(
        batch_size=batch_size,
        data_name='data',
        data_shape=(3, 32, 32),
        label_name='softmax_label',
        label_width=1,
        mean_r=r_mean,
        mean_g=g_mean,
        mean_b=b_mean,
        num_parts=2,
        part_index=0,
        path_imgrec=validation_record,
        preprocess_threads=16,
        scale=scale,
        verbose=False,
    )
    test_data = ImageRecordIter(
        batch_size=batch_size,
        data_name='data',
        data_shape=(3, 32, 32),
        label_name='softmax_label',
        label_width=1,
        mean_r=r_mean,
        mean_g=g_mean,
        mean_b=b_mean,
        num_parts=2,
        part_index=1,
        path_imgrec=validation_record,
        preprocess_threads=16,
        scale=scale,
        verbose=False,
    )

    return training_data, validation_data, test_data  # TODO validation/test distinction
예제 #7
0
rec_file = "/data/deeplearning/dataset/sign/record/20180427_val.rec"
record = mx.recordio.MXRecordIO(rec_file, 'r')

data_iter = ImageRecordIter(
    path_imgrec=rec_file,
    label_width=1,
    data_name='data',
    label_name='softmax_label',
    resize=128,
    data_shape=(3, 112, 112),
    batch_size=1,
    # max_img_size=128,
    # min_img_size=128,
    pad=0,
    fill_value=127,  # only used when pad is valid
    rand_crop=True,
    # max_random_scale=0.8,
    # min_random_scale=0.5,
    # max_aspect_ratio=0.667,
    # min_aspect_ratio=0.375,
    random_h=0,
    random_s=0,
    random_l=0,
    max_rotate_angle=0,
    max_shear_ratio=0,
    rand_mirror=False,
    shuffle=False,
)

data_iter.reset()
n = 0
예제 #8
0
directly. Here is an example that randomly reads 128 images each time and
performs randomized resizing and cropping.
"""

import os
from mxnet import nd
from mxnet.io import ImageRecordIter

rec_path = os.path.expanduser('~/.mxnet/datasets/imagenet/rec/')

# You need to specify ``root`` for ImageNet if you extracted the images into
# a different folder
train_data = ImageRecordIter(
    path_imgrec = os.path.join(rec_path, 'train.rec'),
    path_imgidx = os.path.join(rec_path, 'train.idx'),
    data_shape  = (3, 224, 224),
    batch_size  = 32,
    shuffle     = True
)

#########################################################################
for batch in train_data:
    print(batch.data[0].shape, batch.label[0].shape)
    break

#########################################################################
# Plot some validation images
from gluoncv.utils import viz
val_data = ImageRecordIter(
    path_imgrec = os.path.join(rec_path, 'val.rec'),
    path_imgidx = os.path.join(rec_path, 'val.idx'),
예제 #9
0
class MxNetTrainer:
    def __init__(self):
        self.batch_size = 16
        self.learning_rate = 1e-4
        self.epochs = 5
        self.num_classes = 2
        self.dropout = 0.7

        self.ctx = mx.gpu()

        self.net = gluon.nn.Sequential()

        if os.path.exists(os.path.abspath(os.path.join(os.path.dirname(__file__), '..', REC))):
            self.rec_path = os.path.abspath(os.path.join(os.path.dirname(__file__), '..', REC))
        else:
            return

        if not os.path.exists(os.path.abspath(os.path.join(os.path.dirname(__file__), '..', 'models', 'mxnet'))):
            os.makedirs(os.path.abspath(os.path.join(os.path.dirname(__file__), '..', 'models', 'mxnet')))

        self.save_path = os.path.abspath(os.path.join(os.path.dirname(__file__), '..', 'models', 'mxnet'))

        if self.rec_path != None:
            self.train_data = ImageRecordIter(
                path_imgrec = os.path.join(self.rec_path, REAL_REC),
                path_imgidx = os.path.join(self.rec_path, REAL_IDX),
                data_shape = (3, 384, 384),
                shuffle = True,
                batch_size = self.batch_size
            )


    def model(self):
        with self.net.name_scope():
            self.net.add(gluon.nn.Conv2D(channels=32, kernel_size=5, activation='relu'))
            self.net.add(gluon.nn.MaxPool2D(pool_size=2, strides=2))

            self.net.add(gluon.nn.Conv2D(channels=64, kernel_size=5, activation='relu'))
            self.net.add(gluon.nn.MaxPool2D(pool_size=2, strides=2))

            self.net.add(gluon.nn.Dropout(self.dropout))

            self.net.add(gluon.nn.Conv2D(channels=128, kernel_size=5, activation='relu'))
            self.net.add(gluon.nn.MaxPool2D(pool_size=2, strides=2))

            self.net.add(gluon.nn.Dropout(self.dropout))

            self.net.add(gluon.nn.Conv2D(channels=256, kernel_size=5, activation='relu'))
            self.net.add(gluon.nn.MaxPool2D(pool_size=2, strides=2))

            self.net.add(gluon.nn.Flatten())

            self.net.add(gluon.nn.Dense(512, 'relu'))
            self.net.add(gluon.nn.Dense(256, 'relu'))
            self.net.add(gluon.nn.Dense(65, 'relu'))
            self.net.add(gluon.nn.Dense(32, 'relu'))
            
            self.net.add(gluon.nn.Dense(self.num_classes, 'sigmoid'))


    def evaluate_accuracy(self, data, label):
        acc = mx.metric.Accuracy()
        output = self.net(data)
        predictions = nd.argmax(output, axis=1)
        acc.update(preds=predictions, labels=label)

        return acc.get()[1]


    def train(self):
        self.model()

        self.net.collect_params().initialize(mx.init.Xavier(magnitude=2.24), ctx=self.ctx)
        
        smoothing_constant = .01

        softmax_cross_entropy = gluon.loss.SoftmaxCrossEntropyLoss()    

        trainer = gluon.Trainer(self.net.collect_params(), 'adam', {'learning_rate': self.learning_rate})

        for e in range(self.epochs):
            i = 0
            self.train_data.reset()
            while self.train_data.iter_next():
                d = self.train_data.getdata() / 255.
                l = self.train_data.getlabel()

                data = d.as_in_context(self.ctx)
                label = l.as_in_context(self.ctx)
                # label = nd.array(lb, ctx=self.ctx)

                step = data.shape[0]
                with autograd.record():
                    output = self.net(data)
                    loss = softmax_cross_entropy(output, label)
                
                loss.backward()

                trainer.step(step)
                ##########################
                #  Keep a moving average of the losses
                ##########################
                curr_loss = nd.mean(loss).asscalar()
                moving_loss = (curr_loss if ((i == 0) and (e == 0))
                    else (1 - smoothing_constant) * moving_loss + smoothing_constant * curr_loss)
                

                acc = self.evaluate_accuracy(data, label)
                print("Epoch {:03d} ... Dataset {:04d} ... ".format(e+1, i), "Loss = {:.4f}".format(curr_loss), " Moving Loss = {:.4f}".format(moving_loss), " Accuracy = {:.4f}".format(acc))

                # self.summary_writer.add_histogram(tag='accuracy', values=acc, global_step=e)

                i += 1

            # self.summary_writer.add_scalar(tag='moving_loss', value=moving_loss, global_step=e)

        self.save_path = os.path.join(self.save_path, 'model.params')
        self.net.save_parameters(self.save_path)

        print(self.net)
예제 #10
0
파일: train_vggs.py 프로젝트: daniaokuye/Dy
def train_vgg(gpu=0, lr=0.001, root='', param_file='', Isinit=True, FLF=False):
    num = 1
    batch_size = 40 * num  # 32
    num_workers = 2 * num
    lr = lr
    epoch = 8
    ratio_train_test = 0.8

    ctx = [mx.gpu(i) for i in [gpu]]  # range(8 - gpu, 8)
    net = vgg()

    # net.collect_params().initialize(ctx=ctx)
    if Isinit:
        # special_init(net, ctx=ctx)
        special_initMX(net, ctx=ctx, skip=gls.initp[0], layers=gls.initp[1])
        file = os.path.join(root, 'vgg.models')
    else:
        file = os.path.join(root, param_file)
        net.load_parameters(file, ctx=ctx)
    loaded_param = file.replace('.model', '_mask.pkl')
    global_param.load_param(loaded_param, ctx)
    # net.hybridize()

    # train_data, valid_data = load_data(batch_size=batch_size, num_workers=num_workers)
    data_iter = ImageRecordIter(batch_size=batch_size, data_shape=(3, crop_width, crop_height), shuffle=True,
                                path_imgrec="/home1/ImageNet_ILSVRC2012/train_label.rec",
                                path_imgidx="/home1/ImageNet_ILSVRC2012/train_label.idx",
                                aug_list=transform, preprocess_threads=num_workers)
    # params = net.net.collect_params()
    params = {}
    # try:
    #    for i in global_param.selected_key['total']:
    #        params[i]=net.net.collect_params()[i]
    # except Exception,e:
    for i in range(gls.initp[1]):
        pkey = net.net.collect_params().keys()[i]
        params[pkey] = net.net.collect_params()[pkey]
    trainer = gluon.Trainer(params, 'adam', {'learning_rate': lr})

    CEloss = gluon.loss.SoftmaxCrossEntropyLoss(axis=-1, sparse_label=True)

    epoch_train = 0  # total records
    valid = 0
    for epochs in range(epoch):
        j = epochs * epoch_train
        t = time.time()
        i = 0
        for contain in data_iter:
            i += 1
            global_param.iter = i + j
            batch, label = (contain.data[0], contain.label[0])
            # batch = batch.as_in_context(ctx)
            # label = label.as_in_context(ctx)
            batch = gluon.utils.split_and_load(batch, ctx)
            label = gluon.utils.split_and_load(label, ctx)
            if i < ratio_train_test * epoch_train or epoch_train == 0:
                with autograd.record():
                    losses = [CEloss(net(X), Y) for X, Y in zip(batch, label)]
                    losses = [mx.nd.sum(X) for X in losses]
                    # todo:1.loss;2.init
                    if FLF:
                        loss_k = loss_kernel(net, ctx)
                        # lossa, log_lossa = lossa_compress(net, 1)#loss for net structure
                        loss_all = [X + Y for X, Y in zip(losses, loss_k)]
                    else:
                        loss_all = losses
                for loss in loss_all:
                    loss.backward()
                trainer.step(batch_size)
                value = [X.asscalar() for X in losses]
                value = reduce(lambda X, Y: X + Y, value) / batch_size
                sw.add_scalar(tag='Loss', value=value, global_step=i + j)
                # for k_sw, v_sw in log_lossa.items():
                #     sw.add_scalar(tag=k_sw, value=v_sw, global_step=i + j)
                if i % 200 == 0:
                    print('iter:%d,loss:%4.5f,time:%4.5fs' % (i + j, value, time.time() - t))
                    t = time.time()

            else:
                out = [net(X) for X in batch]
                # value1, idices1 = mx.nd.topk(out, ret_typ='both')
                out = [mx.nd.softmax(X, axis=1) for X in out]
                tops = [(mx.nd.topk(X, ret_typ='both'), Y)
                        for X, Y in zip(out, label)]
                # print mx.nd.sum(value == value1).asscalar(), mx.nd.sum(idices == idices1).asscalar()
                disc, cont = 0, 0
                for (value_, idices_), label_ in tops:
                    real = idices_.reshape(-1) == label_.astype(idices_.dtype)
                    disc += mx.nd.sum(real).asscalar()
                    cont += mx.nd.sum(real * value_.T).asscalar()

                # for a, b in zip(label.asnumpy().astype(np.uint),
                #                 idices.reshape(-1).asnumpy().astype(np.uint)):
                #     if not a in test.keys():
                #         test[a]=set([])
                #     test[a]=set(list(test[a]).append(b))

                discroc = disc / batch_size  # (mx.nd.sum(real) / batch_size).asscalar()
                controc = cont / batch_size  # (mx.nd.sum(real * value) / batch_size).asscalar()
                sw.add_scalar(tag='RocDisc', value=discroc, global_step=valid)
                sw.add_scalar(tag='RocCont', value=controc, global_step=valid)
                valid += 1
                if i % 200 == 0:
                    print 'RocDisc', discroc

        data_iter.reset()
        if i > epoch_train: epoch_train = i
        print 'epcoah length:', epoch_train, 'and i:', i, 'time:', time.time() - t
        online_check(epochs, sw, net.collect_params())
        # save model
        net.save_parameters(file)
        global_param.save_param(loaded_param)
        print '*' * 30
    def __init__(self, hemorrhage_type):
        self.hemorrhage_type = hemorrhage_type

        if os.path.exists(
                os.path.abspath(
                    os.path.join(os.path.dirname(__file__), '..', 'rec'))):
            self.rec_path = os.path.abspath(
                os.path.join(os.path.dirname(__file__), '..', 'rec'))
        else:
            self.rec_path = None

        self.save_path = None

        if not os.path.exists(
                os.path.abspath(
                    os.path.join(os.path.dirname(__file__), '..', 'models',
                                 'mxnet', hemorrhage_type))):
            os.makedirs(
                os.path.abspath(
                    os.path.join(os.path.dirname(__file__), '..', 'models',
                                 'mxnet', hemorrhage_type)))

        self.save_path = os.path.abspath(
            os.path.join(os.path.dirname(__file__), '..', 'models', 'mxnet',
                         hemorrhage_type))

        self.test_dir = os.path.abspath(
            os.path.join(os.path.dirname(__file__), '..',
                         'stage_1_test_images'))

        self.net = gluon.nn.Sequential()

        self.ctx = mx.gpu()
        self.num_outputs = 2

        self.epochs = 6
        self.learning_rate = 1e-3
        self.batch_size = 32

        if self.rec_path != None:
            if self.hemorrhage_type == 'epidural':
                self.train_data = ImageRecordIter(
                    path_imgrec=os.path.join(self.rec_path,
                                             'epidural_rec.rec'),
                    path_imgidx=os.path.join(self.rec_path,
                                             'epidural_rec.idx'),
                    data_shape=(3, 384, 384),
                    batch_size=self.batch_size,
                    shuffle=True)
            elif self.hemorrhage_type == 'intraparenchymal':
                self.train_data = ImageRecordIter(
                    path_imgrec=os.path.join(self.rec_path,
                                             'intraparenchymal_rec.rec'),
                    path_imgidx=os.path.join(self.rec_path,
                                             'intraparenchymal_rec.idx'),
                    data_shape=(3, 384, 384),
                    batch_size=self.batch_size,
                    shuffle=True)
            elif self.hemorrhage_type == 'intraventricular':
                self.train_data = ImageRecordIter(
                    path_imgrec=os.path.join(self.rec_path,
                                             'intraventricular_rec.rec'),
                    path_imgidx=os.path.join(self.rec_path,
                                             'intraventricular_rec.idx'),
                    data_shape=(3, 384, 384),
                    batch_size=self.batch_size,
                    shuffle=True)
            elif self.hemorrhage_type == 'subarachnoid':
                self.train_data = ImageRecordIter(
                    path_imgrec=os.path.join(self.rec_path,
                                             'subarachnoid_rec.rec'),
                    path_imgidx=os.path.join(self.rec_path,
                                             'subarachnoid_rec.idx'),
                    data_shape=(3, 384, 384),
                    batch_size=self.batch_size,
                    shuffle=True)
            elif self.hemorrhage_type == 'subdural':
                self.train_data = ImageRecordIter(
                    path_imgrec=os.path.join(self.rec_path,
                                             'subdural_rec.rec'),
                    path_imgidx=os.path.join(self.rec_path,
                                             'subdural_rec.idx'),
                    data_shape=(3, 384, 384),
                    batch_size=self.batch_size,
                    shuffle=True)
class MxNetTrainer:
    def __init__(self, hemorrhage_type):
        self.hemorrhage_type = hemorrhage_type

        if os.path.exists(
                os.path.abspath(
                    os.path.join(os.path.dirname(__file__), '..', 'rec'))):
            self.rec_path = os.path.abspath(
                os.path.join(os.path.dirname(__file__), '..', 'rec'))
        else:
            self.rec_path = None

        self.save_path = None

        if not os.path.exists(
                os.path.abspath(
                    os.path.join(os.path.dirname(__file__), '..', 'models',
                                 'mxnet', hemorrhage_type))):
            os.makedirs(
                os.path.abspath(
                    os.path.join(os.path.dirname(__file__), '..', 'models',
                                 'mxnet', hemorrhage_type)))

        self.save_path = os.path.abspath(
            os.path.join(os.path.dirname(__file__), '..', 'models', 'mxnet',
                         hemorrhage_type))

        self.test_dir = os.path.abspath(
            os.path.join(os.path.dirname(__file__), '..',
                         'stage_1_test_images'))

        self.net = gluon.nn.Sequential()

        self.ctx = mx.gpu()
        self.num_outputs = 2

        self.epochs = 6
        self.learning_rate = 1e-3
        self.batch_size = 32

        if self.rec_path != None:
            if self.hemorrhage_type == 'epidural':
                self.train_data = ImageRecordIter(
                    path_imgrec=os.path.join(self.rec_path,
                                             'epidural_rec.rec'),
                    path_imgidx=os.path.join(self.rec_path,
                                             'epidural_rec.idx'),
                    data_shape=(3, 384, 384),
                    batch_size=self.batch_size,
                    shuffle=True)
            elif self.hemorrhage_type == 'intraparenchymal':
                self.train_data = ImageRecordIter(
                    path_imgrec=os.path.join(self.rec_path,
                                             'intraparenchymal_rec.rec'),
                    path_imgidx=os.path.join(self.rec_path,
                                             'intraparenchymal_rec.idx'),
                    data_shape=(3, 384, 384),
                    batch_size=self.batch_size,
                    shuffle=True)
            elif self.hemorrhage_type == 'intraventricular':
                self.train_data = ImageRecordIter(
                    path_imgrec=os.path.join(self.rec_path,
                                             'intraventricular_rec.rec'),
                    path_imgidx=os.path.join(self.rec_path,
                                             'intraventricular_rec.idx'),
                    data_shape=(3, 384, 384),
                    batch_size=self.batch_size,
                    shuffle=True)
            elif self.hemorrhage_type == 'subarachnoid':
                self.train_data = ImageRecordIter(
                    path_imgrec=os.path.join(self.rec_path,
                                             'subarachnoid_rec.rec'),
                    path_imgidx=os.path.join(self.rec_path,
                                             'subarachnoid_rec.idx'),
                    data_shape=(3, 384, 384),
                    batch_size=self.batch_size,
                    shuffle=True)
            elif self.hemorrhage_type == 'subdural':
                self.train_data = ImageRecordIter(
                    path_imgrec=os.path.join(self.rec_path,
                                             'subdural_rec.rec'),
                    path_imgidx=os.path.join(self.rec_path,
                                             'subdural_rec.idx'),
                    data_shape=(3, 384, 384),
                    batch_size=self.batch_size,
                    shuffle=True)

    def model(self):
        with self.net.name_scope():
            self.net.add(
                gluon.nn.Conv2D(channels=32, kernel_size=5, activation='relu'))
            self.net.add(gluon.nn.MaxPool2D(pool_size=2, strides=2))

            self.net.add(
                gluon.nn.Conv2D(channels=64, kernel_size=5, activation='relu'))
            self.net.add(gluon.nn.MaxPool2D(pool_size=2, strides=2))

            self.net.add(
                gluon.nn.Conv2D(channels=128, kernel_size=5,
                                activation='relu'))
            self.net.add(gluon.nn.MaxPool2D(pool_size=2, strides=2))

            self.net.add(gluon.nn.Flatten())

            self.net.add(gluon.nn.Dense(256, 'relu'))
            self.net.add(gluon.nn.Dense(64, 'relu'))
            self.net.add(gluon.nn.Dense(32, 'relu'))

            self.net.add(gluon.nn.Dense(self.num_outputs))

    def evaluate_accuracy(self, data, label):
        acc = mx.metric.Accuracy()
        output = self.net(data)
        predictions = nd.argmax(output, axis=1)
        acc.update(preds=predictions, labels=label)

        return acc.get()[1]

    def train(self):
        self.model()

        self.net.collect_params().initialize(mx.init.Xavier(magnitude=2.24),
                                             ctx=self.ctx)

        smoothing_constant = .01

        softmax_cross_entropy = gluon.loss.SoftmaxCrossEntropyLoss()

        trainer = gluon.Trainer(self.net.collect_params(), 'adam',
                                {'learning_rate': self.learning_rate})

        for e in range(self.epochs):
            i = 0
            self.train_data.reset()
            while self.train_data.iter_next():
                d = self.train_data.getdata() / 255.
                l = self.train_data.getlabel()

                data = d.as_in_context(self.ctx)
                label = l.as_in_context(self.ctx)

                step = data.shape[0]
                with autograd.record():
                    output = self.net(data)
                    loss = softmax_cross_entropy(output, label)

                loss.backward()

                trainer.step(step)
                ##########################
                #  Keep a moving average of the losses
                ##########################
                curr_loss = nd.mean(loss).asscalar()
                moving_loss = (curr_loss if ((i == 0) and (e == 0)) else
                               (1 - smoothing_constant) * moving_loss +
                               smoothing_constant * curr_loss)

                acc = self.evaluate_accuracy(data, label)
                print("Epoch {:03d} ... Dataset {:03d} ... ".format(e + 1, i),
                      "Loss = {:.4f}".format(curr_loss),
                      " Moving Loss = {:.4f}".format(moving_loss),
                      " Accuracy = {:.4f}".format(acc))

                # self.summary_writer.add_histogram(tag='accuracy', values=acc, global_step=e)

                i += 1

            # self.summary_writer.add_scalar(tag='moving_loss', value=moving_loss, global_step=e)

        self.save_path = os.path.join(self.save_path, 'model.params')
        self.net.save_parameters(self.save_path)

    def inference(self):
        model_path = os.path.join(self.save_path, 'model.params')

        if os.path.exists(model_path):
            # if not os.path.exists(os.path.abspath(os.path.join(os.path.dirname(__file__), '..', 'test'))):
            #     os.mkdir(os.path.join(os.path.dirname(__file__), '..', 'test'))

            # for i in tqdm(os.listdir(self.test_dir)):
            #     img = pydicom.dcmread(os.path.join(self.test_dir, i)).pixel_array

            #     f_name = i.split('.')[0]

            #     cv2.imwrite(os.path.join(os.path.dirname(__file__), '..', 'test', f_name + '.jpg'), img)

            # img = pydicom.dcmread(os.path.join(self.test_dir, os.listdir(self.test_dir)[0])).pixel_array

            img = cv2.imread(
                os.path.join(
                    os.path.dirname(__file__), '..', 'test',
                    os.listdir(
                        os.path.join(os.path.dirname(__file__), '..',
                                     'test'))[0]))
            img_list = []

            self.model()

            self.net.load_parameters(model_path, ctx=self.ctx)

            img = nd.array(cv2.resize(img, (384, 384)) / 255.)

            img_list.append(img)
            img_list = np.array(img_list)

            data = img_list.as_in_context(self.ctx)

            with autograd.record():
                output = self.net(data)

                print(output)