def train(ctx,
          loss,
          trainer,
          datasetName,
          modelName,
          net,
          train_iter,
          valid_iter,
          num_epochs,
          n_retrain_epoch=0):
    '''
    n_retrain_epoch 是从第 n_retrain_epoch 次开始训练模型
    '''
    train_metric = metric.Accuracy()
    train_history = TrainingHistory(['training-error', 'validation-error'])
    best_val_score = 0
    modelDir, resultDir = get_result_dirs(datasetName)
    for epoch in range(num_epochs):
        train_l_batch, start = 0.0, time.time()  # 计时开始
        train_metric.reset()
        for X, y in train_iter:
            X = X.as_in_context(ctx)
            y = y.as_in_context(ctx).astype('float32')  # 模型的输出是 float32 类型数据
            with autograd.record():  # 记录梯度信息
                outputs = net(X)  # 模型输出
                l = loss(outputs, y).mean()  # 计算平均损失
            l.backward()  # 反向传播
            trainer.step(1)
            train_l_batch += l.asscalar()  # 计算该批量的总损失
            train_metric.update(y, outputs)  # 计算训练精度
        _, train_acc = train_metric.get()
        time_s = "time {:.2f} sec".format(time.time() - start)  # 计时结束
        valid_loss = evaluate_loss(valid_iter, net, ctx, loss)  # 计算验证集的平均损失
        _, val_acc = test(valid_iter, net, ctx)  # 计算验证集的精度
        epoch_s = (
            "epoch {:d}, train loss {:.5f}, valid loss {:.5f}, train acc {:.5f}, valid acc {:.5f}, "
            .format(n_retrain_epoch + epoch, train_l_batch, valid_loss,
                    train_acc, val_acc))
        print(epoch_s + time_s)
        train_history.update([1 - train_acc, 1 - val_acc])  # 更新图像的纵轴
        train_history.plot(
            save_path=f'{resultDir}/{modelName}_history.png')  # 实时更新图像
        if val_acc > best_val_score:  # 保存比较好的模型
            best_val_score = val_acc
            net.save_parameters('{}/{:.4f}-{}-{:d}-best.params'.format(
                modelDir, best_val_score, modelName, n_retrain_epoch + epoch))
    return train_history
    def train(epochs, ctx):
        if isinstance(ctx, mx.Context):
            ctx = [ctx]
        net.initialize(mx.init.Xavier(), ctx=ctx)

        train_data = gluon.data.DataLoader(
            gluon.data.vision.CIFAR10(train=True).transform_first(transform_train),
            batch_size=batch_size, shuffle=True, last_batch='discard', num_workers=num_workers)

        val_data = gluon.data.DataLoader(
            gluon.data.vision.CIFAR10(train=False).transform_first(transform_test),
            batch_size=batch_size, shuffle=False, num_workers=num_workers)

        trainer = gluon.Trainer(net.collect_params(), optimizer,
                                {'learning_rate': opt.lr, 'wd': opt.wd, 'momentum': opt.momentum})
        metric = mx.metric.Accuracy()
        train_metric = mx.metric.RMSE()
        loss_fn = gluon.loss.SoftmaxCrossEntropyLoss(sparse_label=False)
        train_history = TrainingHistory(['training-error', 'validation-error'])

        iteration = 0
        lr_decay_count = 0

        best_val_score = 0

        for epoch in range(epochs):
            tic = time.time()
            train_metric.reset()
            metric.reset()
            train_loss = 0
            num_batch = len(train_data)
            alpha = 1

            if epoch == lr_decay_epoch[lr_decay_count]:
                trainer.set_learning_rate(trainer.learning_rate*lr_decay)
                lr_decay_count += 1

            for i, batch in enumerate(train_data):
                lam = np.random.beta(alpha, alpha)
                if epoch >= epochs - 20:
                    lam = 1

                data_1 = gluon.utils.split_and_load(batch[0], ctx_list=ctx, batch_axis=0)
                label_1 = gluon.utils.split_and_load(batch[1], ctx_list=ctx, batch_axis=0)

                data = [lam*X + (1-lam)*X[::-1] for X in data_1]
                label = []
                for Y in label_1:
                    y1 = label_transform(Y, classes)
                    y2 = label_transform(Y[::-1], classes)
                    label.append(lam*y1 + (1-lam)*y2)

                with ag.record():
                    output = [net(X) for X in data]
                    loss = [loss_fn(yhat, y) for yhat, y in zip(output, label)]
                for l in loss:
                    l.backward()
                trainer.step(batch_size)
                train_loss += sum([l.sum().asscalar() for l in loss])

                output_softmax = [nd.SoftmaxActivation(out) for out in output]
                train_metric.update(label, output_softmax)
                name, acc = train_metric.get()
                iteration += 1

            train_loss /= batch_size * num_batch
            name, acc = train_metric.get()
            name, val_acc = test(ctx, val_data)
            train_history.update([acc, 1-val_acc])
            train_history.plot(save_path='%s/%s_history.png'%(plot_name, model_name))

            if val_acc > best_val_score:
                best_val_score = val_acc
                net.save_parameters('%s/%.4f-cifar-%s-%d-best.params'%(save_dir, best_val_score, model_name, epoch))

            name, val_acc = test(ctx, val_data)
            logging.info('[Epoch %d] train=%f val=%f loss=%f time: %f' %
                (epoch, acc, val_acc, train_loss, time.time()-tic))

            if save_period and save_dir and (epoch + 1) % save_period == 0:
                net.save_parameters('%s/cifar10-%s-%d.params'%(save_dir, model_name, epoch))

        if save_period and save_dir:
            net.save_parameters('%s/cifar10-%s-%d.params'%(save_dir, model_name, epochs-1))
Example #3
0
    def train(epochs, ctx):
        if isinstance(ctx, mx.Context):
            ctx = [ctx]
        net.initialize(mx.init.Xavier(), ctx=ctx)

        if opt.summary:
            summary(net, mx.nd.zeros((1, 3, 32, 32), ctx=ctx[0]))
            sys.exit()

        if opt.dataset == 'cifar10':
            train_data = gluon.data.DataLoader(
                gluon.data.vision.CIFAR10(train=True).transform_first(transform_train),
                batch_size=batch_size, shuffle=True, last_batch='discard', num_workers=num_workers)
            val_data = gluon.data.DataLoader(
                gluon.data.vision.CIFAR10(train=False).transform_first(transform_test),
                batch_size=batch_size, shuffle=False, num_workers=num_workers)
        elif opt.dataset == 'cifar100':
            train_data = gluon.data.DataLoader(
                gluon.data.vision.CIFAR100(train=True).transform_first(transform_train),
                batch_size=batch_size, shuffle=True, last_batch='discard', num_workers=num_workers)
            val_data = gluon.data.DataLoader(
                gluon.data.vision.CIFAR100(train=False).transform_first(transform_test),
                batch_size=batch_size, shuffle=False, num_workers=num_workers)
        else:
            raise ValueError('Unknown Dataset')

        if opt.no_wd and opt.cosine:
            for k, v in net.collect_params('.*beta|.*gamma|.*bias').items():
                v.wd_mult = 0.0

        trainer = gluon.Trainer(net.collect_params(), optimizer, optimizer_params)

        if opt.label_smoothing or opt.mixup:
            sparse_label_loss = False
        else:
            sparse_label_loss = True

        metric = mx.metric.Accuracy()
        train_metric = mx.metric.RMSE()
        loss_fn = gluon.loss.SoftmaxCrossEntropyLoss(sparse_label=sparse_label_loss)
        train_history = TrainingHistory(['training-error', 'validation-error'])

        iteration = 0
        lr_decay_count = 0

        best_val_score = 0

        for epoch in range(epochs):
            tic = time.time()
            train_metric.reset()
            metric.reset()
            train_loss = 0
            num_batch = len(train_data)

            if not opt.cosine:
                if epoch == lr_decay_epoch[lr_decay_count]:
                    trainer.set_learning_rate(trainer.learning_rate * lr_decay)
                    lr_decay_count += 1

            for i, batch in enumerate(train_data):
                data_1 = gluon.utils.split_and_load(batch[0], ctx_list=ctx, batch_axis=0)
                label_1 = gluon.utils.split_and_load(batch[1], ctx_list=ctx, batch_axis=0)

                if opt.mixup:
                    lam = np.random.beta(opt.mixup_alpha, opt.mixup_alpha)
                    if (epoch >= epochs - opt.mixup_off_epoch) or not opt.mixup:
                        lam = 1

                    data = [lam * X + (1 - lam) * X[::-1] for X in data_1]

                    if opt.label_smoothing:
                        eta = 0.1
                    else:
                        eta = 0.0
                    label = mixup_transform(label_1, classes, lam, eta)

                elif opt.label_smoothing:
                    hard_label = label_1
                    label = smooth(label_1, classes)

                with ag.record():
                    output = [net(X) for X in data]
                    loss = [loss_fn(yhat, y) for yhat, y in zip(output, label)]
                for l in loss:
                    l.backward()
                trainer.step(batch_size)
                train_loss += sum([l.sum().asscalar() for l in loss])

                if opt.mixup:
                    output_softmax = [nd.SoftmaxActivation(out) for out in output]
                    train_metric.update(label, output_softmax)
                else:
                    if opt.label_smoothing:
                        train_metric.update(hard_label, output)
                    else:
                        train_metric.update(label, output)

                name, acc = train_metric.get()
                iteration += 1

            train_loss /= batch_size * num_batch
            name, acc = train_metric.get()
            name, val_acc = test(ctx, val_data)
            train_history.update([acc, 1 - val_acc])
            train_history.plot(save_path='%s/%s_history.png' % (plot_name, model_name))

            if val_acc > best_val_score:
                best_val_score = val_acc
                net.save_parameters('%s/%.4f-%s-best.params' %
                                    (save_dir, best_val_score, model_name))

            name, val_acc = test(ctx, val_data)
            logging.info('[Epoch %d] train=%f val=%f loss=%f lr: %f time: %f' %
                         (epoch, acc, val_acc, train_loss, trainer.learning_rate,
                          time.time() - tic))

        host_name = socket.gethostname()
        with open(opt.dataset + '_' + host_name + '_GPU_' + opt.gpus + '_best_Acc.log', 'a') as f:
            f.write('best Acc: {:.4f}\n'.format(best_val_score))
        print("best_val_score: ", best_val_score)
Example #4
0
    def train(epochs, ctx):
        if isinstance(ctx, mx.Context):
            ctx = [ctx]

        if config.train_cfg.param_init:
            init_func = getattr(mx.init, config.train_cfg.init)
            net.initialize(init_func(), ctx=ctx, force_reinit=True)
        else:
            net.load_parameters(config.train_cfg.param_file, ctx=ctx)

        summary(net, stat_name, nd.uniform(
            shape=(1, 3, imgsize, imgsize), ctx=ctx[0]))
        # net = nn.HybridBlock()
        net.hybridize()

        root = config.dir_cfg.dataset
        train_data = gluon.data.DataLoader(
            gluon.data.vision.CIFAR10(
                root=root, train=True).transform_first(transform_train),
            batch_size=batch_size, shuffle=True, last_batch='discard', num_workers=num_workers)

        val_data = gluon.data.DataLoader(
            gluon.data.vision.CIFAR10(
                root=root, train=False).transform_first(transform_test),
            batch_size=batch_size, shuffle=False, num_workers=num_workers)

        trainer_arg = {'learning_rate': config.lr_cfg.lr,
                       'wd': config.lr_cfg.wd, 'lr_scheduler': lr_sch}
        extra_arg = eval(config.lr_cfg.extra_arg)
        trainer_arg.update(extra_arg)
        trainer = gluon.Trainer(net.collect_params(), optimizer, trainer_arg)
        if config.train_cfg.amp:
            amp.init_trainer(trainer)
        metric = mx.metric.Accuracy()
        train_metric = mx.metric.RMSE()
        loss_fn = gluon.loss.SoftmaxCrossEntropyLoss(
            sparse_label=False if config.data_cfg.mixup else True)
        train_history = TrainingHistory(['training-error', 'validation-error'])
        # acc_history = TrainingHistory(['training-acc', 'validation-acc'])
        loss_history = TrainingHistory(['training-loss', 'validation-loss'])

        iteration = 0

        best_val_score = 0

        # print('start training')
        sig_state.emit(1)
        sig_pgbar.emit(0)
        # signal.emit('Training')
        for epoch in range(epochs):
            tic = time.time()
            train_metric.reset()
            metric.reset()
            train_loss = 0
            num_batch = len(train_data)
            alpha = 1
            for i, batch in enumerate(train_data):
                if epoch == 0 and iteration == 1 and config.save_cfg.profiler:
                    profiler.set_state('run')
                    is_profiler_run = True
                if epoch == 0 and iteration == 1 and config.save_cfg.tensorboard:
                    sw.add_graph(net)
                lam = np.random.beta(alpha, alpha)
                if epoch >= epochs - 20 or not config.data_cfg.mixup:
                    lam = 1

                data_1 = gluon.utils.split_and_load(
                    batch[0], ctx_list=ctx, batch_axis=0)
                label_1 = gluon.utils.split_and_load(
                    batch[1], ctx_list=ctx, batch_axis=0)

                if not config.data_cfg.mixup:
                    data = data_1
                    label = label_1
                else:
                    data = [lam*X + (1-lam)*X[::-1] for X in data_1]
                    label = []
                    for Y in label_1:
                        y1 = label_transform(Y, classes)
                        y2 = label_transform(Y[::-1], classes)
                        label.append(lam*y1 + (1-lam)*y2)

                with ag.record():
                    output = [net(X) for X in data]
                    loss = [loss_fn(yhat, y) for yhat, y in zip(output, label)]
                if config.train_cfg.amp:
                    with ag.record():
                        with amp.scale_loss(loss, trainer) as scaled_loss:
                            ag.backward(scaled_loss)
                            # scaled_loss.backward()
                else:
                    for l in loss:
                        l.backward()
                trainer.step(batch_size)
                train_loss += sum([l.sum().asscalar() for l in loss])

                output_softmax = [nd.SoftmaxActivation(out) for out in output]
                train_metric.update(label, output_softmax)
                metric.update(label_1, output_softmax)
                name, acc = train_metric.get()
                if config.save_cfg.tensorboard:
                    sw.add_scalar(tag='lr', value=trainer.learning_rate,
                                  global_step=iteration)
                if epoch == 0 and iteration == 1 and config.save_cfg.profiler:
                    nd.waitall()
                    profiler.set_state('stop')
                    profiler.dump()
                iteration += 1
                sig_pgbar.emit(iteration)
                if check_flag()[0]:
                    sig_state.emit(2)
                while(check_flag()[0] or check_flag()[1]):
                    if check_flag()[1]:
                        print('stop')
                        return
                    else:
                        time.sleep(5)
                        print('pausing')

            epoch_time = time.time() - tic
            train_loss /= batch_size * num_batch
            name, acc = train_metric.get()
            _, train_acc = metric.get()
            name, val_acc, _ = test(ctx, val_data)
            # if config.data_cfg.mixup:
            #     train_history.update([acc, 1-val_acc])
            #     plt.cla()
            #     train_history.plot(save_path='%s/%s_history.png' %
            #                        (plot_name, model_name))
            # else:
            train_history.update([1-train_acc, 1-val_acc])
            plt.cla()
            train_history.plot(save_path='%s/%s_history.png' %
                               (plot_name, model_name))

            if val_acc > best_val_score:
                best_val_score = val_acc
                net.save_parameters('%s/%.4f-cifar-%s-%d-best.params' %
                                    (save_dir, best_val_score, model_name, epoch))

            current_lr = trainer.learning_rate
            name, val_acc, val_loss = test(ctx, val_data)

            logging.info('[Epoch %d] loss=%f train_acc=%f train_RMSE=%f\n     val_acc=%f val_loss=%f lr=%f time: %f' %
                         (epoch, train_loss, train_acc, acc, val_acc, val_loss, current_lr, epoch_time))
            loss_history.update([train_loss, val_loss])
            plt.cla()
            loss_history.plot(save_path='%s/%s_loss.png' %
                              (plot_name, model_name), y_lim=(0, 2), legend_loc='best')
            if config.save_cfg.tensorboard:
                sw._add_scalars(tag='Acc',
                                scalar_dict={'train_acc': train_acc, 'test_acc': val_acc}, global_step=epoch)
                sw._add_scalars(tag='Loss',
                                scalar_dict={'train_loss': train_loss, 'test_loss': val_loss}, global_step=epoch)

            sig_table.emit([epoch, train_loss, train_acc,
                            val_loss, val_acc, current_lr, epoch_time])
            csv_writer.writerow([epoch, train_loss, train_acc,
                                 val_loss, val_acc, current_lr, epoch_time])
            csv_file.flush()

            if save_period and save_dir and (epoch + 1) % save_period == 0:
                net.save_parameters('%s/cifar10-%s-%d.params' %
                                    (save_dir, model_name, epoch))
        if save_period and save_dir:
            net.save_parameters('%s/cifar10-%s-%d.params' %
                                (save_dir, model_name, epochs-1))
Example #5
0
            btic = time.time()

    name, acc = train_metric.get()

    # test
    #acc_top1_val, acc_top5_val, loss_val = test(ctx, val_data)
    acc_top1_val, acc_top5_val, loss_val, loss_mse, loss_pre = test(
        ctx, val_data)

    # Update history and print metrics
    train_history.update([
        acc, acc_top1_val, acc_top5_val, train_loss / (i + 1), loss_val,
        loss_mse, loss_pre
    ])
    train_history.plot(save_path=os.path.join(opt.save_dir,
                                              'trainlog_acc.jpg'),
                       labels=['training-acc', 'val-top1-acc', 'val-top5-acc'])
    train_history.plot(
        save_path=os.path.join(opt.save_dir, 'trainlog_loss.jpg'),
        labels=['training-loss', 'cross-loss', 'mse-loss', 'pre-loss'])
    logger.info('[Epoch %d] train=%f loss=%f time: %f' %
                (epoch, acc, train_loss / (i + 1), time.time() - tic))
    #logger.info('[Epoch %d] val top1 =%f top5=%f val loss=%f,lr=%f' %
    #   (epoch, acc_top1_val, acc_top5_val, loss_val ,trainer.learning_rate ))
    logger.info(
        '[Epoch %d] val top1 =%f top5=%f val loss=%f,mesloss=%f,loss_pre = %f, lr=%f'
        % (epoch, acc_top1_val, acc_top5_val, loss_val, loss_mse, loss_pre,
           trainer.learning_rate))
    if acc_top1_val > best_val_score and epoch > 5:
        best_val_score = acc_top1_val
        net.save_parameters(
Example #6
0
    
    # Evaluate on Validation data
    #name, val_acc = test(ctx, val_data)
    val_acc_top1, val_acc_top5 = test(ctx, val_data)

    # Update history and print metrics
    train_history.update([1-acc, 1-val_acc_top1])
    train_history2.update([acc, val_acc_top1, val_acc_top5])
    
    print('[Epoch %d] train=%f val_top1=%f val_top5=%f loss=%f time: %f' %
        (epoch, acc, val_acc_top1, val_acc_top5, train_loss, time.time()-tic))

# We can plot the metric scores with:
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
train_history.plot(['training-error', 'validation-error'], 
                   save_path="./cifar100_resnet56_v1_{o}_{ep}epochs_errors_{t}.png".format(o=optimizer,
                                                                                           ep=epochs,
                                                                                           t=timestamp))
train_history2.plot(['training-acc', 'val-acc-top1', 'val-acc-top5'],
                   save_path="./cifar100_resnet56_v1_{o}_{ep}epochs_accuracies_{t}.png".format(o=optimizer,
                                                                                               ep=epochs,
                                                                                               t=timestamp))
print("Done.")


# In[ ]:




Example #7
0
def train(epochs, ctx):
    if isinstance(ctx, mx.Context):
        ctx = [ctx]
    net.initialize(mx.init.Xavier(), ctx=ctx)

    train_data = gluon.data.DataLoader(
        gluon.data.vision.CIFAR10(train=True).transform_first(transform_train),
        batch_size=batch_size, shuffle=True, last_batch='discard', num_workers=num_workers)

    val_data = gluon.data.DataLoader(
        gluon.data.vision.CIFAR10(train=False).transform_first(transform_test),
        batch_size=batch_size, shuffle=False, num_workers=num_workers)

    trainer = gluon.Trainer(net.collect_params(), optimizer,
                            {'learning_rate': opt.lr, 'wd': opt.wd, 'momentum': opt.momentum})
    metric = mx.metric.Accuracy()
    train_metric = mx.metric.RMSE()
    loss_fn = gluon.loss.SoftmaxCrossEntropyLoss(sparse_label=False)
    train_history = TrainingHistory(['training-error', 'validation-error'])

    iteration = 0
    lr_decay_count = 0

    best_val_score = 0

    for epoch in range(epochs):
        tic = time.time()
        train_metric.reset()
        metric.reset()
        train_loss = 0
        num_batch = len(train_data)
        alpha = 1

        if epoch == lr_decay_epoch[lr_decay_count]:
            trainer.set_learning_rate(trainer.learning_rate*lr_decay)
            lr_decay_count += 1

        for i, batch in enumerate(train_data):
            lam = np.random.beta(alpha, alpha)
            if epoch >= epochs - 50:
                lam = 1

            data_1 = gluon.utils.split_and_load(batch[0], ctx_list=ctx, batch_axis=0)
            label_1 = gluon.utils.split_and_load(batch[1], ctx_list=ctx, batch_axis=0)

            data = [lam*X + (1-lam)*X[::-1] for X in data_1]
            label = []
            for Y in label_1:
                y1 = label_transform(Y, classes)
                y2 = label_transform(Y[::-1], classes)
                label.append(lam*y1 + (1-lam)*y2)

            with ag.record():
                output = [net(X) for X in data]
                loss = [loss_fn(yhat, y) for yhat, y in zip(output, label)]
            for l in loss:
                l.backward()
            trainer.step(batch_size)
            train_loss += sum([l.sum().asscalar() for l in loss])

            output_softmax = [nd.SoftmaxActivation(out) for out in output]
            train_metric.update(label, output_softmax)
            name, acc = train_metric.get()
            iteration += 1

        train_loss /= batch_size * num_batch
        name, acc = train_metric.get()
        name, val_acc = test(ctx, val_data)
        train_history.update([acc, 1-val_acc])
        train_history.plot(save_path='%s/%s_history.png'%(plot_name, model_name))

        if val_acc > best_val_score and epoch > 200:
            best_val_score = val_acc
            net.save_params('%s/%.4f-imagenet-%s-%d-best.params'%(save_dir, best_val_score, model_name, epoch))

        name, val_acc = test(ctx, val_data)
        logging.info('[Epoch %d] train=%f val=%f loss=%f time: %f' %
            (epoch, acc, val_acc, train_loss, time.time()-tic))

        if save_period and save_dir and (epoch + 1) % save_period == 0:
            net.save_params('%s/cifar10-%s-%d.params'%(save_dir, model_name, epoch))

    if save_period and save_dir:
        net.save_params('%s/cifar10-%s-%d.params'%(save_dir, model_name, epochs-1))
Example #8
0
                if save_best_val_acc:
                    val_acc_save_m = 'Params saved on epoch {}, new best val acc founded'.format(epoch+1)
                    print(val_acc_save_m)
                    if save:
                        log.write(val_acc_save_m)
                    for i in range(classes):
                       per_class_acc = '{}={}'.format(classes_list[i], test_on_single_class(net, val_data, ctx, i)[1])
                       print(per_class_acc)
                       if save:
                           log.write(per_class_acc + '\n')
                else:
                    print('Params saved on epoch {}'.format(epoch+1))
                net.save_parameters(os.path.join(
                        params_path,
                        '{:s}_{:03d}__{}.params'.format(net_name, epoch+1, model_name))
                        )
                if not ssh:
                    train_history.plot(save_path=(os.path.join(
                               params_path,
                               '{:s}_{:03d}__{}.png'.format(net_name, epoch+1, model_name))
                               ))
    if not ssh:
        train_history.plot()

# name, test_acc = test(net, test_data, ctx)
# test_score = '[Finished] Test-acc: {:.3f}'.format(test_acc)
# print(test_score)
# if train:
#     log.write(test_score + '\n')
#     log.close()
Example #9
0
                '[Epoch %d] [%d | %d] train=%f loss=%f mseloss=%f  pre_loss %f time: %f'
                % (epoch, i, len(train_data), acc, train_loss /
                   (i + 1), mse_loss / (i + 1), pre_loss /
                   (i + 1), time.time() - btic))
            btic = time.time()

    name, acc = train_metric.get()

    # test
    #acc_top1_val, acc_top5_val, loss_val = test(ctx, val_data)
    acc_top1_val, acc_top5_val, loss_val, loss_mse, loss_pre = test(
        ctx, val_data)

    # Update history and print metrics
    train_history.update([acc, acc_top1_val, acc_top5_val])
    train_history.plot(save_path=os.path.join(opt.save_dir, 'trainlog.jpg'))
    train_history_loss.update(
        [train_loss / (i + 1), loss_val, loss_mse, loss_pre])
    train_history_loss.plot(
        save_path=os.path.join(opt.save_dir, 'trainlog_loss.jpg'))
    logger.info('[Epoch %d] train=%f loss=%f time: %f' %
                (epoch, acc, train_loss / (i + 1), time.time() - tic))
    #logger.info('[Epoch %d] val top1 =%f top5=%f val loss=%f,lr=%f' %
    #   (epoch, acc_top1_val, acc_top5_val, loss_val ,trainer.learning_rate ))
    logger.info(
        '[Epoch %d] val top1 =%f top5=%f val loss=%f,mesloss=%f,loss_pre = %f, lr=%f'
        % (epoch, acc_top1_val, acc_top5_val, loss_val, loss_mse, loss_pre,
           trainer.learning_rate))
    if acc_top1_val > best_val_score and epoch > 5:
        best_val_score = acc_top1_val
        net.save_parameters(
Example #10
0
    def train(epochs, ctx):
        if isinstance(ctx, mx.Context):
            ctx = [ctx]
        net.initialize(mx.init.MSRAPrelu(), ctx=ctx)

        if opt.summary:
            net.summary(mx.nd.zeros((1, 3, 32, 32)))

        if opt.dataset == 'cifar10':
            # CIFAR10
            train_data = gluon.data.DataLoader(gluon.data.vision.CIFAR10(
                train=True).transform_first(transform_train),
                                               batch_size=batch_size,
                                               shuffle=True,
                                               last_batch='discard',
                                               num_workers=num_workers)
            val_data = gluon.data.DataLoader(gluon.data.vision.CIFAR10(
                train=False).transform_first(transform_test),
                                             batch_size=batch_size,
                                             shuffle=False,
                                             num_workers=num_workers)
        elif opt.dataset == 'cifar100':
            # CIFAR100
            train_data = gluon.data.DataLoader(gluon.data.vision.CIFAR100(
                train=True).transform_first(transform_train),
                                               batch_size=batch_size,
                                               shuffle=True,
                                               last_batch='discard',
                                               num_workers=num_workers)
            val_data = gluon.data.DataLoader(gluon.data.vision.CIFAR100(
                train=False).transform_first(transform_test),
                                             batch_size=batch_size,
                                             shuffle=False,
                                             num_workers=num_workers)
        else:
            raise ValueError('Unknown Dataset')

        if optimizer == 'nag':
            trainer = gluon.Trainer(net.collect_params(), optimizer, {
                'learning_rate': opt.lr,
                'wd': opt.wd,
                'momentum': opt.momentum
            })
        elif optimizer == 'adagrad':
            trainer = gluon.Trainer(net.collect_params(), optimizer, {
                'learning_rate': opt.lr,
                'wd': opt.wd
            })
        elif optimizer == 'adam':
            trainer = gluon.Trainer(net.collect_params(), optimizer, {
                'learning_rate': opt.lr,
                'wd': opt.wd
            })
        else:
            raise ValueError('Unknown optimizer')

        metric = mx.metric.Accuracy()
        train_metric = mx.metric.Accuracy()
        loss_fn = gluon.loss.SoftmaxCrossEntropyLoss()
        train_history = TrainingHistory(['training-error', 'validation-error'])
        host_name = socket.gethostname()

        iteration = 0
        lr_decay_count = 0

        best_val_score = 0

        for epoch in range(epochs):
            tic = time.time()
            train_metric.reset()
            metric.reset()
            train_loss = 0
            num_batch = len(train_data)
            alpha = 1

            if epoch == lr_decay_epoch[lr_decay_count]:
                trainer.set_learning_rate(trainer.learning_rate * lr_decay)
                lr_decay_count += 1

            for i, batch in enumerate(train_data):
                data = gluon.utils.split_and_load(batch[0],
                                                  ctx_list=ctx,
                                                  batch_axis=0)
                label = gluon.utils.split_and_load(batch[1],
                                                   ctx_list=ctx,
                                                   batch_axis=0)
                with ag.record():
                    output = [net(X) for X in data]
                    loss = [loss_fn(yhat, y) for yhat, y in zip(output, label)]
                for l in loss:
                    l.backward()
                trainer.step(batch_size)
                train_loss += sum([l.sum().asscalar() for l in loss])

                train_metric.update(label, output)
                name, acc = train_metric.get()
                iteration += 1

            train_loss /= batch_size * num_batch
            name, acc = train_metric.get()
            name, val_acc = test(ctx, val_data)
            train_history.update([1 - acc, 1 - val_acc])
            train_history.plot(save_path='%s/%s_history.png' %
                               (plot_path, model_name))

            if val_acc > best_val_score:
                best_val_score = val_acc
                # net.save_parameters('%s/%.4f-cifar-%s-%d-best.params'%(save_dir, best_val_score, model_name, epoch))
                pass

            logging.info('[Epoch %d] train=%f val=%f loss=%f time: %f' %
                         (epoch, acc, val_acc, train_loss, time.time() - tic))

            if save_period and save_dir and (epoch + 1) % save_period == 0:
                # net.save_parameters('%s/cifar10-%s-%d.params'%(save_dir, model_name, epoch))
                pass

            if epoch == epochs - 1:
                with open(
                        opt.dataset + '_' + host_name + '_GPU_' + opt.gpus +
                        '_best_Acc.log', 'a') as f:
                    f.write('best Acc: {:.4f}\n'.format(best_val_score))

        print("best_val_score: ", best_val_score)
        if save_period and save_dir:
            # net.save_parameters('%s/cifar10-%s-%d.params'%(save_dir, model_name, epochs-1))
            pass
        # Update metrics
        train_loss += sum([l.sum().asscalar() for l in loss])
        train_metric.update(label, output)

    name, acc = train_metric.get()
    # Evaluate on Validation data
    name, val_acc = test(ctx, val_data)

    # Update history and print metrics
    train_history.update([1-acc, 1-val_acc])
    print('[Epoch %d] train=%f val=%f loss=%f time: %f' %
        (epoch, acc, val_acc, train_loss, time.time()-tic))

# We can plot the metric scores with:

train_history.plot()

################################################################
# If you trained the model for 240 epochs, the plot may look like:
#
# |image-aug|
#
# We can better observe the process of model training with plots.
# For example, one may ask what will happen if there's no data augmentation:
#
# |image-no-aug|
#
# We can see that training error is much lower than validation error.
# After the model reaches 100\% accuracy on training data,
# it stops improving on validation data.
# These two plots evidently demonstrates the importance of data augmentation.
Example #12
0
    def train(ctx, batch_size):
        #net.initialize(mx.init.Xavier(), ctx=ctx)
        train_data = DataLoader(ImageDataset(root=default.dataset_path, train=True), \
                                batch_size=batch_size,shuffle=True,num_workers=num_workers)
        val_data = DataLoader(ImageDataset(root=default.dataset_path, train=False), \
                              batch_size=batch_size, shuffle=True,num_workers=num_workers)

        # lr_epoch = [int(epoch) for epoch in args.lr_step.split(',')]
        net.collect_params().reset_ctx(ctx)
        lr = args.lr
        end_lr = args.end_lr
        lr_decay = args.lr_decay
        lr_decay_step = args.lr_decay_step
        all_step = len(train_data)
        schedule = mx.lr_scheduler.FactorScheduler(step=lr_decay_step *
                                                   all_step,
                                                   factor=lr_decay,
                                                   stop_factor_lr=end_lr)
        adam_optimizer = mx.optimizer.Adam(learning_rate=lr,
                                           lr_scheduler=schedule)
        trainer = gluon.Trainer(net.collect_params(), optimizer=adam_optimizer)

        train_metric = CtcMetrics()
        train_history = TrainingHistory(['training-error', 'validation-error'])

        iteration = 0
        best_val_score = 0

        save_period = args.save_period
        save_dir = args.save_dir
        model_name = args.prefix
        plot_path = args.save_dir
        epochs = args.end_epoch
        frequent = args.frequent
        for epoch in range(epochs):
            tic = time.time()
            train_metric.reset()
            train_loss = 0
            num_batch = 0
            tic_b = time.time()
            for datas, labels in train_data:
                data = gluon.utils.split_and_load(nd.array(datas),
                                                  ctx_list=ctx,
                                                  batch_axis=0,
                                                  even_split=False)
                label = gluon.utils.split_and_load(nd.array(labels),
                                                   ctx_list=ctx,
                                                   batch_axis=0,
                                                   even_split=False)
                with ag.record():
                    output = [net(X) for X in data]
                    loss = [loss_fn(yhat, y) for yhat, y in zip(output, label)]
                for l in loss:
                    l.backward()
                trainer.step(batch_size)
                train_loss += sum([l.sum().asscalar() for l in loss])

                train_metric.update(label, output)
                name, acc = train_metric.get()
                iteration += 1
                num_batch += 1
                if num_batch % frequent == 0:
                    train_loss_b = train_loss / (batch_size * num_batch)
                    logging.info(
                        '[Epoch %d] [num_bath %d] tain_acc=%f loss=%f time/batch: %f'
                        % (epoch, num_batch, acc, train_loss_b,
                           (time.time() - tic_b) / num_batch))
            train_loss /= batch_size * num_batch
            name, acc = train_metric.get()
            name, val_acc = test(val_data, ctx)
            train_history.update([1 - acc, 1 - val_acc])
            train_history.plot(save_path='%s/%s_history.png' %
                               (plot_path, model_name))
            if val_acc > best_val_score:
                best_val_score = val_acc
                net.save_parameters(
                    '%s/%.4f-crnn-%s-%d-best.params' %
                    (save_dir, best_val_score, model_name, epoch))
            logging.info('[Epoch %d] train=%f val=%f loss=%f time: %f' %
                         (epoch, acc, val_acc, train_loss, time.time() - tic))

            if save_period and save_dir and (epoch + 1) % save_period == 0:
                symbol_file = os.path.join(save_dir, model_name)
                net.export(path=symbol_file, epoch=epoch)
                # net.save_parameters('%s/crnn-%s-%d.params' % (save_dir, model_name, epoch))

        if save_period and save_dir:
            symbol_file = os.path.join(save_dir, model_name)
            net.export(path=symbol_file, epoch=epoch - 1)
Example #13
0
def train(net, ctx):
    if isinstance(ctx, mx.Context):
        ctx = [ctx]
    net.initialize(initializer, ctx=ctx)

    val_dataloader = get_dataloader(DatasetSplit(train=False),
                                    batch_size=100,
                                    train=False)

    metric = mx.metric.Accuracy()
    train_metric = mx.metric.Accuracy()
    loss_fn = gluon.loss.SoftmaxCrossEntropyLoss()
    if use_pillars:
        plc_loss_fn = gluon.loss.L2Loss(weight=w1)
    if use_cpl:
        loss_fn_cpl = gluon.loss.L2Loss(weight=w2)
    train_history = TrainingHistory(['training-error', 'validation-error'])
    timestr = datetime.datetime.now().strftime('%Y-%m-%d_%H-%M-%S')
    param_dir = os.path.join(save_dir, name, timestr)
    if not os.path.exists(param_dir):
        os.makedirs(param_dir)

    param_file_fmt = '%s/cifar10_%s_%d-%d-%d.params'
    training_record_fmt = '[Session %d, Epoch %d] train=%.4f val=%.4f loss=%.4f '
    if use_pillars:
        training_record_fmt += 'plc-loss=%.4f '
    training_record_fmt += 'time: %.2f'

    prev_dataloader, prev_dataset, prev_pillarset, pillarset = None, None, None, None
    record_acc = dict()

    for sess in range(sessions):

        record_acc[sess] = list()
        logging.info("[Session %d] begin training ..." % (sess + 1))
        if sess == 0 and opt.resume_s1:
            _, val_acc = test(net, ctx, val_dataloader)
            record_acc[sess].append(val_acc)
            logging.info('session 1 test acc : %.4f' % val_acc)
            prev_dataset = DatasetSplit(split_id=sess, train=True)
            prev_dataloader = get_dataloader(prev_dataset,
                                             batch_sizes[sess],
                                             train=True)
            continue

        train_dataset = DatasetSplit(split_id=sess, train=True)
        lr_decay_count, best_val_score = 0, 0

        if sess != 0:
            # Sampling data for continuous training
            logging.info(
                "[Session %d] sampling training data and pillars ..." %
                (sess + 1))
            dataloader = get_dataloader(train_dataset,
                                        batch_size=100,
                                        train=False)
            train_dataset = data_sampler.sample_dataset(train_dataset,
                                                        dataloader,
                                                        net,
                                                        loss_fn,
                                                        num_data_samples,
                                                        ctx=ctx)
            if cumulative:
                train_dataset = merge_datasets(prev_dataset, train_dataset)

        train_dataloader = get_dataloader(train_dataset,
                                          batch_sizes[sess],
                                          train=True)
        # Build trainer for net.
        trainer = gluon.Trainer(net.collect_params(), optimizer,
                                optimizer_params[sess])

        for epoch in range(epochs[sess]):
            tic = time.time()
            train_metric.reset()
            metric.reset()
            train_loss, train_plc_loss = 0, 0
            num_batch = len(train_dataloader)

            if epoch == lr_decay_epochs[sess][lr_decay_count]:
                trainer.set_learning_rate(trainer.learning_rate * lr_decay)
                lr_decay_count += 1

            for i, batch in enumerate(train_dataloader):
                data = gluon.utils.split_and_load(batch[0],
                                                  ctx_list=ctx,
                                                  batch_axis=0)
                label = gluon.utils.split_and_load(batch[1],
                                                   ctx_list=ctx,
                                                   batch_axis=0)
                all_loss = list()
                with ag.record():
                    output = [net(X)[1] for X in data]
                    output_feat = [net(X)[0] for X in data]
                    loss = [loss_fn(yhat, y) for yhat, y in zip(output, label)]
                    all_loss.extend(loss)
                    # Normalize each loss for the trainer with batch_size=1
                    all_loss = [nd.mean(l) for l in all_loss]

                ag.backward(all_loss)
                trainer.step(1, ignore_stale_grad=True)
                train_loss += sum([l.sum().asscalar() for l in loss])
                if sess > 0 and use_pillars:
                    train_plc_loss += sum(
                        [al.mean().asscalar() for al in plc_loss])

                train_metric.update(label, output)

            train_loss /= batch_sizes[sess] * num_batch
            _, acc = train_metric.get()
            _, val_acc = test(net, ctx, val_dataloader)
            train_history.update([1 - acc, 1 - val_acc])
            train_history.plot(save_path='%s/%s_history.png' %
                               (plot_path, model_name))
            if epoch >= epochs[sess] - 5:
                record_acc[sess].append(val_acc)

            training_record = [sess + 1, epoch, acc, val_acc, train_loss]
            if use_pillars:
                training_record += [train_plc_loss]
            training_record += [time.time() - tic]
            logging.info(training_record_fmt % tuple(training_record))

            net.save_parameters(
                param_file_fmt %
                (param_dir, model_name, sess, epochs[sess], epoch))
        prev_dataset = train_dataset
        prev_dataloader = train_dataloader
        prev_pillarset = pillarset
        if sess == 0 or sess == 1:
            save_data = get_dataloader(DatasetSplit(split_id=0, train=True),
                                       batch_size=10000,
                                       train=True)
            for i, batch in enumerate(save_data):
                data = gluon.utils.split_and_load(batch[0],
                                                  ctx_list=ctx,
                                                  batch_axis=0)
                label = gluon.utils.split_and_load(batch[1],
                                                   ctx_list=ctx,
                                                   batch_axis=0)
                outputs = net(data[0])[0]
                np.save('session{}_feats.npy'.format(sess), outputs.asnumpy())
                np.save('session{}_label.npy'.format(sess), label[0].asnumpy())

    for i in range(len(list(record_acc.keys()))):
        mean = np.mean(np.array(record_acc[i]))
        std = np.std(np.array(record_acc[i]))
        print('[Sess %d] Mean=%f Std=%f' % (i + 1, mean, std))
Example #14
0
    def train(epochs, ctx):
        if isinstance(ctx, mx.Context):
            ctx = [ctx]
        net.initialize(mx.init.Xavier(), ctx=ctx)

        train_data = gluon.data.DataLoader(gluon.data.vision.CIFAR10(
            train=True).transform_first(transform_train),
                                           batch_size=batch_size,
                                           shuffle=True,
                                           last_batch='discard',
                                           num_workers=num_workers)

        val_data = gluon.data.DataLoader(gluon.data.vision.CIFAR10(
            train=False).transform_first(transform_test),
                                         batch_size=batch_size,
                                         shuffle=False,
                                         num_workers=num_workers)

        trainer = gluon.Trainer(net.collect_params(), optimizer)
        metric = mx.metric.Accuracy()
        train_metric = mx.metric.Accuracy()
        loss_fn = gluon.loss.SoftmaxCrossEntropyLoss()
        train_history = TrainingHistory(['training-error', 'validation-error'])

        iteration = 0
        lr_decay_count = 0

        best_val_score = 0

        for epoch in range(epochs):
            tic = time.time()
            train_metric.reset()
            metric.reset()
            train_loss = 0
            num_batch = len(train_data)

            if epoch == lr_decay_epoch[lr_decay_count]:
                trainer.set_learning_rate(trainer.learning_rate * lr_decay)
                lr_decay_count += 1

            for i, batch in enumerate(train_data):
                data = gluon.utils.split_and_load(batch[0],
                                                  ctx_list=ctx,
                                                  batch_axis=0)
                label = gluon.utils.split_and_load(batch[1],
                                                   ctx_list=ctx,
                                                   batch_axis=0)

                with ag.record():
                    output = [net(X) for X in data]
                    loss = [loss_fn(yhat, y) for yhat, y in zip(output, label)]
                for l in loss:
                    l.backward()
                trainer.step(batch_size)
                train_loss += sum([l.sum().asscalar() for l in loss])

                train_metric.update(label, output)
                name, acc = train_metric.get()
                iteration += 1

            train_loss /= batch_size * num_batch
            name, acc = train_metric.get()
            name, val_acc = test(ctx, val_data)
            train_history.update([1 - acc, 1 - val_acc])
            train_history.plot(save_path='%s/%s_history.png' %
                               (plot_path, model_name))

            if val_acc > best_val_score:
                best_val_score = val_acc
                net.save_parameters(
                    '%s/%.4f-cifar-%s-%d-best.params' %
                    (save_dir, best_val_score, model_name, epoch))

            logging.info('[Epoch %d] train=%f val=%f loss=%f time: %f' %
                         (epoch, acc, val_acc, train_loss, time.time() - tic))

            if save_period and save_dir and (epoch + 1) % save_period == 0:
                net.save_parameters('%s/cifar10-%s-%d.params' %
                                    (save_dir, model_name, epoch))

        if save_period and save_dir:
            net.save_parameters('%s/cifar10-%s-%d.params' %
                                (save_dir, model_name, epochs - 1))
        # AutoGrad
        with ag.record():
            output = [net(X) for X in data]
            loss = [loss_fn(yhat, y) for yhat, y in zip(output, label)]

        # Backpropagation
        for l in loss:
            l.backward()

        # Optimize
        trainer.step(batch_size)

        # Update metrics
        train_loss += sum([l.sum().asscalar() for l in loss])
        train_metric.update(label, output)

    name, acc = train_metric.get()
    # Evaluate on Validation data
    name, val_acc = test(ctx, val_data)

    # Update history and print metrics
    train_history.update([1 - acc, 1 - val_acc])
    print('[Epoch %d] train=%f val=%f loss=%f time: %f' %
          (epoch, acc, val_acc, train_loss, time.time() - tic))

# We can plot the metric scores with:
train_history.plot(['training-error', 'validation-error'],
                   save_path="./cifar100_resnet56_v1_nadam.png")
print("Done.")
Example #16
0
            output = []
            for _, X in enumerate(data):
                X = X.reshape((-1,) + X.shape[2:])
                pred = net(X)
                output.append(pred)
            loss = [loss_fn(yhat, y) for yhat, y in zip(output, label)]

        # Backpropagation
        for l in loss:
            l.backward()

        # Optimize
        trainer.step(batch_size)

        # Update metrics
        train_loss += sum([l.mean().asscalar() for l in loss])
        train_metric.update(label, output)

        if i == 100:
            break

    name, acc = train_metric.get()

    # Update history and print metrics
    train_history.update([acc])
    print('[Epoch %d] train=%f loss=%f time: %f' %
        (epoch, acc, train_loss / (i+1), time.time()-tic))

# We can plot the metric scores with:
train_history.plot()
Example #17
0
            if gen_update_time > 25:
                iter4G = 5

        # make a prediction
        if ep % pred_per_epoch == 0:
            fake = generator(make_noise(1))[0]
            unique_fake = generator(pred_noise)[0]
            pred_path = 'logs/pred-w1keras'
            pred_unique_path = os.path.join(pred_path, 'unique')
            makedirs(pred_path)
            makedirs(pred_unique_path)
            vis.show_img(fake.transpose((1, 2, 0)), save_path=pred_path)
            vis.show_img(unique_fake.transpose((1, 2, 0)),
                         save_path=pred_unique_path)

        # save checkpoint
        if should_save_checkpoint:
            if ep % save_per_epoch == 0:
                generator.save_parameters(
                    os.path.join(save_dir,
                                 'generator_{:04d}.params'.format(ep)))
                discriminator.save_parameters(
                    os.path.join(save_dir,
                                 'discriminator_{:04d}.params'.format(ep)))

        # save history plot every epoch
        history.plot(history_labels, save_path='logs/historys-w1keras')

history.plot(history_labels, save_path='logs/history-sw1keras')
Example #18
0
        # Optimize
        trainer.step(batch_size,ignore_stale_grad=True)        

        # Update metrics
        train_loss += sum([l.mean().asscalar() for l in loss])
        train_metric.update(label, output)
        if i % opt.log_interval == 0:
            name, acc = train_metric.get()
            logger.info('[Epoch %d] [%d | %d] train=%f loss=%f time: %f' %
                  (epoch,i,len(train_data), acc, train_loss / (i+1), time.time()-btic) )
            btic = time.time()

    name, acc = train_metric.get()
    
    # test
    acc_top1_val, acc_top5_val, loss_val = test(ctx, val_data)

    # Update history and print metrics
    train_history.update([acc,acc_top1_val,acc_top5_val])
    logger.info('[Epoch %d] train=%f loss=%f time: %f' %
        (epoch, acc, train_loss / (i+1), time.time()-tic))
    logger.info('[Epoch %d] val top1 =%f top5=%f val loss=%f,lr=%f' %
        (epoch, acc_top1_val, acc_top5_val, loss_val ,trainer.learning_rate ))    
    if acc_top1_val > best_val_score and epoch > 5:
        best_val_score = acc_top1_val
        net.save_parameters('%s/%.4f-%s-%s-%03d-best.params'%(opt.save_dir, best_val_score, opt.dataset, opt.model, epoch))
        trainer.save_states('%s/%.4f-%s-%s-%03d-best.states'%(opt.save_dir, best_val_score, opt.dataset, opt.model, epoch))            

# We can plot the metric scores with:
train_history.plot(save_path=os.path.join(opt.save_dir,'trainlog.jpg'))
Example #19
0
    def train(epochs, ctx):
        if isinstance(ctx, mx.Context):
            ctx = [ctx]
        net.initialize(mx.init.MSRAPrelu(), ctx=ctx)

        root = os.path.join('..', 'datasets', 'cifar-10')
        train_data = gluon.data.DataLoader(
            gluon.data.vision.CIFAR10(
                root=root, train=True).transform_first(transform_train),
            batch_size=batch_size, shuffle=True, last_batch='discard', num_workers=num_workers)

        val_data = gluon.data.DataLoader(
            gluon.data.vision.CIFAR10(
                root=root, train=False).transform_first(transform_test),
            batch_size=batch_size, shuffle=False, num_workers=num_workers)

        trainer = gluon.Trainer(net.collect_params(), optimizer,
                                {'learning_rate': opt.lr, 'wd': opt.wd,
                                 'momentum': opt.momentum, 'lr_scheduler': lr_sch})
        if opt.amp:
            amp.init_trainer(trainer)
        metric = mx.metric.Accuracy()
        train_metric = mx.metric.RMSE()
        loss_fn = gluon.loss.SoftmaxCrossEntropyLoss(
            sparse_label=False if opt.mixup else True)
        train_history = TrainingHistory(['training-error', 'validation-error'])
        # acc_history = TrainingHistory(['training-acc', 'validation-acc'])
        loss_history = TrainingHistory(['training-loss', 'validation-loss'])

        iteration = 0

        best_val_score = 0

        for epoch in range(epochs):
            tic = time.time()
            train_metric.reset()
            metric.reset()
            train_loss = 0
            num_batch = len(train_data)
            alpha = 1

            for i, batch in enumerate(train_data):
                if epoch == 0 and iteration == 1 and opt.profile_mode:
                    profiler.set_state('run')
                lam = np.random.beta(alpha, alpha)
                if epoch >= epochs - 20 or not opt.mixup:
                    lam = 1

                data_1 = gluon.utils.split_and_load(
                    batch[0], ctx_list=ctx, batch_axis=0)
                label_1 = gluon.utils.split_and_load(
                    batch[1], ctx_list=ctx, batch_axis=0)

                if not opt.mixup:
                    data = data_1
                    label = label_1
                else:
                    data = [lam*X + (1-lam)*X[::-1] for X in data_1]
                    label = []
                    for Y in label_1:
                        y1 = label_transform(Y, classes)
                        y2 = label_transform(Y[::-1], classes)
                        label.append(lam*y1 + (1-lam)*y2)

                with ag.record():
                    output = [net(X) for X in data]
                    loss = [loss_fn(yhat, y) for yhat, y in zip(output, label)]
                if opt.amp:
                    with ag.record():
                        with amp.scale_loss(loss, trainer) as scaled_loss:
                            ag.backward(scaled_loss)
                            # scaled_loss.backward()
                else:
                    for l in loss:
                        l.backward()
                trainer.step(batch_size)
                train_loss += sum([l.sum().asscalar() for l in loss])

                output_softmax = [nd.SoftmaxActivation(out) for out in output]
                train_metric.update(label, output_softmax)
                metric.update(label_1, output_softmax)
                name, acc = train_metric.get()
                sw.add_scalar(tag='lr', value=trainer.learning_rate,
                              global_step=iteration)
                if epoch == 0 and iteration == 1 and opt.profile_mode:
                    nd.waitall()
                    profiler.set_state('stop')
                iteration += 1

            train_loss /= batch_size * num_batch
            name, acc = train_metric.get()
            _, train_acc = metric.get()
            name, val_acc, _ = test(ctx, val_data)
            if opt.mixup:
                train_history.update([acc, 1-val_acc])
                plt.cla()
                train_history.plot(save_path='%s/%s_history.png' %
                                   (plot_name, model_name))
            else:
                train_history.update([1-train_acc, 1-val_acc])
                plt.cla()
                train_history.plot(save_path='%s/%s_history.png' %
                                   (plot_name, model_name))
            # acc_history.update([train_acc, val_acc])
            # plt.cla()
            # acc_history.plot(save_path='%s/%s_acc.png' %
            #                  (plot_name, model_name), legend_loc='best')

            if val_acc > best_val_score:
                best_val_score = val_acc
                net.save_parameters('%s/%.4f-cifar-%s-%d-best.params' %
                                    (save_dir, best_val_score, model_name, epoch))

            current_lr = trainer.learning_rate
            name, val_acc, val_loss = test(ctx, val_data)
            loss_history.update([train_loss, val_loss])
            plt.cla()
            loss_history.plot(save_path='%s/%s_loss.png' %
                              (plot_name, model_name), y_lim=(0, 2), legend_loc='best')
            logging.info('[Epoch %d] loss=%f train_acc=%f train_RMSE=%f\n     val_acc=%f val_loss=%f lr=%f time: %f' %
                         (epoch, train_loss, train_acc, acc, val_acc, val_loss, current_lr, time.time()-tic))
            sw._add_scalars(tag='Acc',
                            scalar_dict={'train_acc': train_acc, 'test_acc': val_acc}, global_step=epoch)
            sw._add_scalars(tag='Loss',
                            scalar_dict={'train_loss': train_loss, 'test_loss': val_loss}, global_step=epoch)
            if save_period and save_dir and (epoch + 1) % save_period == 0:
                net.save_parameters('%s/cifar10-%s-%d.params' %
                                    (save_dir, model_name, epoch))
        if save_period and save_dir:
            net.save_parameters('%s/cifar10-%s-%d.params' %
                                (save_dir, model_name, epochs-1))