示例#1
0
    if grid.ndim == 3 and grid.shape[2] == 1:
        grid = grid.squeeze()
    return grid


# %%
# Fix the seed
# mx.random.seed(42)

# %%
# -- Get some data to train on

mnist_train = gluon.data.vision.datasets.MNIST(train=True)
mnist_valid = gluon.data.vision.datasets.MNIST(train=False)

transformer = transforms.Compose([transforms.ToTensor()])

# %%
# -- Create a data iterator that feeds batches

train_data = gluon.data.DataLoader(mnist_train.transform_first(transformer),
                                   batch_size=BATCH_SIZE,
                                   shuffle=True,
                                   last_batch="rollover",
                                   num_workers=N_WORKERS)

eval_data = gluon.data.DataLoader(mnist_valid.transform_first(transformer),
                                  batch_size=BATCH_SIZE,
                                  shuffle=False,
                                  last_batch="rollover",
                                  num_workers=N_WORKERS)
    def forward(self, x):
        x = mx.image.copyMakeBorder(x, self.padding, self.padding,
                                    self.padding, self.padding)
        # print(x.shape)
        x, _ = mx.image.random_crop(x, (self.size, self.size))
        return x


normalize = T.Normalize(mean=[0.491, 0.482, 0.447], std=[0.247, 0.243, 0.262])

train_transfrom = T.Compose([
    RandomCrop(32, padding=4),
    # T.RandomResizedCrop(32),
    T.RandomFlipLeftRight(),
    T.ToTensor(),
    normalize
])

val_transform = T.Compose([T.ToTensor(), normalize])

trainset = datasets.CIFAR10('./data',
                            train=True).transform_first(train_transfrom)
trainloader = gluon.data.DataLoader(trainset,
                                    batch_size=args.batch_size,
                                    shuffle=True,
                                    num_workers=2)

testset = datasets.CIFAR10('./data',
                           train=False).transform_first(val_transform)
testloader = gluon.data.DataLoader(testset,
示例#3
0
Read with GluonCV
-----------------

The prepared dataset can be loaded with utility class :py:class:`gluoncv.data.ImageNet`
directly. Here is an example that randomly reads 128 images each time and
performs randomized resizing and cropping.
"""

from gluoncv.data import ImageNet
from mxnet.gluon.data import DataLoader
from mxnet.gluon.data.vision import transforms

train_trans = transforms.Compose(
    [transforms.RandomResizedCrop(224),
     transforms.ToTensor()])

# You need to specify ``root`` for ImageNet if you extracted the images into
# a different folder
train_data = DataLoader(ImageNet(train=True).transform_first(train_trans),
                        batch_size=128,
                        shuffle=True)

#########################################################################
for x, y in train_data:
    print(x.shape, y.shape)
    break

#########################################################################
# Plot some validation images
from gluoncv.utils import viz
def run_mnist(
    hook=None,
    set_modes=False,
    num_steps_train=None,
    num_steps_eval=None,
    epochs=2,
    save_interval=None,
    save_path="./saveParams",
):
    batch_size = 4
    normalize_mean = 0.13
    mnist_train = datasets.FashionMNIST(train=True)

    X, y = mnist_train[0]
    ("X shape: ", X.shape, "X dtype", X.dtype, "y:", y)

    text_labels = [
        "t-shirt",
        "trouser",
        "pullover",
        "dress",
        "coat",
        "sandal",
        "shirt",
        "sneaker",
        "bag",
        "ankle boot",
    ]
    transformer = transforms.Compose(
        [transforms.ToTensor(),
         transforms.Normalize(normalize_mean, 0.31)])

    mnist_train = mnist_train.transform_first(transformer)
    mnist_valid = gluon.data.vision.FashionMNIST(train=False)

    train_data = gluon.data.DataLoader(mnist_train,
                                       batch_size=batch_size,
                                       shuffle=True,
                                       num_workers=4)
    valid_data = gluon.data.DataLoader(
        mnist_valid.transform_first(transformer),
        batch_size=batch_size,
        num_workers=4)

    # Create Model in Gluon
    net = nn.HybridSequential()
    net.add(
        nn.Conv2D(channels=6, kernel_size=5, activation="relu"),
        nn.MaxPool2D(pool_size=2, strides=2),
        nn.Conv2D(channels=16, kernel_size=3, activation="relu"),
        nn.MaxPool2D(pool_size=2, strides=2),
        nn.Flatten(),
        nn.Dense(120, activation="relu"),
        nn.Dense(84, activation="relu"),
        nn.Dense(10),
    )
    net.initialize(init=init.Xavier(), ctx=mx.cpu())

    if hook is not None:
        # Register the forward Hook
        hook.register_hook(net)

    softmax_cross_entropy = gluon.loss.SoftmaxCrossEntropyLoss()
    trainer = gluon.Trainer(net.collect_params(), "sgd",
                            {"learning_rate": 0.1})
    hook.register_hook(softmax_cross_entropy)

    # Start the training.
    for epoch in range(epochs):
        train_loss, train_acc, valid_acc = 0.0, 0.0, 0.0
        tic = time.time()
        if set_modes:
            hook.set_mode(modes.TRAIN)

        i = 0
        for data, label in train_data:
            data = data.as_in_context(mx.cpu(0))
            # forward + backward
            with autograd.record():
                output = net(data)
                loss = softmax_cross_entropy(output, label)
            loss.backward()
            # update parameters
            trainer.step(batch_size)
            # calculate training metrics
            train_loss += loss.mean().asscalar()
            train_acc += acc(output, label)
            i += 1
            if num_steps_train is not None and i >= num_steps_train:
                break
        # calculate validation accuracy
        if set_modes:
            hook.set_mode(modes.EVAL)
        i = 0
        for data, label in valid_data:
            data = data.as_in_context(mx.cpu(0))
            val_output = net(data)
            valid_acc += acc(val_output, label)
            loss = softmax_cross_entropy(val_output, label)
            i += 1
            if num_steps_eval is not None and i >= num_steps_eval:
                break
        print(
            "Epoch %d: loss %.3f, train acc %.3f, test acc %.3f, in %.1f sec" %
            (
                epoch,
                train_loss / len(train_data),
                train_acc / len(train_data),
                valid_acc / len(valid_data),
                time.time() - tic,
            ))
        if save_interval is not None and (epoch % save_interval) == 0:
            net.save_parameters("{0}/params_{1}.params".format(
                save_path, epoch))
示例#5
0
def train_cifar10(config):
    args = config.pop("args")
    vars(args).update(config)
    np.random.seed(args.seed)
    random.seed(args.seed)
    mx.random.seed(args.seed)

    # Set Hyper-params
    batch_size = args.batch_size * max(args.num_gpus, 1)
    ctx = [mx.gpu(i)
           for i in range(args.num_gpus)] if args.num_gpus > 0 else [mx.cpu()]

    # Define DataLoader
    transform_train = transforms.Compose([
        gcv_transforms.RandomCrop(32, pad=4),
        transforms.RandomFlipLeftRight(),
        transforms.ToTensor(),
        transforms.Normalize([0.4914, 0.4822, 0.4465],
                             [0.2023, 0.1994, 0.2010]),
    ])

    transform_test = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize([0.4914, 0.4822, 0.4465],
                             [0.2023, 0.1994, 0.2010]),
    ])

    train_data = gluon.data.DataLoader(
        gluon.data.vision.CIFAR10(train=True).transform_first(transform_train),
        batch_size=batch_size,
        shuffle=True,
        last_batch="discard",
        num_workers=args.num_workers,
    )

    test_data = gluon.data.DataLoader(
        gluon.data.vision.CIFAR10(train=False).transform_first(transform_test),
        batch_size=batch_size,
        shuffle=False,
        num_workers=args.num_workers,
    )

    # Load model architecture and Initialize the net with pretrained model
    finetune_net = get_model(args.model, pretrained=True)
    with finetune_net.name_scope():
        finetune_net.fc = nn.Dense(args.classes)
    finetune_net.fc.initialize(init.Xavier(), ctx=ctx)
    finetune_net.collect_params().reset_ctx(ctx)
    finetune_net.hybridize()

    # Define trainer
    trainer = gluon.Trainer(
        finetune_net.collect_params(),
        "sgd",
        {
            "learning_rate": args.lr,
            "momentum": args.momentum,
            "wd": args.wd
        },
    )
    L = gluon.loss.SoftmaxCrossEntropyLoss()
    metric = mx.metric.Accuracy()

    def train(epoch):
        for i, batch in enumerate(train_data):
            data = gluon.utils.split_and_load(batch[0],
                                              ctx_list=ctx,
                                              batch_axis=0,
                                              even_split=False)
            label = gluon.utils.split_and_load(batch[1],
                                               ctx_list=ctx,
                                               batch_axis=0,
                                               even_split=False)
            with ag.record():
                outputs = [finetune_net(X) for X in data]
                loss = [L(yhat, y) for yhat, y in zip(outputs, label)]
            for ls in loss:
                ls.backward()

            trainer.step(batch_size)
        mx.nd.waitall()

    def test():
        test_loss = 0
        for i, batch in enumerate(test_data):
            data = gluon.utils.split_and_load(batch[0],
                                              ctx_list=ctx,
                                              batch_axis=0,
                                              even_split=False)
            label = gluon.utils.split_and_load(batch[1],
                                               ctx_list=ctx,
                                               batch_axis=0,
                                               even_split=False)
            outputs = [finetune_net(X) for X in data]
            loss = [L(yhat, y) for yhat, y in zip(outputs, label)]

            test_loss += sum(ls.mean().asscalar() for ls in loss) / len(loss)
            metric.update(label, outputs)

        _, test_acc = metric.get()
        test_loss /= len(test_data)
        return test_loss, test_acc

    for epoch in range(1, args.epochs + 1):
        train(epoch)
        test_loss, test_acc = test()
        tune.report(mean_loss=test_loss, mean_accuracy=test_acc)
示例#6
0
def main():
    opt = parse_args()
    
    batch_size = opt.batch_size
    classes = 10

    num_gpus = opt.num_gpus
    batch_size *= max(1, num_gpus)
    context = [mx.gpu(i) for i in range(num_gpus)] if num_gpus > 0 else [mx.cpu()]
    num_workers = opt.num_workers

    lr_decay = opt.lr_decay
    lr_decay_epoch = [int(i) for i in opt.lr_decay_epoch.split(',')] + [np.inf]

    model_name = opt.model
    if model_name.startswith('cifar_wideresnet'):
        kwargs = {'classes': classes,
                'drop_rate': opt.drop_rate}
    else:
        kwargs = {'classes': classes}
    net = get_model(model_name, **kwargs)
    if opt.resume_from:
        net.load_parameters(opt.resume_from, ctx = context)
    optimizer = 'nag'

    save_period = opt.save_period
    if opt.save_dir and save_period:
        save_dir = opt.save_dir
        makedirs(save_dir)
    else:
        save_dir = ''
        save_period = 0

    plot_path = opt.save_plot_dir

    logging.basicConfig(level=logging.INFO)
    logging.info(opt)

    transform_train = transforms.Compose([
        gcv_transforms.RandomCrop(32, pad=4),
        transforms.RandomFlipLeftRight(),
        transforms.ToTensor(),
        transforms.Normalize([0.4914, 0.4822, 0.4465], [0.2023, 0.1994, 0.2010])
    ])

    transform_test = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize([0.4914, 0.4822, 0.4465], [0.2023, 0.1994, 0.2010])
    ])

    def test(ctx, val_data):
        metric = mx.metric.Accuracy()
        for i, batch in enumerate(val_data):
            data = gluon.utils.split_and_load(batch[0], ctx_list=ctx, batch_axis=0)
            label = gluon.utils.split_and_load(batch[1], ctx_list=ctx, batch_axis=0)
            outputs = [net(X) for X in data]
            metric.update(label, outputs)
        return metric.get()

    def train(epochs, ctx):
        if isinstance(ctx, mx.Context):
            ctx = [ctx]
        net.initialize(mx.init.Xavier(), ctx=ctx)

        train_data = gluon.data.DataLoader(
            gluon.data.vision.CIFAR10(train=True).transform_first(transform_train),
            batch_size=batch_size, shuffle=True, last_batch='discard', num_workers=num_workers)

        val_data = gluon.data.DataLoader(
            gluon.data.vision.CIFAR10(train=False).transform_first(transform_test),
            batch_size=batch_size, shuffle=False, num_workers=num_workers)

        trainer = gluon.Trainer(net.collect_params(), optimizer,
                                {'learning_rate': opt.lr, 'wd': opt.wd, 'momentum': opt.momentum})
        metric = mx.metric.Accuracy()
        train_metric = mx.metric.Accuracy()
        loss_fn = gluon.loss.SoftmaxCrossEntropyLoss()
        train_history = TrainingHistory(['training-error', 'validation-error'])

        iteration = 0
        lr_decay_count = 0

        best_val_score = 0

        for epoch in range(epochs):
            tic = time.time()
            train_metric.reset()
            metric.reset()
            train_loss = 0
            num_batch = len(train_data)
            alpha = 1

            if epoch == lr_decay_epoch[lr_decay_count]:
                trainer.set_learning_rate(trainer.learning_rate*lr_decay)
                lr_decay_count += 1

            for i, batch in enumerate(train_data):
                data = gluon.utils.split_and_load(batch[0], ctx_list=ctx, batch_axis=0)
                label = gluon.utils.split_and_load(batch[1], ctx_list=ctx, batch_axis=0)

                with ag.record():
                    output = [net(X) for X in data]
                    loss = [loss_fn(yhat, y) for yhat, y in zip(output, label)]
                for l in loss:
                    l.backward()
                trainer.step(batch_size)
                train_loss += sum([l.sum().asscalar() for l in loss])

                train_metric.update(label, output)
                name, acc = train_metric.get()
                iteration += 1

            train_loss /= batch_size * num_batch
            name, acc = train_metric.get()
            name, val_acc = test(ctx, val_data)
            train_history.update([1-acc, 1-val_acc])
            train_history.plot(save_path='%s/%s_history.png'%(plot_path, model_name))

            if val_acc > best_val_score:
                best_val_score = val_acc
                net.save_parameters('%s/%.4f-cifar-%s-%d-best.params'%(save_dir, best_val_score, model_name, epoch))

            logging.info('[Epoch %d] train=%f val=%f loss=%f time: %f' %
                (epoch, acc, val_acc, train_loss, time.time()-tic))

            if save_period and save_dir and (epoch + 1) % save_period == 0:
                net.save_parameters('%s/cifar10-%s-%d.params'%(save_dir, model_name, epoch))

        if save_period and save_dir:
            net.save_parameters('%s/cifar10-%s-%d.params'%(save_dir, model_name, epochs-1))



    if opt.mode == 'hybrid':
        net.hybridize()
    train(opt.num_epochs, context)
示例#7
0
net = get_model(model_name, **kwargs)
net.cast(opt.dtype)
if opt.params_file:
    net.load_params(opt.params_file, ctx=ctx)
net.hybridize()

acc_top1 = mx.metric.Accuracy()
acc_top5 = mx.metric.TopKAccuracy(5)

normalize = transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])

transform_test = transforms.Compose([
    transforms.Resize(256, keep_ratio=True),
    transforms.CenterCrop(224),
    transforms.ToTensor(), normalize
])


def test(ctx, val_data, mode='image'):
    acc_top1.reset()
    acc_top5.reset()
    if not opt.rec_dir:
        num_batch = len(val_data)
    for i, batch in enumerate(val_data):
        if mode == 'image':
            data = gluon.utils.split_and_load(batch[0],
                                              ctx_list=ctx,
                                              batch_axis=0)
            label = gluon.utils.split_and_load(batch[1],
                                               ctx_list=ctx,
示例#8
0
        txts.append(txt)
    plt.savefig(output_path)


def inf_train_gen(loader):
    """
    Using iterations train network.
    :param loader: Dataloader
    :return: batch of data
    """
    while True:
        for batch in loader:
            yield batch


transform_test = transforms.Compose([transforms.ToTensor()])

_transform_train = transforms.Compose([
    transforms.RandomBrightness(0.3),
    transforms.RandomContrast(0.3),
    transforms.RandomSaturation(0.3),
    transforms.RandomFlipLeftRight(),
    transforms.ToTensor()
])


def transform_train(data, label):
    im = _transform_train(data)
    return im, label

示例#9
0
        #tmp = nd.zeros(shape=(self.n_frame, self.crop_size, self.crop_size, 3), dtype='float32')
        if current_length < self.n_frame:
            #construct the last frame and concat
            extra_data = nd.tile(nd_image_list[-1],
                                 reps=(self.n_frame - current_length, 1, 1, 1))
            extra_data = extra_data.transpose((1, 0, 2, 3))
            cthw_data = nd.concat(cthw_data, extra_data, dim=1)
        # begin to construct the label
        label_nd = np.zeros(shape=(self.max_label), dtype=np.float32)
        for tag_index in labels:
            label_nd[tag_index - 1] = 1
        return cthw_data, label_nd


train_transform = T.Compose([
    T.ToTensor(),
    T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])


def get_simple_meitu_dataloader(datadir,
                                batch_size=4,
                                n_frame=32,
                                crop_size=112,
                                scale_h=128,
                                scale_w=171,
                                num_workers=6):
    """construct the dataset and then set the datasetloader"""
    train_dataset = SimpleMeitu(datadir,
                                n_frame,
                                crop_size,
示例#10
0
def main():
    opt = parse_args()

    filehandler = logging.FileHandler(opt.logging_file, mode='a+')
    # streamhandler = logging.StreamHandler()

    logger = logging.getLogger('ImageNet')
    logger.setLevel(level=logging.DEBUG)
    logger.addHandler(filehandler)
    # logger.addHandler(streamhandler)

    logger.info(opt)

    if opt.amp:
        amp.init()

    batch_size = opt.batch_size
    classes = 1000
    num_training_samples = 1281167
    num_validating_samples = 50000

    num_gpus = opt.num_gpus
    batch_size *= max(1, num_gpus)
    context = [mx.gpu(i)
               for i in range(num_gpus)] if num_gpus > 0 else [mx.cpu()]
    num_workers = opt.num_workers
    accumulate = opt.accumulate

    lr_decay = opt.lr_decay
    lr_decay_period = opt.lr_decay_period
    if opt.lr_decay_period > 0:
        lr_decay_epoch = list(
            range(lr_decay_period, opt.num_epochs, lr_decay_period))
    else:
        lr_decay_epoch = [int(i) for i in opt.lr_decay_epoch.split(',')]
    lr_decay_epoch = [e - opt.warmup_epochs for e in lr_decay_epoch]
    num_batches = num_training_samples // batch_size

    lr_scheduler = LRSequential([
        LRScheduler('linear',
                    base_lr=0,
                    target_lr=opt.lr,
                    nepochs=opt.warmup_epochs,
                    iters_per_epoch=num_batches),
        LRScheduler(opt.lr_mode,
                    base_lr=opt.lr,
                    target_lr=0,
                    nepochs=opt.num_epochs - opt.warmup_epochs,
                    iters_per_epoch=num_batches,
                    step_epoch=lr_decay_epoch,
                    step_factor=lr_decay,
                    power=2)
    ])

    model_name = opt.model

    kwargs = {'ctx': context, 'pretrained': opt.use_pretrained}
    if opt.use_gn:
        kwargs['norm_layer'] = gcv.nn.GroupNorm
    if model_name.startswith('vgg'):
        kwargs['batch_norm'] = opt.batch_norm
    elif model_name.startswith('resnext'):
        kwargs['use_se'] = opt.use_se

    if opt.last_gamma:
        kwargs['last_gamma'] = True

    optimizer = 'sgd'
    optimizer_params = {
        'wd': opt.wd,
        'momentum': opt.momentum,
        'lr_scheduler': lr_scheduler,
        'begin_num_update': num_batches * opt.resume_epoch
    }
    # if opt.dtype != 'float32':
    #     optimizer_params['multi_precision'] = True

    # net = get_model(model_name, **kwargs)
    if opt.model_backend == 'gluoncv':
        net = glcv_get_model(model_name, **kwargs)
    elif opt.model_backend == 'gluoncv2':
        net = glcv2_get_model(model_name, **kwargs)
    else:
        raise ValueError(f'Unknown backend: {opt.model_backend}')
    # net.cast(opt.dtype)
    if opt.resume_params != '':
        net.load_parameters(opt.resume_params, ctx=context, cast_dtype=True)

    # teacher model for distillation training
    if opt.teacher is not None and opt.hard_weight < 1.0:
        teacher_name = opt.teacher
        if opt.teacher_backend == 'gluoncv':
            teacher = glcv_get_model(teacher_name, **kwargs)
        elif opt.teacher_backend == 'gluoncv2':
            teacher = glcv2_get_model(teacher_name, **kwargs)
        else:
            raise ValueError(f'Unknown backend: {opt.teacher_backend}')
        # teacher = glcv2_get_model(teacher_name, pretrained=True, ctx=context)
        # teacher.cast(opt.dtype)
        teacher.collect_params().setattr('grad_req', 'null')
        distillation = True
    else:
        distillation = False

    # Two functions for reading data from record file or raw images
    def get_data_rec(rec_train, rec_val):
        rec_train = os.path.expanduser(rec_train)
        rec_val = os.path.expanduser(rec_val)

        # mean_rgb = [123.68, 116.779, 103.939]
        # std_rgb = [58.393, 57.12, 57.375]

        train_dataset = ImageRecordDataset(filename=rec_train, flag=1)
        val_dataset = ImageRecordDataset(filename=rec_val, flag=1)
        return train_dataset, val_dataset

    def get_data_loader(data_dir):
        train_dataset = ImageNet(data_dir, train=True)
        val_dataset = ImageNet(data_dir, train=False)
        return train_dataset, val_dataset

    def batch_fn(batch, ctx):
        data = gluon.utils.split_and_load(batch[0], ctx_list=ctx, batch_axis=0)
        label = gluon.utils.split_and_load(batch[1],
                                           ctx_list=ctx,
                                           batch_axis=0)
        return data, label

    if opt.use_rec:
        train_dataset, val_dataset = get_data_rec(opt.rec_train, opt.rec_val)
    else:
        train_dataset, val_dataset = get_data_loader(opt.data_dir)

    normalize = transforms.Normalize([0.485, 0.456, 0.406],
                                     [0.229, 0.224, 0.225])
    jitter_param = 0.4
    lighting_param = 0.1
    if not opt.multi_scale:
        train_dataset = train_dataset.transform_first(
            transforms.Compose([
                transforms.RandomResizedCrop(opt.input_size),
                transforms.RandomFlipLeftRight(),
                transforms.RandomColorJitter(brightness=jitter_param,
                                             contrast=jitter_param,
                                             saturation=jitter_param),
                transforms.RandomLighting(lighting_param),
                transforms.ToTensor(), normalize
            ]))
        train_data = gluon.data.DataLoader(train_dataset,
                                           batch_size=batch_size,
                                           shuffle=True,
                                           pin_memory=True,
                                           last_batch='rollover',
                                           num_workers=num_workers)
    else:
        train_data = RandomTransformDataLoader(
            [
                Transform(
                    transforms.Compose([
                        # transforms.RandomResizedCrop(opt.input_size),
                        transforms.RandomResizedCrop(x * 32),
                        transforms.RandomFlipLeftRight(),
                        transforms.RandomColorJitter(brightness=jitter_param,
                                                     contrast=jitter_param,
                                                     saturation=jitter_param),
                        transforms.RandomLighting(lighting_param),
                        transforms.ToTensor(),
                        normalize
                    ])) for x in range(10, 20)
            ],
            train_dataset,
            interval=10 * opt.accumulate,
            batch_size=batch_size,
            shuffle=False,
            pin_memory=True,
            last_batch='rollover',
            num_workers=num_workers)
    val_dataset = val_dataset.transform_first(
        transforms.Compose([
            transforms.Resize(opt.input_size, keep_ratio=True),
            transforms.CenterCrop(opt.input_size),
            transforms.ToTensor(), normalize
        ]))
    val_data = gluon.data.DataLoader(val_dataset,
                                     batch_size=batch_size,
                                     shuffle=False,
                                     pin_memory=True,
                                     last_batch='keep',
                                     num_workers=num_workers)

    if opt.mixup:
        train_metric = mx.metric.RMSE()
    else:
        train_metric = mx.metric.Accuracy()
    train_loss_metric = mx.metric.Loss()
    acc_top1 = mx.metric.Accuracy()
    acc_top5 = mx.metric.TopKAccuracy(5)

    save_frequency = opt.save_frequency
    if opt.save_dir and save_frequency:
        if opt.wandb:
            save_dir = wandb.run.dir
        else:
            save_dir = opt.save_dir
            makedirs(save_dir)
    else:
        save_dir = ''
        save_frequency = 0

    def mixup_transform(label, classes, lam=1, eta=0.0):
        if isinstance(label, nd.NDArray):
            label = [label]
        res = []
        for l in label:
            y1 = l.one_hot(classes,
                           on_value=1 - eta + eta / classes,
                           off_value=eta / classes)
            y2 = l[::-1].one_hot(classes,
                                 on_value=1 - eta + eta / classes,
                                 off_value=eta / classes)
            res.append(lam * y1 + (1 - lam) * y2)
        return res

    def smooth(label, classes, eta=0.1):
        if isinstance(label, nd.NDArray):
            label = [label]
        smoothed = []
        for l in label:
            res = l.one_hot(classes,
                            on_value=1 - eta + eta / classes,
                            off_value=eta / classes)
            smoothed.append(res)
        return smoothed

    def test(ctx, val_data):
        acc_top1.reset()
        acc_top5.reset()
        for i, batch in tqdm.tqdm(enumerate(val_data),
                                  desc='Validating',
                                  total=num_validating_samples // batch_size):
            data, label = batch_fn(batch, ctx)
            # outputs = [net(X.astype(opt.dtype, copy=False)) for X in data]
            outputs = [net(X) for X in data]
            acc_top1.update(label, outputs)
            acc_top5.update(label, outputs)

        _, top1 = acc_top1.get()
        _, top5 = acc_top5.get()
        return 1 - top1, 1 - top5

    def train(ctx):
        if isinstance(ctx, mx.Context):
            ctx = [ctx]
        if opt.resume_params == '':
            import warnings
            with warnings.catch_warnings(record=True) as w:
                net.initialize(mx.init.MSRAPrelu(), ctx=ctx)

        if opt.no_wd:
            for k, v in net.collect_params('.*beta|.*gamma|.*bias').items():
                v.wd_mult = 0.0

        if accumulate > 1:
            logger.info(f'accumulate: {accumulate}, using "add" grad_req')
            import warnings
            with warnings.catch_warnings(record=True) as w:
                net.collect_params().setattr('grad_req', 'add')

        trainer = gluon.Trainer(net.collect_params(),
                                optimizer,
                                optimizer_params,
                                update_on_kvstore=False if opt.amp else None)
        if opt.amp:
            amp.init_trainer(trainer)
        if opt.resume_states != '':
            trainer.load_states(opt.resume_states)

        if opt.label_smoothing or opt.mixup:
            sparse_label_loss = False
        else:
            sparse_label_loss = True
        if distillation:
            L = gcv.loss.DistillationSoftmaxCrossEntropyLoss(
                temperature=opt.temperature,
                hard_weight=opt.hard_weight,
                sparse_label=sparse_label_loss)
        else:
            L = gluon.loss.SoftmaxCrossEntropyLoss(
                sparse_label=sparse_label_loss)

        best_val_score = 1

        err_top1_val, err_top5_val = test(ctx, val_data)
        logger.info('initial validation: err-top1=%f err-top5=%f' %
                    (err_top1_val, err_top5_val))

        for epoch in range(opt.resume_epoch, opt.num_epochs):
            tic = time.time()
            train_metric.reset()
            train_loss_metric.reset()
            btic = time.time()
            pbar = tqdm.tqdm(total=num_batches,
                             desc=f'Training [{epoch}]',
                             leave=True)
            for i, batch in enumerate(train_data):
                data, label = batch_fn(batch, ctx)

                if opt.mixup:
                    lam = np.random.beta(opt.mixup_alpha, opt.mixup_alpha)
                    if epoch >= opt.num_epochs - opt.mixup_off_epoch:
                        lam = 1
                    data = [lam * X + (1 - lam) * X[::-1] for X in data]

                    if opt.label_smoothing:
                        eta = 0.1
                    else:
                        eta = 0.0
                    label = mixup_transform(label, classes, lam, eta)

                elif opt.label_smoothing:
                    hard_label = label
                    label = smooth(label, classes)

                if distillation:
                    # teacher_prob = [nd.softmax(teacher(X.astype(opt.dtype, copy=False)) / opt.temperature) \
                    #                 for X in data]
                    with ag.predict_mode():
                        teacher_prob = [
                            nd.softmax(
                                teacher(
                                    nd.transpose(
                                        nd.image.resize(
                                            nd.transpose(X, (0, 2, 3, 1)),
                                            size=opt.teacher_imgsize),
                                        (0, 3, 1, 2))) / opt.temperature)
                            for X in data
                        ]

                with ag.record():
                    # outputs = [net(X.astype(opt.dtype, copy=False)) for X in data]
                    outputs = [net(X) for X in data]
                    if distillation:
                        # loss = [L(yhat.astype('float32', copy=False),
                        #           y.astype('float32', copy=False),
                        #           p.astype('float32', copy=False)) for yhat, y, p in zip(outputs, label, teacher_prob)]
                        # print([outputs, label, teacher_prob])
                        loss = [
                            L(yhat, y, p)
                            for yhat, y, p in zip(outputs, label, teacher_prob)
                        ]
                    else:
                        # loss = [L(yhat, y.astype(opt.dtype, copy=False)) for yhat, y in zip(outputs, label)]
                        loss = [L(yhat, y) for yhat, y in zip(outputs, label)]
                    if opt.amp:
                        with amp.scale_loss(loss, trainer) as scaled_loss:
                            ag.backward(scaled_loss)
                    else:
                        ag.backward(loss)
                if accumulate > 1:
                    if (i + 1) % accumulate == 0:
                        trainer.step(batch_size * accumulate)
                        net.collect_params().zero_grad()
                else:
                    trainer.step(batch_size)

                train_loss_metric.update(0, loss)

                if opt.mixup:
                    output_softmax = [nd.SoftmaxActivation(out.astype('float32', copy=False)) \
                                      for out in outputs]
                    train_metric.update(label, output_softmax)
                else:
                    if opt.label_smoothing:
                        train_metric.update(hard_label, outputs)
                    else:
                        train_metric.update(label, outputs)

                _, loss_score = train_loss_metric.get()
                train_metric_name, train_metric_score = train_metric.get()
                samplers_per_sec = batch_size / (time.time() - btic)
                postfix = f'{samplers_per_sec:.1f} imgs/sec, ' \
                          f'loss: {loss_score:.4f}, ' \
                          f'acc: {train_metric_score * 100:.2f}, ' \
                          f'lr: {trainer.learning_rate:.4e}'
                if opt.multi_scale:
                    postfix += f', size: {data[0].shape[-1]}'
                pbar.set_postfix_str(postfix)
                pbar.update()
                btic = time.time()
                if opt.log_interval and not (i + 1) % opt.log_interval:
                    step = epoch * num_batches + i
                    wandb.log(
                        {
                            'samplers_per_sec': samplers_per_sec,
                            train_metric_name: train_metric_score,
                            'lr': trainer.learning_rate,
                            'loss': loss_score
                        },
                        step=step)
                    logger.info(
                        'Epoch[%d] Batch [%d]\tSpeed: %f samples/sec\t%s=%f\tlr=%f'
                        % (epoch, i, samplers_per_sec, train_metric_name,
                           train_metric_score, trainer.learning_rate))

            pbar.close()
            train_metric_name, train_metric_score = train_metric.get()
            throughput = int(batch_size * i / (time.time() - tic))

            err_top1_val, err_top5_val = test(ctx, val_data)
            wandb.log({
                'err1': err_top1_val,
                'err5': err_top5_val
            },
                      step=epoch * num_batches)

            logger.info('[Epoch %d] training: %s=%f' %
                        (epoch, train_metric_name, train_metric_score))
            logger.info('[Epoch %d] speed: %d samples/sec\ttime cost: %f' %
                        (epoch, throughput, time.time() - tic))
            logger.info('[Epoch %d] validation: err-top1=%f err-top5=%f' %
                        (epoch, err_top1_val, err_top5_val))

            if err_top1_val < best_val_score:
                best_val_score = err_top1_val
                net.save_parameters(
                    '%s/%.4f-imagenet-%s-%d-best.params' %
                    (save_dir, best_val_score, model_name, epoch))
                trainer.save_states(
                    '%s/%.4f-imagenet-%s-%d-best.states' %
                    (save_dir, best_val_score, model_name, epoch))

            if save_frequency and save_dir and (epoch +
                                                1) % save_frequency == 0:
                net.save_parameters('%s/imagenet-%s-%d.params' %
                                    (save_dir, model_name, epoch))
                trainer.save_states('%s/imagenet-%s-%d.states' %
                                    (save_dir, model_name, epoch))

        if save_frequency and save_dir:
            net.save_parameters('%s/imagenet-%s-%d.params' %
                                (save_dir, model_name, opt.num_epochs - 1))
            trainer.save_states('%s/imagenet-%s-%d.states' %
                                (save_dir, model_name, opt.num_epochs - 1))

    if opt.mode == 'hybrid':
        net.hybridize(static_alloc=True, static_shape=not opt.multi_scale)
        if distillation:
            teacher.hybridize(static_alloc=True,
                              static_shape=not opt.multi_scale)
    train(context)
示例#11
0
from mxnet.image import imread

logging.basicConfig(level=logging.WARNING, format='%(asctime)s %(message)s')
import requests
import json
import binascii
import numpy as np
from pymongo import MongoClient
from requests import ReadTimeout
from pprint import pprint

#image transform
normalize = T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
test_transform = T.Compose(
    [T.Resize(256), T.CenterCrop(224),
     T.ToTensor(), normalize])


# define mongodb connect
def get_db():
    mongdb = {}
    mongdb['host'] = 'dds-bp10da4305cf39f41.mongodb.rds.aliyuncs.com'
    mongdb['port'] = 3717
    client = MongoClient(host=mongdb['host'], port=mongdb['port'])
    dev = client.get_database('dev')
    dev.authenticate(name='nnsearch', password='******')
    return dev


@asyncio.coroutine
def download(url, session, semaphore, chunk_size=1 << 15):
示例#12
0
def save_and_load_sequential_example():
    # Use GPU if one exists, else use CPU.
    ctx = mx.gpu() if mx.test_utils.list_gpus() else mx.cpu()

    # MNIST images are 28x28. Total pixels in input layer is 28x28 = 784.
    #num_inputs = 784
    # Clasify the images into one of the 10 digits.
    #num_outputs = 10
    # 64 images in a batch.
    batch_size = 64

    # Load the training data.
    train_data = gluon.data.DataLoader(gluon.data.vision.MNIST(
        train=True).transform_first(transforms.ToTensor()),
                                       batch_size,
                                       shuffle=True)
    valid_data = gluon.data.DataLoader(gluon.data.vision.MNIST(
        train=False).transform_first(transforms.ToTensor()),
                                       batch_size,
                                       shuffle=True)

    # Define a model.
    net = build_lenet(gluon.nn.Sequential())

    # Initialize the parameters with Xavier initializer.
    net.collect_params().initialize(mx.init.Xavier(), ctx=ctx)
    # Use cross entropy loss.
    softmax_cross_entropy = gluon.loss.SoftmaxCrossEntropyLoss()
    # Use Adam optimizer.
    trainer = gluon.Trainer(net.collect_params(), 'adam',
                            {'learning_rate': .001})

    train_model(net,
                trainer,
                softmax_cross_entropy,
                train_data,
                valid_data,
                num_epochs=10)

    #--------------------
    if True:
        model_param_filepath = './lenet.params'

        # Save model parameters to file.
        net.save_parameters(model_param_filepath)

        # Define a model.
        new_net = build_lenet(gluon.nn.Sequential())
        # Load model parameters from file.
        new_net.load_parameters(model_param_filepath, ctx=ctx)
    else:
        # NOTE [info] >> Sequential models may not be serialized as JSON files.

        model_filepath = './lenet.json'
        model_param_filepath = './lenet.params'

        # Save model architecture to file.
        sym_json = net(mx.sym.var('data')).tojson()
        sym_json = json.loads(sym_json)
        with open(model_filepath, 'w') as fd:
            json.dump(sym_json, fd, indent='\t')
        # Save model parameters to file.
        net.save_parameters(model_param_filepath)

        # Load model architecture from file.
        with open(model_filepath, 'r') as fd:
            sym_json = json.load(fd)
        sym_json = json.dumps(sym_json)
        new_net = gluon.nn.SymbolBlock(outputs=mx.sym.load_json(sym_json),
                                       inputs=mx.sym.var('data'))
        # Load model parameters from file.
        new_net.load_parameters(model_param_filepath, ctx=ctx)

    verify_loaded_model(new_net, ctx=ctx)
示例#13
0
def train_neural_network_example():
    mnist_train = datasets.FashionMNIST(train=True)
    X, y = mnist_train[0]
    print('X shape:', X.shape, 'X dtype:', X.dtype, 'y:', y)

    text_labels = [
        't-shirt', 'trouser', 'pullover', 'dress', 'coat', 'sandal', 'shirt',
        'sneaker', 'bag', 'ankle boot'
    ]
    X, y = mnist_train[0:6]

    # Plot images.
    _, figs = plt.subplots(1, X.shape[0], figsize=(15, 15))
    for f, x, yi in zip(figs, X, y):
        # 3D -> 2D by removing the last channel dim.
        f.imshow(x.reshape((28, 28)).asnumpy())
        ax = f.axes
        ax.set_title(text_labels[int(yi)])
        ax.title.set_fontsize(20)
        ax.get_xaxis().set_visible(False)
        ax.get_yaxis().set_visible(False)
    plt.show()

    transformer = transforms.Compose(
        [transforms.ToTensor(),
         transforms.Normalize(0.13, 0.31)])
    mnist_train = mnist_train.transform_first(transformer)

    batch_size = 256
    train_data = gluon.data.DataLoader(mnist_train,
                                       batch_size=batch_size,
                                       shuffle=True,
                                       num_workers=4)

    for data, label in train_data:
        print(data.shape, label.shape)
        break

    mnist_valid = gluon.data.vision.FashionMNIST(train=False)
    valid_data = gluon.data.DataLoader(
        mnist_valid.transform_first(transformer),
        batch_size=batch_size,
        num_workers=4)

    # Define a model.
    net = build_network(nn.Sequential())
    net.initialize(init=init.Xavier())
    #net.collect_params().initialize(init.Xavier(), ctx=ctx)

    # Define the loss function and optimization method for training.
    softmax_cross_entropy = gluon.loss.SoftmaxCrossEntropyLoss()
    trainer = gluon.Trainer(net.collect_params(), 'sgd',
                            {'learning_rate': 0.1})
    #trainer = gluon.Trainer(net.collect_params(), 'adam', {'learning_rate': .001})

    train_model(net,
                trainer,
                softmax_cross_entropy,
                train_data,
                valid_data,
                num_epochs=10)

    # Save the model parameters.
    net.save_parameters('./net.params')
示例#14
0
 def __init__(self, args):
     self.args = args
     # image transform
     input_transform = transforms.Compose([
         transforms.ToTensor(),
         transforms.Normalize([.485, .456, .406], [.229, .224, .225]),
     ])
     # dataset and dataloader
     trainset = get_mxnet_dataset(args.dataset,
                                  split='train',
                                  transform=input_transform)
     valset = get_mxnet_dataset(args.dataset,
                                split='val',
                                transform=input_transform)
     self.train_data = gluon.data.DataLoader(trainset,
                                             args.batch_size,
                                             shuffle=True,
                                             last_batch='rollover',
                                             num_workers=args.workers)
     self.eval_data = gluon.data.DataLoader(valset,
                                            args.test_batch_size,
                                            last_batch='keep',
                                            num_workers=args.workers)
     # create network
     model = get_segmentation_model(model=args.model,
                                    dataset=args.dataset,
                                    backbone=args.backbone,
                                    norm_layer=args.norm_layer,
                                    aux=args.aux)
     print(model)
     self.net = DataParallelModel(model, args.ctx, args.syncbn)
     self.evaluator = DataParallelModel(SegEvalModel(model), args.ctx)
     # resume checkpoint if needed
     if args.resume is not None:
         if os.path.isfile(args.resume):
             model.load_params(args.resume, ctx=args.ctx)
         else:
             raise RuntimeError("=> no checkpoint found at '{}'" \
                 .format(args.resume))
     # create criterion
     criterion = SoftmaxCrossEntropyLossWithAux(args.aux)
     self.criterion = DataParallelCriterion(criterion, args.ctx,
                                            args.syncbn)
     # optimizer and lr scheduling
     self.lr_scheduler = LRScheduler(mode='poly',
                                     baselr=args.lr,
                                     niters=len(self.train_data),
                                     nepochs=args.epochs)
     kv = mx.kv.create(args.kvstore)
     self.optimizer = gluon.Trainer(self.net.module.collect_params(),
                                    'sgd', {
                                        'lr_scheduler': self.lr_scheduler,
                                        'wd': args.weight_decay,
                                        'momentum': args.momentum,
                                        'multi_precision': True
                                    },
                                    kvstore=kv)
     time_str = time.strftime("%Y-%m-%d___%H-%M-%S", time.localtime())
     net_name = args.model
     dataset_name = args.dataset
     note = args.backbone
     log_dir = os.path.join(self.args.logdir, net_name, dataset_name, note,
                            time_str)
     os.makedirs(log_dir, exist_ok=True)
     self.writer = SummaryWriter(log_dir=log_dir)
     config_str = json.dumps(self.args, indent=2, sort_keys=True).replace(
         '\n', '\n\n').replace('  ', '\t')
     self.writer.add_text(tag='config', text_string=config_str)
示例#15
0
if not os.path.exists(save_dir):
    makedirs(save_dir)
name = opt.name

plot_path = opt.save_plot_dir

logging.basicConfig(level=logging.INFO)
logging.info(opt)

trans, aug_trans = list(), list()
if opt.data_aug:
    aug_trans = [
        gcv_transforms.RandomCrop(32, pad=4),
        transforms.RandomFlipLeftRight()
    ]
trans.append(transforms.ToTensor())
trans.append(
    transforms.Normalize([0.4914, 0.4822, 0.4465], [0.2023, 0.1994, 0.2010]))

transform_train = transforms.Compose(aug_trans + trans)
transform_test = transforms.Compose(trans)
data_sampler = eval('DataSampler%d' % opt.data_sampler)()
num_data_samples = opt.data_k
use_som = opt.use_som
use_pillars = opt.plc_loss
pillar_sampler = eval('PillarSampler%d' % opt.pillar_sampler)()
num_pillar_samples = opt.pillar_k
w1 = opt.w1
w2 = opt.w2

cumulative = opt.cumulative
示例#16
0
def train(net, ctx):
    if isinstance(ctx, mx.Context):
        ctx = [ctx]
    if not opt.resume_from:
        net.initialize(mx.init.Xavier(), ctx=ctx)
        if opt.dataset == 'NC_CIFAR100':
            n = mx.nd.zeros(shape=(1, 3, 32, 32), ctx=ctx[0])  #####init CNN
        else:
            raise KeyError('dataset keyerror')
        for m in range(9):
            net(n, m)

    def makeSchedule(start_lr, base_lr, length, step, factor):
        schedule = mx.lr_scheduler.MultiFactorScheduler(step=step,
                                                        factor=factor)
        schedule.base_lr = base_lr
        schedule = LinearWarmUp(schedule, start_lr=start_lr, length=length)
        return schedule


# ==========================================================================

    sesses = list(np.arange(opt.sess_num))
    epochs = [opt.epoch] * opt.sess_num
    lrs = [opt.base_lrs] + [opt.lrs] * (opt.sess_num - 1)
    lr_decay = opt.lr_decay
    base_decay_epoch = [int(i)
                        for i in opt.base_decay_epoch.split(',')] + [np.inf]
    lr_decay_epoch = [base_decay_epoch
                      ] + [[opt.inc_decay_epoch, np.inf]] * (opt.sess_num - 1)

    AL_weight = opt.AL_weight
    min_weight = opt.min_weight
    oce_weight = opt.oce_weight
    pdl_weight = opt.pdl_weight
    max_weight = opt.max_weight
    temperature = opt.temperature

    use_AL = opt.use_AL  # anchor loss
    use_ng_min = opt.use_ng_min  # Neural Gas min loss
    use_ng_max = opt.use_ng_max  # Neural Gas min loss
    ng_update = opt.ng_update  # Neural Gas update node
    use_oce = opt.use_oce  # old samples cross entropy loss
    use_pdl = opt.use_pdl  # probability distillation loss
    use_nme = opt.use_nme  # Similarity loss
    use_warmUp = opt.use_warmUp
    use_ng = opt.use_ng  # Neural Gas
    fix_conv = opt.fix_conv  # fix cnn to train novel classes
    fix_epoch = opt.fix_epoch
    c_way = opt.c_way
    k_shot = opt.k_shot
    base_acc = opt.base_acc  # base model acc
    select_best_method = opt.select_best  # select from _best, _best2, _best3
    init_class = 60
    anchor_num = 400
    # ==========================================================================
    acc_dict = {}
    all_best_e = []
    if model_name[-7:] != 'maxhead':
        net.fc3.initialize(mx.init.Normal(sigma=0.001),
                           ctx=ctx,
                           force_reinit=True)
        net.fc4.initialize(mx.init.Normal(sigma=0.001),
                           ctx=ctx,
                           force_reinit=True)
        net.fc5.initialize(mx.init.Normal(sigma=0.001),
                           ctx=ctx,
                           force_reinit=True)
        net.fc6.initialize(mx.init.Normal(sigma=0.001),
                           ctx=ctx,
                           force_reinit=True)
        net.fc7.initialize(mx.init.Normal(sigma=0.001),
                           ctx=ctx,
                           force_reinit=True)
        net.fc8.initialize(mx.init.Normal(sigma=0.001),
                           ctx=ctx,
                           force_reinit=True)
        net.fc9.initialize(mx.init.Normal(sigma=0.001),
                           ctx=ctx,
                           force_reinit=True)
        net.fc10.initialize(mx.init.Normal(sigma=0.001),
                            ctx=ctx,
                            force_reinit=True)

    for sess in sesses:
        logger.info('session : %d' % sess)
        schedule = makeSchedule(start_lr=0,
                                base_lr=lrs[sess],
                                length=5,
                                step=lr_decay_epoch[sess],
                                factor=lr_decay)

        # prepare the first anchor batch
        if sess == 0 and opt.resume_from:
            acc_dict[str(sess)] = list()
            acc_dict[str(sess)].append([base_acc, 0])
            all_best_e.append(0)
            continue
        # quick cnn totally unfix, not use data augmentation
        if sess == 1 and model_name == 'quick_cnn' and use_AL:
            transform_train = transforms.Compose([
                transforms.ToTensor(),
                transforms.Normalize([0.4914, 0.4822, 0.4465],
                                     [0.2023, 0.1994, 0.2010])
            ])

            transform_test = transforms.Compose([
                transforms.ToTensor(),
                transforms.Normalize([0.4914, 0.4822, 0.4465],
                                     [0.2023, 0.1994, 0.2010])
            ])
            anchor_trans = transforms.Compose([
                transforms.ToTensor(),
                transforms.Normalize([0.4914, 0.4822, 0.4465],
                                     [0.2023, 0.1994, 0.2010])
            ])
        else:
            transform_train = transforms.Compose([
                gcv_transforms.RandomCrop(32, pad=4),
                transforms.RandomFlipLeftRight(),
                transforms.ToTensor(),
                transforms.Normalize([0.5071, 0.4866, 0.4409],
                                     [0.2009, 0.1984, 0.2023])
            ])

            transform_test = transforms.Compose([
                transforms.ToTensor(),
                transforms.Normalize([0.5071, 0.4866, 0.4409],
                                     [0.2009, 0.1984, 0.2023])
            ])

            anchor_trans = transforms.Compose([
                transforms.ToTensor(),
                transforms.Normalize([0.5071, 0.4866, 0.4409],
                                     [0.2009, 0.1984, 0.2023])
            ])
        # ng_init and ng_update
        if use_AL or use_nme or use_pdl or use_oce:
            if sess != 0:
                if ng_update == True:
                    if sess == 1:
                        update_anchor1, bmu, variances= \
                            prepare_anchor(DATASET,logger,anchor_trans,num_workers,feature_size,net,ctx,use_ng,init_class)
                        update_anchor_data = DataLoader(
                            update_anchor1,
                            anchor_trans,
                            update_anchor1.__len__(),
                            num_workers,
                            shuffle=False)
                        if opt.ng_var:
                            idx_1 = np.where(variances.asnumpy() > 0.5)
                            idx_2 = np.where(variances.asnumpy() < 0.5)
                            variances[idx_1] = 0.9
                            variances[idx_2] = 1
                    else:
                        base_class = init_class + (sess - 1) * 5
                        new_class = list(init_class + (sess - 1) * 5 +
                                         (np.arange(5)))
                        new_set = DATASET(train=True,
                                          fine_label=True,
                                          fix_class=new_class,
                                          base_class=base_class,
                                          logger=logger)
                        update_anchor2 = merge_datasets(
                            update_anchor1, new_set)
                        update_anchor_data = DataLoader(
                            update_anchor2,
                            anchor_trans,
                            update_anchor2.__len__(),
                            num_workers,
                            shuffle=False)
                elif (sess == 1):
                    update_anchor, bmu, variances =  \
                        prepare_anchor(DATASET,logger,anchor_trans,num_workers,feature_size,net,ctx,use_ng,init_class)
                    update_anchor_data = DataLoader(update_anchor,
                                                    anchor_trans,
                                                    update_anchor.__len__(),
                                                    num_workers,
                                                    shuffle=False)

                    if opt.ng_var:
                        idx_1 = np.where(variances.asnumpy() > 0.5)
                        idx_2 = np.where(variances.asnumpy() < 0.5)
                        variances[idx_1] = 0.9
                        variances[idx_2] = 1

                for batch in update_anchor_data:
                    anc_data = gluon.utils.split_and_load(batch[0],
                                                          ctx_list=[ctx[0]],
                                                          batch_axis=0)
                    anc_label = gluon.utils.split_and_load(batch[1],
                                                           ctx_list=[ctx[0]],
                                                           batch_axis=0)
                    with ag.pause():
                        anchor_feat, anchor_logit = net(anc_data[0], sess - 1)
                    anchor_feat = [anchor_feat]
                    anchor_logit = [anchor_logit]

        trainer = gluon.Trainer(net.collect_params(), optimizer, {
            'learning_rate': lrs[sess],
            'wd': opt.wd,
            'momentum': opt.momentum
        })

        metric = mx.metric.Accuracy()
        train_metric = mx.metric.Accuracy()
        loss_fn = gluon.loss.SoftmaxCrossEntropyLoss()

        # ==========================================================================
        # all loss init
        if use_nme:

            def loss_fn_disG(f1, f2, weight):
                f1 = f1.reshape(anchor_num, -1)
                f2 = f2.reshape(anchor_num, -1)
                similar = mx.nd.sum(f1 * f2, 1)
                return (1 - similar) * weight

            digG_weight = opt.nme_weight
        if use_AL:
            if model_name == 'quick_cnn':
                AL_w = [120, 75, 120, 100, 50, 60, 90, 90]
                AL_weight = AL_w[sess - 1]
            else:
                AL_weight = opt.AL_weight
            if opt.ng_var:

                def l2lossVar(feat, anc, weight, var):
                    dim = feat.shape[1]
                    feat = feat.reshape(-1, dim)
                    anc = anc.reshape(-1, dim)
                    loss = mx.nd.square(feat - anc)
                    loss = loss * weight * var
                    return mx.nd.mean(loss, axis=0, exclude=True)

                loss_fn_AL = l2lossVar
            else:
                loss_fn_AL = gluon.loss.L2Loss(weight=AL_weight)

        if use_pdl:
            loss_fn_pdl = DistillationSoftmaxCrossEntropyLoss(
                temperature=temperature, hard_weight=0, weight=pdl_weight)
        if use_oce:
            loss_fn_oce = gluon.loss.SoftmaxCrossEntropyLoss(weight=oce_weight)
        if use_ng_min:
            loss_fn_max = NG_Max_Loss(lmbd=max_weight, margin=0.5)
        if use_ng_min:
            min_loss = NG_Min_Loss(
                num_classes=opt.c_way,
                feature_size=feature_size,
                lmbd=min_weight,  # center weight = 0.01 in the paper
                ctx=ctx[0])
            min_loss.initialize(mx.init.Xavier(magnitude=2.24),
                                ctx=ctx,
                                force_reinit=True)  # init matrix.
            center_trainer = gluon.Trainer(
                min_loss.collect_params(),
                optimizer="sgd",
                optimizer_params={"learning_rate":
                                  opt.ng_min_lr})  # alpha=0.1 in the paper.
        # ==========================================================================
        lr_decay_count = 0

        # dataloader
        if opt.cum and sess == 1:
            base_class = list(np.arange(init_class))
            joint_data = DATASET(train=True,
                                 fine_label=True,
                                 c_way=init_class,
                                 k_shot=500,
                                 fix_class=base_class,
                                 logger=logger)

        if sess == 0:
            base_class = list(np.arange(init_class))
            new_class = list(init_class + (np.arange(5)))
            base_data = DATASET(train=True,
                                fine_label=True,
                                c_way=init_class,
                                k_shot=500,
                                fix_class=base_class,
                                logger=logger)
            bc_val_data = DataLoader(DATASET(train=False,
                                             fine_label=True,
                                             fix_class=base_class,
                                             logger=logger),
                                     transform_test,
                                     100,
                                     num_workers,
                                     shuffle=False)
            nc_val_data = DataLoader(DATASET(train=False,
                                             fine_label=True,
                                             fix_class=new_class,
                                             base_class=len(base_class),
                                             logger=logger),
                                     transform_test,
                                     100,
                                     num_workers,
                                     shuffle=False)
        else:
            base_class = list(np.arange(init_class + (sess - 1) * 5))
            new_class = list(init_class + (sess - 1) * 5 + (np.arange(5)))
            train_data_nc = DATASET(train=True,
                                    fine_label=True,
                                    c_way=c_way,
                                    k_shot=k_shot,
                                    fix_class=new_class,
                                    base_class=len(base_class),
                                    logger=logger)
            bc_val_data = DataLoader(DATASET(train=False,
                                             fine_label=True,
                                             fix_class=base_class,
                                             logger=logger),
                                     transform_test,
                                     100,
                                     num_workers,
                                     shuffle=False)
            nc_val_data = DataLoader(DATASET(train=False,
                                             fine_label=True,
                                             fix_class=new_class,
                                             base_class=len(base_class),
                                             logger=logger),
                                     transform_test,
                                     100,
                                     num_workers,
                                     shuffle=False)

        if sess == 0:
            train_data = DataLoader(base_data,
                                    transform_train,
                                    min(batch_size, base_data.__len__()),
                                    num_workers,
                                    shuffle=True)
        else:
            if opt.cum:  # cumulative : merge base and novel dataset.
                joint_data = merge_datasets(joint_data, train_data_nc)
                train_data = DataLoader(joint_data,
                                        transform_train,
                                        min(batch_size, joint_data.__len__()),
                                        num_workers,
                                        shuffle=True)

            elif opt.use_all_novel:  # use all novel data

                if sess == 1:
                    novel_data = train_data_nc
                else:
                    novel_data = merge_datasets(novel_data, train_data_nc)

                train_data = DataLoader(novel_data,
                                        transform_train,
                                        min(batch_size, novel_data.__len__()),
                                        num_workers,
                                        shuffle=True)
            else:  # basic method
                train_data = DataLoader(train_data_nc,
                                        transform_train,
                                        min(batch_size,
                                            train_data_nc.__len__()),
                                        num_workers,
                                        shuffle=True)

        for epoch in range(epochs[sess]):
            tic = time.time()
            train_metric.reset()
            metric.reset()
            train_loss, train_anchor_loss, train_oce_loss = 0, 0, 0
            train_disg_loss, train_pdl_loss, train_min_loss = 0, 0, 0
            train_max_loss = 0
            num_batch = len(train_data)

            if use_warmUp:
                lr = schedule(epoch)
                trainer.set_learning_rate(lr)
            else:
                lr = trainer.learning_rate
                if epoch == lr_decay_epoch[sess][lr_decay_count]:
                    trainer.set_learning_rate(trainer.learning_rate * lr_decay)
                    lr_decay_count += 1

            if sess != 0 and epoch < fix_epoch:
                fix_cnn = fix_conv
            else:
                fix_cnn = False

            for i, batch in enumerate(train_data):
                data = gluon.utils.split_and_load(batch[0],
                                                  ctx_list=ctx,
                                                  batch_axis=0)
                label = gluon.utils.split_and_load(batch[1],
                                                   ctx_list=ctx,
                                                   batch_axis=0)
                all_loss = list()
                with ag.record():
                    output_feat, output = net(data[0], sess, fix_cnn)
                    output_feat = [output_feat]
                    output = [output]
                    loss = [loss_fn(yhat, y) for yhat, y in zip(output, label)]
                    all_loss.extend(loss)

                    if use_nme:
                        anchor_h = [net(X, sess, fix_cnn)[0] for X in anc_data]
                        disg_loss = [
                            loss_fn_disG(a_h, a, weight=digG_weight)
                            for a_h, a in zip(anchor_h, anchor_feat)
                        ]
                        all_loss.extend(disg_loss)

                    if sess > 0 and use_ng_max:
                        max_loss = [
                            loss_fn_max(feat, label, feature_size, epoch, sess,
                                        init_class)
                            for feat, label in zip(output_feat, label)
                        ]
                        all_loss.extend(max_loss[0])

                    if sess > 0 and use_AL:  # For anchor loss
                        anchor_h = [net(X, sess, fix_cnn)[0] for X in anc_data]
                        if opt.ng_var:
                            anchor_loss = [
                                loss_fn_AL(anchor_h[0], anchor_feat[0],
                                           AL_weight, variances)
                            ]
                            all_loss.extend(anchor_loss)
                        else:
                            anchor_loss = [
                                loss_fn_AL(a_h, a)
                                for a_h, a in zip(anchor_h, anchor_feat)
                            ]
                            all_loss.extend(anchor_loss)
                    if sess > 0 and use_ng_min:
                        loss_min = min_loss(output_feat[0], label[0])
                        all_loss.extend(loss_min)

                    if sess > 0 and use_pdl:
                        anchor_l = [net(X, sess, fix_cnn)[1] for X in anc_data]
                        anchor_l = [anchor_l[0][:, :60 + (sess - 1) * 5]]
                        soft_label = [
                            mx.nd.softmax(anchor_logit[0][:, :60 +
                                                          (sess - 1) * 5] /
                                          temperature)
                        ]
                        pdl_loss = [
                            loss_fn_pdl(a_h, a,
                                        soft_a) for a_h, a, soft_a in zip(
                                            anchor_l, anc_label, soft_label)
                        ]
                        all_loss.extend(pdl_loss)

                    if sess > 0 and use_oce:
                        anchorp = [net(X, sess, fix_cnn)[1] for X in anc_data]
                        oce_Loss = [
                            loss_fn_oce(ap, a)
                            for ap, a in zip(anchorp, anc_label)
                        ]
                        all_loss.extend(oce_Loss)

                    all_loss = [nd.mean(l) for l in all_loss]

                ag.backward(all_loss)
                trainer.step(1, ignore_stale_grad=True)
                if use_ng_min:
                    center_trainer.step(opt.c_way * opt.k_shot)
                train_loss += sum([l.sum().asscalar() for l in loss])
                if sess > 0 and use_AL:
                    train_anchor_loss += sum(
                        [al.mean().asscalar() for al in anchor_loss])
                if sess > 0 and use_oce:
                    train_oce_loss += sum(
                        [al.mean().asscalar() for al in oce_Loss])
                if sess > 0 and use_nme:
                    train_disg_loss += sum(
                        [al.mean().asscalar() for al in disg_loss])
                if sess > 0 and use_pdl:
                    train_pdl_loss += sum(
                        [al.mean().asscalar() for al in pdl_loss])
                if sess > 0 and use_ng_min:
                    train_min_loss += sum(
                        [al.mean().asscalar() for al in loss_min])
                if sess > 0 and use_ng_max:
                    train_max_loss += sum(
                        [al.mean().asscalar() for al in max_loss[0]])
                train_metric.update(label, output)

            train_loss /= batch_size * num_batch
            name, acc = train_metric.get()
            name, bc_val_acc = test(ctx, bc_val_data, net, sess)
            name, nc_val_acc = test(ctx, nc_val_data, net, sess)

            if epoch == 0:
                acc_dict[str(sess)] = list()
            acc_dict[str(sess)].append([bc_val_acc, nc_val_acc])

            if sess == 0:
                overall = bc_val_acc
            else:
                overall = (bc_val_acc * (init_class + (sess - 1) * 5) +
                           nc_val_acc * 5) / (init_class + sess * 5)
            logger.info(
                '[Epoch %d] lr=%.4f train=%.4f | val(base)=%.4f val(novel)=%.4f overall=%.4f | loss=%.8f anc loss=%.8f '
                'pdl loss:%.8f oce loss: %.8f time: %.8f' %
                (epoch, lr, acc, bc_val_acc, nc_val_acc, overall, train_loss,
                 train_anchor_loss / AL_weight, train_pdl_loss / pdl_weight,
                 train_oce_loss / oce_weight, time.time() - tic))
            if use_nme:
                logger.info('digG loss:%.8f' % (train_disg_loss / digG_weight))
            if use_ng_min:
                logger.info('min_loss:%.8f' % (train_min_loss / min_weight))
            if use_ng_max:
                logger.info('max_loss:%.8f' % (train_max_loss / max_weight))
            if save_period and save_dir and (epoch + 1) % save_period == 0:
                net.save_parameters('%s/sess-%s-%d.params' %
                                    (save_dir, model_name, epoch))

            select = eval(select_best_method)
            best_e = select(acc_dict, sess)
            logger.info('best select : base: %f novel: %f ' %
                        (acc_dict[str(sess)][best_e][0],
                         acc_dict[str(sess)][best_e][1]))
            if use_AL and model_name == 'quick_cnn':
                reload_path = '%s/sess-%s-%d.params' % (save_dir, model_name,
                                                        best_e)
                net.load_parameters(reload_path, ctx=context)
        all_best_e.append(best_e)
        reload_path = '%s/sess-%s-%d.params' % (save_dir, model_name, best_e)
        net.load_parameters(reload_path, ctx=context)

        with open('%s/acc_dict.json' % save_dir, 'w') as json_file:
            json.dump(acc_dict, json_file)

        plot_pr(acc_dict, sess, save_dir)
    plot_all_sess(acc_dict, save_dir, all_best_e)
示例#17
0
def main(im_dir, ckpt_path, ctx, out_dir='result', write_images=True):
    east_model = east.EAST(nclass=2, text_scale=1024)
    ctx = mx.cpu() if ctx < 0 else mx.gpu()
    # east_model.hybridize()
    east_model.load_parameters(ckpt_path, ctx)
    east_transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize([.485, .456, .406], [.229, .224, .225])
    ])

    imlist = glob.glob1(im_dir, '*g')
    for i, im_name in enumerate(imlist):
        im_path = os.path.join(im_dir, im_name)
        start_time = time.time()
        im = cv2.imread(im_path)
        im_resized, (ratio_h, ratio_w) = resize_image(im, max_side_len=2048)
        timer = {'net': 0, 'restore': 0, 'nms': 0}
        start = time.time()
        im_resized = east_transform(
            mx.nd.array(np.array(im_resized).astype('float32')))

        f_score, f_geometry = east_model.forward(
            im_resized.expand_dims(axis=0))
        timer['net'] = time.time() - start

        score_map = f_score.asnumpy().transpose((0, 2, 3, 1))
        cv2.imwrite("score_map{}.png".format(i), score_map[0, :, :, 0] * 255)
        geo_map = f_geometry.asnumpy().transpose((0, 2, 3, 1))
        boxes, timer = detect(score_map, geo_map, timer)
        print('{} : net {:.0f}ms, restore {:.0f}ms, nms {:.0f}ms'.format(
            im_name, timer['net'] * 1000, timer['restore'] * 1000,
            timer['nms'] * 1000))

        if boxes is not None:
            boxes = boxes[:, :8].reshape((-1, 4, 2))
            boxes[:, :, 0] /= ratio_h
            boxes[:, :, 1] /= ratio_w

        duration = time.time() - start_time
        print('[timing] {}'.format(duration))

        # save to file
        if boxes is not None:
            res_file = os.path.join(
                out_dir,
                '{}.txt'.format(os.path.basename(im_path).split('.')[0]))
            print("num_Boxes:{}".format(len(boxes)))
            with open(res_file, 'w') as f:
                for box in boxes:
                    # to avoid submitting errors
                    box = sort_poly(box.astype(np.int32))
                    # if np.linalg.norm(box[0] - box[1]) < 5 or np.linalg.norm(box[3] - box[0]) < 5:
                    #     continue
                    f.write('{},{},{},{},{},{},{},{}\r\n'.format(
                        box[0, 0],
                        box[0, 1],
                        box[1, 0],
                        box[1, 1],
                        box[2, 0],
                        box[2, 1],
                        box[3, 0],
                        box[3, 1],
                    ))
                    cv2.polylines(im[:, :, ::-1].astype(np.uint8),
                                  [box.astype(np.int32).reshape((-1, 1, 2))],
                                  True,
                                  color=(255, 255, 0),
                                  thickness=2)
                if write_images:
                    img_path = os.path.join(out_dir, im_name)
                    cv2.imwrite(img_path, im)
示例#18
0
def main():
    mnist_train = datasets.FashionMNIST(train=True)
    X, y = mnist_train[0]
    #print('X ',X)
    #print('y ', y)

    print('X shape: ', X.shape, 'X dtype ', X.dtype, 'y:', y)

    text_labels = [
        't-shirt', 'trouser', 'pullover', 'dress', 'coat', 'sandal', 'shirt',
        'sneaker', 'bag', 'ankle boot'
    ]
    X, y = mnist_train[0:10]
    # plot images
    #  display.set_matplotlib_formats('svg')
    #  _, figs = plt.subplots(1, X.shape[0], figsize=(15, 15))
    #  for f,x,yi in zip(figs, X,y):
    #    # 3D->2D by removing the last channel dim
    #    f.imshow(x.reshape((28,28)).asnumpy())
    #    ax = f.axes
    #    ax.set_title(text_labels[int(yi)])
    #    ax.title.set_fontsize(14)
    #    ax.get_xaxis().set_visible(False)
    #    ax.get_yaxis().set_visible(False)
    #  plt.show()
    #
    transformer = transforms.Compose(
        [transforms.ToTensor(),
         transforms.Normalize(0.13, 0.31)])

    mnist_train = mnist_train.transform_first(transformer)

    batch_size = 64

    train_data = gluon.data.DataLoader(mnist_train,
                                       batch_size=batch_size,
                                       shuffle=True,
                                       num_workers=4)

    mnist_valid = gluon.data.vision.FashionMNIST(train=False)
    valid_data = gluon.data.DataLoader(
        mnist_valid.transform_first(transformer),
        batch_size=batch_size,
        num_workers=4)

    net = nn.Sequential()

    net.add(nn.Conv2D(channels=6, kernel_size=5, activation='relu'),
            nn.MaxPool2D(pool_size=2, strides=2),
            nn.Conv2D(channels=16, kernel_size=3, activation='relu'),
            nn.MaxPool2D(pool_size=2, strides=2), nn.Flatten(),
            nn.Dense(120, activation='relu'), nn.Dense(84, activation='relu'),
            nn.Dense(10))
    net.initialize(init=init.Xavier())

    softmax_cross_entropy = gluon.loss.SoftmaxCrossEntropyLoss()

    trainer_opt = gluon.Trainer(net.collect_params(), 'sgd',
                                {'learning_rate': 0.1})

    for epoch in range(10):
        train_loss, train_acc, valid_acc = 0., 0., 0.,
        tic = time.time()

        for data, label in train_data:
            # forward and backward
            with autograd.record():
                output = net(data)
                loss = softmax_cross_entropy(output, label)
            loss.backward()

            # update parameters
            trainer_opt.step(batch_size)

            # calculate training metrics
            train_loss += loss.mean().asscalar()
            train_acc += acc(output, label)

        for data, label in valid_data:
            valid_acc += acc(net(data), label)

        print(
            "epoch %d: loss %.3f, train acc %.3f, test acc %.3f, in %.1f sec" %
            (epoch, train_loss / len(train_data), train_acc / len(train_data),
             valid_acc / len(valid_data), time.time() - tic))

        net.save_parameters('net.params')
示例#19
0
def test(model):
    if args.dataset == 'cifar10':
        transform_train = transforms.Compose([
            gcv_transforms.RandomCrop(32, pad=4),
            transforms.RandomFlipLeftRight(),
            transforms.ToTensor(),
            transforms.Normalize([0.4914, 0.4822, 0.4465],
                                 [0.2023, 0.1994, 0.2010])
        ])

        transform_test = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize([0.4914, 0.4822, 0.4465],
                                 [0.2023, 0.1994, 0.2010])
        ])
        train_data = gluon.data.DataLoader(gluon.data.vision.CIFAR10(
            train=True).transform_first(transform_train),
                                           batch_size=batch_size,
                                           shuffle=True,
                                           last_batch='discard',
                                           num_workers=args.num_workers)

        val_data = gluon.data.DataLoader(gluon.data.vision.CIFAR10(
            train=False).transform_first(transform_test),
                                         batch_size=batch_size,
                                         shuffle=False,
                                         num_workers=args.num_workers)
    elif args.dataset == 'cifar100':
        transform_train = transforms.Compose([
            gcv_transforms.RandomCrop(32, pad=4),
            transforms.RandomFlipLeftRight(),
            transforms.ToTensor(),
            transforms.Normalize([0.5071, 0.4865, 0.4409],
                                 [0.2673, 0.2564, 0.2762])
        ])

        transform_test = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize([0.5071, 0.4865, 0.4409],
                                 [0.2673, 0.2564, 0.2762])
        ])
        train_data = gluon.data.DataLoader(gluon.data.vision.CIFAR100(
            train=True).transform_first(transform_train),
                                           batch_size=batch_size,
                                           shuffle=True,
                                           last_batch='discard',
                                           num_workers=num_workers)

        val_data = gluon.data.DataLoader(gluon.data.vision.CIFAR100(
            train=False).transform_first(transform_test),
                                         batch_size=batch_size,
                                         shuffle=False,
                                         num_workers=num_workers)

    metric = mx.metric.Accuracy()
    metric.reset()
    for i, batch in enumerate(val_data):
        data = gluon.utils.split_and_load(batch[0],
                                          ctx_list=context,
                                          batch_axis=0)
        label = gluon.utils.split_and_load(batch[1],
                                           ctx_list=context,
                                           batch_axis=0)
        outputs = [model(X.astype(args.dtype, copy=False)) for X in data]
        metric.update(label, outputs)
        _, acc = metric.get()
    print('\nTest-set Accuracy: ', acc)
    return acc
示例#20
0
def get_built_in_dataset(name,
                         train=True,
                         input_size=224,
                         batch_size=256,
                         num_workers=32,
                         shuffle=True,
                         fine_label=False,
                         **kwargs):
    """Returns built-in popular image classification dataset based on provided string name ('cifar10', 'cifar100','mnist','imagenet').
    """
    logger.info(f'get_built_in_dataset {name}')
    name = name.lower()
    if name in ('cifar10', 'cifar'):
        import gluoncv.data.transforms as gcv_transforms
        if train:
            transform_split = transforms.Compose([
                gcv_transforms.RandomCrop(32, pad=4),
                transforms.RandomFlipLeftRight(),
                transforms.ToTensor(),
                transforms.Normalize([0.4914, 0.4822, 0.4465],
                                     [0.2023, 0.1994, 0.2010])
            ])
        else:
            transform_split = transforms.Compose([
                transforms.ToTensor(),
                transforms.Normalize([0.4914, 0.4822, 0.4465],
                                     [0.2023, 0.1994, 0.2010])
            ])
        return gluon.data.vision.CIFAR10(
            train=train).transform_first(transform_split)
    elif name == 'cifar100':
        import gluoncv.data.transforms as gcv_transforms
        if train:
            transform_split = transforms.Compose([
                gcv_transforms.RandomCrop(32, pad=4),
                transforms.RandomFlipLeftRight(),
                transforms.ToTensor(),
                transforms.Normalize([0.4914, 0.4822, 0.4465],
                                     [0.2023, 0.1994, 0.2010])
            ])
        else:
            transform_split = transforms.Compose([
                transforms.ToTensor(),
                transforms.Normalize([0.4914, 0.4822, 0.4465],
                                     [0.2023, 0.1994, 0.2010])
            ])
        return gluon.data.vision.CIFAR100(
            train=train,
            fine_label=fine_label).transform_first(transform_split)
    elif name == 'mnist':

        def transform(data, label):
            return nd.transpose(data.astype(np.float32),
                                (2, 0, 1)) / 255, label.astype(np.float32)

        return gluon.data.vision.MNIST(train=train, transform=transform)
    elif name == 'fashionmnist':

        def transform(data, label):
            return nd.transpose(data.astype(np.float32),
                                (2, 0, 1)) / 255, label.astype(np.float32)

        return gluon.data.vision.FashionMNIST(train=train, transform=transform)
    elif name == 'imagenet':
        print("Loading the imagenet from ~/.mxnet/imagenet/")
        # Please setup the ImageNet dataset following the tutorial from GluonCV
        if train:
            rec_file = '~/.mxnet/imagenet/train.rec'
            rec_file_idx = '~/.mxnet/imagenet/train.idx'
        else:
            rec_file = '~/.mxnet/imagenet/val.rec'
            rec_file_idx = '~/.mxnet/imagenet/val.idx'
        data_loader = get_data_rec(input_size,
                                   0.875,
                                   rec_file,
                                   rec_file_idx,
                                   batch_size,
                                   num_workers,
                                   train,
                                   shuffle=shuffle,
                                   **kwargs)
        return data_loader
    else:
        raise NotImplementedError
示例#21
0
def train():
    if config.restart_training:
        shutil.rmtree(config.output_dir, ignore_errors=True)
    if config.output_dir is None:
        config.output_dir = 'output'
    if not os.path.exists(config.output_dir):
        os.makedirs(config.output_dir)
    logger = setup_logger(os.path.join(config.output_dir, 'train_log'))
    logger.info('train with gpu %s and mxnet %s' %
                (config.gpu_id, mx.__version__))

    ctx = mx.gpu(config.gpu_id)
    # 设置随机种子
    mx.random.seed(2)
    mx.random.seed(2, ctx=ctx)

    train_transfroms = transforms.Compose(
        [transforms.RandomBrightness(0.5),
         transforms.ToTensor()])
    train_dataset = ImageDataset(config.trainfile,
                                 (config.img_h, config.img_w),
                                 3,
                                 80,
                                 config.alphabet,
                                 phase='train')
    train_data_loader = DataLoader(
        train_dataset.transform_first(train_transfroms),
        config.train_batch_size,
        shuffle=True,
        last_batch='keep',
        num_workers=config.workers)
    test_dataset = ImageDataset(config.testfile, (config.img_h, config.img_w),
                                3,
                                80,
                                config.alphabet,
                                phase='test')
    test_data_loader = DataLoader(test_dataset.transform_first(
        transforms.ToTensor()),
                                  config.eval_batch_size,
                                  shuffle=True,
                                  last_batch='keep',
                                  num_workers=config.workers)
    net = CRNN(len(config.alphabet), hidden_size=config.nh)
    net.hybridize()
    if not config.restart_training and config.checkpoint != '':
        logger.info('load pretrained net from {}'.format(config.checkpoint))
        net.load_parameters(config.checkpoint, ctx=ctx)
    else:
        net.initialize(ctx=ctx)

    criterion = gluon.loss.CTCLoss()

    all_step = len(train_data_loader)
    logger.info('each epoch contains {} steps'.format(all_step))
    schedule = mx.lr_scheduler.FactorScheduler(step=config.lr_decay_step *
                                               all_step,
                                               factor=config.lr_decay,
                                               stop_factor_lr=config.end_lr)
    # schedule = mx.lr_scheduler.MultiFactorScheduler(step=[15 * all_step, 30 * all_step, 60 * all_step,80 * all_step],
    #                                                 factor=0.1)
    adam_optimizer = mx.optimizer.Adam(learning_rate=config.lr,
                                       lr_scheduler=schedule)
    trainer = gluon.Trainer(net.collect_params(), optimizer=adam_optimizer)

    sw = SummaryWriter(logdir=config.output_dir)
    for epoch in range(config.start_epoch, config.end_epoch):
        loss = .0
        train_acc = .0
        tick = time.time()
        cur_step = 0
        for i, (data, label) in enumerate(train_data_loader):
            data = data.as_in_context(ctx)
            label = label.as_in_context(ctx)

            with autograd.record():
                output = net(data)
                loss_ctc = criterion(output, label)
            loss_ctc.backward()
            trainer.step(data.shape[0])

            loss_c = loss_ctc.mean()
            cur_step = epoch * all_step + i
            sw.add_scalar(tag='ctc_loss',
                          value=loss_c.asscalar(),
                          global_step=cur_step // 2)
            sw.add_scalar(tag='lr',
                          value=trainer.learning_rate,
                          global_step=cur_step // 2)
            loss += loss_c
            acc = accuracy(output, label, config.alphabet)
            train_acc += acc
            if (i + 1) % config.display_interval == 0:
                acc /= len(label)
                sw.add_scalar(tag='train_acc', value=acc, global_step=cur_step)
                batch_time = time.time() - tick
                logger.info(
                    '[{}/{}], [{}/{}],step: {}, Speed: {:.3f} samples/sec, ctc loss: {:.4f},acc: {:.4f}, lr:{},'
                    ' time:{:.4f} s'.format(
                        epoch, config.end_epoch, i, all_step, cur_step,
                        config.display_interval * config.train_batch_size /
                        batch_time,
                        loss.asscalar() / config.display_interval, acc,
                        trainer.learning_rate, batch_time))
                loss = .0
                tick = time.time()
                nd.waitall()
        if epoch == 0:
            sw.add_graph(net)
        logger.info('start val ....')
        train_acc /= train_dataset.__len__()
        validation_accuracy = evaluate_accuracy(
            net, test_data_loader, ctx,
            config.alphabet) / test_dataset.__len__()
        sw.add_scalar(tag='val_acc',
                      value=validation_accuracy,
                      global_step=cur_step)
        logger.info("Epoch {},train_acc {:.4f}, val_acc {:.4f}".format(
            epoch, train_acc, validation_accuracy))
        net.save_parameters("{}/{}_{:.4f}_{:.4f}.params".format(
            config.output_dir, epoch, train_acc, validation_accuracy))
    sw.close()
示例#22
0
def train_mnist():
    mnist_train = get_train_data()
    print("type(mnist_train)=", type(mnist_train))

    transformer = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(0.13, 0.31)])

    mnist_train = mnist_train.transform_first(transformer)
    batch_size = 256

    train_data = gluon.data.DataLoader(
        mnist_train, batch_size=batch_size, shuffle=True, num_workers=4)

    mnist_valid = gluon.data.vision.FashionMNIST(train=False)
    valid_data = gluon.data.DataLoader(
        mnist_valid.transform_first(transformer),
        batch_size=batch_size, num_workers=4)

    # for data, label in train_data:
    #     print(data.shape, label.shape)
    #     break

    net = nn.Sequential()
    net.add(nn.Conv2D(channels=6, kernel_size=5, activation='relu'),
            nn.MaxPool2D(pool_size=2, strides=2),
            nn.Conv2D(channels=16, kernel_size=3, activation='relu'),
            nn.MaxPool2D(pool_size=2, strides=2),
            nn.Flatten(),
            nn.Dense(120, activation="relu"),
            nn.Dense(84, activation="relu"),
            nn.Dense(10))
    net.initialize(init=init.Xavier())

    softmax_cross_entropy = gluon.loss.SoftmaxCrossEntropyLoss()

    trainer = gluon.Trainer(net.collect_params(), 'sgd', {'learning_rate': 0.1})

    for epoch in range(10):
        train_loss, train_acc, valid_acc = 0., 0., 0.
        tic = time.time()
        for data, label in train_data:
            # forward + backward
            with autograd.record():
                output = net(data)
                loss = softmax_cross_entropy(output, label)
            loss.backward()
            # update parameters
            trainer.step(batch_size)
            # calculate training metrics
            train_loss += loss.mean().asscalar()
            train_acc += acc(output, label)

        # calculate validation accuracy
        for data, label in valid_data:
            valid_acc += acc(net(data), label)
        print("Epoch %d: loss %.3f, train acc %.3f, test acc %.3f, in %.1f sec" % (
            epoch, train_loss / len(train_data), train_acc / len(train_data),
            valid_acc / len(valid_data), time.time() - tic))

        net.save_parameters('./output_weight/net_%s.params'%(epoch))
示例#23
0
 def __call__(self, x=None):
     x = T.Resize((self.w, self.h))(x)
     x = T.ToTensor()(x)
     x = T.Normalize(mean=(0.485, 0.456, 0.406),
                     std=(0.229, 0.224, 0.225))(x)
     return x
示例#24
0
import numpy as np
import mxnet as mx
from mxnet import nd
import time
from mxnet.gluon.data.vision import transforms

gpu_total_time = 0.0
transformer = transforms.ToTensor()
count = 1000

for i in range(count):
    image = mx.nd.random.uniform(0, 255, (1, 512, 512, 3),
                                 ctx=mx.gpu(0)).astype(np.uint8)
    tic = time.time()
    #image.as_in_context(mx.gpu(0))
    res = transformer(image)
    # To force the calculation
    #a = res.shape
    res.wait_to_read()
    tac = time.time()
    gpu_total_time += (tac - tic) * 1000

print("Total time for GPU ToTensor - ", gpu_total_time)
print("Average time per ToTensor 1,512,512,3 - ", gpu_total_time / count)

cpu_total_time = 0.0
transformer = transforms.ToTensor()
count = 1000

for i in range(count):
    image = mx.nd.random.uniform(0, 255, (1, 512, 512, 3)).astype(np.uint8)
示例#25
0
from mxnet.gluon.data import DataLoader, Dataset
from mxnet import nd
from mxnet.image import imread
import os
import numpy as np
import mxnet as mx
import mxnet.gluon.data.vision.transforms as T
from collections import Counter

normalize = T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
default_transform = T.Compose([
    T.Resize(256),
    T.RandomResizedCrop(size=224, scale=(1.0, 1.0),
                        ratio=(1.0, 1.0)),  # just crop,not scale
    T.RandomFlipLeftRight(),
    T.ToTensor(),  # last to swap  channel to c,w,h
    normalize
])

test_transform = T.Compose(
    [T.Resize(256), T.CenterCrop(224),
     T.ToTensor(), normalize])


# like the
class DeepInClassFashion(Dataset):
    """
    the DeepInClassFashion dataset.read data from list_item_inshop.txt,

    """
    def __init__(self,
示例#26
0
    def __init__(self, args, logger):
        self.args = args
        self.logger = logger

        # image transform
        input_transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize([.485, .456, .406], [.229, .224, .225]),
        ])

        # dataset and dataloader
        data_kwargs = {
            'transform': input_transform,
            'base_size': args.base_size,
            'crop_size': args.crop_size
        }
        trainset = get_segmentation_dataset(args.dataset,
                                            split=args.train_split,
                                            mode='train',
                                            **data_kwargs)
        valset = get_segmentation_dataset(args.dataset,
                                          split='val',
                                          mode='val',
                                          **data_kwargs)
        self.train_data = gluon.data.DataLoader(trainset,
                                                args.batch_size,
                                                shuffle=True,
                                                last_batch='rollover',
                                                num_workers=args.workers)
        self.eval_data = gluon.data.DataLoader(valset,
                                               args.test_batch_size,
                                               last_batch='rollover',
                                               num_workers=args.workers)

        # create network
        if args.model_zoo is not None:
            model = get_model(args.model_zoo,
                              norm_layer=args.norm_layer,
                              norm_kwargs=args.norm_kwargs,
                              aux=args.aux,
                              base_size=args.base_size,
                              crop_size=args.crop_size,
                              pretrained=args.pretrained)
        else:
            model = get_segmentation_model(model=args.model,
                                           dataset=args.dataset,
                                           backbone=args.backbone,
                                           norm_layer=args.norm_layer,
                                           norm_kwargs=args.norm_kwargs,
                                           aux=args.aux,
                                           base_size=args.base_size,
                                           crop_size=args.crop_size)
        # for resnest use only
        from gluoncv.nn.dropblock import set_drop_prob
        from functools import partial
        apply_drop_prob = partial(set_drop_prob, 0.0)
        model.apply(apply_drop_prob)

        model.cast(args.dtype)
        logger.info(model)

        self.net = DataParallelModel(model, args.ctx, args.syncbn)
        self.evaluator = DataParallelModel(SegEvalModel(model), args.ctx)
        # resume checkpoint if needed
        if args.resume is not None:
            if os.path.isfile(args.resume):
                model.load_parameters(args.resume, ctx=args.ctx)
            else:
                raise RuntimeError("=> no checkpoint found at '{}'".format(
                    args.resume))

        # create criterion
        if 'icnet' in args.model:
            criterion = ICNetLoss(crop_size=args.crop_size)
        elif 'danet' in args.model or (args.model_zoo
                                       and 'danet' in args.model_zoo):
            criterion = SegmentationMultiLosses()
        else:
            criterion = MixSoftmaxCrossEntropyLoss(args.aux,
                                                   aux_weight=args.aux_weight)
        self.criterion = DataParallelCriterion(criterion, args.ctx,
                                               args.syncbn)

        # optimizer and lr scheduling
        self.lr_scheduler = LRSequential([
            LRScheduler('linear',
                        base_lr=0,
                        target_lr=args.lr,
                        nepochs=args.warmup_epochs,
                        iters_per_epoch=len(self.train_data)),
            LRScheduler(mode='poly',
                        base_lr=args.lr,
                        nepochs=args.epochs - args.warmup_epochs,
                        iters_per_epoch=len(self.train_data),
                        power=0.9)
        ])
        kv = mx.kv.create(args.kvstore)

        if args.optimizer == 'sgd':
            optimizer_params = {
                'lr_scheduler': self.lr_scheduler,
                'wd': args.weight_decay,
                'momentum': args.momentum,
                'learning_rate': args.lr
            }
        elif args.optimizer == 'adam':
            optimizer_params = {
                'lr_scheduler': self.lr_scheduler,
                'wd': args.weight_decay,
                'learning_rate': args.lr
            }
        else:
            raise ValueError('Unsupported optimizer {} used'.format(
                args.optimizer))

        if args.dtype == 'float16':
            optimizer_params['multi_precision'] = True

        if args.no_wd:
            for k, v in self.net.module.collect_params(
                    '.*beta|.*gamma|.*bias').items():
                v.wd_mult = 0.0

        self.optimizer = gluon.Trainer(self.net.module.collect_params(),
                                       args.optimizer,
                                       optimizer_params,
                                       kvstore=kv)
        # evaluation metrics
        self.metric = gluoncv.utils.metrics.SegmentationMetric(
            trainset.num_class)
示例#27
0
def test(args):
    # output folder
    outdir = 'outdir'
    if not os.path.exists(outdir):
        os.makedirs(outdir)
    # image transform
    input_transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize([.485, .456, .406], [.229, .224, .225]),
    ])
    # dataset and dataloader
    if args.eval:
        testset = get_segmentation_dataset(args.dataset,
                                           split='val',
                                           mode='testval',
                                           transform=input_transform)
        total_inter, total_union, total_correct, total_label = \
            np.int64(0), np.int64(0), np.int64(0), np.int64(0)
    else:
        testset = get_segmentation_dataset(args.dataset,
                                           split='test',
                                           mode='test',
                                           transform=input_transform)
    test_data = gluon.data.DataLoader(testset,
                                      args.test_batch_size,
                                      last_batch='keep',
                                      batchify_fn=ms_batchify_fn,
                                      num_workers=args.workers)
    # create network
    if args.model_zoo is not None:
        model = get_model(args.model_zoo, pretrained=True)
    else:
        model = get_segmentation_model(model=args.model,
                                       dataset=args.dataset,
                                       ctx=args.ctx,
                                       backbone=args.backbone,
                                       norm_layer=args.norm_layer)
        # load pretrained weight
        assert args.resume is not None, '=> Please provide the checkpoint using --resume'
        if os.path.isfile(args.resume):
            model.load_params(args.resume, ctx=args.ctx)
        else:
            raise RuntimeError("=> no checkpoint found at '{}'" \
                .format(args.resume))
    print(model)
    evaluator = MultiEvalModel(model, testset.num_class, ctx_list=args.ctx)

    tbar = tqdm(test_data)
    for i, (data, dsts) in enumerate(tbar):
        if args.eval:
            targets = dsts
            predicts = evaluator.parallel_forward(data)
            for predict, target in zip(predicts, targets):
                target = target.as_in_context(predict[0].context)
                correct, labeled = batch_pix_accuracy(predict[0], target)
                inter, union = batch_intersection_union(
                    predict[0], target, testset.num_class)
                total_correct += correct.astype('int64')
                total_label += labeled.astype('int64')
                total_inter += inter.astype('int64')
                total_union += union.astype('int64')
            pixAcc = np.float64(1.0) * total_correct / (
                np.spacing(1, dtype=np.float64) + total_label)
            IoU = np.float64(1.0) * total_inter / (
                np.spacing(1, dtype=np.float64) + total_union)
            mIoU = IoU.mean()
            tbar.set_description('pixAcc: %.4f, mIoU: %.4f' % (pixAcc, mIoU))
        else:
            im_paths = dsts
            predicts = evaluator.parallel_forward(data)
            for predict, impath in zip(predicts, im_paths):
                predict = mx.nd.squeeze(mx.nd.argmax(
                    predict[0], 1)).asnumpy() + testset.pred_offset
                mask = get_color_pallete(predict, args.dataset)
                outname = os.path.splitext(impath)[0] + '.png'
                mask.save(os.path.join(outdir, outname))
示例#28
0
def train():
    logging.info('Start Training for Task: %s\n' % (task))

    # Initialize the net with pretrained model
    pretrained_net = gluon.model_zoo.vision.get_model(model_name,
                                                      pretrained=True)

    finetune_net = gluon.model_zoo.vision.get_model(model_name,
                                                    classes=task_num_class)
    finetune_net.features = pretrained_net.features
    finetune_net.output.initialize(init.Xavier(), ctx=ctx)
    finetune_net.collect_params().reset_ctx(ctx)
    finetune_net.hybridize()

    # Carefully set the 'scale' parameter to make the 'muti-scale train' and 'muti-scale test'
    train_transform = transforms.Compose([
        transforms.RandomResizedCrop(448,
                                     scale=(0.76, 1.0),
                                     ratio=(0.999, 1.001)),
        transforms.RandomFlipLeftRight(),
        transforms.RandomBrightness(0.20),
        #transforms.RandomColorJitter(brightness=jitter_param, contrast=jitter_param,
        #                             saturation=jitter_param),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ])

    train_dataset = my_customdataset.my_custom_dataset(
        imgroot=os.path.join(".", 'train_valid_allset', task, 'train'),
        labelmasterpath='label_master.csv')
    train_data = gluon.data.DataLoader(
        train_dataset.transform_first(train_transform),
        batch_size=batch_size,
        shuffle=True,
        num_workers=num_workers,
        last_batch='discard')

    val_transform = transforms.Compose([
        transforms.Resize(480),
        transforms.CenterCrop(448),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ])
    val_dataset = my_customdataset.my_custom_dataset(
        imgroot=os.path.join(".", 'train_valid_allset', task, 'val'),
        labelmasterpath='label_master.csv')
    val_data = gluon.data.DataLoader(
        val_dataset.transform_first(val_transform),
        batch_size=batch_size,
        shuffle=False,
        num_workers=num_workers)

    # Define Trainer use ADam to make mdoel converge quickly
    trainer = gluon.Trainer(finetune_net.collect_params(), 'adam',
                            {'learning_rate': lr})
    metric = mx.metric.Accuracy()
    L = gluon.loss.SoftmaxCrossEntropyLoss()
    lr_counter = 0
    num_batch = len(train_data)

    # Start Training
    best_AP = 0
    best_acc = 0
    for epoch in range(epochs):
        train_acc = 0.
        #### Load the best model when go to the next training stage
        if epoch == lr_steps[lr_counter]:
            finetune_net.collect_params().load(best_path, ctx=ctx)
            trainer.set_learning_rate(trainer.learning_rate * lr_factor)
            lr_counter += 1

        tic = time.time()
        train_loss = 0
        metric.reset()
        AP = 0.
        AP_cnt = 0

        for i, batch in enumerate(train_data):
            data = gluon.utils.split_and_load(batch[0],
                                              ctx_list=ctx,
                                              batch_axis=0,
                                              even_split=False)
            label = gluon.utils.split_and_load(batch[1],
                                               ctx_list=ctx,
                                               batch_axis=0,
                                               even_split=False)
            with ag.record():
                outputs = [finetune_net(X) for X in data]
                loss = []
                ###### Handle 'm' label by soft-softmax function ######
                for yhat, y in zip(outputs[0], label[0]):
                    loss_1 = 0
                    if y[1] == 99:  # only have y [4,0,0,0,0]
                        loss_1 += L(yhat, y[0])
                    elif y[2] == 99:  #have one m [4,1,0,0,0]
                        loss_1 = 0.8 * L(yhat, y[0]) + 0.2 * L(yhat, y[1])
                    elif y[3] == 99:  #have two m [4,1,3,0,0]
                        loss_1 = 0.7 * L(yhat, y[0]) + 0.15 * L(
                            yhat, y[1]) + 0.15 * L(yhat, y[2])
                    else:  # have many m [4,1,3,2,0]
                        loss_1 = 0.6 * L(yhat, y[0]) + 0.13 * L(
                            yhat, y[1]) + 0.13 * L(yhat, y[2]) + 0.13 * L(
                                yhat, y[3])

                    loss += [loss_1]

                #loss = [L(yhat, y) for yhat, y in zip(outputs, label)
            # for l in loss:
            #     l.backward()
            ag.backward(loss)  # for soft-softmax

            trainer.step(batch_size)
            train_loss += sum([l.mean().asscalar() for l in loss]) / len(loss)
            #train_acc += accuracy(outputs, label)
            metric.update([label[0][:, 0]], outputs)
            #ap, cnt = calculate_ap(label, outputs)
            #AP += ap
            #AP_cnt += cnt
            #progressbar(i, num_batch-1)

        #train_map = AP / AP_cnt
        _, train_acc = metric.get()
        train_loss /= num_batch

        val_acc, val_loss = validate(finetune_net, val_data, ctx)

        logging.info(
            '[Epoch %d] Train-acc: %.3f, loss: %.3f | Val-acc: %.3f, loss: %.3f | time: %.1f | learning_rate %.6f'
            % (epoch, train_acc, train_loss, val_acc, val_loss,
               time.time() - tic, trainer.learning_rate))
        f_val.writelines(
            '[Epoch %d] Train-acc: %.3f, , loss: %.3f | Val-acc: %.3f,  loss: %.3f | time: %.1f | learning_rate %.6f\n'
            % (epoch, train_acc, train_loss, val_acc, val_loss,
               time.time() - tic, trainer.learning_rate))
        ### Save the best model every stage
        if val_acc > best_acc:
            #best_AP = this_AP
            best_acc = val_acc
            if not os.path.exists(os.path.join('.', 'models')):
                os.makedirs(os.path.join('.', 'models'))
            if not os.path.exists(
                    os.path.join(
                        '.', 'models', '%s_%s_%s_%s_staging.params' %
                        (task, model_name, epoch, best_acc))):
                f = open(
                    os.path.join(
                        '.', 'models', '%s_%s_%s_%s_staging.params' %
                        (task, model_name, epoch, best_acc)), 'w')
                f.close()
            best_path = os.path.join(
                '.', 'models', '%s_%s_%s_%s_staging.params' %
                (task, model_name, epoch, best_acc))
            finetune_net.collect_params().save(best_path)

    logging.info('\n')
    finetune_net.collect_params().load(best_path, ctx=ctx)
    f_val.writelines(
        'Best val acc is :[Epoch %d] Train-acc: %.3f, loss: %.3f | Best-val-acc: %.3f, loss: %.3f | time: %.1f | learning_rate %.6f\n'
        % (epoch, train_acc, train_loss, best_acc, val_loss, time.time() - tic,
           trainer.learning_rate))
    return (finetune_net)
示例#29
0
plot_name = opt.save_plot_dir

logging_handlers = [logging.StreamHandler()]
if opt.logging_dir:
    logging_dir = opt.logging_dir
    makedirs(logging_dir)
    logging_handlers.append(logging.FileHandler('%s/train_cifar10_%s.log'%(logging_dir, model_name)))

logging.basicConfig(level=logging.INFO, handlers = logging_handlers)
logging.info(opt)

transform_train = transforms.Compose([
    gcv_transforms.RandomCrop(32, pad=4),
    transforms.RandomFlipLeftRight(),
    transforms.ToTensor(),
    transforms.Normalize([0.4914, 0.4822, 0.4465], [0.2023, 0.1994, 0.2010])
])

transform_test = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize([0.4914, 0.4822, 0.4465], [0.2023, 0.1994, 0.2010])
])

def label_transform(label, classes):
    ind = label.astype('int')
    res = nd.zeros((ind.shape[0], classes), ctx = label.context)
    res[nd.arange(ind.shape[0], ctx = label.context), ind] = 1
    return res

def test(ctx, val_data):
def main():
    opt = parse_args()
    batch_size = opt.batch_size
    classes = 10

    log_dir = os.path.join(opt.save_dir, "logs")
    model_dir = os.path.join(opt.save_dir, "params")
    if not os.path.exists(model_dir):
        os.makedirs(model_dir)

    # Init dataloader
    jitter_param = 0.4
    transform_train = transforms.Compose([
        gcv_transforms.RandomCrop(32, pad=4),
        transforms.RandomFlipLeftRight(),
        transforms.RandomBrightness(jitter_param),
        transforms.RandomColorJitter(jitter_param),
        transforms.RandomContrast(jitter_param),
        transforms.RandomSaturation(jitter_param),
        transforms.ToTensor(),
        transforms.Normalize([0.4914, 0.4822, 0.4465],
                             [0.2023, 0.1994, 0.2010])
    ])

    transform_test = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize([0.4914, 0.4822, 0.4465],
                             [0.2023, 0.1994, 0.2010])
    ])

    train_data = gluon.data.DataLoader(
        gluon.data.vision.CIFAR10(train=True).transform_first(transform_train),
        batch_size=batch_size,
        shuffle=True,
        last_batch='discard',
        num_workers=opt.num_workers)

    val_data = gluon.data.DataLoader(
        gluon.data.vision.CIFAR10(train=False).transform_first(transform_test),
        batch_size=batch_size,
        shuffle=False,
        num_workers=opt.num_workers)

    num_gpus = opt.num_gpus
    batch_size *= max(1, num_gpus)
    context = [mx.gpu(i)
               for i in range(num_gpus)] if num_gpus > 0 else [mx.cpu()]

    lr_decay = opt.lr_decay
    lr_decay_epoch = [int(i) for i in opt.lr_decay_epoch.split(',')] + [np.inf]

    model_name = opt.model
    model_name = opt.model
    if model_name.startswith('cifar_wideresnet'):
        kwargs = {'classes': classes, 'drop_rate': opt.drop_rate}
    else:
        kwargs = {'classes': classes}
    net = get_model(model_name, **kwargs)

    if opt.resume_from:
        net.load_parameters(opt.resume_from, ctx=context)
    optimizer = 'nag'

    save_period = opt.save_period
    if opt.save_dir and save_period:
        save_dir = opt.save_dir
        makedirs(save_dir)
    else:
        save_dir = ''
        save_period = 0

    def test(ctx, val_loader):
        metric = mx.metric.Accuracy()
        for i, batch in enumerate(val_loader):
            data = gluon.utils.split_and_load(batch[0],
                                              ctx_list=ctx,
                                              batch_axis=0)
            label = gluon.utils.split_and_load(batch[1],
                                               ctx_list=ctx,
                                               batch_axis=0)
            outputs = [net(X) for X in data]
            metric.update(label, outputs)
        return metric.get()

    def train(train_data, val_data, epochs, ctx):
        if isinstance(ctx, mx.Context):
            ctx = [ctx]

        net.hybridize()
        net.initialize(mx.init.Xavier(), ctx=ctx)
        net.forward(mx.nd.ones((1, 3, 30, 30), ctx=ctx[0]))
        with SummaryWriter(logdir=log_dir, verbose=False) as sw:
            sw.add_graph(net)

        trainer = gluon.Trainer(net.collect_params(), optimizer, {
            'learning_rate': opt.lr,
            'wd': opt.wd,
            'momentum': opt.momentum
        })
        metric = mx.metric.Accuracy()
        train_metric = mx.metric.Accuracy()
        loss_fn = gluon.loss.SoftmaxCrossEntropyLoss()

        iteration = 0
        lr_decay_count = 0

        best_val_score = 0
        global_step = 0

        for epoch in range(epochs):
            tic = time.time()
            train_metric.reset()
            metric.reset()
            train_loss = 0
            num_batch = len(train_data)
            alpha = 1

            if epoch == lr_decay_epoch[lr_decay_count]:
                trainer.set_learning_rate(trainer.learning_rate * lr_decay)
                lr_decay_count += 1

            tbar = tqdm(train_data)

            for i, batch in enumerate(tbar):
                data = gluon.utils.split_and_load(batch[0],
                                                  ctx_list=ctx,
                                                  batch_axis=0)
                label = gluon.utils.split_and_load(batch[1],
                                                   ctx_list=ctx,
                                                   batch_axis=0)

                with ag.record():
                    output = [net(X) for X in data]
                    loss = [loss_fn(yhat, y) for yhat, y in zip(output, label)]
                for l in loss:
                    l.backward()
                trainer.step(batch_size)
                train_loss += sum([l.sum().asscalar() for l in loss])

                train_metric.update(label, output)
                name, acc = train_metric.get()
                iteration += 1
                global_step += len(loss)

            train_loss /= batch_size * num_batch
            name, acc = train_metric.get()
            name, val_acc = test(ctx, val_data)

            if val_acc > best_val_score:
                best_val_score = val_acc
                net.save_parameters('{}/{}-{}-{:04.3f}-best.params'.format(
                    model_dir, model_name, epoch, best_val_score))

            with SummaryWriter(logdir=log_dir, verbose=False) as sw:
                sw.add_scalar(tag="TrainLos",
                              value=train_loss,
                              global_step=global_step)
                sw.add_scalar(tag="TrainAcc",
                              value=acc,
                              global_step=global_step)
                sw.add_scalar(tag="ValAcc",
                              value=val_acc,
                              global_step=global_step)
                sw.add_graph(net)

            logging.info('[Epoch %d] train=%f val=%f loss=%f time: %f' %
                         (epoch, acc, val_acc, train_loss, time.time() - tic))

            if save_period and save_dir and (epoch + 1) % save_period == 0:
                net.save_parameters('{}/{}-{}.params'.format(
                    save_dir, model_name, epoch))

        if save_period and save_dir:
            net.save_parameters('{}/{}-{}.params'.format(
                save_dir, model_name, epochs - 1))

    if opt.mode == 'hybrid':
        net.hybridize()
    train(train_data, val_data, opt.num_epochs, context)