step=steps,
        factor=args.lr_decay,
        base_lr=(args.lr * num_workers),
        warmup_steps=(args.warmup_epochs * epoch_size),
        warmup_begin_lr=args.warmup_lr)
elif args.lr_mode == 'poly':
    lr_sched = lr_scheduler.PolyScheduler(args.num_epochs * epoch_size,
                                          base_lr=(args.lr * num_workers),
                                          pwr=2,
                                          warmup_steps=(args.warmup_epochs *
                                                        epoch_size),
                                          warmup_begin_lr=args.warmup_lr)
elif args.lr_mode == 'cosine':
    lr_sched = lr_scheduler.CosineScheduler(args.num_epochs * epoch_size,
                                            base_lr=(args.lr * num_workers),
                                            warmup_steps=(args.warmup_epochs *
                                                          epoch_size),
                                            warmup_begin_lr=args.warmup_lr)
else:
    raise ValueError('Invalid lr mode')


# Function for reading data from record file
# For more details about data loading in MXNet, please refer to
# https://mxnet.incubator.apache.org/tutorials/basic/data.html?highlight=imagerecorditer
def get_data_rec(rec_train, rec_train_idx, rec_val, rec_val_idx, batch_size,
                 data_nthreads):
    rec_train = os.path.expanduser(rec_train)
    rec_train_idx = os.path.expanduser(rec_train_idx)
    rec_val = os.path.expanduser(rec_val)
    rec_val_idx = os.path.expanduser(rec_val_idx)
Exemple #2
0
def Training(Args):
    logging.basicConfig()

    args = Args

    # instantiate datasets
    classes = [
        "sacs", "pots", "meubles", "cartons", "matelas", "palette", "caddy",
        "bidon", "chaise", "electromenager"
    ]

    train_dataset = GTDataset(split='train')
    validation_dataset = GTDataset(split='val')
    test_dataset = GTDataset(split='test')
    logging.info(
        "There is {} training images, {} validation images, {} testing images".
        format(len(train_dataset), len(validation_dataset), len(test_dataset)))

    # instantiate model
    batch_size = args.batch
    image_size = 512
    num_workers = args.workers
    num_epochs = args.epochs
    ctx = [mx.gpu(0)] if mx.context.num_gpus() > 0 else [mx.cpu()]

    logging.info('using context ' + str(ctx))

    net = gcv.model_zoo.get_model(args.basemodel, pretrained=True)
    net.reset_class(classes)

    # instantiate training iterator
    with autograd.train_mode():
        _, _, anchors = net(mx.nd.zeros((1, 3, image_size, image_size)))
    train_transform = SSDDefaultTrainTransform(image_size, image_size, anchors)
    batchify_fn = Tuple(Stack(), Stack(),
                        Stack())  # stack image, cls_targets, box_targets
    train_data = gluon.data.DataLoader(
        train_dataset.transform(train_transform),
        batch_size,
        shuffle=True,
        batchify_fn=batchify_fn,
        last_batch='rollover',
        num_workers=num_workers)

    # instantiate val iterator
    val_transform = SSDDefaultValTransform(image_size, image_size)
    batchify_fn = Tuple(Stack(), Pad(pad_val=-1))
    val_data = gluon.data.DataLoader(
        validation_dataset.transform(val_transform),
        batch_size,
        shuffle=False,
        batchify_fn=batchify_fn,
        last_batch='keep',
        num_workers=num_workers)

    # learning rate scheduler
    scheduler = lr_scheduler.CosineScheduler(
        max_update=int(args.max_update * args.epochs),
        base_lr=args.base_lr,
        final_lr=args.final_lr,
        warmup_steps=int(args.warmup_steps * args.epochs),
        warmup_begin_lr=args.warmup_begin_lr,
        warmup_mode='linear')

    # gluon trainer
    net.collect_params().reset_ctx(ctx)
    trainer = gluon.Trainer(net.collect_params(), args.opt, {
        'learning_rate': args.base_lr,
        'wd': 0.0004,
        'momentum': args.momentum
    })

    # SSD losses
    mbox_loss = gcv.loss.SSDMultiBoxLoss()

    # training loop
    best_val = 0

    for epoch in range(num_epochs):

        trainer.set_learning_rate(scheduler(epoch))
        net.hybridize(static_alloc=True, static_shape=True)

        tic = time.time()

        for batch in train_data:

            data = gluon.utils.split_and_load(batch[0],
                                              ctx_list=ctx,
                                              batch_axis=0)
            cls_targets = gluon.utils.split_and_load(batch[1],
                                                     ctx_list=ctx,
                                                     batch_axis=0)
            box_targets = gluon.utils.split_and_load(batch[2],
                                                     ctx_list=ctx,
                                                     batch_axis=0)

            with autograd.record():
                cls_preds, box_preds = [], []
                for x in data:
                    cls_pred, box_pred, _ = net(x)
                    cls_preds.append(cls_pred)
                    box_preds.append(box_pred)
                sum_loss, cls_loss, box_loss = mbox_loss(
                    cls_preds, box_preds, cls_targets, box_targets)
                autograd.backward(sum_loss)

            trainer.step(args.batch)

        name, val = validate(net, val_data, ctx, classes, image_size)

        #le mAP est le dernier élément du tableau
        meanAP = val[10]

        print('[Epoch {}] Training cost: {:.3f}, Learning rate {}, mAP={:.3f}'.
              format(epoch, (time.time() - tic), trainer.learning_rate,
                     meanAP))

        # If validation accuracy improve, save the parameters
        if meanAP > best_val:
            net.save_parameters(args.model_dir + '/ssd_resnet.trash.params')
            best_val = meanAP
            best_epoch = epoch
            best_tab = val
            print('Saving the parameters, best mAP {}'.format(best_val))

    net.save_parameters(args.model_dir + '/ssd_resnet.trashai' + '-mAP_' +
                        str(best_val) + '_Lr' + str(args.base_lr) +
                        '-BestEpoch_' + str(best_epoch) + '.params')
    print("Best mAP : {}".format(best_val))
    print("Best Epoch : " + str(best_epoch))
    print("Sac : " + str(best_tab[0]))
    print("Pot : " + str(best_tab[1]))
    print("Meuble : " + str(best_tab[2]))
    print("Carton : " + str(best_tab[3]))
    print("Matelas : " + str(best_tab[4]))
    print("Palette : " + str(best_tab[5]))
    print("Caddy : " + str(best_tab[6]))
    print("Bidon : " + str(best_tab[7]))
    print("Chaise : " + str(best_tab[8]))
    print("Electromeganer : " + str(best_tab[9]))
Exemple #3
0
def train_net(net, config, check_flag, logger, sig_state, sig_pgbar, sig_table):
    print(config)
    # config = Configs()
    # matplotlib.use('Agg')
    # import matplotlib.pyplot as plt
    sig_pgbar.emit(-1)
    mx.random.seed(1)
    matplotlib.use('Agg')
    import matplotlib.pyplot as plt

    classes = 10
    num_epochs = config.train_cfg.epoch
    batch_size = config.train_cfg.batchsize
    optimizer = config.lr_cfg.optimizer
    lr = config.lr_cfg.lr
    num_gpus = config.train_cfg.gpu
    batch_size *= max(1, num_gpus)
    context = [mx.gpu(i)
               for i in range(num_gpus)] if num_gpus > 0 else [mx.cpu()]
    num_workers = config.data_cfg.worker

    warmup = config.lr_cfg.warmup
    if config.lr_cfg.decay == 'cosine':
        lr_sch = lr_scheduler.CosineScheduler((50000//batch_size)*num_epochs,
                                              base_lr=lr,
                                              warmup_steps=warmup *
                                              (50000//batch_size),
                                              final_lr=1e-5)
    else:
        lr_sch = lr_scheduler.FactorScheduler((50000//batch_size)*config.lr_cfg.factor_epoch,
                                              factor=config.lr_cfg.factor,
                                              base_lr=lr,
                                              warmup_steps=warmup*(50000//batch_size))

    model_name = config.net_cfg.name

    if config.data_cfg.mixup:
        model_name += '_mixup'
    if config.train_cfg.amp:
        model_name += '_amp'

    base_dir = './'+model_name
    if os.path.exists(base_dir):
        base_dir = base_dir + '-' + \
            time.strftime("%m-%d-%H.%M.%S", time.localtime())
    makedirs(base_dir)

    if config.save_cfg.tensorboard:
        logdir = base_dir+'/tb/'+model_name
        if os.path.exists(logdir):
            logdir = logdir + '-' + \
                time.strftime("%m-%d-%H.%M.%S", time.localtime())
        sw = SummaryWriter(logdir=logdir, flush_secs=5, verbose=False)
        cmd_file = open(base_dir+'/tb.bat', mode='w')
        cmd_file.write('tensorboard --logdir=./')
        cmd_file.close()

    save_period = 10
    save_dir = base_dir+'/'+'params'
    makedirs(save_dir)

    plot_name = base_dir+'/'+'plot'
    makedirs(plot_name)

    stat_name = base_dir+'/'+'stat.txt'

    csv_name = base_dir+'/'+'data.csv'
    if os.path.exists(csv_name):
        csv_name = base_dir+'/'+'data-' + \
            time.strftime("%m-%d-%H.%M.%S", time.localtime())+'.csv'
    csv_file = open(csv_name, mode='w', newline='')
    csv_writer = csv.writer(csv_file)
    csv_writer.writerow(['Epoch', 'train_loss', 'train_acc',
                         'valid_loss', 'valid_acc', 'lr', 'time'])

    logging_handlers = [logging.StreamHandler(), logger]
    logging_handlers.append(logging.FileHandler(
        '%s/train_cifar10_%s.log' % (model_name, model_name)))

    logging.basicConfig(level=logging.INFO, handlers=logging_handlers)
    logging.info(config)

    if config.train_cfg.amp:
        amp.init()

    if config.save_cfg.profiler:
        profiler.set_config(profile_all=True,
                            aggregate_stats=True,
                            continuous_dump=True,
                            filename=base_dir+'/%s_profile.json' % model_name)
        is_profiler_run = False

    trans_list = []
    imgsize = config.data_cfg.size
    if config.data_cfg.crop:
        trans_list.append(gcv_transforms.RandomCrop(
            32, pad=config.data_cfg.crop_pad))
    if config.data_cfg.cutout:
        trans_list.append(CutOut(config.data_cfg.cutout_size))
    if config.data_cfg.flip:
        trans_list.append(transforms.RandomFlipLeftRight())
    if config.data_cfg.erase:
        trans_list.append(gcv_transforms.block.RandomErasing(s_max=0.25))
    trans_list.append(transforms.Resize(imgsize))
    trans_list.append(transforms.ToTensor())
    trans_list.append(transforms.Normalize([0.4914, 0.4822, 0.4465],
                                           [0.2023, 0.1994, 0.2010]))

    transform_train = transforms.Compose(trans_list)

    transform_test = transforms.Compose([
        transforms.Resize(imgsize),
        transforms.ToTensor(),
        transforms.Normalize([0.4914, 0.4822, 0.4465],
                             [0.2023, 0.1994, 0.2010])
    ])

    def label_transform(label, classes):
        ind = label.astype('int')
        res = nd.zeros((ind.shape[0], classes), ctx=label.context)
        res[nd.arange(ind.shape[0], ctx=label.context), ind] = 1
        return res

    def test(ctx, val_data):
        metric = mx.metric.Accuracy()
        loss_fn = gluon.loss.SoftmaxCrossEntropyLoss()
        num_batch = len(val_data)
        test_loss = 0
        for i, batch in enumerate(val_data):
            data = gluon.utils.split_and_load(
                batch[0], ctx_list=ctx, batch_axis=0)
            label = gluon.utils.split_and_load(
                batch[1], ctx_list=ctx, batch_axis=0)
            outputs = [net(X) for X in data]
            loss = [loss_fn(yhat, y) for yhat, y in zip(outputs, label)]
            metric.update(label, outputs)
            test_loss += sum([l.sum().asscalar() for l in loss])
        test_loss /= batch_size * num_batch
        name, val_acc = metric.get()
        return name, val_acc, test_loss

    def train(epochs, ctx):
        if isinstance(ctx, mx.Context):
            ctx = [ctx]

        if config.train_cfg.param_init:
            init_func = getattr(mx.init, config.train_cfg.init)
            net.initialize(init_func(), ctx=ctx, force_reinit=True)
        else:
            net.load_parameters(config.train_cfg.param_file, ctx=ctx)

        summary(net, stat_name, nd.uniform(
            shape=(1, 3, imgsize, imgsize), ctx=ctx[0]))
        # net = nn.HybridBlock()
        net.hybridize()

        root = config.dir_cfg.dataset
        train_data = gluon.data.DataLoader(
            gluon.data.vision.CIFAR10(
                root=root, train=True).transform_first(transform_train),
            batch_size=batch_size, shuffle=True, last_batch='discard', num_workers=num_workers)

        val_data = gluon.data.DataLoader(
            gluon.data.vision.CIFAR10(
                root=root, train=False).transform_first(transform_test),
            batch_size=batch_size, shuffle=False, num_workers=num_workers)

        trainer_arg = {'learning_rate': config.lr_cfg.lr,
                       'wd': config.lr_cfg.wd, 'lr_scheduler': lr_sch}
        extra_arg = eval(config.lr_cfg.extra_arg)
        trainer_arg.update(extra_arg)
        trainer = gluon.Trainer(net.collect_params(), optimizer, trainer_arg)
        if config.train_cfg.amp:
            amp.init_trainer(trainer)
        metric = mx.metric.Accuracy()
        train_metric = mx.metric.RMSE()
        loss_fn = gluon.loss.SoftmaxCrossEntropyLoss(
            sparse_label=False if config.data_cfg.mixup else True)
        train_history = TrainingHistory(['training-error', 'validation-error'])
        # acc_history = TrainingHistory(['training-acc', 'validation-acc'])
        loss_history = TrainingHistory(['training-loss', 'validation-loss'])

        iteration = 0

        best_val_score = 0

        # print('start training')
        sig_state.emit(1)
        sig_pgbar.emit(0)
        # signal.emit('Training')
        for epoch in range(epochs):
            tic = time.time()
            train_metric.reset()
            metric.reset()
            train_loss = 0
            num_batch = len(train_data)
            alpha = 1
            for i, batch in enumerate(train_data):
                if epoch == 0 and iteration == 1 and config.save_cfg.profiler:
                    profiler.set_state('run')
                    is_profiler_run = True
                if epoch == 0 and iteration == 1 and config.save_cfg.tensorboard:
                    sw.add_graph(net)
                lam = np.random.beta(alpha, alpha)
                if epoch >= epochs - 20 or not config.data_cfg.mixup:
                    lam = 1

                data_1 = gluon.utils.split_and_load(
                    batch[0], ctx_list=ctx, batch_axis=0)
                label_1 = gluon.utils.split_and_load(
                    batch[1], ctx_list=ctx, batch_axis=0)

                if not config.data_cfg.mixup:
                    data = data_1
                    label = label_1
                else:
                    data = [lam*X + (1-lam)*X[::-1] for X in data_1]
                    label = []
                    for Y in label_1:
                        y1 = label_transform(Y, classes)
                        y2 = label_transform(Y[::-1], classes)
                        label.append(lam*y1 + (1-lam)*y2)

                with ag.record():
                    output = [net(X) for X in data]
                    loss = [loss_fn(yhat, y) for yhat, y in zip(output, label)]
                if config.train_cfg.amp:
                    with ag.record():
                        with amp.scale_loss(loss, trainer) as scaled_loss:
                            ag.backward(scaled_loss)
                            # scaled_loss.backward()
                else:
                    for l in loss:
                        l.backward()
                trainer.step(batch_size)
                train_loss += sum([l.sum().asscalar() for l in loss])

                output_softmax = [nd.SoftmaxActivation(out) for out in output]
                train_metric.update(label, output_softmax)
                metric.update(label_1, output_softmax)
                name, acc = train_metric.get()
                if config.save_cfg.tensorboard:
                    sw.add_scalar(tag='lr', value=trainer.learning_rate,
                                  global_step=iteration)
                if epoch == 0 and iteration == 1 and config.save_cfg.profiler:
                    nd.waitall()
                    profiler.set_state('stop')
                    profiler.dump()
                iteration += 1
                sig_pgbar.emit(iteration)
                if check_flag()[0]:
                    sig_state.emit(2)
                while(check_flag()[0] or check_flag()[1]):
                    if check_flag()[1]:
                        print('stop')
                        return
                    else:
                        time.sleep(5)
                        print('pausing')

            epoch_time = time.time() - tic
            train_loss /= batch_size * num_batch
            name, acc = train_metric.get()
            _, train_acc = metric.get()
            name, val_acc, _ = test(ctx, val_data)
            # if config.data_cfg.mixup:
            #     train_history.update([acc, 1-val_acc])
            #     plt.cla()
            #     train_history.plot(save_path='%s/%s_history.png' %
            #                        (plot_name, model_name))
            # else:
            train_history.update([1-train_acc, 1-val_acc])
            plt.cla()
            train_history.plot(save_path='%s/%s_history.png' %
                               (plot_name, model_name))

            if val_acc > best_val_score:
                best_val_score = val_acc
                net.save_parameters('%s/%.4f-cifar-%s-%d-best.params' %
                                    (save_dir, best_val_score, model_name, epoch))

            current_lr = trainer.learning_rate
            name, val_acc, val_loss = test(ctx, val_data)

            logging.info('[Epoch %d] loss=%f train_acc=%f train_RMSE=%f\n     val_acc=%f val_loss=%f lr=%f time: %f' %
                         (epoch, train_loss, train_acc, acc, val_acc, val_loss, current_lr, epoch_time))
            loss_history.update([train_loss, val_loss])
            plt.cla()
            loss_history.plot(save_path='%s/%s_loss.png' %
                              (plot_name, model_name), y_lim=(0, 2), legend_loc='best')
            if config.save_cfg.tensorboard:
                sw._add_scalars(tag='Acc',
                                scalar_dict={'train_acc': train_acc, 'test_acc': val_acc}, global_step=epoch)
                sw._add_scalars(tag='Loss',
                                scalar_dict={'train_loss': train_loss, 'test_loss': val_loss}, global_step=epoch)

            sig_table.emit([epoch, train_loss, train_acc,
                            val_loss, val_acc, current_lr, epoch_time])
            csv_writer.writerow([epoch, train_loss, train_acc,
                                 val_loss, val_acc, current_lr, epoch_time])
            csv_file.flush()

            if save_period and save_dir and (epoch + 1) % save_period == 0:
                net.save_parameters('%s/cifar10-%s-%d.params' %
                                    (save_dir, model_name, epoch))
        if save_period and save_dir:
            net.save_parameters('%s/cifar10-%s-%d.params' %
                                (save_dir, model_name, epochs-1))

    train(num_epochs, context)
    if config.save_cfg.tensorboard:
        sw.close()

    for ctx in context:
        ctx.empty_cache()

    csv_file.close()
    logging.shutdown()
    reload(logging)
    sig_state.emit(0)
Exemple #4
0
def main():
    import matplotlib
    matplotlib.use('Agg')
    import matplotlib.pyplot as plt

    opt = parse_args()
    batch_size = opt.batch_size
    classes = 10

    num_gpus = opt.num_gpus
    batch_size *= max(1, num_gpus)
    context = [mx.gpu(i)
               for i in range(num_gpus)] if num_gpus > 0 else [mx.cpu()]
    num_workers = opt.num_workers

    lr_sch = lr_scheduler.CosineScheduler((50000//batch_size)*opt.num_epochs,
                                          base_lr=opt.lr,
                                          warmup_steps=5*(50000//batch_size),
                                          final_lr=1e-5)
    # lr_sch = lr_scheduler.FactorScheduler((50000//batch_size)*20,
    #                                       factor=0.2, base_lr=opt.lr,
    #                                       warmup_steps=5*(50000//batch_size))
    # lr_sch = LRScheduler('cosine',opt.lr, niters=(50000//batch_size)*opt.num_epochs,)

    model_name = opt.model
    net = SKT_Lite()
    # if model_name.startswith('cifar_wideresnet'):
    #     kwargs = {'classes': classes,
    #             'drop_rate': opt.drop_rate}
    # else:
    #     kwargs = {'classes': classes}
    # net = get_model(model_name, **kwargs)
    if opt.mixup:
        model_name += '_mixup'
    if opt.amp:
        model_name += '_amp'

    makedirs('./'+model_name)
    os.chdir('./'+model_name)
    sw = SummaryWriter(
        logdir='.\\tb\\'+model_name, flush_secs=5, verbose=False)
    makedirs(opt.save_plot_dir)

    if opt.resume_from:
        net.load_parameters(opt.resume_from, ctx=context)
    optimizer = 'nag'

    save_period = opt.save_period
    if opt.save_dir and save_period:
        save_dir = opt.save_dir
        makedirs(save_dir)
    else:
        save_dir = ''
        save_period = 0

    plot_name = opt.save_plot_dir

    logging_handlers = [logging.StreamHandler()]
    if opt.logging_dir:
        logging_dir = opt.logging_dir
        makedirs(logging_dir)
        logging_handlers.append(logging.FileHandler(
            '%s/train_cifar10_%s.log' % (logging_dir, model_name)))

    logging.basicConfig(level=logging.INFO, handlers=logging_handlers)
    logging.info(opt)

    if opt.amp:
        amp.init()

    if opt.profile_mode:
        profiler.set_config(profile_all=True,
                            aggregate_stats=True,
                            continuous_dump=True,
                            filename='%s_profile.json' % model_name)

    transform_train = transforms.Compose([
        gcv_transforms.RandomCrop(32, pad=4),
        CutOut(8),
        # gcv_transforms.block.RandomErasing(s_max=0.25),
        transforms.RandomFlipLeftRight(),
        # transforms.RandomFlipTopBottom(),
        transforms.Resize(32),
        transforms.ToTensor(),
        transforms.Normalize([0.4914, 0.4822, 0.4465],
                             [0.2023, 0.1994, 0.2010])
    ])

    transform_test = transforms.Compose([
        transforms.Resize(32),
        transforms.ToTensor(),
        transforms.Normalize([0.4914, 0.4822, 0.4465],
                             [0.2023, 0.1994, 0.2010])
    ])

    def label_transform(label, classes):
        ind = label.astype('int')
        res = nd.zeros((ind.shape[0], classes), ctx=label.context)
        res[nd.arange(ind.shape[0], ctx=label.context), ind] = 1
        return res

    def test(ctx, val_data):
        metric = mx.metric.Accuracy()
        loss_fn = gluon.loss.SoftmaxCrossEntropyLoss()
        num_batch = len(val_data)
        test_loss = 0
        for i, batch in enumerate(val_data):
            data = gluon.utils.split_and_load(
                batch[0], ctx_list=ctx, batch_axis=0)
            label = gluon.utils.split_and_load(
                batch[1], ctx_list=ctx, batch_axis=0)
            outputs = [net(X) for X in data]
            loss = [loss_fn(yhat, y) for yhat, y in zip(outputs, label)]
            metric.update(label, outputs)
            test_loss += sum([l.sum().asscalar() for l in loss])
        test_loss /= batch_size * num_batch
        name, val_acc = metric.get()
        return name, val_acc, test_loss

    def train(epochs, ctx):
        if isinstance(ctx, mx.Context):
            ctx = [ctx]
        net.initialize(mx.init.MSRAPrelu(), ctx=ctx)

        root = os.path.join('..', 'datasets', 'cifar-10')
        train_data = gluon.data.DataLoader(
            gluon.data.vision.CIFAR10(
                root=root, train=True).transform_first(transform_train),
            batch_size=batch_size, shuffle=True, last_batch='discard', num_workers=num_workers)

        val_data = gluon.data.DataLoader(
            gluon.data.vision.CIFAR10(
                root=root, train=False).transform_first(transform_test),
            batch_size=batch_size, shuffle=False, num_workers=num_workers)

        trainer = gluon.Trainer(net.collect_params(), optimizer,
                                {'learning_rate': opt.lr, 'wd': opt.wd,
                                 'momentum': opt.momentum, 'lr_scheduler': lr_sch})
        if opt.amp:
            amp.init_trainer(trainer)
        metric = mx.metric.Accuracy()
        train_metric = mx.metric.RMSE()
        loss_fn = gluon.loss.SoftmaxCrossEntropyLoss(
            sparse_label=False if opt.mixup else True)
        train_history = TrainingHistory(['training-error', 'validation-error'])
        # acc_history = TrainingHistory(['training-acc', 'validation-acc'])
        loss_history = TrainingHistory(['training-loss', 'validation-loss'])

        iteration = 0

        best_val_score = 0

        for epoch in range(epochs):
            tic = time.time()
            train_metric.reset()
            metric.reset()
            train_loss = 0
            num_batch = len(train_data)
            alpha = 1

            for i, batch in enumerate(train_data):
                if epoch == 0 and iteration == 1 and opt.profile_mode:
                    profiler.set_state('run')
                lam = np.random.beta(alpha, alpha)
                if epoch >= epochs - 20 or not opt.mixup:
                    lam = 1

                data_1 = gluon.utils.split_and_load(
                    batch[0], ctx_list=ctx, batch_axis=0)
                label_1 = gluon.utils.split_and_load(
                    batch[1], ctx_list=ctx, batch_axis=0)

                if not opt.mixup:
                    data = data_1
                    label = label_1
                else:
                    data = [lam*X + (1-lam)*X[::-1] for X in data_1]
                    label = []
                    for Y in label_1:
                        y1 = label_transform(Y, classes)
                        y2 = label_transform(Y[::-1], classes)
                        label.append(lam*y1 + (1-lam)*y2)

                with ag.record():
                    output = [net(X) for X in data]
                    loss = [loss_fn(yhat, y) for yhat, y in zip(output, label)]
                if opt.amp:
                    with ag.record():
                        with amp.scale_loss(loss, trainer) as scaled_loss:
                            ag.backward(scaled_loss)
                            # scaled_loss.backward()
                else:
                    for l in loss:
                        l.backward()
                trainer.step(batch_size)
                train_loss += sum([l.sum().asscalar() for l in loss])

                output_softmax = [nd.SoftmaxActivation(out) for out in output]
                train_metric.update(label, output_softmax)
                metric.update(label_1, output_softmax)
                name, acc = train_metric.get()
                sw.add_scalar(tag='lr', value=trainer.learning_rate,
                              global_step=iteration)
                if epoch == 0 and iteration == 1 and opt.profile_mode:
                    nd.waitall()
                    profiler.set_state('stop')
                iteration += 1

            train_loss /= batch_size * num_batch
            name, acc = train_metric.get()
            _, train_acc = metric.get()
            name, val_acc, _ = test(ctx, val_data)
            if opt.mixup:
                train_history.update([acc, 1-val_acc])
                plt.cla()
                train_history.plot(save_path='%s/%s_history.png' %
                                   (plot_name, model_name))
            else:
                train_history.update([1-train_acc, 1-val_acc])
                plt.cla()
                train_history.plot(save_path='%s/%s_history.png' %
                                   (plot_name, model_name))
            # acc_history.update([train_acc, val_acc])
            # plt.cla()
            # acc_history.plot(save_path='%s/%s_acc.png' %
            #                  (plot_name, model_name), legend_loc='best')

            if val_acc > best_val_score:
                best_val_score = val_acc
                net.save_parameters('%s/%.4f-cifar-%s-%d-best.params' %
                                    (save_dir, best_val_score, model_name, epoch))

            current_lr = trainer.learning_rate
            name, val_acc, val_loss = test(ctx, val_data)
            loss_history.update([train_loss, val_loss])
            plt.cla()
            loss_history.plot(save_path='%s/%s_loss.png' %
                              (plot_name, model_name), y_lim=(0, 2), legend_loc='best')
            logging.info('[Epoch %d] loss=%f train_acc=%f train_RMSE=%f\n     val_acc=%f val_loss=%f lr=%f time: %f' %
                         (epoch, train_loss, train_acc, acc, val_acc, val_loss, current_lr, time.time()-tic))
            sw._add_scalars(tag='Acc',
                            scalar_dict={'train_acc': train_acc, 'test_acc': val_acc}, global_step=epoch)
            sw._add_scalars(tag='Loss',
                            scalar_dict={'train_loss': train_loss, 'test_loss': val_loss}, global_step=epoch)
            if save_period and save_dir and (epoch + 1) % save_period == 0:
                net.save_parameters('%s/cifar10-%s-%d.params' %
                                    (save_dir, model_name, epoch))
        if save_period and save_dir:
            net.save_parameters('%s/cifar10-%s-%d.params' %
                                (save_dir, model_name, epochs-1))

    if opt.mode == 'hybrid':
        net.hybridize()
    train(opt.num_epochs, context)
    if opt.profile_mode:
        profiler.dump(finished=False)
    sw.close()