示例#1
0
def get_built_in_dataset(name, train=True, input_size=224, batch_size=256, num_workers=32,
                         shuffle=True, **kwargs):
    """Returns built-in popular image classification dataset based on provided string name ('cifar10', 'cifar100','mnist','imagenet').
    """
    logger.info('get_built_in_dataset {}'.format(name))
    name = name.lower()
    if name in ['cifar10', 'cifar']:
        import gluoncv.data.transforms as gcv_transforms
        transform_split = transforms.Compose([
            gcv_transforms.RandomCrop(32, pad=4),
            transforms.RandomFlipLeftRight(),
            transforms.ToTensor(),
            transforms.Normalize([0.4914, 0.4822, 0.4465], [0.2023, 0.1994, 0.2010])
        ]) if train else transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize([0.4914, 0.4822, 0.4465], [0.2023, 0.1994, 0.2010])
        ])
        return gluon.data.vision.CIFAR10(train=train).transform_first(transform_split)
    elif name == 'cifar100':
        import gluoncv.data.transforms as gcv_transforms
        transform_split = transforms.Compose([
            gcv_transforms.RandomCrop(32, pad=4),
            transforms.RandomFlipLeftRight(),
            transforms.ToTensor(),
            transforms.Normalize([0.4914, 0.4822, 0.4465], [0.2023, 0.1994, 0.2010])
        ]) if train else transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize([0.4914, 0.4822, 0.4465], [0.2023, 0.1994, 0.2010])
        ])
        return gluon.data.vision.CIFAR100(train=train).transform_first(transform_split)
    elif name == 'mnist':
        def transform(data, label):
            return nd.transpose(data.astype(np.float32), (2,0,1))/255, label.astype(np.float32)
        return gluon.data.vision.MNIST(train=train, transform=transform)
    elif name == 'fashionmnist':
        def transform(data, label):
            return nd.transpose(data.astype(np.float32), (2,0,1))/255, label.astype(np.float32)
        return gluon.data.vision.FashionMNIST(train=train, transform=transform)
    elif name == 'imagenet':
        # Please setup the ImageNet dataset following the tutorial from GluonCV
        if train:
            rec_file = '/media/ramdisk/rec/train.rec'
            rec_file_idx = '/media/ramdisk/rec/train.idx'
        else:
            rec_file = '/media/ramdisk/rec/val.rec'
            rec_file_idx = '/media/ramdisk/rec/val.idx'
        data_loader = get_data_rec(input_size, 0.875, rec_file, rec_file_idx,
                                   batch_size, num_workers, train, shuffle=shuffle,
                                   **kwargs)
        return data_loader
    else:
        raise NotImplementedError
示例#2
0
def get_data_loader(data_dir, batch_size, num_workers):
    normalize = transforms.Normalize([0.485, 0.456, 0.406],
                                     [0.229, 0.224, 0.225])
    jitter_param = 0.4
    lighting_param = 0.1
    input_size = opt.input_size
    crop_ratio = opt.crop_ratio if opt.crop_ratio > 0 else 0.875
    resize = int(math.ceil(input_size / crop_ratio))

    def batch_fn(batch, ctx):
        data = gluon.utils.split_and_load(batch[0], ctx_list=ctx, batch_axis=0)
        label = gluon.utils.split_and_load(batch[1],
                                           ctx_list=ctx,
                                           batch_axis=0)
        return data, label

    transform_train = transforms.Compose([
        gcv_transforms.RandomCrop(input_size, pad=4),
        transforms.RandomFlipLeftRight(),
        # transforms.RandomColorJitter(
        #     brightness=jitter_param,
        #     contrast=jitter_param,
        #     saturation=jitter_param),
        # transforms.RandomLighting(lighting_param),
        transforms.ToTensor(),
        normalize
    ])
    transform_test = transforms.Compose([
        # transforms.Resize(resize, keep_ratio=True),
        # transforms.CenterCrop(input_size),
        transforms.Resize(input_size),
        transforms.ToTensor(),
        normalize
    ])

    train_data = gluon.data.DataLoader(ImageNet32(
        data_dir, train=True).transform_first(transform_train),
                                       batch_size=batch_size,
                                       shuffle=True,
                                       last_batch='discard',
                                       num_workers=num_workers)
    val_data = gluon.data.DataLoader(ImageNet32(
        data_dir, train=False).transform_first(transform_test),
                                     batch_size=batch_size,
                                     shuffle=False,
                                     num_workers=num_workers)

    return train_data, val_data, batch_fn
示例#3
0
    def get_data(self, opt):
        def batch_fn(batch, ctx):
            data = gluon.utils.split_and_load(batch[0],
                                              ctx_list=ctx,
                                              batch_axis=0)
            label = gluon.utils.split_and_load(batch[1],
                                               ctx_list=ctx,
                                               batch_axis=0)
            return data, label

        transform_train = transforms.Compose([
            gcv_transforms.RandomCrop(32, pad=4),
            transforms.RandomFlipLeftRight(),
            transforms.ToTensor(),
            transforms.Normalize([0.4914, 0.4822, 0.4465],
                                 [0.2023, 0.1994, 0.2010])
        ])

        transform_test = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize([0.4914, 0.4822, 0.4465],
                                 [0.2023, 0.1994, 0.2010])
        ])

        train_data = gluon.data.DataLoader(gluon.data.vision.CIFAR100(
            train=True).transform_first(transform_train),
                                           batch_size=opt.batch_size,
                                           shuffle=True,
                                           last_batch='discard',
                                           num_workers=opt.num_workers)

        val_data = gluon.data.DataLoader(gluon.data.vision.CIFAR100(
            train=False).transform_first(transform_test),
                                         batch_size=opt.batch_size,
                                         shuffle=False,
                                         num_workers=opt.num_workers)
        return train_data, val_data, batch_fn
示例#4
0
def main():
    opt = parse_args()

    batch_size = opt.batch_size
    if opt.dataset == 'cifar10':
        classes = 10
    elif opt.dataset == 'cifar100':
        classes = 100
    else:
        raise ValueError('Unknown Dataset')

    if len(mx.test_utils.list_gpus()) == 0:
        context = [mx.cpu()]
    else:
        context = [mx.gpu(int(i)) for i in opt.gpus.split(',') if i.strip()]
        context = context if context else [mx.cpu()]
    print("context: ", context)
    num_gpus = len(context)
    batch_size *= max(1, num_gpus)
    num_workers = opt.num_workers

    lr_decay = opt.lr_decay
    lr_decay_period = opt.lr_decay_period
    if opt.lr_decay_period > 0:
        lr_decay_epoch = list(range(lr_decay_period, opt.num_epochs, lr_decay_period))
    else:
        # lr_decay_epoch = [int(i) for i in opt.lr_decay_epoch.split(',')]
        lr_decay_epoch = [int(i) for i in opt.lr_decay_epoch.split(',')] + [np.inf]
    lr_decay_epoch = [e - opt.warmup_epochs for e in lr_decay_epoch]
    num_batches = 50000 // batch_size

    lr_scheduler = LRSequential([
        LRScheduler('linear', base_lr=0, target_lr=opt.lr,
                    nepochs=opt.warmup_epochs, iters_per_epoch=num_batches),
        LRScheduler('cosine', base_lr=opt.lr, target_lr=0,
                    nepochs=opt.num_epochs - opt.warmup_epochs,
                    iters_per_epoch=num_batches,
                    step_epoch=lr_decay_epoch,
                    step_factor=lr_decay, power=2)
    ])

    optimizer = 'nag'
    if opt.cosine:
        optimizer_params = {'wd': opt.wd, 'momentum': opt.momentum, 'lr_scheduler': lr_scheduler}
    else:
        optimizer_params = {'learning_rate': opt.lr, 'wd': opt.wd, 'momentum': opt.momentum}

    layers = [opt.blocks] * 3
    channels = [x * opt.channel_times for x in [16, 16, 32, 64]]
    start_layer = opt.start_layer
    askc_type = opt.askc_type

    if channels[0] == 64:
        cardinality = 32
    elif channels[0] == 128:
        cardinality = 64
    bottleneck_width = 4

    print("model: ", opt.model)
    print("askc_type: ", opt.askc_type)
    print("layers: ", layers)
    print("channels: ", channels)
    print("start_layer: ", start_layer)
    print("classes: ", classes)
    print("deep_stem: ", opt.deep_stem)

    model_prefix = opt.dataset + '-' + askc_type
    model_suffix = '-c-' + str(opt.channel_times) + '-s-' + str(opt.start_layer)
    if opt.model == 'resnet':
        net = CIFARAFFResNet(askc_type=askc_type, start_layer=start_layer, layers=layers,
                             channels=channels, classes=classes, deep_stem=opt.deep_stem)
        model_name = model_prefix + '-resnet-' + str(sum(layers) * 2 + 2) + model_suffix
    elif opt.model == 'resnext':
        net = CIFARAFFResNeXt(askc_type=askc_type, start_layer=start_layer, layers=layers,
                              channels=channels, cardinality=cardinality,
                              bottleneck_width=bottleneck_width, classes=classes,
                              deep_stem=opt.deep_stem, use_se=False)
        model_name = model_prefix + '-resneXt-' + str(sum(layers) * 3 + 2) + '-' + \
                     str(cardinality) + 'x' + str(bottleneck_width) + model_suffix
    else:
        raise ValueError('Unknown opt.model')

    if opt.resume_from:
        net.load_parameters(opt.resume_from, ctx=context, ignore_extra=True)

    save_period = opt.save_period
    if opt.save_dir and save_period:
        save_dir = opt.save_dir
        makedirs(save_dir)
    else:
        save_dir = ''
        save_period = 0

    plot_name = opt.save_plot_dir

    logging_handlers = [logging.StreamHandler()]
    if opt.logging_dir:
        logging_dir = opt.logging_dir
        makedirs(logging_dir)
        logging_handlers.append(logging.FileHandler('%s/train_%s.log' %
                                                    (logging_dir, model_name)))

    logging.basicConfig(level=logging.INFO, handlers=logging_handlers)
    logging.info(opt)

    transform_train = []
    if opt.auto_aug:
        print('Using AutoAugment')
        from autogluon.utils.augment import AugmentationBlock, autoaug_cifar10_policies
        transform_train.append(AugmentationBlock(autoaug_cifar10_policies()))

    transform_train.extend([
        gcv_transforms.RandomCrop(32, pad=4),
        transforms.RandomFlipLeftRight(),
        transforms.ToTensor(),
        transforms.Normalize([0.4914, 0.4822, 0.4465], [0.2023, 0.1994, 0.2010])
    ])
    transform_train = transforms.Compose(transform_train)

    transform_test = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize([0.4914, 0.4822, 0.4465], [0.2023, 0.1994, 0.2010])
    ])

    def label_transform(label, classes):
        ind = label.astype('int')
        res = nd.zeros((ind.shape[0], classes), ctx=label.context)
        res[nd.arange(ind.shape[0], ctx=label.context), ind] = 1
        return res

    def mixup_transform(label, classes, lam=1, eta=0.0):
        if isinstance(label, nd.NDArray):
            label = [label]
        res = []
        for l in label:
            y1 = l.one_hot(classes, on_value=1 - eta + eta/classes, off_value=eta/classes)
            y2 = l[::-1].one_hot(classes, on_value=1 - eta + eta/classes, off_value=eta/classes)
            res.append(lam*y1 + (1-lam)*y2)
        return res

    def smooth(label, classes, eta=0.1):
        if isinstance(label, nd.NDArray):
            label = [label]
        smoothed = []
        for l in label:
            res = l.one_hot(classes, on_value=1 - eta + eta/classes, off_value = eta/classes)
            smoothed.append(res)
        return smoothed

    def test(ctx, val_data):
        metric = mx.metric.Accuracy()
        for i, batch in enumerate(val_data):
            data = gluon.utils.split_and_load(batch[0], ctx_list=ctx, batch_axis=0)
            label = gluon.utils.split_and_load(batch[1], ctx_list=ctx, batch_axis=0)
            outputs = [net(X) for X in data]
            metric.update(label, outputs)
        return metric.get()

    def train(epochs, ctx):
        if isinstance(ctx, mx.Context):
            ctx = [ctx]
        net.initialize(mx.init.Xavier(), ctx=ctx)

        if opt.summary:
            summary(net, mx.nd.zeros((1, 3, 32, 32), ctx=ctx[0]))
            sys.exit()

        if opt.dataset == 'cifar10':
            train_data = gluon.data.DataLoader(
                gluon.data.vision.CIFAR10(train=True).transform_first(transform_train),
                batch_size=batch_size, shuffle=True, last_batch='discard', num_workers=num_workers)
            val_data = gluon.data.DataLoader(
                gluon.data.vision.CIFAR10(train=False).transform_first(transform_test),
                batch_size=batch_size, shuffle=False, num_workers=num_workers)
        elif opt.dataset == 'cifar100':
            train_data = gluon.data.DataLoader(
                gluon.data.vision.CIFAR100(train=True).transform_first(transform_train),
                batch_size=batch_size, shuffle=True, last_batch='discard', num_workers=num_workers)
            val_data = gluon.data.DataLoader(
                gluon.data.vision.CIFAR100(train=False).transform_first(transform_test),
                batch_size=batch_size, shuffle=False, num_workers=num_workers)
        else:
            raise ValueError('Unknown Dataset')

        if opt.no_wd and opt.cosine:
            for k, v in net.collect_params('.*beta|.*gamma|.*bias').items():
                v.wd_mult = 0.0

        trainer = gluon.Trainer(net.collect_params(), optimizer, optimizer_params)

        if opt.label_smoothing or opt.mixup:
            sparse_label_loss = False
        else:
            sparse_label_loss = True

        metric = mx.metric.Accuracy()
        train_metric = mx.metric.RMSE()
        loss_fn = gluon.loss.SoftmaxCrossEntropyLoss(sparse_label=sparse_label_loss)
        train_history = TrainingHistory(['training-error', 'validation-error'])

        iteration = 0
        lr_decay_count = 0

        best_val_score = 0

        for epoch in range(epochs):
            tic = time.time()
            train_metric.reset()
            metric.reset()
            train_loss = 0
            num_batch = len(train_data)

            if not opt.cosine:
                if epoch == lr_decay_epoch[lr_decay_count]:
                    trainer.set_learning_rate(trainer.learning_rate * lr_decay)
                    lr_decay_count += 1

            for i, batch in enumerate(train_data):
                data_1 = gluon.utils.split_and_load(batch[0], ctx_list=ctx, batch_axis=0)
                label_1 = gluon.utils.split_and_load(batch[1], ctx_list=ctx, batch_axis=0)

                if opt.mixup:
                    lam = np.random.beta(opt.mixup_alpha, opt.mixup_alpha)
                    if (epoch >= epochs - opt.mixup_off_epoch) or not opt.mixup:
                        lam = 1

                    data = [lam * X + (1 - lam) * X[::-1] for X in data_1]

                    if opt.label_smoothing:
                        eta = 0.1
                    else:
                        eta = 0.0
                    label = mixup_transform(label_1, classes, lam, eta)

                elif opt.label_smoothing:
                    hard_label = label_1
                    label = smooth(label_1, classes)

                with ag.record():
                    output = [net(X) for X in data]
                    loss = [loss_fn(yhat, y) for yhat, y in zip(output, label)]
                for l in loss:
                    l.backward()
                trainer.step(batch_size)
                train_loss += sum([l.sum().asscalar() for l in loss])

                if opt.mixup:
                    output_softmax = [nd.SoftmaxActivation(out) for out in output]
                    train_metric.update(label, output_softmax)
                else:
                    if opt.label_smoothing:
                        train_metric.update(hard_label, output)
                    else:
                        train_metric.update(label, output)

                name, acc = train_metric.get()
                iteration += 1

            train_loss /= batch_size * num_batch
            name, acc = train_metric.get()
            name, val_acc = test(ctx, val_data)
            train_history.update([acc, 1 - val_acc])
            train_history.plot(save_path='%s/%s_history.png' % (plot_name, model_name))

            if val_acc > best_val_score:
                best_val_score = val_acc
                net.save_parameters('%s/%.4f-%s-best.params' %
                                    (save_dir, best_val_score, model_name))

            name, val_acc = test(ctx, val_data)
            logging.info('[Epoch %d] train=%f val=%f loss=%f lr: %f time: %f' %
                         (epoch, acc, val_acc, train_loss, trainer.learning_rate,
                          time.time() - tic))

        host_name = socket.gethostname()
        with open(opt.dataset + '_' + host_name + '_GPU_' + opt.gpus + '_best_Acc.log', 'a') as f:
            f.write('best Acc: {:.4f}\n'.format(best_val_score))
        print("best_val_score: ", best_val_score)

    if not opt.summary:
        if opt.mode == 'hybrid':
            net.hybridize()
    train(opt.num_epochs, context)
示例#5
0
def train_net(net, config, check_flag, logger, sig_state, sig_pgbar, sig_table):
    print(config)
    # config = Configs()
    # matplotlib.use('Agg')
    # import matplotlib.pyplot as plt
    sig_pgbar.emit(-1)
    mx.random.seed(1)
    matplotlib.use('Agg')
    import matplotlib.pyplot as plt

    classes = 10
    num_epochs = config.train_cfg.epoch
    batch_size = config.train_cfg.batchsize
    optimizer = config.lr_cfg.optimizer
    lr = config.lr_cfg.lr
    num_gpus = config.train_cfg.gpu
    batch_size *= max(1, num_gpus)
    context = [mx.gpu(i)
               for i in range(num_gpus)] if num_gpus > 0 else [mx.cpu()]
    num_workers = config.data_cfg.worker

    warmup = config.lr_cfg.warmup
    if config.lr_cfg.decay == 'cosine':
        lr_sch = lr_scheduler.CosineScheduler((50000//batch_size)*num_epochs,
                                              base_lr=lr,
                                              warmup_steps=warmup *
                                              (50000//batch_size),
                                              final_lr=1e-5)
    else:
        lr_sch = lr_scheduler.FactorScheduler((50000//batch_size)*config.lr_cfg.factor_epoch,
                                              factor=config.lr_cfg.factor,
                                              base_lr=lr,
                                              warmup_steps=warmup*(50000//batch_size))

    model_name = config.net_cfg.name

    if config.data_cfg.mixup:
        model_name += '_mixup'
    if config.train_cfg.amp:
        model_name += '_amp'

    base_dir = './'+model_name
    if os.path.exists(base_dir):
        base_dir = base_dir + '-' + \
            time.strftime("%m-%d-%H.%M.%S", time.localtime())
    makedirs(base_dir)

    if config.save_cfg.tensorboard:
        logdir = base_dir+'/tb/'+model_name
        if os.path.exists(logdir):
            logdir = logdir + '-' + \
                time.strftime("%m-%d-%H.%M.%S", time.localtime())
        sw = SummaryWriter(logdir=logdir, flush_secs=5, verbose=False)
        cmd_file = open(base_dir+'/tb.bat', mode='w')
        cmd_file.write('tensorboard --logdir=./')
        cmd_file.close()

    save_period = 10
    save_dir = base_dir+'/'+'params'
    makedirs(save_dir)

    plot_name = base_dir+'/'+'plot'
    makedirs(plot_name)

    stat_name = base_dir+'/'+'stat.txt'

    csv_name = base_dir+'/'+'data.csv'
    if os.path.exists(csv_name):
        csv_name = base_dir+'/'+'data-' + \
            time.strftime("%m-%d-%H.%M.%S", time.localtime())+'.csv'
    csv_file = open(csv_name, mode='w', newline='')
    csv_writer = csv.writer(csv_file)
    csv_writer.writerow(['Epoch', 'train_loss', 'train_acc',
                         'valid_loss', 'valid_acc', 'lr', 'time'])

    logging_handlers = [logging.StreamHandler(), logger]
    logging_handlers.append(logging.FileHandler(
        '%s/train_cifar10_%s.log' % (model_name, model_name)))

    logging.basicConfig(level=logging.INFO, handlers=logging_handlers)
    logging.info(config)

    if config.train_cfg.amp:
        amp.init()

    if config.save_cfg.profiler:
        profiler.set_config(profile_all=True,
                            aggregate_stats=True,
                            continuous_dump=True,
                            filename=base_dir+'/%s_profile.json' % model_name)
        is_profiler_run = False

    trans_list = []
    imgsize = config.data_cfg.size
    if config.data_cfg.crop:
        trans_list.append(gcv_transforms.RandomCrop(
            32, pad=config.data_cfg.crop_pad))
    if config.data_cfg.cutout:
        trans_list.append(CutOut(config.data_cfg.cutout_size))
    if config.data_cfg.flip:
        trans_list.append(transforms.RandomFlipLeftRight())
    if config.data_cfg.erase:
        trans_list.append(gcv_transforms.block.RandomErasing(s_max=0.25))
    trans_list.append(transforms.Resize(imgsize))
    trans_list.append(transforms.ToTensor())
    trans_list.append(transforms.Normalize([0.4914, 0.4822, 0.4465],
                                           [0.2023, 0.1994, 0.2010]))

    transform_train = transforms.Compose(trans_list)

    transform_test = transforms.Compose([
        transforms.Resize(imgsize),
        transforms.ToTensor(),
        transforms.Normalize([0.4914, 0.4822, 0.4465],
                             [0.2023, 0.1994, 0.2010])
    ])

    def label_transform(label, classes):
        ind = label.astype('int')
        res = nd.zeros((ind.shape[0], classes), ctx=label.context)
        res[nd.arange(ind.shape[0], ctx=label.context), ind] = 1
        return res

    def test(ctx, val_data):
        metric = mx.metric.Accuracy()
        loss_fn = gluon.loss.SoftmaxCrossEntropyLoss()
        num_batch = len(val_data)
        test_loss = 0
        for i, batch in enumerate(val_data):
            data = gluon.utils.split_and_load(
                batch[0], ctx_list=ctx, batch_axis=0)
            label = gluon.utils.split_and_load(
                batch[1], ctx_list=ctx, batch_axis=0)
            outputs = [net(X) for X in data]
            loss = [loss_fn(yhat, y) for yhat, y in zip(outputs, label)]
            metric.update(label, outputs)
            test_loss += sum([l.sum().asscalar() for l in loss])
        test_loss /= batch_size * num_batch
        name, val_acc = metric.get()
        return name, val_acc, test_loss

    def train(epochs, ctx):
        if isinstance(ctx, mx.Context):
            ctx = [ctx]

        if config.train_cfg.param_init:
            init_func = getattr(mx.init, config.train_cfg.init)
            net.initialize(init_func(), ctx=ctx, force_reinit=True)
        else:
            net.load_parameters(config.train_cfg.param_file, ctx=ctx)

        summary(net, stat_name, nd.uniform(
            shape=(1, 3, imgsize, imgsize), ctx=ctx[0]))
        # net = nn.HybridBlock()
        net.hybridize()

        root = config.dir_cfg.dataset
        train_data = gluon.data.DataLoader(
            gluon.data.vision.CIFAR10(
                root=root, train=True).transform_first(transform_train),
            batch_size=batch_size, shuffle=True, last_batch='discard', num_workers=num_workers)

        val_data = gluon.data.DataLoader(
            gluon.data.vision.CIFAR10(
                root=root, train=False).transform_first(transform_test),
            batch_size=batch_size, shuffle=False, num_workers=num_workers)

        trainer_arg = {'learning_rate': config.lr_cfg.lr,
                       'wd': config.lr_cfg.wd, 'lr_scheduler': lr_sch}
        extra_arg = eval(config.lr_cfg.extra_arg)
        trainer_arg.update(extra_arg)
        trainer = gluon.Trainer(net.collect_params(), optimizer, trainer_arg)
        if config.train_cfg.amp:
            amp.init_trainer(trainer)
        metric = mx.metric.Accuracy()
        train_metric = mx.metric.RMSE()
        loss_fn = gluon.loss.SoftmaxCrossEntropyLoss(
            sparse_label=False if config.data_cfg.mixup else True)
        train_history = TrainingHistory(['training-error', 'validation-error'])
        # acc_history = TrainingHistory(['training-acc', 'validation-acc'])
        loss_history = TrainingHistory(['training-loss', 'validation-loss'])

        iteration = 0

        best_val_score = 0

        # print('start training')
        sig_state.emit(1)
        sig_pgbar.emit(0)
        # signal.emit('Training')
        for epoch in range(epochs):
            tic = time.time()
            train_metric.reset()
            metric.reset()
            train_loss = 0
            num_batch = len(train_data)
            alpha = 1
            for i, batch in enumerate(train_data):
                if epoch == 0 and iteration == 1 and config.save_cfg.profiler:
                    profiler.set_state('run')
                    is_profiler_run = True
                if epoch == 0 and iteration == 1 and config.save_cfg.tensorboard:
                    sw.add_graph(net)
                lam = np.random.beta(alpha, alpha)
                if epoch >= epochs - 20 or not config.data_cfg.mixup:
                    lam = 1

                data_1 = gluon.utils.split_and_load(
                    batch[0], ctx_list=ctx, batch_axis=0)
                label_1 = gluon.utils.split_and_load(
                    batch[1], ctx_list=ctx, batch_axis=0)

                if not config.data_cfg.mixup:
                    data = data_1
                    label = label_1
                else:
                    data = [lam*X + (1-lam)*X[::-1] for X in data_1]
                    label = []
                    for Y in label_1:
                        y1 = label_transform(Y, classes)
                        y2 = label_transform(Y[::-1], classes)
                        label.append(lam*y1 + (1-lam)*y2)

                with ag.record():
                    output = [net(X) for X in data]
                    loss = [loss_fn(yhat, y) for yhat, y in zip(output, label)]
                if config.train_cfg.amp:
                    with ag.record():
                        with amp.scale_loss(loss, trainer) as scaled_loss:
                            ag.backward(scaled_loss)
                            # scaled_loss.backward()
                else:
                    for l in loss:
                        l.backward()
                trainer.step(batch_size)
                train_loss += sum([l.sum().asscalar() for l in loss])

                output_softmax = [nd.SoftmaxActivation(out) for out in output]
                train_metric.update(label, output_softmax)
                metric.update(label_1, output_softmax)
                name, acc = train_metric.get()
                if config.save_cfg.tensorboard:
                    sw.add_scalar(tag='lr', value=trainer.learning_rate,
                                  global_step=iteration)
                if epoch == 0 and iteration == 1 and config.save_cfg.profiler:
                    nd.waitall()
                    profiler.set_state('stop')
                    profiler.dump()
                iteration += 1
                sig_pgbar.emit(iteration)
                if check_flag()[0]:
                    sig_state.emit(2)
                while(check_flag()[0] or check_flag()[1]):
                    if check_flag()[1]:
                        print('stop')
                        return
                    else:
                        time.sleep(5)
                        print('pausing')

            epoch_time = time.time() - tic
            train_loss /= batch_size * num_batch
            name, acc = train_metric.get()
            _, train_acc = metric.get()
            name, val_acc, _ = test(ctx, val_data)
            # if config.data_cfg.mixup:
            #     train_history.update([acc, 1-val_acc])
            #     plt.cla()
            #     train_history.plot(save_path='%s/%s_history.png' %
            #                        (plot_name, model_name))
            # else:
            train_history.update([1-train_acc, 1-val_acc])
            plt.cla()
            train_history.plot(save_path='%s/%s_history.png' %
                               (plot_name, model_name))

            if val_acc > best_val_score:
                best_val_score = val_acc
                net.save_parameters('%s/%.4f-cifar-%s-%d-best.params' %
                                    (save_dir, best_val_score, model_name, epoch))

            current_lr = trainer.learning_rate
            name, val_acc, val_loss = test(ctx, val_data)

            logging.info('[Epoch %d] loss=%f train_acc=%f train_RMSE=%f\n     val_acc=%f val_loss=%f lr=%f time: %f' %
                         (epoch, train_loss, train_acc, acc, val_acc, val_loss, current_lr, epoch_time))
            loss_history.update([train_loss, val_loss])
            plt.cla()
            loss_history.plot(save_path='%s/%s_loss.png' %
                              (plot_name, model_name), y_lim=(0, 2), legend_loc='best')
            if config.save_cfg.tensorboard:
                sw._add_scalars(tag='Acc',
                                scalar_dict={'train_acc': train_acc, 'test_acc': val_acc}, global_step=epoch)
                sw._add_scalars(tag='Loss',
                                scalar_dict={'train_loss': train_loss, 'test_loss': val_loss}, global_step=epoch)

            sig_table.emit([epoch, train_loss, train_acc,
                            val_loss, val_acc, current_lr, epoch_time])
            csv_writer.writerow([epoch, train_loss, train_acc,
                                 val_loss, val_acc, current_lr, epoch_time])
            csv_file.flush()

            if save_period and save_dir and (epoch + 1) % save_period == 0:
                net.save_parameters('%s/cifar10-%s-%d.params' %
                                    (save_dir, model_name, epoch))
        if save_period and save_dir:
            net.save_parameters('%s/cifar10-%s-%d.params' %
                                (save_dir, model_name, epochs-1))

    train(num_epochs, context)
    if config.save_cfg.tensorboard:
        sw.close()

    for ctx in context:
        ctx.empty_cache()

    csv_file.close()
    logging.shutdown()
    reload(logging)
    sig_state.emit(0)
示例#6
0
def get_built_in_dataset(name,
                         train=True,
                         input_size=224,
                         batch_size=256,
                         num_workers=32,
                         shuffle=True,
                         **kwargs):
    """AutoGluonFunction
    """
    print('get_built_in_dataset', name)
    if name == 'cifar10' or name == 'cifar':
        transform_split = transforms.Compose([
            gcv_transforms.RandomCrop(32, pad=4),
            transforms.RandomFlipLeftRight(),
            transforms.ToTensor(),
            transforms.Normalize([0.4914, 0.4822, 0.4465],
                                 [0.2023, 0.1994, 0.2010])
        ]) if train else transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize([0.4914, 0.4822, 0.4465],
                                 [0.2023, 0.1994, 0.2010])
        ])
        return gluon.data.vision.CIFAR10(
            train=train).transform_first(transform_split)
    elif name == 'mnist':

        def transform(data, label):
            return nd.transpose(data.astype(np.float32),
                                (2, 0, 1)) / 255, label.astype(np.float32)

        return gluon.data.vision.MNIST(train=train, transform=transform)
    elif name == 'cifar100':
        transform_split = transforms.Compose([
            gcv_transforms.RandomCrop(32, pad=4),
            transforms.RandomFlipLeftRight(),
            transforms.ToTensor(),
            transforms.Normalize([0.4914, 0.4822, 0.4465],
                                 [0.2023, 0.1994, 0.2010])
        ]) if train else transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize([0.4914, 0.4822, 0.4465],
                                 [0.2023, 0.1994, 0.2010])
        ])
        return gluon.data.vision.CIFAR100(
            train=train).transform_first(transform_split)
    elif name == 'imagenet':
        # Please setup the ImageNet dataset following the tutorial from GluonCV
        if train:
            rec_file = '/media/ramdisk/rec/train.rec'
            rec_file_idx = '/media/ramdisk/rec/train.idx'
        else:
            rec_file = '/media/ramdisk/rec/val.rec'
            rec_file_idx = '/media/ramdisk/rec/val.idx'
        data_loader = get_data_rec(input_size,
                                   0.875,
                                   rec_file,
                                   rec_file_idx,
                                   batch_size,
                                   num_workers,
                                   train,
                                   shuffle=shuffle,
                                   **kwargs)
        return data_loader
    else:
        raise NotImplemented
示例#7
0
def main():
    opt = parse_args()

    bps.init()

    gpu_name = subprocess.check_output(
        ['nvidia-smi', '--query-gpu=gpu_name', '--format=csv'])
    gpu_name = gpu_name.decode('utf8').split('\n')[-2]
    gpu_name = '-'.join(gpu_name.split())
    filename = "cifar100-%d-%s-%s.log" % (bps.size(), gpu_name,
                                          opt.logging_file)
    filehandler = logging.FileHandler(filename)
    streamhandler = logging.StreamHandler()

    logger = logging.getLogger('')
    logger.setLevel(logging.INFO)
    logger.addHandler(filehandler)
    logger.addHandler(streamhandler)

    logger.info(opt)

    batch_size = opt.batch_size
    classes = 100

    num_gpus = opt.num_gpus
    # batch_size *= max(1, num_gpus)
    context = mx.gpu(bps.local_rank()) if num_gpus > 0 else mx.cpu(
        bps.local_rank())
    num_workers = opt.num_workers
    nworker = bps.size()
    rank = bps.rank()

    lr_decay = opt.lr_decay
    lr_decay_epoch = [int(i) for i in opt.lr_decay_epoch.split(',')]

    num_batches = 50000 // (opt.batch_size * nworker)
    lr_scheduler = LRSequential([
        LRScheduler('linear',
                    base_lr=opt.warmup_lr,
                    target_lr=opt.lr * nworker / bps.local_size(),
                    nepochs=opt.warmup_epochs,
                    iters_per_epoch=num_batches),
        LRScheduler('step',
                    base_lr=opt.lr * nworker / bps.local_size(),
                    target_lr=0,
                    nepochs=opt.num_epochs - opt.warmup_epochs,
                    iters_per_epoch=num_batches,
                    step_epoch=lr_decay_epoch,
                    step_factor=lr_decay,
                    power=2)
    ])

    num_batches = 50000 // (opt.batch_size * nworker)
    lr_scheduler = LRSequential([
        LRScheduler('linear',
                    base_lr=opt.warmup_lr,
                    target_lr=opt.lr * nworker / bps.local_size(),
                    nepochs=opt.warmup_epochs,
                    iters_per_epoch=num_batches),
        LRScheduler('step',
                    base_lr=opt.lr * nworker / bps.local_size(),
                    target_lr=0,
                    nepochs=opt.num_epochs - opt.warmup_epochs,
                    iters_per_epoch=num_batches,
                    step_epoch=lr_decay_epoch,
                    step_factor=lr_decay,
                    power=2)
    ])

    model_name = opt.model
    if model_name.startswith('cifar_wideresnet'):
        kwargs = {'classes': classes, 'drop_rate': opt.drop_rate}
    else:
        kwargs = {'classes': classes}
    net = get_model(model_name, **kwargs)
    if opt.resume_from:
        net.load_parameters(opt.resume_from, ctx=context)

    if opt.compressor:
        optimizer = 'sgd'
    else:
        optimizer = 'nag'

    save_period = opt.save_period
    if opt.save_dir and save_period:
        save_dir = opt.save_dir
        makedirs(save_dir)
    else:
        save_dir = ''
        save_period = 0

    # from https://github.com/weiaicunzai/pytorch-cifar/blob/master/conf/global_settings.py
    CIFAR100_TRAIN_MEAN = [
        0.5070751592371323, 0.48654887331495095, 0.4409178433670343
    ]
    CIFAR100_TRAIN_STD = [
        0.2673342858792401, 0.2564384629170883, 0.27615047132568404
    ]

    transform_train = transforms.Compose([
        gcv_transforms.RandomCrop(32, pad=4),
        transforms.RandomFlipLeftRight(),
        transforms.ToTensor(),
        transforms.Normalize(CIFAR100_TRAIN_MEAN, CIFAR100_TRAIN_STD)
    ])

    transform_test = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(CIFAR100_TRAIN_MEAN, CIFAR100_TRAIN_STD)
    ])

    def test(ctx, val_data):
        metric = mx.metric.Accuracy()
        for i, batch in enumerate(val_data):
            data = gluon.utils.split_and_load(batch[0],
                                              ctx_list=ctx,
                                              batch_axis=0)
            label = gluon.utils.split_and_load(batch[1],
                                               ctx_list=ctx,
                                               batch_axis=0)
            outputs = [net(X) for X in data]
            metric.update(label, outputs)
        return metric.get()

    def train(epochs, ctx):
        if isinstance(ctx, mx.Context):
            ctx = [ctx]
        net.initialize(mx.init.Xavier(), ctx=ctx)

        train_data = gluon.data.DataLoader(gluon.data.vision.CIFAR100(
            train=True).shard(nworker, rank).transform_first(transform_train),
                                           batch_size=batch_size,
                                           shuffle=True,
                                           last_batch='discard',
                                           num_workers=num_workers)

        val_data = gluon.data.DataLoader(gluon.data.vision.CIFAR100(
            train=False).shard(nworker, rank).transform_first(transform_test),
                                         batch_size=batch_size,
                                         shuffle=False,
                                         num_workers=num_workers)

        params = net.collect_params()

        compression_params = {
            "compressor": opt.compressor,
            "ef": opt.ef,
            "momentum": opt.compress_momentum,
            "scaling": opt.onebit_scaling,
            "k": opt.k,
            "fp16": opt.fp16_pushpull
        }

        optimizer_params = {
            'lr_scheduler': lr_scheduler,
            'wd': opt.wd,
            'momentum': opt.momentum
        }

        trainer = bps.DistributedTrainer(params,
                                         optimizer,
                                         optimizer_params,
                                         compression_params=compression_params)
        metric = mx.metric.Accuracy()
        train_metric = mx.metric.Accuracy()
        loss_fn = gluon.loss.SoftmaxCrossEntropyLoss()

        iteration = 0
        best_val_score = 0
        bps.byteps_declare_tensor("acc")
        for epoch in range(epochs):
            tic = time.time()
            train_metric.reset()
            metric.reset()
            train_loss = 0
            num_batch = len(train_data)

            for i, batch in enumerate(train_data):
                data = gluon.utils.split_and_load(batch[0],
                                                  ctx_list=ctx,
                                                  batch_axis=0)
                label = gluon.utils.split_and_load(batch[1],
                                                   ctx_list=ctx,
                                                   batch_axis=0)

                with ag.record():
                    output = [net(X) for X in data]
                    loss = [loss_fn(yhat, y) for yhat, y in zip(output, label)]
                for l in loss:
                    l.backward()
                trainer.step(batch_size)
                train_loss += sum([l.sum().asscalar() for l in loss])

                train_metric.update(label, output)
                name, train_acc = train_metric.get()
                iteration += 1

            train_loss /= batch_size * num_batch
            name, train_acc = train_metric.get()
            throughput = int(batch_size * nworker * i / (time.time() - tic))

            logger.info(
                '[Epoch %d] speed: %d samples/sec\ttime cost: %f lr=%f' %
                (epoch, throughput, time.time() - tic, trainer.learning_rate))

            name, val_acc = test(ctx, val_data)
            acc = mx.nd.array([train_acc, val_acc], ctx=ctx[0])
            bps.byteps_push_pull(acc, name="acc", is_average=False)
            acc /= bps.size()
            train_acc, val_acc = acc[0].asscalar(), acc[1].asscalar()
            if bps.rank() == 0:
                logger.info('[Epoch %d] training: %s=%f' %
                            (epoch, name, train_acc))
                logger.info('[Epoch %d] validation: %s=%f' %
                            (epoch, name, val_acc))

            if val_acc > best_val_score:
                best_val_score = val_acc
                net.save_parameters(
                    '%s/%.4f-cifar-%s-%d-best.params' %
                    (save_dir, best_val_score, model_name, epoch))

            if save_period and save_dir and (epoch + 1) % save_period == 0:
                net.save_parameters('%s/cifar100-%s-%d.params' %
                                    (save_dir, model_name, epoch))

        if save_period and save_dir:
            net.save_parameters('%s/cifar100-%s-%d.params' %
                                (save_dir, model_name, epochs - 1))

    if opt.mode == 'hybrid':
        net.hybridize()
    train(opt.num_epochs, context)
示例#8
0
def train_cifar(args, reporter):
    print('args', args)
    batch_size = args.batch_size

    num_gpus = args.num_gpus
    batch_size *= max(1, num_gpus)
    context = [mx.gpu(i)
               for i in range(num_gpus)] if num_gpus > 0 else [mx.cpu()]
    num_workers = args.num_workers

    model_name = args.model
    net = get_model(model_name, classes=10)

    transform_train = transforms.Compose([
        gcv_transforms.RandomCrop(32, pad=4),
        transforms.RandomFlipLeftRight(),
        transforms.ToTensor(),
        transforms.Normalize([0.4914, 0.4822, 0.4465],
                             [0.2023, 0.1994, 0.2010])
    ])

    transform_test = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize([0.4914, 0.4822, 0.4465],
                             [0.2023, 0.1994, 0.2010])
    ])

    def test(ctx, val_data):
        metric = mx.metric.Accuracy()
        for i, batch in enumerate(val_data):
            data = gluon.utils.split_and_load(batch[0],
                                              ctx_list=ctx,
                                              batch_axis=0)
            label = gluon.utils.split_and_load(batch[1],
                                               ctx_list=ctx,
                                               batch_axis=0)
            outputs = [net(X) for X in data]
            metric.update(label, outputs)
        return metric.get()

    def train(epochs, ctx):
        if isinstance(ctx, mx.Context):
            ctx = [ctx]
        net.initialize(mx.init.Xavier(), ctx=ctx)

        train_data = gluon.data.DataLoader(gluon.data.vision.CIFAR10(
            train=True).transform_first(transform_train),
                                           batch_size=batch_size,
                                           shuffle=True,
                                           last_batch='discard',
                                           num_workers=num_workers)

        val_data = gluon.data.DataLoader(gluon.data.vision.CIFAR10(
            train=False).transform_first(transform_test),
                                         batch_size=batch_size,
                                         shuffle=False,
                                         num_workers=num_workers)

        lr_scheduler = LRScheduler(mode='cosine',
                                   base_lr=args.lr,
                                   nepochs=args.epochs,
                                   iters_per_epoch=len(train_data))
        trainer = gluon.Trainer(net.collect_params(), 'sgd', {
            'lr_scheduler': lr_scheduler,
            'wd': args.wd,
            'momentum': args.momentum
        })
        metric = mx.metric.Accuracy()
        train_metric = mx.metric.Accuracy()
        loss_fn = gluon.loss.SoftmaxCrossEntropyLoss()

        iteration = 0
        best_val_score = 0

        start_epoch = 0

        for epoch in range(start_epoch, epochs):
            tic = time.time()
            train_metric.reset()
            metric.reset()
            train_loss = 0
            num_batch = len(train_data)
            alpha = 1

            for i, batch in enumerate(train_data):
                data = gluon.utils.split_and_load(batch[0],
                                                  ctx_list=ctx,
                                                  batch_axis=0)
                label = gluon.utils.split_and_load(batch[1],
                                                   ctx_list=ctx,
                                                   batch_axis=0)

                with mx.autograd.record():
                    output = [net(X) for X in data]
                    loss = [loss_fn(yhat, y) for yhat, y in zip(output, label)]
                for l in loss:
                    l.backward()
                trainer.step(batch_size)
                train_loss += sum([l.sum().asscalar() for l in loss])

                train_metric.update(label, output)
                name, acc = train_metric.get()
                iteration += 1

            train_loss /= batch_size * num_batch
            name, acc = train_metric.get()
            name, val_acc = test(ctx, val_data)
            reporter(epoch=epoch, accuracy=val_acc)

    train(args.epochs, context)
示例#9
0
def main():
    opt = parse_args()

    batch_size = opt.batch_size
    classes = 10

    num_gpus = opt.num_gpus
    batch_size *= max(1, num_gpus)
    context = [mx.gpu(i) for i in range(num_gpus)] if num_gpus > 0 else [mx.cpu()]
    num_workers = opt.num_workers

    lr_decay = opt.lr_decay
    lr_decay_epoch = [int(i) for i in opt.lr_decay_epoch.split(',')] + [np.inf]

    model_name = opt.model
    if model_name.startswith('cifar_wideresnet'):
        kwargs = {'classes': classes,
                'drop_rate': opt.drop_rate}
    else:
        kwargs = {'classes': classes}
    net = get_model(model_name, **kwargs)
    model_name += '_mixup'
    if opt.resume_from:
        net.load_parameters(opt.resume_from, ctx = context)
    optimizer = 'nag'

    save_period = opt.save_period
    if opt.save_dir and save_period:
        save_dir = opt.save_dir
        makedirs(save_dir)
    else:
        save_dir = ''
        save_period = 0

    plot_name = opt.save_plot_dir

    logging_handlers = [logging.StreamHandler()]
    if opt.logging_dir:
        logging_dir = opt.logging_dir
        makedirs(logging_dir)
        logging_handlers.append(logging.FileHandler('%s/train_cifar10_%s.log'%(logging_dir, model_name)))

    logging.basicConfig(level=logging.INFO, handlers = logging_handlers)
    logging.info(opt)

    transform_train = transforms.Compose([
        gcv_transforms.RandomCrop(32, pad=4),
        transforms.RandomFlipLeftRight(),
        transforms.ToTensor(),
        transforms.Normalize([0.4914, 0.4822, 0.4465], [0.2023, 0.1994, 0.2010])
    ])

    transform_test = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize([0.4914, 0.4822, 0.4465], [0.2023, 0.1994, 0.2010])
    ])

    def label_transform(label, classes):
        ind = label.astype('int')
        res = nd.zeros((ind.shape[0], classes), ctx = label.context)
        res[nd.arange(ind.shape[0], ctx = label.context), ind] = 1
        return res

    def test(ctx, val_data):
        metric = mx.metric.Accuracy()
        for i, batch in enumerate(val_data):
            data = gluon.utils.split_and_load(batch[0], ctx_list=ctx, batch_axis=0)
            label = gluon.utils.split_and_load(batch[1], ctx_list=ctx, batch_axis=0)
            outputs = [net(X) for X in data]
            metric.update(label, outputs)
        return metric.get()

    def train(epochs, ctx):
        if isinstance(ctx, mx.Context):
            ctx = [ctx]
        net.initialize(mx.init.Xavier(), ctx=ctx)

        train_data = gluon.data.DataLoader(
            gluon.data.vision.CIFAR10(train=True).transform_first(transform_train),
            batch_size=batch_size, shuffle=True, last_batch='discard', num_workers=num_workers)

        val_data = gluon.data.DataLoader(
            gluon.data.vision.CIFAR10(train=False).transform_first(transform_test),
            batch_size=batch_size, shuffle=False, num_workers=num_workers)

        trainer = gluon.Trainer(net.collect_params(), optimizer,
                                {'learning_rate': opt.lr, 'wd': opt.wd, 'momentum': opt.momentum})
        metric = mx.metric.Accuracy()
        train_metric = mx.metric.RMSE()
        loss_fn = gluon.loss.SoftmaxCrossEntropyLoss(sparse_label=False)
        train_history = TrainingHistory(['training-error', 'validation-error'])

        iteration = 0
        lr_decay_count = 0

        best_val_score = 0

        for epoch in range(epochs):
            tic = time.time()
            train_metric.reset()
            metric.reset()
            train_loss = 0
            num_batch = len(train_data)
            alpha = 1

            if epoch == lr_decay_epoch[lr_decay_count]:
                trainer.set_learning_rate(trainer.learning_rate*lr_decay)
                lr_decay_count += 1

            for i, batch in enumerate(train_data):
                lam = np.random.beta(alpha, alpha)
                if epoch >= epochs - 20:
                    lam = 1

                data_1 = gluon.utils.split_and_load(batch[0], ctx_list=ctx, batch_axis=0)
                label_1 = gluon.utils.split_and_load(batch[1], ctx_list=ctx, batch_axis=0)

                data = [lam*X + (1-lam)*X[::-1] for X in data_1]
                label = []
                for Y in label_1:
                    y1 = label_transform(Y, classes)
                    y2 = label_transform(Y[::-1], classes)
                    label.append(lam*y1 + (1-lam)*y2)

                with ag.record():
                    output = [net(X) for X in data]
                    loss = [loss_fn(yhat, y) for yhat, y in zip(output, label)]
                for l in loss:
                    l.backward()
                trainer.step(batch_size)
                train_loss += sum([l.sum().asscalar() for l in loss])

                output_softmax = [nd.SoftmaxActivation(out) for out in output]
                train_metric.update(label, output_softmax)
                name, acc = train_metric.get()
                iteration += 1

            train_loss /= batch_size * num_batch
            name, acc = train_metric.get()
            name, val_acc = test(ctx, val_data)
            train_history.update([acc, 1-val_acc])
            train_history.plot(save_path='%s/%s_history.png'%(plot_name, model_name))

            if val_acc > best_val_score:
                best_val_score = val_acc
                net.save_parameters('%s/%.4f-cifar-%s-%d-best.params'%(save_dir, best_val_score, model_name, epoch))

            name, val_acc = test(ctx, val_data)
            logging.info('[Epoch %d] train=%f val=%f loss=%f time: %f' %
                (epoch, acc, val_acc, train_loss, time.time()-tic))

            if save_period and save_dir and (epoch + 1) % save_period == 0:
                net.save_parameters('%s/cifar10-%s-%d.params'%(save_dir, model_name, epoch))

        if save_period and save_dir:
            net.save_parameters('%s/cifar10-%s-%d.params'%(save_dir, model_name, epochs-1))

    if opt.mode == 'hybrid':
        net.hybridize()
    train(opt.num_epochs, context)
def test(model):
    if args.dataset == 'cifar10':
        transform_train = transforms.Compose([
            gcv_transforms.RandomCrop(32, pad=4),
            transforms.RandomFlipLeftRight(),
            transforms.ToTensor(),
            transforms.Normalize([0.4914, 0.4822, 0.4465],
                                 [0.2023, 0.1994, 0.2010])
        ])

        transform_test = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize([0.4914, 0.4822, 0.4465],
                                 [0.2023, 0.1994, 0.2010])
        ])
        train_data = gluon.data.DataLoader(gluon.data.vision.CIFAR10(
            train=True).transform_first(transform_train),
                                           batch_size=batch_size,
                                           shuffle=True,
                                           last_batch='discard',
                                           num_workers=args.num_workers)

        val_data = gluon.data.DataLoader(gluon.data.vision.CIFAR10(
            train=False).transform_first(transform_test),
                                         batch_size=batch_size,
                                         shuffle=False,
                                         num_workers=args.num_workers)
    elif args.dataset == 'cifar100':
        transform_train = transforms.Compose([
            gcv_transforms.RandomCrop(32, pad=4),
            transforms.RandomFlipLeftRight(),
            transforms.ToTensor(),
            transforms.Normalize([0.5071, 0.4865, 0.4409],
                                 [0.2673, 0.2564, 0.2762])
        ])

        transform_test = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize([0.5071, 0.4865, 0.4409],
                                 [0.2673, 0.2564, 0.2762])
        ])
        train_data = gluon.data.DataLoader(gluon.data.vision.CIFAR100(
            train=True).transform_first(transform_train),
                                           batch_size=batch_size,
                                           shuffle=True,
                                           last_batch='discard',
                                           num_workers=num_workers)

        val_data = gluon.data.DataLoader(gluon.data.vision.CIFAR100(
            train=False).transform_first(transform_test),
                                         batch_size=batch_size,
                                         shuffle=False,
                                         num_workers=num_workers)

    metric = mx.metric.Accuracy()
    metric.reset()
    for i, batch in enumerate(val_data):
        data = gluon.utils.split_and_load(batch[0],
                                          ctx_list=context,
                                          batch_axis=0)
        label = gluon.utils.split_and_load(batch[1],
                                           ctx_list=context,
                                           batch_axis=0)
        outputs = [model(X.astype(args.dtype, copy=False)) for X in data]
        metric.update(label, outputs)
        _, acc = metric.get()
    print('\nTest-set Accuracy: ', acc)
    return acc
示例#11
0
save_dir = save_dir + opt.save_name

makedirs(save_dir)
logger = logging.getLogger()
logger.setLevel(logging.INFO)
log_save_dir = os.path.join(save_dir, 'log.txt')
fh = logging.FileHandler(log_save_dir)
fh.setLevel(logging.INFO)
logger.addHandler(fh)
logger.info(opt)

if opt.dataset == 'NC_CUB200':
    transform_train = transforms.Compose([
        # transforms.RandomResizedCrop(84),
        transforms.Resize(256),
        gcv_transforms.RandomCrop(224, pad=0),
        # gcv_transforms.RandomCrop(224, pad=8),
        transforms.RandomFlipLeftRight(),
        # transforms.RandomColorJitter(brightness=0.4, contrast=0.4,
        #                              saturation=0.4),
        # transforms.RandomLighting(0.1),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ])

    transform_test = transforms.Compose([
        transforms.Resize(224),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ])
示例#12
0
def train(net, ctx):
    if isinstance(ctx, mx.Context):
        ctx = [ctx]
    if not opt.resume_from:
        net.initialize(mx.init.Xavier(), ctx=ctx)
        if opt.dataset == 'NC_CIFAR100':
            n = mx.nd.zeros(shape=(1, 3, 32, 32), ctx=ctx[0])  #####init CNN
        else:
            raise KeyError('dataset keyerror')
        for m in range(9):
            net(n, m)

    def makeSchedule(start_lr, base_lr, length, step, factor):
        schedule = mx.lr_scheduler.MultiFactorScheduler(step=step,
                                                        factor=factor)
        schedule.base_lr = base_lr
        schedule = LinearWarmUp(schedule, start_lr=start_lr, length=length)
        return schedule


# ==========================================================================

    sesses = list(np.arange(opt.sess_num))
    epochs = [opt.epoch] * opt.sess_num
    lrs = [opt.base_lrs] + [opt.lrs] * (opt.sess_num - 1)
    lr_decay = opt.lr_decay
    base_decay_epoch = [int(i)
                        for i in opt.base_decay_epoch.split(',')] + [np.inf]
    lr_decay_epoch = [base_decay_epoch
                      ] + [[opt.inc_decay_epoch, np.inf]] * (opt.sess_num - 1)

    AL_weight = opt.AL_weight
    min_weight = opt.min_weight
    oce_weight = opt.oce_weight
    pdl_weight = opt.pdl_weight
    max_weight = opt.max_weight
    temperature = opt.temperature

    use_AL = opt.use_AL  # anchor loss
    use_ng_min = opt.use_ng_min  # Neural Gas min loss
    use_ng_max = opt.use_ng_max  # Neural Gas min loss
    ng_update = opt.ng_update  # Neural Gas update node
    use_oce = opt.use_oce  # old samples cross entropy loss
    use_pdl = opt.use_pdl  # probability distillation loss
    use_nme = opt.use_nme  # Similarity loss
    use_warmUp = opt.use_warmUp
    use_ng = opt.use_ng  # Neural Gas
    fix_conv = opt.fix_conv  # fix cnn to train novel classes
    fix_epoch = opt.fix_epoch
    c_way = opt.c_way
    k_shot = opt.k_shot
    base_acc = opt.base_acc  # base model acc
    select_best_method = opt.select_best  # select from _best, _best2, _best3
    init_class = 60
    anchor_num = 400
    # ==========================================================================
    acc_dict = {}
    all_best_e = []
    if model_name[-7:] != 'maxhead':
        net.fc3.initialize(mx.init.Normal(sigma=0.001),
                           ctx=ctx,
                           force_reinit=True)
        net.fc4.initialize(mx.init.Normal(sigma=0.001),
                           ctx=ctx,
                           force_reinit=True)
        net.fc5.initialize(mx.init.Normal(sigma=0.001),
                           ctx=ctx,
                           force_reinit=True)
        net.fc6.initialize(mx.init.Normal(sigma=0.001),
                           ctx=ctx,
                           force_reinit=True)
        net.fc7.initialize(mx.init.Normal(sigma=0.001),
                           ctx=ctx,
                           force_reinit=True)
        net.fc8.initialize(mx.init.Normal(sigma=0.001),
                           ctx=ctx,
                           force_reinit=True)
        net.fc9.initialize(mx.init.Normal(sigma=0.001),
                           ctx=ctx,
                           force_reinit=True)
        net.fc10.initialize(mx.init.Normal(sigma=0.001),
                            ctx=ctx,
                            force_reinit=True)

    for sess in sesses:
        logger.info('session : %d' % sess)
        schedule = makeSchedule(start_lr=0,
                                base_lr=lrs[sess],
                                length=5,
                                step=lr_decay_epoch[sess],
                                factor=lr_decay)

        # prepare the first anchor batch
        if sess == 0 and opt.resume_from:
            acc_dict[str(sess)] = list()
            acc_dict[str(sess)].append([base_acc, 0])
            all_best_e.append(0)
            continue
        # quick cnn totally unfix, not use data augmentation
        if sess == 1 and model_name == 'quick_cnn' and use_AL:
            transform_train = transforms.Compose([
                transforms.ToTensor(),
                transforms.Normalize([0.4914, 0.4822, 0.4465],
                                     [0.2023, 0.1994, 0.2010])
            ])

            transform_test = transforms.Compose([
                transforms.ToTensor(),
                transforms.Normalize([0.4914, 0.4822, 0.4465],
                                     [0.2023, 0.1994, 0.2010])
            ])
            anchor_trans = transforms.Compose([
                transforms.ToTensor(),
                transforms.Normalize([0.4914, 0.4822, 0.4465],
                                     [0.2023, 0.1994, 0.2010])
            ])
        else:
            transform_train = transforms.Compose([
                gcv_transforms.RandomCrop(32, pad=4),
                transforms.RandomFlipLeftRight(),
                transforms.ToTensor(),
                transforms.Normalize([0.5071, 0.4866, 0.4409],
                                     [0.2009, 0.1984, 0.2023])
            ])

            transform_test = transforms.Compose([
                transforms.ToTensor(),
                transforms.Normalize([0.5071, 0.4866, 0.4409],
                                     [0.2009, 0.1984, 0.2023])
            ])

            anchor_trans = transforms.Compose([
                transforms.ToTensor(),
                transforms.Normalize([0.5071, 0.4866, 0.4409],
                                     [0.2009, 0.1984, 0.2023])
            ])
        # ng_init and ng_update
        if use_AL or use_nme or use_pdl or use_oce:
            if sess != 0:
                if ng_update == True:
                    if sess == 1:
                        update_anchor1, bmu, variances= \
                            prepare_anchor(DATASET,logger,anchor_trans,num_workers,feature_size,net,ctx,use_ng,init_class)
                        update_anchor_data = DataLoader(
                            update_anchor1,
                            anchor_trans,
                            update_anchor1.__len__(),
                            num_workers,
                            shuffle=False)
                        if opt.ng_var:
                            idx_1 = np.where(variances.asnumpy() > 0.5)
                            idx_2 = np.where(variances.asnumpy() < 0.5)
                            variances[idx_1] = 0.9
                            variances[idx_2] = 1
                    else:
                        base_class = init_class + (sess - 1) * 5
                        new_class = list(init_class + (sess - 1) * 5 +
                                         (np.arange(5)))
                        new_set = DATASET(train=True,
                                          fine_label=True,
                                          fix_class=new_class,
                                          base_class=base_class,
                                          logger=logger)
                        update_anchor2 = merge_datasets(
                            update_anchor1, new_set)
                        update_anchor_data = DataLoader(
                            update_anchor2,
                            anchor_trans,
                            update_anchor2.__len__(),
                            num_workers,
                            shuffle=False)
                elif (sess == 1):
                    update_anchor, bmu, variances =  \
                        prepare_anchor(DATASET,logger,anchor_trans,num_workers,feature_size,net,ctx,use_ng,init_class)
                    update_anchor_data = DataLoader(update_anchor,
                                                    anchor_trans,
                                                    update_anchor.__len__(),
                                                    num_workers,
                                                    shuffle=False)

                    if opt.ng_var:
                        idx_1 = np.where(variances.asnumpy() > 0.5)
                        idx_2 = np.where(variances.asnumpy() < 0.5)
                        variances[idx_1] = 0.9
                        variances[idx_2] = 1

                for batch in update_anchor_data:
                    anc_data = gluon.utils.split_and_load(batch[0],
                                                          ctx_list=[ctx[0]],
                                                          batch_axis=0)
                    anc_label = gluon.utils.split_and_load(batch[1],
                                                           ctx_list=[ctx[0]],
                                                           batch_axis=0)
                    with ag.pause():
                        anchor_feat, anchor_logit = net(anc_data[0], sess - 1)
                    anchor_feat = [anchor_feat]
                    anchor_logit = [anchor_logit]

        trainer = gluon.Trainer(net.collect_params(), optimizer, {
            'learning_rate': lrs[sess],
            'wd': opt.wd,
            'momentum': opt.momentum
        })

        metric = mx.metric.Accuracy()
        train_metric = mx.metric.Accuracy()
        loss_fn = gluon.loss.SoftmaxCrossEntropyLoss()

        # ==========================================================================
        # all loss init
        if use_nme:

            def loss_fn_disG(f1, f2, weight):
                f1 = f1.reshape(anchor_num, -1)
                f2 = f2.reshape(anchor_num, -1)
                similar = mx.nd.sum(f1 * f2, 1)
                return (1 - similar) * weight

            digG_weight = opt.nme_weight
        if use_AL:
            if model_name == 'quick_cnn':
                AL_w = [120, 75, 120, 100, 50, 60, 90, 90]
                AL_weight = AL_w[sess - 1]
            else:
                AL_weight = opt.AL_weight
            if opt.ng_var:

                def l2lossVar(feat, anc, weight, var):
                    dim = feat.shape[1]
                    feat = feat.reshape(-1, dim)
                    anc = anc.reshape(-1, dim)
                    loss = mx.nd.square(feat - anc)
                    loss = loss * weight * var
                    return mx.nd.mean(loss, axis=0, exclude=True)

                loss_fn_AL = l2lossVar
            else:
                loss_fn_AL = gluon.loss.L2Loss(weight=AL_weight)

        if use_pdl:
            loss_fn_pdl = DistillationSoftmaxCrossEntropyLoss(
                temperature=temperature, hard_weight=0, weight=pdl_weight)
        if use_oce:
            loss_fn_oce = gluon.loss.SoftmaxCrossEntropyLoss(weight=oce_weight)
        if use_ng_min:
            loss_fn_max = NG_Max_Loss(lmbd=max_weight, margin=0.5)
        if use_ng_min:
            min_loss = NG_Min_Loss(
                num_classes=opt.c_way,
                feature_size=feature_size,
                lmbd=min_weight,  # center weight = 0.01 in the paper
                ctx=ctx[0])
            min_loss.initialize(mx.init.Xavier(magnitude=2.24),
                                ctx=ctx,
                                force_reinit=True)  # init matrix.
            center_trainer = gluon.Trainer(
                min_loss.collect_params(),
                optimizer="sgd",
                optimizer_params={"learning_rate":
                                  opt.ng_min_lr})  # alpha=0.1 in the paper.
        # ==========================================================================
        lr_decay_count = 0

        # dataloader
        if opt.cum and sess == 1:
            base_class = list(np.arange(init_class))
            joint_data = DATASET(train=True,
                                 fine_label=True,
                                 c_way=init_class,
                                 k_shot=500,
                                 fix_class=base_class,
                                 logger=logger)

        if sess == 0:
            base_class = list(np.arange(init_class))
            new_class = list(init_class + (np.arange(5)))
            base_data = DATASET(train=True,
                                fine_label=True,
                                c_way=init_class,
                                k_shot=500,
                                fix_class=base_class,
                                logger=logger)
            bc_val_data = DataLoader(DATASET(train=False,
                                             fine_label=True,
                                             fix_class=base_class,
                                             logger=logger),
                                     transform_test,
                                     100,
                                     num_workers,
                                     shuffle=False)
            nc_val_data = DataLoader(DATASET(train=False,
                                             fine_label=True,
                                             fix_class=new_class,
                                             base_class=len(base_class),
                                             logger=logger),
                                     transform_test,
                                     100,
                                     num_workers,
                                     shuffle=False)
        else:
            base_class = list(np.arange(init_class + (sess - 1) * 5))
            new_class = list(init_class + (sess - 1) * 5 + (np.arange(5)))
            train_data_nc = DATASET(train=True,
                                    fine_label=True,
                                    c_way=c_way,
                                    k_shot=k_shot,
                                    fix_class=new_class,
                                    base_class=len(base_class),
                                    logger=logger)
            bc_val_data = DataLoader(DATASET(train=False,
                                             fine_label=True,
                                             fix_class=base_class,
                                             logger=logger),
                                     transform_test,
                                     100,
                                     num_workers,
                                     shuffle=False)
            nc_val_data = DataLoader(DATASET(train=False,
                                             fine_label=True,
                                             fix_class=new_class,
                                             base_class=len(base_class),
                                             logger=logger),
                                     transform_test,
                                     100,
                                     num_workers,
                                     shuffle=False)

        if sess == 0:
            train_data = DataLoader(base_data,
                                    transform_train,
                                    min(batch_size, base_data.__len__()),
                                    num_workers,
                                    shuffle=True)
        else:
            if opt.cum:  # cumulative : merge base and novel dataset.
                joint_data = merge_datasets(joint_data, train_data_nc)
                train_data = DataLoader(joint_data,
                                        transform_train,
                                        min(batch_size, joint_data.__len__()),
                                        num_workers,
                                        shuffle=True)

            elif opt.use_all_novel:  # use all novel data

                if sess == 1:
                    novel_data = train_data_nc
                else:
                    novel_data = merge_datasets(novel_data, train_data_nc)

                train_data = DataLoader(novel_data,
                                        transform_train,
                                        min(batch_size, novel_data.__len__()),
                                        num_workers,
                                        shuffle=True)
            else:  # basic method
                train_data = DataLoader(train_data_nc,
                                        transform_train,
                                        min(batch_size,
                                            train_data_nc.__len__()),
                                        num_workers,
                                        shuffle=True)

        for epoch in range(epochs[sess]):
            tic = time.time()
            train_metric.reset()
            metric.reset()
            train_loss, train_anchor_loss, train_oce_loss = 0, 0, 0
            train_disg_loss, train_pdl_loss, train_min_loss = 0, 0, 0
            train_max_loss = 0
            num_batch = len(train_data)

            if use_warmUp:
                lr = schedule(epoch)
                trainer.set_learning_rate(lr)
            else:
                lr = trainer.learning_rate
                if epoch == lr_decay_epoch[sess][lr_decay_count]:
                    trainer.set_learning_rate(trainer.learning_rate * lr_decay)
                    lr_decay_count += 1

            if sess != 0 and epoch < fix_epoch:
                fix_cnn = fix_conv
            else:
                fix_cnn = False

            for i, batch in enumerate(train_data):
                data = gluon.utils.split_and_load(batch[0],
                                                  ctx_list=ctx,
                                                  batch_axis=0)
                label = gluon.utils.split_and_load(batch[1],
                                                   ctx_list=ctx,
                                                   batch_axis=0)
                all_loss = list()
                with ag.record():
                    output_feat, output = net(data[0], sess, fix_cnn)
                    output_feat = [output_feat]
                    output = [output]
                    loss = [loss_fn(yhat, y) for yhat, y in zip(output, label)]
                    all_loss.extend(loss)

                    if use_nme:
                        anchor_h = [net(X, sess, fix_cnn)[0] for X in anc_data]
                        disg_loss = [
                            loss_fn_disG(a_h, a, weight=digG_weight)
                            for a_h, a in zip(anchor_h, anchor_feat)
                        ]
                        all_loss.extend(disg_loss)

                    if sess > 0 and use_ng_max:
                        max_loss = [
                            loss_fn_max(feat, label, feature_size, epoch, sess,
                                        init_class)
                            for feat, label in zip(output_feat, label)
                        ]
                        all_loss.extend(max_loss[0])

                    if sess > 0 and use_AL:  # For anchor loss
                        anchor_h = [net(X, sess, fix_cnn)[0] for X in anc_data]
                        if opt.ng_var:
                            anchor_loss = [
                                loss_fn_AL(anchor_h[0], anchor_feat[0],
                                           AL_weight, variances)
                            ]
                            all_loss.extend(anchor_loss)
                        else:
                            anchor_loss = [
                                loss_fn_AL(a_h, a)
                                for a_h, a in zip(anchor_h, anchor_feat)
                            ]
                            all_loss.extend(anchor_loss)
                    if sess > 0 and use_ng_min:
                        loss_min = min_loss(output_feat[0], label[0])
                        all_loss.extend(loss_min)

                    if sess > 0 and use_pdl:
                        anchor_l = [net(X, sess, fix_cnn)[1] for X in anc_data]
                        anchor_l = [anchor_l[0][:, :60 + (sess - 1) * 5]]
                        soft_label = [
                            mx.nd.softmax(anchor_logit[0][:, :60 +
                                                          (sess - 1) * 5] /
                                          temperature)
                        ]
                        pdl_loss = [
                            loss_fn_pdl(a_h, a,
                                        soft_a) for a_h, a, soft_a in zip(
                                            anchor_l, anc_label, soft_label)
                        ]
                        all_loss.extend(pdl_loss)

                    if sess > 0 and use_oce:
                        anchorp = [net(X, sess, fix_cnn)[1] for X in anc_data]
                        oce_Loss = [
                            loss_fn_oce(ap, a)
                            for ap, a in zip(anchorp, anc_label)
                        ]
                        all_loss.extend(oce_Loss)

                    all_loss = [nd.mean(l) for l in all_loss]

                ag.backward(all_loss)
                trainer.step(1, ignore_stale_grad=True)
                if use_ng_min:
                    center_trainer.step(opt.c_way * opt.k_shot)
                train_loss += sum([l.sum().asscalar() for l in loss])
                if sess > 0 and use_AL:
                    train_anchor_loss += sum(
                        [al.mean().asscalar() for al in anchor_loss])
                if sess > 0 and use_oce:
                    train_oce_loss += sum(
                        [al.mean().asscalar() for al in oce_Loss])
                if sess > 0 and use_nme:
                    train_disg_loss += sum(
                        [al.mean().asscalar() for al in disg_loss])
                if sess > 0 and use_pdl:
                    train_pdl_loss += sum(
                        [al.mean().asscalar() for al in pdl_loss])
                if sess > 0 and use_ng_min:
                    train_min_loss += sum(
                        [al.mean().asscalar() for al in loss_min])
                if sess > 0 and use_ng_max:
                    train_max_loss += sum(
                        [al.mean().asscalar() for al in max_loss[0]])
                train_metric.update(label, output)

            train_loss /= batch_size * num_batch
            name, acc = train_metric.get()
            name, bc_val_acc = test(ctx, bc_val_data, net, sess)
            name, nc_val_acc = test(ctx, nc_val_data, net, sess)

            if epoch == 0:
                acc_dict[str(sess)] = list()
            acc_dict[str(sess)].append([bc_val_acc, nc_val_acc])

            if sess == 0:
                overall = bc_val_acc
            else:
                overall = (bc_val_acc * (init_class + (sess - 1) * 5) +
                           nc_val_acc * 5) / (init_class + sess * 5)
            logger.info(
                '[Epoch %d] lr=%.4f train=%.4f | val(base)=%.4f val(novel)=%.4f overall=%.4f | loss=%.8f anc loss=%.8f '
                'pdl loss:%.8f oce loss: %.8f time: %.8f' %
                (epoch, lr, acc, bc_val_acc, nc_val_acc, overall, train_loss,
                 train_anchor_loss / AL_weight, train_pdl_loss / pdl_weight,
                 train_oce_loss / oce_weight, time.time() - tic))
            if use_nme:
                logger.info('digG loss:%.8f' % (train_disg_loss / digG_weight))
            if use_ng_min:
                logger.info('min_loss:%.8f' % (train_min_loss / min_weight))
            if use_ng_max:
                logger.info('max_loss:%.8f' % (train_max_loss / max_weight))
            if save_period and save_dir and (epoch + 1) % save_period == 0:
                net.save_parameters('%s/sess-%s-%d.params' %
                                    (save_dir, model_name, epoch))

            select = eval(select_best_method)
            best_e = select(acc_dict, sess)
            logger.info('best select : base: %f novel: %f ' %
                        (acc_dict[str(sess)][best_e][0],
                         acc_dict[str(sess)][best_e][1]))
            if use_AL and model_name == 'quick_cnn':
                reload_path = '%s/sess-%s-%d.params' % (save_dir, model_name,
                                                        best_e)
                net.load_parameters(reload_path, ctx=context)
        all_best_e.append(best_e)
        reload_path = '%s/sess-%s-%d.params' % (save_dir, model_name, best_e)
        net.load_parameters(reload_path, ctx=context)

        with open('%s/acc_dict.json' % save_dir, 'w') as json_file:
            json.dump(acc_dict, json_file)

        plot_pr(acc_dict, sess, save_dir)
    plot_all_sess(acc_dict, save_dir, all_best_e)
def main():
    opt = parse_args()
    batch_size = opt.batch_size
    classes = 10

    log_dir = os.path.join(opt.save_dir, "logs")
    model_dir = os.path.join(opt.save_dir, "params")
    if not os.path.exists(model_dir):
        os.makedirs(model_dir)

    # Init dataloader
    jitter_param = 0.4
    transform_train = transforms.Compose([
        gcv_transforms.RandomCrop(32, pad=4),
        transforms.RandomFlipLeftRight(),
        transforms.RandomBrightness(jitter_param),
        transforms.RandomColorJitter(jitter_param),
        transforms.RandomContrast(jitter_param),
        transforms.RandomSaturation(jitter_param),
        transforms.ToTensor(),
        transforms.Normalize([0.4914, 0.4822, 0.4465],
                             [0.2023, 0.1994, 0.2010])
    ])

    transform_test = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize([0.4914, 0.4822, 0.4465],
                             [0.2023, 0.1994, 0.2010])
    ])

    train_data = gluon.data.DataLoader(
        gluon.data.vision.CIFAR10(train=True).transform_first(transform_train),
        batch_size=batch_size,
        shuffle=True,
        last_batch='discard',
        num_workers=opt.num_workers)

    val_data = gluon.data.DataLoader(
        gluon.data.vision.CIFAR10(train=False).transform_first(transform_test),
        batch_size=batch_size,
        shuffle=False,
        num_workers=opt.num_workers)

    num_gpus = opt.num_gpus
    batch_size *= max(1, num_gpus)
    context = [mx.gpu(i)
               for i in range(num_gpus)] if num_gpus > 0 else [mx.cpu()]

    lr_decay = opt.lr_decay
    lr_decay_epoch = [int(i) for i in opt.lr_decay_epoch.split(',')] + [np.inf]

    model_name = opt.model
    model_name = opt.model
    if model_name.startswith('cifar_wideresnet'):
        kwargs = {'classes': classes, 'drop_rate': opt.drop_rate}
    else:
        kwargs = {'classes': classes}
    net = get_model(model_name, **kwargs)

    if opt.resume_from:
        net.load_parameters(opt.resume_from, ctx=context)
    optimizer = 'nag'

    save_period = opt.save_period
    if opt.save_dir and save_period:
        save_dir = opt.save_dir
        makedirs(save_dir)
    else:
        save_dir = ''
        save_period = 0

    def test(ctx, val_loader):
        metric = mx.metric.Accuracy()
        for i, batch in enumerate(val_loader):
            data = gluon.utils.split_and_load(batch[0],
                                              ctx_list=ctx,
                                              batch_axis=0)
            label = gluon.utils.split_and_load(batch[1],
                                               ctx_list=ctx,
                                               batch_axis=0)
            outputs = [net(X) for X in data]
            metric.update(label, outputs)
        return metric.get()

    def train(train_data, val_data, epochs, ctx):
        if isinstance(ctx, mx.Context):
            ctx = [ctx]

        net.hybridize()
        net.initialize(mx.init.Xavier(), ctx=ctx)
        net.forward(mx.nd.ones((1, 3, 30, 30), ctx=ctx[0]))
        with SummaryWriter(logdir=log_dir, verbose=False) as sw:
            sw.add_graph(net)

        trainer = gluon.Trainer(net.collect_params(), optimizer, {
            'learning_rate': opt.lr,
            'wd': opt.wd,
            'momentum': opt.momentum
        })
        metric = mx.metric.Accuracy()
        train_metric = mx.metric.Accuracy()
        loss_fn = gluon.loss.SoftmaxCrossEntropyLoss()

        iteration = 0
        lr_decay_count = 0

        best_val_score = 0
        global_step = 0

        for epoch in range(epochs):
            tic = time.time()
            train_metric.reset()
            metric.reset()
            train_loss = 0
            num_batch = len(train_data)
            alpha = 1

            if epoch == lr_decay_epoch[lr_decay_count]:
                trainer.set_learning_rate(trainer.learning_rate * lr_decay)
                lr_decay_count += 1

            tbar = tqdm(train_data)

            for i, batch in enumerate(tbar):
                data = gluon.utils.split_and_load(batch[0],
                                                  ctx_list=ctx,
                                                  batch_axis=0)
                label = gluon.utils.split_and_load(batch[1],
                                                   ctx_list=ctx,
                                                   batch_axis=0)

                with ag.record():
                    output = [net(X) for X in data]
                    loss = [loss_fn(yhat, y) for yhat, y in zip(output, label)]
                for l in loss:
                    l.backward()
                trainer.step(batch_size)
                train_loss += sum([l.sum().asscalar() for l in loss])

                train_metric.update(label, output)
                name, acc = train_metric.get()
                iteration += 1
                global_step += len(loss)

            train_loss /= batch_size * num_batch
            name, acc = train_metric.get()
            name, val_acc = test(ctx, val_data)

            if val_acc > best_val_score:
                best_val_score = val_acc
                net.save_parameters('{}/{}-{}-{:04.3f}-best.params'.format(
                    model_dir, model_name, epoch, best_val_score))

            with SummaryWriter(logdir=log_dir, verbose=False) as sw:
                sw.add_scalar(tag="TrainLos",
                              value=train_loss,
                              global_step=global_step)
                sw.add_scalar(tag="TrainAcc",
                              value=acc,
                              global_step=global_step)
                sw.add_scalar(tag="ValAcc",
                              value=val_acc,
                              global_step=global_step)
                sw.add_graph(net)

            logging.info('[Epoch %d] train=%f val=%f loss=%f time: %f' %
                         (epoch, acc, val_acc, train_loss, time.time() - tic))

            if save_period and save_dir and (epoch + 1) % save_period == 0:
                net.save_parameters('{}/{}-{}.params'.format(
                    save_dir, model_name, epoch))

        if save_period and save_dir:
            net.save_parameters('{}/{}-{}.params'.format(
                save_dir, model_name, epochs - 1))

    if opt.mode == 'hybrid':
        net.hybridize()
    train(train_data, val_data, opt.num_epochs, context)
示例#14
0
def main():
    import matplotlib
    matplotlib.use('Agg')
    import matplotlib.pyplot as plt

    opt = parse_args()
    batch_size = opt.batch_size
    classes = 10

    num_gpus = opt.num_gpus
    batch_size *= max(1, num_gpus)
    context = [mx.gpu(i)
               for i in range(num_gpus)] if num_gpus > 0 else [mx.cpu()]
    num_workers = opt.num_workers

    lr_sch = lr_scheduler.CosineScheduler((50000//batch_size)*opt.num_epochs,
                                          base_lr=opt.lr,
                                          warmup_steps=5*(50000//batch_size),
                                          final_lr=1e-5)
    # lr_sch = lr_scheduler.FactorScheduler((50000//batch_size)*20,
    #                                       factor=0.2, base_lr=opt.lr,
    #                                       warmup_steps=5*(50000//batch_size))
    # lr_sch = LRScheduler('cosine',opt.lr, niters=(50000//batch_size)*opt.num_epochs,)

    model_name = opt.model
    net = SKT_Lite()
    # if model_name.startswith('cifar_wideresnet'):
    #     kwargs = {'classes': classes,
    #             'drop_rate': opt.drop_rate}
    # else:
    #     kwargs = {'classes': classes}
    # net = get_model(model_name, **kwargs)
    if opt.mixup:
        model_name += '_mixup'
    if opt.amp:
        model_name += '_amp'

    makedirs('./'+model_name)
    os.chdir('./'+model_name)
    sw = SummaryWriter(
        logdir='.\\tb\\'+model_name, flush_secs=5, verbose=False)
    makedirs(opt.save_plot_dir)

    if opt.resume_from:
        net.load_parameters(opt.resume_from, ctx=context)
    optimizer = 'nag'

    save_period = opt.save_period
    if opt.save_dir and save_period:
        save_dir = opt.save_dir
        makedirs(save_dir)
    else:
        save_dir = ''
        save_period = 0

    plot_name = opt.save_plot_dir

    logging_handlers = [logging.StreamHandler()]
    if opt.logging_dir:
        logging_dir = opt.logging_dir
        makedirs(logging_dir)
        logging_handlers.append(logging.FileHandler(
            '%s/train_cifar10_%s.log' % (logging_dir, model_name)))

    logging.basicConfig(level=logging.INFO, handlers=logging_handlers)
    logging.info(opt)

    if opt.amp:
        amp.init()

    if opt.profile_mode:
        profiler.set_config(profile_all=True,
                            aggregate_stats=True,
                            continuous_dump=True,
                            filename='%s_profile.json' % model_name)

    transform_train = transforms.Compose([
        gcv_transforms.RandomCrop(32, pad=4),
        CutOut(8),
        # gcv_transforms.block.RandomErasing(s_max=0.25),
        transforms.RandomFlipLeftRight(),
        # transforms.RandomFlipTopBottom(),
        transforms.Resize(32),
        transforms.ToTensor(),
        transforms.Normalize([0.4914, 0.4822, 0.4465],
                             [0.2023, 0.1994, 0.2010])
    ])

    transform_test = transforms.Compose([
        transforms.Resize(32),
        transforms.ToTensor(),
        transforms.Normalize([0.4914, 0.4822, 0.4465],
                             [0.2023, 0.1994, 0.2010])
    ])

    def label_transform(label, classes):
        ind = label.astype('int')
        res = nd.zeros((ind.shape[0], classes), ctx=label.context)
        res[nd.arange(ind.shape[0], ctx=label.context), ind] = 1
        return res

    def test(ctx, val_data):
        metric = mx.metric.Accuracy()
        loss_fn = gluon.loss.SoftmaxCrossEntropyLoss()
        num_batch = len(val_data)
        test_loss = 0
        for i, batch in enumerate(val_data):
            data = gluon.utils.split_and_load(
                batch[0], ctx_list=ctx, batch_axis=0)
            label = gluon.utils.split_and_load(
                batch[1], ctx_list=ctx, batch_axis=0)
            outputs = [net(X) for X in data]
            loss = [loss_fn(yhat, y) for yhat, y in zip(outputs, label)]
            metric.update(label, outputs)
            test_loss += sum([l.sum().asscalar() for l in loss])
        test_loss /= batch_size * num_batch
        name, val_acc = metric.get()
        return name, val_acc, test_loss

    def train(epochs, ctx):
        if isinstance(ctx, mx.Context):
            ctx = [ctx]
        net.initialize(mx.init.MSRAPrelu(), ctx=ctx)

        root = os.path.join('..', 'datasets', 'cifar-10')
        train_data = gluon.data.DataLoader(
            gluon.data.vision.CIFAR10(
                root=root, train=True).transform_first(transform_train),
            batch_size=batch_size, shuffle=True, last_batch='discard', num_workers=num_workers)

        val_data = gluon.data.DataLoader(
            gluon.data.vision.CIFAR10(
                root=root, train=False).transform_first(transform_test),
            batch_size=batch_size, shuffle=False, num_workers=num_workers)

        trainer = gluon.Trainer(net.collect_params(), optimizer,
                                {'learning_rate': opt.lr, 'wd': opt.wd,
                                 'momentum': opt.momentum, 'lr_scheduler': lr_sch})
        if opt.amp:
            amp.init_trainer(trainer)
        metric = mx.metric.Accuracy()
        train_metric = mx.metric.RMSE()
        loss_fn = gluon.loss.SoftmaxCrossEntropyLoss(
            sparse_label=False if opt.mixup else True)
        train_history = TrainingHistory(['training-error', 'validation-error'])
        # acc_history = TrainingHistory(['training-acc', 'validation-acc'])
        loss_history = TrainingHistory(['training-loss', 'validation-loss'])

        iteration = 0

        best_val_score = 0

        for epoch in range(epochs):
            tic = time.time()
            train_metric.reset()
            metric.reset()
            train_loss = 0
            num_batch = len(train_data)
            alpha = 1

            for i, batch in enumerate(train_data):
                if epoch == 0 and iteration == 1 and opt.profile_mode:
                    profiler.set_state('run')
                lam = np.random.beta(alpha, alpha)
                if epoch >= epochs - 20 or not opt.mixup:
                    lam = 1

                data_1 = gluon.utils.split_and_load(
                    batch[0], ctx_list=ctx, batch_axis=0)
                label_1 = gluon.utils.split_and_load(
                    batch[1], ctx_list=ctx, batch_axis=0)

                if not opt.mixup:
                    data = data_1
                    label = label_1
                else:
                    data = [lam*X + (1-lam)*X[::-1] for X in data_1]
                    label = []
                    for Y in label_1:
                        y1 = label_transform(Y, classes)
                        y2 = label_transform(Y[::-1], classes)
                        label.append(lam*y1 + (1-lam)*y2)

                with ag.record():
                    output = [net(X) for X in data]
                    loss = [loss_fn(yhat, y) for yhat, y in zip(output, label)]
                if opt.amp:
                    with ag.record():
                        with amp.scale_loss(loss, trainer) as scaled_loss:
                            ag.backward(scaled_loss)
                            # scaled_loss.backward()
                else:
                    for l in loss:
                        l.backward()
                trainer.step(batch_size)
                train_loss += sum([l.sum().asscalar() for l in loss])

                output_softmax = [nd.SoftmaxActivation(out) for out in output]
                train_metric.update(label, output_softmax)
                metric.update(label_1, output_softmax)
                name, acc = train_metric.get()
                sw.add_scalar(tag='lr', value=trainer.learning_rate,
                              global_step=iteration)
                if epoch == 0 and iteration == 1 and opt.profile_mode:
                    nd.waitall()
                    profiler.set_state('stop')
                iteration += 1

            train_loss /= batch_size * num_batch
            name, acc = train_metric.get()
            _, train_acc = metric.get()
            name, val_acc, _ = test(ctx, val_data)
            if opt.mixup:
                train_history.update([acc, 1-val_acc])
                plt.cla()
                train_history.plot(save_path='%s/%s_history.png' %
                                   (plot_name, model_name))
            else:
                train_history.update([1-train_acc, 1-val_acc])
                plt.cla()
                train_history.plot(save_path='%s/%s_history.png' %
                                   (plot_name, model_name))
            # acc_history.update([train_acc, val_acc])
            # plt.cla()
            # acc_history.plot(save_path='%s/%s_acc.png' %
            #                  (plot_name, model_name), legend_loc='best')

            if val_acc > best_val_score:
                best_val_score = val_acc
                net.save_parameters('%s/%.4f-cifar-%s-%d-best.params' %
                                    (save_dir, best_val_score, model_name, epoch))

            current_lr = trainer.learning_rate
            name, val_acc, val_loss = test(ctx, val_data)
            loss_history.update([train_loss, val_loss])
            plt.cla()
            loss_history.plot(save_path='%s/%s_loss.png' %
                              (plot_name, model_name), y_lim=(0, 2), legend_loc='best')
            logging.info('[Epoch %d] loss=%f train_acc=%f train_RMSE=%f\n     val_acc=%f val_loss=%f lr=%f time: %f' %
                         (epoch, train_loss, train_acc, acc, val_acc, val_loss, current_lr, time.time()-tic))
            sw._add_scalars(tag='Acc',
                            scalar_dict={'train_acc': train_acc, 'test_acc': val_acc}, global_step=epoch)
            sw._add_scalars(tag='Loss',
                            scalar_dict={'train_loss': train_loss, 'test_loss': val_loss}, global_step=epoch)
            if save_period and save_dir and (epoch + 1) % save_period == 0:
                net.save_parameters('%s/cifar10-%s-%d.params' %
                                    (save_dir, model_name, epoch))
        if save_period and save_dir:
            net.save_parameters('%s/cifar10-%s-%d.params' %
                                (save_dir, model_name, epochs-1))

    if opt.mode == 'hybrid':
        net.hybridize()
    train(opt.num_epochs, context)
    if opt.profile_mode:
        profiler.dump(finished=False)
    sw.close()
示例#15
0
        opt.dataset, model_name), time.localtime())
save_dir = save_dir + opt.save_name

makedirs(save_dir)
logger = logging.getLogger()
logger.setLevel(logging.INFO)
log_save_dir = os.path.join(save_dir, 'log.txt')
fh = logging.FileHandler(log_save_dir)
fh.setLevel(logging.INFO)
logger.addHandler(fh)
logger.info(opt)

if opt.dataset == 'NC_MINI_IMAGENET':
    transform_train = transforms.Compose([
        transforms.RandomResizedCrop(84),
        gcv_transforms.RandomCrop(84, pad=8),
        transforms.RandomFlipLeftRight(),
        transforms.RandomColorJitter(brightness=0.4,
                                     contrast=0.4,
                                     saturation=0.4),
        transforms.RandomLighting(0.1),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ])

    transform_test = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ])

    anchor_trans = transforms.Compose([
def main():
    opt = parse_args()

    hvd.init()
    
    batch_size = opt.batch_size
    classes = 100

    # num_gpus = opt.num_gpus
    # batch_size *= max(1, num_gpus)
    context = [mx.gpu(hvd.local_rank())]
    num_workers = hvd.size()
    rank = hvd.rank()

    lr_decay = opt.lr_decay
    lr_decay_epoch = [int(i) for i in opt.lr_decay_epoch.split(',')] + [np.inf]

    model_name = opt.model
    if model_name.startswith('cifar_wideresnet'):
        kwargs = {'classes': classes,
                'drop_rate': opt.drop_rate}
    else:
        kwargs = {'classes': classes}
    net = get_model(model_name, **kwargs)
    if opt.resume_from:
        net.load_parameters(opt.resume_from, ctx = context)
    # optimizer = 'nag'
    optimizer = opt.optimizer

    save_period = opt.save_period
    if opt.save_dir and save_period:
        save_dir = opt.save_dir
        makedirs(save_dir)
    else:
        save_dir = ''
        save_period = 0

    # plot_path = opt.save_plot_dir

    logging.basicConfig(level=logging.INFO,
                    filename="train_cifar100_qsparselocalsgd_{}_{}_{}_{}.log".format(opt.model, opt.optimizer, opt.batch_size, opt.lr),
                    filemode='a')
    console = logging.StreamHandler()
    console.setLevel(logging.INFO)
    logging.getLogger('').addHandler(console)

    if rank == 0:
        logging.info(opt)

    transform_train = transforms.Compose([
        gcv_transforms.RandomCrop(32, pad=4),
        transforms.RandomFlipLeftRight(),
        transforms.ToTensor(),
        transforms.Normalize([0.4914, 0.4822, 0.4465], [0.2023, 0.1994, 0.2010])
    ])

    transform_test = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize([0.4914, 0.4822, 0.4465], [0.2023, 0.1994, 0.2010])
    ])

    optimizer_params = {'learning_rate': opt.lr, 'wd': opt.wd, 'momentum': opt.momentum}

    def test(ctx, val_data):
        metric = mx.metric.Accuracy()
        for i, batch in enumerate(val_data):
            data = gluon.utils.split_and_load(batch[0], ctx_list=ctx, batch_axis=0)
            label = gluon.utils.split_and_load(batch[1], ctx_list=ctx, batch_axis=0)
            # outputs = [net(X) for X in data]
            outputs = [net(X) for X in data]
            metric.update(label, outputs)
        return metric.get()

    def train(epochs, ctx):
        if isinstance(ctx, mx.Context):
            ctx = [ctx]
        net.initialize(mx.init.Xavier(), ctx=ctx)

        # if opt.print_tensor_shape and rank == 0:
        #     print(net)

        train_dataset = gluon.data.vision.CIFAR100(train=True).transform_first(transform_train)

        train_data = gluon.data.DataLoader(
            train_dataset,
            sampler=SplitSampler(len(train_dataset), num_parts=num_workers, part_index=rank),
            batch_size=batch_size, last_batch='discard', num_workers=opt.num_workers)

        # val_dataset = gluon.data.vision.CIFAR100(train=False).transform_first(transform_test)
        # val_data = gluon.data.DataLoader(
        #     val_dataset,
        #     sampler=SplitSampler(len(val_dataset), num_parts=num_workers, part_index=rank),
        #     batch_size=batch_size, num_workers=opt.num_workers)

        val_data = gluon.data.DataLoader(
            gluon.data.vision.CIFAR100(train=False).transform_first(transform_test),
            batch_size=batch_size, shuffle=False, num_workers=opt.num_workers)

        hvd.broadcast_parameters(net.collect_params(), root_rank=0)

        trainer = QSparseLocalSGDTrainerV1(
            net.collect_params(),  
            'nag', optimizer_params, 
            input_sparse_ratio=1./opt.input_sparse, 
            output_sparse_ratio=1./opt.output_sparse, 
            layer_sparse_ratio=1./opt.layer_sparse,
            local_sgd_interval=opt.local_sgd_interval)

        # trainer = gluon.Trainer(net.collect_params(), optimizer,
                                # {'learning_rate': opt.lr, 'wd': opt.wd, 'momentum': opt.momentum})
        
        metric = mx.metric.Accuracy()
        train_metric = mx.metric.Accuracy()
        loss_fn = gluon.loss.SoftmaxCrossEntropyLoss()
        train_history = TrainingHistory(['training-error', 'validation-error'])

        iteration = 0
        lr_decay_count = 0

        best_val_score = 0

        lr = opt.lr

        for epoch in range(epochs):
            tic = time.time()
            train_metric.reset()
            metric.reset()
            train_loss = 0
            num_batch = len(train_data)
            alpha = 1

            if epoch == lr_decay_epoch[lr_decay_count]:
                lr *= lr_decay
                trainer.set_learning_rate(lr)
                lr_decay_count += 1

            for i, batch in enumerate(train_data):
                data = gluon.utils.split_and_load(batch[0], ctx_list=ctx, batch_axis=0)
                label = gluon.utils.split_and_load(batch[1], ctx_list=ctx, batch_axis=0)

                with ag.record():
                    output = [net(X) for X in data]
                    loss = [loss_fn(yhat, y) for yhat, y in zip(output, label)]

                for l in loss:
                    l.backward()
                trainer.step(batch_size)
                train_loss += sum([l.sum().asscalar() for l in loss])

                train_metric.update(label, output)
                name, acc = train_metric.get()
                iteration += 1

            mx.nd.waitall()
            toc = time.time()
            
            train_loss /= batch_size * num_batch
            name, acc = train_metric.get()
            # name, val_acc = test(ctx, val_data)

            trainer.pre_test()
            name, val_acc = test(ctx, val_data)
            trainer.post_test()
            
            train_history.update([1-acc, 1-val_acc])
            # train_history.plot(save_path='%s/%s_history.png'%(plot_path, model_name))

            # allreduce the results
            allreduce_array_nd = mx.nd.array([train_loss, acc, val_acc])
            hvd.allreduce_(allreduce_array_nd, name='allreduce_array', average=True)
            allreduce_array_np = allreduce_array_nd.asnumpy()
            train_loss = np.asscalar(allreduce_array_np[0])
            acc = np.asscalar(allreduce_array_np[1])
            val_acc = np.asscalar(allreduce_array_np[2])

            if val_acc > best_val_score:
                best_val_score = val_acc
                # net.save_parameters('%s/%.4f-cifar-%s-%d-best.params'%(save_dir, best_val_score, model_name, epoch))

            if rank == 0:
                logging.info('[Epoch %d] train=%f val=%f loss=%f comm=%.2f time: %f' %
                    (epoch, acc, val_acc, train_loss, trainer._comm_counter/1e6, toc-tic))

                if save_period and save_dir and (epoch + 1) % save_period == 0:
                    net.save_parameters('%s/cifar10-%s-%d.params'%(save_dir, model_name, epoch))

            trainer._comm_counter = 0.

        if rank == 0:
            if save_period and save_dir:
                net.save_parameters('%s/cifar10-%s-%d.params'%(save_dir, model_name, epochs-1))



    if opt.mode == 'hybrid':
        net.hybridize()
    train(opt.num_epochs, context)
def main():
    opt = parse_args()

    batch_size = opt.batch_size
    num_gpus = opt.num_gpus
    batch_size *= max(1, num_gpus)
    context = [mx.gpu(i) for i in range(num_gpus)] if num_gpus > 0 else [mx.cpu()]
    for ctx in context:
        mx.random.seed(seed_state=opt.random_seed, ctx=ctx)
    np.random.seed(opt.random_seed)
    random.seed(opt.random_seed)
    num_workers = opt.num_workers
    lr_decay = opt.lr_decay
    lr_decay_epoch = [int(i) for i in opt.lr_decay_epoch.split(',')] + [np.inf]
    save_period = opt.save_period
    log_dir = opt.log_dir
    save_dir = opt.save_dir
    makedirs(save_dir)
    makedirs(log_dir)

    if opt.dataset == 'cifar10':
        classes = 10
        transform_train = transforms.Compose([
            gcv_transforms.RandomCrop(32, pad=4),
            transforms.RandomFlipLeftRight(),
            transforms.ToTensor(),
            transforms.Normalize([0.4914, 0.4822, 0.4465], [0.2023, 0.1994, 0.2010])
        ])

        transform_test = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize([0.4914, 0.4822, 0.4465], [0.2023, 0.1994, 0.2010])
        ])
        train_data = gluon.data.DataLoader(
            gluon.data.vision.CIFAR10(train=True).transform_first(transform_train),
            batch_size=batch_size, shuffle=True, last_batch='discard', num_workers=num_workers)

        val_data = gluon.data.DataLoader(
            gluon.data.vision.CIFAR10(train=False).transform_first(transform_test),
            batch_size=batch_size, shuffle=False, num_workers=num_workers)
    elif opt.dataset == 'cifar100':
        classes = 100
        transform_train = transforms.Compose([
            gcv_transforms.RandomCrop(32, pad=4),
            transforms.RandomFlipLeftRight(),
            transforms.ToTensor(),
            transforms.Normalize([0.5071, 0.4865, 0.4409], [0.2673, 0.2564, 0.2762])
        ])

        transform_test = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize([0.5071, 0.4865, 0.4409], [0.2673, 0.2564, 0.2762])
        ])
        train_data = gluon.data.DataLoader(
            gluon.data.vision.CIFAR100(train=True).transform_first(transform_train),
            batch_size=batch_size, shuffle=True, last_batch='discard', num_workers=num_workers)

        val_data = gluon.data.DataLoader(
            gluon.data.vision.CIFAR100(train=False).transform_first(transform_test),
            batch_size=batch_size, shuffle=False, num_workers=num_workers)


    model = models.__dict__[opt.arch](dataset=opt.dataset, depth=opt.depth)
    model_name = opt.arch + '_' + str(opt.depth)
    if opt.resume:
        net.load_parameters(opt.resume, ctx=context)

    sw = SummaryWriter(logdir=log_dir, flush_secs=5, verbose=False)


    logging.basicConfig(level=logging.INFO)
    logging.info(opt)

    def test(ctx, val_data):
        metric = mx.metric.Accuracy()
        for i, batch in enumerate(val_data):
            data = gluon.utils.split_and_load(batch[0], ctx_list=ctx, batch_axis=0)
            label = gluon.utils.split_and_load(batch[1], ctx_list=ctx, batch_axis=0)
            outputs = [model(X.astype(opt.dtype, copy=False)) for X in data]
            metric.update(label, outputs)
        return metric.get()

    def train(epochs, ctx):
        if isinstance(ctx, mx.Context):
            ctx = [ctx]
        model.initialize(mx.init.Xavier(), ctx=ctx)
        trainer = gluon.Trainer(model.collect_params(), 'sgd',
                                {'learning_rate': opt.lr, 'wd': opt.wd, 'momentum': opt.momentum})
        metric = mx.metric.Accuracy()
        train_metric = mx.metric.Accuracy()
        loss_fn = gluon.loss.SoftmaxCrossEntropyLoss()
        iteration = 0
        lr_decay_count = 0

        best_val_score = 0

        for epoch in range(epochs):
            tic = time.time()
            btic = time.time()
            train_metric.reset()
            metric.reset()
            train_loss = 0
            num_batch = len(train_data)
            if epoch == lr_decay_epoch[lr_decay_count]:
                trainer.set_learning_rate(trainer.learning_rate * lr_decay)
                lr_decay_count += 1

            for i, batch in enumerate(train_data):
                data = gluon.utils.split_and_load(batch[0], ctx_list=ctx, batch_axis=0)
                label = gluon.utils.split_and_load(batch[1], ctx_list=ctx, batch_axis=0)

                with ag.record():
                    output = [model(X.astype(opt.dtype, copy=False)) for X in data]
                    loss = [loss_fn(yhat, y.astype(opt.dtype, copy=False)) for yhat, y in zip(output, label)]
                for l in loss:
                    l.backward()
                trainer.step(batch_size)
                train_loss += sum([l.sum().asscalar() for l in loss])
                sw.add_scalar(tag='train_loss', value=train_loss / len(loss), global_step=iteration)

                train_metric.update(label, output)
                name, acc = train_metric.get()
                sw.add_scalar(tag='train_{}_curves'.format(name),
                              value=('train_{}_value'.format(name), acc),
                              global_step=iteration)

                if opt.log_interval and not (i + 1) % opt.log_interval:
                    logging.info('Epoch[%d] Batch [%d]\tSpeed: %f samples/sec\t%s=%f\tlr=%f' % (
                        epoch, i, batch_size * opt.log_interval / (time.time() - btic),
                        name, acc, trainer.learning_rate))
                    btic = time.time()

                iteration += 1
            if epoch == 0:
                sw.add_graph(model)
            train_loss /= batch_size * num_batch
            _, acc = train_metric.get()
            _, val_acc = test(ctx, val_data)
            sw.add_scalar(tag='val_acc_curves', value=('valid_acc_value', val_acc), global_step=epoch)

            if val_acc > best_val_score:
                best_val_score = val_acc
                model.save_parameters('%s/%.4f-%s-%s-%d-best.params' % (save_dir, best_val_score, opt.dataset, model_name, epoch))
                trainer.save_states('%s/%.4f-%s-%s-%d-best.states' % (save_dir, best_val_score, opt.dataset, model_name, epoch))

            logging.info('[Epoch %d] train=%f val=%f loss=%f time: %f' %
                         (epoch, acc, val_acc, train_loss, time.time() - tic))

            if save_period and save_dir and (epoch + 1) % save_period == 0:
                model.save_parameters('%s/%s-%s-%d.params' % (save_dir, opt.dataset, model_name, epoch))
                trainer.save_states('%s/%s-%s-%d.states'%(save_dir, opt.datset, model_name, epoch))

        if save_period and save_dir:
            model.save_parameters('%s/%s-%s-%d.params' % (save_dir, opt.dataset, model_name, epochs - 1))
            trainer.save_states('%s/%s-%s-%d.states' % (save_dir, opt.dataset, model_name, epochs - 1))

    if opt.mode == 'hybrid':
        model.hybridize()
    train(opt.num_epochs, context)
示例#18
0
def main():
    opt = parse_args()

    batch_size = opt.batch_size
    if opt.dataset == 'cifar10':
        classes = 10
    elif opt.dataset == 'cifar100':
        classes = 100
    else:
        raise ValueError('Unknown Dataset')

    if len(mx.test_utils.list_gpus()) == 0:
        context = [mx.cpu()]
    else:
        context = [mx.gpu(int(i)) for i in opt.gpus.split(',') if i.strip()]
        context = context if context else [mx.cpu()]
    print("context: ", context)
    num_gpus = len(context)
    batch_size *= max(1, num_gpus)
    num_workers = opt.num_workers

    lr_decay = opt.lr_decay
    lr_decay_epoch = [int(i) for i in opt.lr_decay_epoch.split(',')] + [np.inf]

    # layers = [18, 18, 18]

    if opt.model == 'resnet':
        layers = [opt.blocks] * 3
        channels = [x * opt.channel_times for x in [16, 16, 32, 64]]
        start_layer = opt.start_layer

        net = ResNet110V2ASKC(in_askc_type=opt.askc_type,
                              layers=layers,
                              channels=channels,
                              classes=classes,
                              start_layer=start_layer)
        model_name = 'resnet_' + opt.askc_type + '_blocks_' + str(opt.blocks) + '_channel_times_' + \
                     str(opt.channel_times) + '_start_layer_' + str(start_layer) + '_mixup'
        print("opt.askc_type: ", opt.askc_type)
        print("layers: ", layers)
        print("channels: ", channels)
        print("classes: ", classes)
        print("start_layer: ", start_layer)
    elif opt.model == 'resnext29_32x4d':
        num_layers = 29
        layer = (num_layers - 2) // 9
        layers = [layer] * 3
        cardinality = 32
        bottleneck_width = 4
        net = CIFARResNextASKC(layers,
                               cardinality,
                               bottleneck_width,
                               classes,
                               use_se=False)
        model_name = 'resnext29_32x4d_askc_mixup'
    elif opt.model == 'resnext38_32x4d':
        num_layers = 38
        layer = (num_layers - 2) // 9
        layers = [layer] * 3
        cardinality = 32
        bottleneck_width = 4
        net = CIFARResNextASKC(layers,
                               cardinality,
                               bottleneck_width,
                               classes,
                               use_se=False)
        model_name = 'resnext47_32x4d_iaff_mixup'
    elif opt.model == 'resnext47_32x4d':
        num_layers = 47
        layer = (num_layers - 2) // 9
        layers = [layer] * 3
        cardinality = 32
        bottleneck_width = 4
        net = CIFARResNextASKC(layers,
                               cardinality,
                               bottleneck_width,
                               classes,
                               use_se=False)
        model_name = 'resnext47_32x4d_aff_mixup'
    elif opt.model == 'se_resnext38_32x4d':
        num_layers = 38
        layer = (num_layers - 2) // 9
        layers = [layer] * 3
        cardinality = 32
        bottleneck_width = 4
        net = CIFARResNextASKC(layers,
                               cardinality,
                               bottleneck_width,
                               classes,
                               use_se=True)
        model_name = 'se_resnext38_32x4d_askc_mixup'

    if opt.resume_from:
        net.load_parameters(opt.resume_from, ctx=context)
    optimizer = 'nag'

    save_period = opt.save_period
    if opt.save_dir and save_period:
        save_dir = opt.save_dir
        makedirs(save_dir)
    else:
        save_dir = ''
        save_period = 0

    plot_name = opt.save_plot_dir

    logging_handlers = [logging.StreamHandler()]
    if opt.logging_dir:
        logging_dir = opt.logging_dir
        makedirs(logging_dir)
        logging_handlers.append(
            logging.FileHandler('%s/train_%s_%s.log' %
                                (logging_dir, opt.dataset, model_name)))

    logging.basicConfig(level=logging.INFO, handlers=logging_handlers)
    logging.info(opt)

    transform_train = transforms.Compose([
        gcv_transforms.RandomCrop(32, pad=4),
        transforms.RandomFlipLeftRight(),
        transforms.ToTensor(),
        transforms.Normalize([0.4914, 0.4822, 0.4465],
                             [0.2023, 0.1994, 0.2010])
    ])

    transform_test = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize([0.4914, 0.4822, 0.4465],
                             [0.2023, 0.1994, 0.2010])
    ])

    def label_transform(label, classes):
        ind = label.astype('int')
        res = nd.zeros((ind.shape[0], classes), ctx=label.context)
        res[nd.arange(ind.shape[0], ctx=label.context), ind] = 1
        return res

    def test(ctx, val_data):
        metric = mx.metric.Accuracy()
        for i, batch in enumerate(val_data):
            data = gluon.utils.split_and_load(batch[0],
                                              ctx_list=ctx,
                                              batch_axis=0)
            label = gluon.utils.split_and_load(batch[1],
                                               ctx_list=ctx,
                                               batch_axis=0)
            outputs = [net(X) for X in data]
            metric.update(label, outputs)
        return metric.get()

    def train(epochs, ctx):
        if isinstance(ctx, mx.Context):
            ctx = [ctx]
        net.initialize(mx.init.Xavier(), ctx=ctx)

        if opt.summary:
            summary(net, mx.nd.zeros((1, 3, 32, 32), ctx=ctx[0]))
            sys.exit()

        if opt.dataset == 'cifar10':
            # CIFAR10
            train_data = gluon.data.DataLoader(gluon.data.vision.CIFAR10(
                train=True).transform_first(transform_train),
                                               batch_size=batch_size,
                                               shuffle=True,
                                               last_batch='discard',
                                               num_workers=num_workers)
            val_data = gluon.data.DataLoader(gluon.data.vision.CIFAR10(
                train=False).transform_first(transform_test),
                                             batch_size=batch_size,
                                             shuffle=False,
                                             num_workers=num_workers)
        elif opt.dataset == 'cifar100':
            # CIFAR100
            train_data = gluon.data.DataLoader(gluon.data.vision.CIFAR100(
                train=True).transform_first(transform_train),
                                               batch_size=batch_size,
                                               shuffle=True,
                                               last_batch='discard',
                                               num_workers=num_workers)
            val_data = gluon.data.DataLoader(gluon.data.vision.CIFAR100(
                train=False).transform_first(transform_test),
                                             batch_size=batch_size,
                                             shuffle=False,
                                             num_workers=num_workers)
        else:
            raise ValueError('Unknown Dataset')

        trainer = gluon.Trainer(net.collect_params(), optimizer, {
            'learning_rate': opt.lr,
            'wd': opt.wd,
            'momentum': opt.momentum
        })
        metric = mx.metric.Accuracy()
        train_metric = mx.metric.RMSE()
        loss_fn = gluon.loss.SoftmaxCrossEntropyLoss(sparse_label=False)
        train_history = TrainingHistory(['training-error', 'validation-error'])

        iteration = 0
        lr_decay_count = 0

        best_val_score = 0

        for epoch in range(epochs):
            tic = time.time()
            train_metric.reset()
            metric.reset()
            train_loss = 0
            num_batch = len(train_data)
            alpha = 1

            if epoch == lr_decay_epoch[lr_decay_count]:
                trainer.set_learning_rate(trainer.learning_rate * lr_decay)
                lr_decay_count += 1

            for i, batch in enumerate(train_data):
                lam = np.random.beta(alpha, alpha)
                if epoch >= epochs - 20:
                    lam = 1

                data_1 = gluon.utils.split_and_load(batch[0],
                                                    ctx_list=ctx,
                                                    batch_axis=0)
                label_1 = gluon.utils.split_and_load(batch[1],
                                                     ctx_list=ctx,
                                                     batch_axis=0)

                data = [lam * X + (1 - lam) * X[::-1] for X in data_1]
                label = []
                for Y in label_1:
                    y1 = label_transform(Y, classes)
                    y2 = label_transform(Y[::-1], classes)
                    label.append(lam * y1 + (1 - lam) * y2)

                with ag.record():
                    output = [net(X) for X in data]
                    loss = [loss_fn(yhat, y) for yhat, y in zip(output, label)]
                for l in loss:
                    l.backward()
                trainer.step(batch_size)
                train_loss += sum([l.sum().asscalar() for l in loss])

                output_softmax = [nd.SoftmaxActivation(out) for out in output]
                train_metric.update(label, output_softmax)
                name, acc = train_metric.get()
                iteration += 1

            train_loss /= batch_size * num_batch
            name, acc = train_metric.get()
            name, val_acc = test(ctx, val_data)
            train_history.update([acc, 1 - val_acc])
            train_history.plot(save_path='%s/%s_history.png' %
                               (plot_name, model_name))

            if val_acc > best_val_score:
                best_val_score = val_acc
                net.save_parameters(
                    '%s/%.4f-%s-%s-best.params' %
                    (save_dir, best_val_score, opt.dataset, model_name))

            name, val_acc = test(ctx, val_data)
            logging.info('[Epoch %d] train=%f val=%f loss=%f time: %f' %
                         (epoch, acc, val_acc, train_loss, time.time() - tic))

        host_name = socket.gethostname()
        with open(
                opt.dataset + '_' + host_name + '_GPU_' + opt.gpus +
                '_best_Acc.log', 'a') as f:
            f.write('best Acc: {:.4f}\n'.format(best_val_score))

    if not opt.summary:
        if opt.mode == 'hybrid':
            net.hybridize()
    train(opt.num_epochs, context)
示例#19
0
plot_name = opt.save_plot_dir

logging_handlers = [logging.StreamHandler()]
if opt.logging_dir:
    logging_dir = opt.logging_dir
    makedirs(logging_dir)
    logging_handlers.append(
        logging.FileHandler('%s/train_cifar10_%s.log' %
                            (logging_dir, model_name)))

logging.basicConfig(level=logging.INFO, handlers=logging_handlers)
logging.info(opt)

transform_train = transforms.Compose([
    gcv_transforms.RandomCrop(32, pad=4),
    transforms.RandomFlipLeftRight(),
    transforms.ToTensor(),
    transforms.Normalize([0.4914, 0.4822, 0.4465], [0.2023, 0.1994, 0.2010])
])

transform_test = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize([0.4914, 0.4822, 0.4465], [0.2023, 0.1994, 0.2010])
])


def label_transform(label, classes):
    ind = label.astype('int')
    res = nd.zeros((ind.shape[0], classes), ctx=label.context)
    res[nd.arange(ind.shape[0], ctx=label.context), ind] = 1
示例#20
0
import numpy as np

import mxnet as mx
from mxnet import nd, autograd, gluon, context
from mxnet.gluon.data.vision import transforms, CIFAR10
from gluoncv.data import transforms as gcv_transforms

# ## 2. Prepare Data

# ### 2.1 Define Data Transformers

# In[3]:

train_transformer = transforms.Compose([
    gcv_transforms.RandomCrop(cfg.IMAGE_SIZE, pad=4),
    transforms.RandomFlipLeftRight(),
    transforms.ToTensor(),
    transforms.Normalize(cfg.CIFAR10_MEAN, cfg.CIFAR10_STD)
])

test_transformer = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(cfg.CIFAR10_MEAN, cfg.CIFAR10_STD)
])

LOG(INFO, 'Data Transformers defining done')

# ### 2.2 Load Dataset

# In[4]:
示例#21
0
    def __init__(self,
                 batch_size,
                 good_sample_ratio,
                 sub_sampling,
                 dataset,
                 mixedmodel=False):

        self.transform_train = transforms.Compose([
            gcv_transforms.RandomCrop(32, pad=4),
            transforms.RandomFlipLeftRight(),
            transforms.ToTensor(),
            transforms.Normalize([0.4914, 0.4822, 0.4465],
                                 [0.2023, 0.1994, 0.2010])
        ])

        self.transform_test = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize([0.4914, 0.4822, 0.4465],
                                 [0.2023, 0.1994, 0.2010])
        ])

        self.dataset_name = dataset
        self.dataset = {}
        self.label_size = 10
        if dataset == 'mnist':
            dataset_handle = MNIST
        elif dataset == 'fashionmnist':
            dataset_handle = FashionMNIST
        elif dataset == 'cifar10':
            dataset_handle = CIFAR10
        else:
            raise NotImplementedError

        self.good_sample_ratio = good_sample_ratio
        self.batch_size = batch_size

        if not os.path.isfile(save_data_process):
            print('using handle')
            train = dataset_handle(train=True, )
            test = dataset_handle(train=False, )
            print('finished building train and test')

            size_train = train._label.shape[0]
            print('size of dataset:', size_train)

            if sub_sampling < 1.0:
                size_ptrain = int(size_train * sub_sampling)
                subsample_idx = np.random.choice(size_train,
                                                 size_ptrain,
                                                 replace=False)
                train._data = train._data[subsample_idx]
                train._label = train._label[subsample_idx]
            print('finish subsampling')

            self.dataset['train'] = copy.deepcopy(train)
            self.dataset['test'] = copy.deepcopy(test)
            self.dataset['test-bad'] = copy.deepcopy(test)

            print('finish deepcopy')
            if self.dataset_name != 'cifar10':
                self.dataset['train']._data = mx.nd.transpose(
                    train._data, axes=(0, 3, 1, 2)).astype('float32') / 255.0
                self.dataset['test']._data = mx.nd.transpose(
                    test._data, axes=(0, 3, 1, 2)).astype('float32') / 255.0

            self.num_train, self.num_test = self.dataset['train']._label.shape[
                0], self.dataset['test']._label.shape[0]

            print('start flipping labels')
            print('making bad training set')
            cnt_label = {}
            for idx, label in enumerate(self.dataset['train']._label):
                cnt_label[label] = cnt_label.get(label, 0) + 1
            cnt_good_label_tgt = {}

            for k, v in cnt_label.items():
                cnt_good_label_tgt[k] = int(v * self.good_sample_ratio)

            manipulate_label = {}
            good_idx_set = []
            for idx, label in enumerate(self.dataset['train']._label):
                manipulate_label[label] = manipulate_label.get(label, 0) + 1
                if manipulate_label[label] > cnt_good_label_tgt[label]:
                    if not mixedmodel:
                        p = np.random.randint(0, self.label_size)
                        while True:
                            if p != label:
                                self.dataset['train']._label[idx] = p
                                break
                            p = np.random.randint(0, self.label_size)
                    else:
                        p = label + 1 if label < self.label_size - 1 else 0
                        self.dataset['train']._label[idx] = p
                else:
                    good_idx_set.append(idx)
            self.good_idx_set = good_idx_set

            print('making bad validation set')
            cnt_label_val = {}
            for idx, label in enumerate(self.dataset['test-bad']._label):
                cnt_label_val[label] = cnt_label_val.get(label, 0) + 1
            cnt_good_label_tgt_val = {}

            for k, v in cnt_label_val.items():
                cnt_good_label_tgt_val[k] = int(v * self.good_sample_ratio)

            manipulate_label_val = {}
            good_idx_set_val = []
            for idx, label in enumerate(self.dataset['test-bad']._label):
                manipulate_label_val[label] = manipulate_label_val.get(
                    label, 0) + 1
                if manipulate_label_val[label] > cnt_good_label_tgt_val[label]:
                    if not mixedmodel:
                        p = np.random.randint(0, self.label_size)
                        while True:
                            if p != label:
                                self.dataset['test-bad']._label[idx] = p
                                break
                            p = np.random.randint(0, self.label_size)
                    else:
                        p = label + 1 if label < self.label_size - 1 else 0
                        self.dataset['test-bad']._label[idx] = p
                else:
                    good_idx_set_val.append(idx)
            self.good_idx_set_val = good_idx_set_val

            print('finish flipping labels')
            self.good_idx_array = np.array(self.good_idx_set)
            self.all_idx_array = np.arange(len(self.dataset['train']._label))
            self.bad_idx_array = np.setdiff1d(self.all_idx_array,
                                              self.good_idx_array)
            self.dataset['train']._data = mx.nd.concat(
                self.dataset['train']._data[self.good_idx_array],
                self.dataset['train']._data[self.bad_idx_array],
                dim=0)
            self.dataset['train']._label = np.concatenate(
                (self.dataset['train']._label[self.good_idx_array],
                 self.dataset['train']._label[self.bad_idx_array]),
                axis=0)
            self.good_idx_array = np.arange(len(self.good_idx_array))
            self.bad_idx_array = np.setdiff1d(self.all_idx_array,
                                              self.good_idx_array)

            save = {}
            save['train_data'] = self.dataset['train']
            save['test_data'] = self.dataset['test']
            save['test_data_bad'] = self.dataset['test-bad']
            save['good_idx_array'] = self.good_idx_array
            save['bad_idx_array'] = self.bad_idx_array
            save['all_idx_array'] = self.all_idx_array
            with open(save_data_process, "wb") as f:
                pickle.dump(save, f)
        else:
            with open(save_data_process, "rb") as f:
                save = pickle.load(f)
            self.dataset['train'] = save['train_data']
            self.dataset['test'] = save['test_data']
            self.good_idx_array = save['good_idx_array']
            self.bad_idx_array = save['bad_idx_array']
            self.all_idx_array = save['all_idx_array']
            self.dataset['test-bad'] = save['test_data_bad']
        if self.dataset_name == 'cifar10':
            self.val_data = gluon.data.DataLoader(
                self.dataset['test'].transform_first(self.transform_test),
                batch_size=batch_size,
                shuffle=False)
        else:
            self.val_data = gluon.data.DataLoader(self.dataset['test'],
                                                  batch_size=batch_size,
                                                  shuffle=False)
        if self.dataset_name == 'cifar10':
            self.val_data_bad = gluon.data.DataLoader(
                self.dataset['test-bad'].transform_first(self.transform_test),
                batch_size=batch_size,
                shuffle=False)
        else:
            self.val_data_bad = gluon.data.DataLoader(self.dataset['test-bad'],
                                                      batch_size=batch_size,
                                                      shuffle=False)
示例#22
0
def main():
    opt = parse_args()

    batch_size = opt.batch_size
    if opt.dataset == 'cifar10':
        classes = 10
    elif opt.dataset == 'cifar100':
        classes = 100
    else:
        raise ValueError('Unknown Dataset')

    if len(mx.test_utils.list_gpus()) == 0:
        context = [mx.cpu()]
    else:
        context = [mx.gpu(int(i)) for i in opt.gpus.split(',') if i.strip()]
        context = context if context else [mx.cpu()]
    print("context: ", context)
    num_gpus = len(context)
    batch_size *= max(1, num_gpus)
    num_workers = opt.num_workers

    lr_decay = opt.lr_decay
    lr_decay_epoch = [int(i) for i in opt.lr_decay_epoch.split(',')] + [np.inf]

    model_name = 'ResNet20_b_' + str(opt.blocks) + '_' + opt.act_type
    print("model_name", model_name)

    if model_name.startswith('cifar_wideresnet'):
        kwargs = {'classes': classes, 'drop_rate': opt.drop_rate}
    else:
        kwargs = {'classes': classes}

    # scenario = 'ATAC'

    # main config
    layers = [opt.blocks] * 3
    channels = [x * 1 for x in [16, 16, 32, 64]]
    act_type = opt.act_type  # relu, prelu, elu, selu, gelu, swish, xUnit, ChaATAC
    r = opt.r

    # spatial scope
    skernel = 3
    dilation = 1
    act_dilation = 1  # (8, 16), 4

    # ablation study
    useReLU = opt.useReLU
    useGlobal = opt.useGlobal
    asBackbone = False
    act_layers = opt.act_layers
    replace_act = 'relu'
    act_order = 'bac'  # 'pre', 'bac'

    print("model: ", opt.model)
    print("r: ", opt.r)
    if opt.model == 'atac':
        net = ResNet20V2ATAC(layers=layers,
                             channels=channels,
                             classes=classes,
                             act_type=act_type,
                             r=r,
                             skernel=skernel,
                             dilation=dilation,
                             useReLU=useReLU,
                             useGlobal=useGlobal,
                             act_layers=act_layers,
                             replace_act=replace_act,
                             act_order=act_order,
                             asBackbone=asBackbone)

        print("layers: ", layers)
        print("channels: ", channels)
        print("act_type: ", act_type)

        print("skernel: ", skernel)
        print("dilation: ", dilation)
        print("act_dilation: ", act_dilation)

        print("useReLU: ", useReLU)
        print("useGlobal: ", useGlobal)
        print("asBackbone: ", asBackbone)
        print("act_layers: ", act_layers)
        print("replace_act: ", replace_act)
        print("act_order: ", act_order)

    if opt.resume_from:
        net.load_parameters(opt.resume_from, ctx=context)
    optimizer = 'nag'

    save_period = opt.save_period
    if opt.save_dir and save_period:
        save_dir = opt.save_dir
        makedirs(save_dir)
    else:
        save_dir = ''
        save_period = 0

    plot_path = opt.save_plot_dir

    logging.basicConfig(level=logging.INFO)
    logging.info(opt)

    transform_train = transforms.Compose([
        gcv_transforms.RandomCrop(32, pad=4),
        transforms.RandomFlipLeftRight(),
        transforms.ToTensor(),
        transforms.Normalize([0.4914, 0.4822, 0.4465],
                             [0.2023, 0.1994, 0.2010])
    ])

    transform_test = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize([0.4914, 0.4822, 0.4465],
                             [0.2023, 0.1994, 0.2010])
    ])

    def test(ctx, val_data):
        metric = mx.metric.Accuracy()
        for i, batch in enumerate(val_data):
            data = gluon.utils.split_and_load(batch[0],
                                              ctx_list=ctx,
                                              batch_axis=0)
            label = gluon.utils.split_and_load(batch[1],
                                               ctx_list=ctx,
                                               batch_axis=0)
            outputs = [net(X) for X in data]
            metric.update(label, outputs)
        return metric.get()

    def train(epochs, ctx):
        if isinstance(ctx, mx.Context):
            ctx = [ctx]
        net.initialize(mx.init.MSRAPrelu(), ctx=ctx)

        if opt.summary:
            net.summary(mx.nd.zeros((1, 3, 32, 32)))

        if opt.dataset == 'cifar10':
            # CIFAR10
            train_data = gluon.data.DataLoader(gluon.data.vision.CIFAR10(
                train=True).transform_first(transform_train),
                                               batch_size=batch_size,
                                               shuffle=True,
                                               last_batch='discard',
                                               num_workers=num_workers)
            val_data = gluon.data.DataLoader(gluon.data.vision.CIFAR10(
                train=False).transform_first(transform_test),
                                             batch_size=batch_size,
                                             shuffle=False,
                                             num_workers=num_workers)
        elif opt.dataset == 'cifar100':
            # CIFAR100
            train_data = gluon.data.DataLoader(gluon.data.vision.CIFAR100(
                train=True).transform_first(transform_train),
                                               batch_size=batch_size,
                                               shuffle=True,
                                               last_batch='discard',
                                               num_workers=num_workers)
            val_data = gluon.data.DataLoader(gluon.data.vision.CIFAR100(
                train=False).transform_first(transform_test),
                                             batch_size=batch_size,
                                             shuffle=False,
                                             num_workers=num_workers)
        else:
            raise ValueError('Unknown Dataset')

        if optimizer == 'nag':
            trainer = gluon.Trainer(net.collect_params(), optimizer, {
                'learning_rate': opt.lr,
                'wd': opt.wd,
                'momentum': opt.momentum
            })
        elif optimizer == 'adagrad':
            trainer = gluon.Trainer(net.collect_params(), optimizer, {
                'learning_rate': opt.lr,
                'wd': opt.wd
            })
        elif optimizer == 'adam':
            trainer = gluon.Trainer(net.collect_params(), optimizer, {
                'learning_rate': opt.lr,
                'wd': opt.wd
            })
        else:
            raise ValueError('Unknown optimizer')

        metric = mx.metric.Accuracy()
        train_metric = mx.metric.Accuracy()
        loss_fn = gluon.loss.SoftmaxCrossEntropyLoss()
        train_history = TrainingHistory(['training-error', 'validation-error'])
        host_name = socket.gethostname()

        iteration = 0
        lr_decay_count = 0

        best_val_score = 0

        for epoch in range(epochs):
            tic = time.time()
            train_metric.reset()
            metric.reset()
            train_loss = 0
            num_batch = len(train_data)
            alpha = 1

            if epoch == lr_decay_epoch[lr_decay_count]:
                trainer.set_learning_rate(trainer.learning_rate * lr_decay)
                lr_decay_count += 1

            for i, batch in enumerate(train_data):
                data = gluon.utils.split_and_load(batch[0],
                                                  ctx_list=ctx,
                                                  batch_axis=0)
                label = gluon.utils.split_and_load(batch[1],
                                                   ctx_list=ctx,
                                                   batch_axis=0)
                with ag.record():
                    output = [net(X) for X in data]
                    loss = [loss_fn(yhat, y) for yhat, y in zip(output, label)]
                for l in loss:
                    l.backward()
                trainer.step(batch_size)
                train_loss += sum([l.sum().asscalar() for l in loss])

                train_metric.update(label, output)
                name, acc = train_metric.get()
                iteration += 1

            train_loss /= batch_size * num_batch
            name, acc = train_metric.get()
            name, val_acc = test(ctx, val_data)
            train_history.update([1 - acc, 1 - val_acc])
            train_history.plot(save_path='%s/%s_history.png' %
                               (plot_path, model_name))

            if val_acc > best_val_score:
                best_val_score = val_acc
                # net.save_parameters('%s/%.4f-cifar-%s-%d-best.params'%(save_dir, best_val_score, model_name, epoch))
                pass

            logging.info('[Epoch %d] train=%f val=%f loss=%f time: %f' %
                         (epoch, acc, val_acc, train_loss, time.time() - tic))

            if save_period and save_dir and (epoch + 1) % save_period == 0:
                # net.save_parameters('%s/cifar10-%s-%d.params'%(save_dir, model_name, epoch))
                pass

            if epoch == epochs - 1:
                with open(
                        opt.dataset + '_' + host_name + '_GPU_' + opt.gpus +
                        '_best_Acc.log', 'a') as f:
                    f.write('best Acc: {:.4f}\n'.format(best_val_score))

        print("best_val_score: ", best_val_score)
        if save_period and save_dir:
            # net.save_parameters('%s/cifar10-%s-%d.params'%(save_dir, model_name, epochs-1))
            pass

    if opt.mode == 'hybrid':
        net.hybridize()
    train(opt.num_epochs, context)
def main():
    opt = parse_args()

    batch_size = opt.batch_size
    classes = 10

    num_gpus = opt.num_gpus
    batch_size *= max(1, num_gpus)
    context = [mx.gpu(i)
               for i in range(num_gpus)] if num_gpus > 0 else [mx.cpu()]
    num_workers = opt.num_workers

    lr_decay = opt.lr_decay
    lr_decay_epoch = [int(i) for i in opt.lr_decay_epoch.split(',')] + [np.inf]

    model_name = opt.model
    if model_name.startswith('cifar_wideresnet'):
        kwargs = {'classes': classes, 'drop_rate': opt.drop_rate}
    else:
        kwargs = {'classes': classes}
    net = get_model(model_name, **kwargs)
    if opt.resume_from:
        net.load_parameters(opt.resume_from, ctx=context)
    optimizer = 'nag'

    save_period = opt.save_period
    if opt.save_dir and save_period:
        save_dir = opt.save_dir
        makedirs(save_dir)
    else:
        save_dir = ''
        save_period = 0

    plot_path = opt.save_plot_dir

    logging.basicConfig(level=logging.INFO)
    logging.info(opt)

    transform_train = transforms.Compose([
        gcv_transforms.RandomCrop(32, pad=4),
        transforms.RandomFlipLeftRight(),
        transforms.ToTensor(),
        transforms.Normalize([0.4914, 0.4822, 0.4465],
                             [0.2023, 0.1994, 0.2010])
    ])

    transform_test = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize([0.4914, 0.4822, 0.4465],
                             [0.2023, 0.1994, 0.2010])
    ])

    def test(ctx, val_data):
        metric = mx.metric.Accuracy()
        for i, batch in enumerate(val_data):
            data = gluon.utils.split_and_load(batch[0],
                                              ctx_list=ctx,
                                              batch_axis=0)
            label = gluon.utils.split_and_load(batch[1],
                                               ctx_list=ctx,
                                               batch_axis=0)
            outputs = [net(X) for X in data]
            metric.update(label, outputs)
        return metric.get()

    def train(epochs, ctx):
        if isinstance(ctx, mx.Context):
            ctx = [ctx]
        net.initialize(mx.init.Xavier(), ctx=ctx)

        train_data = gluon.data.DataLoader(gluon.data.vision.CIFAR10(
            train=True).take(256).transform_first(transform_train),
                                           batch_size=batch_size,
                                           shuffle=True,
                                           last_batch='discard',
                                           num_workers=num_workers)

        val_data = gluon.data.DataLoader(gluon.data.vision.CIFAR10(
            train=False).take(64).transform_first(transform_test),
                                         batch_size=batch_size,
                                         shuffle=False,
                                         num_workers=num_workers)

        trainer = gluon.Trainer(net.collect_params(), optimizer, {
            'learning_rate': opt.lr,
            'wd': opt.wd,
            'momentum': opt.momentum
        })
        metric = mx.metric.Accuracy()
        train_metric = mx.metric.Accuracy()
        loss_fn = gluon.loss.SoftmaxCrossEntropyLoss()

        iteration = 0
        lr_decay_count = 0

        best_val_score = 0

        for epoch in range(epochs):
            tic = time.time()
            train_metric.reset()
            metric.reset()
            train_loss = 0
            num_batch = len(train_data)
            alpha = 1

            if epoch == lr_decay_epoch[lr_decay_count]:
                trainer.set_learning_rate(trainer.learning_rate * lr_decay)
                lr_decay_count += 1

            for i, batch in enumerate(train_data):
                data = gluon.utils.split_and_load(batch[0],
                                                  ctx_list=ctx,
                                                  batch_axis=0)
                label = gluon.utils.split_and_load(batch[1],
                                                   ctx_list=ctx,
                                                   batch_axis=0)

                with ag.record():
                    output = [net(X) for X in data]
                    loss = [loss_fn(yhat, y) for yhat, y in zip(output, label)]
                for l in loss:
                    l.backward()
                trainer.step(batch_size)
                train_loss += sum([l.sum().asscalar() for l in loss])

                train_metric.update(label, output)
                name, acc = train_metric.get()
                iteration += 1

            train_loss /= batch_size * num_batch
            name, acc = train_metric.get()
            name, val_acc = test(ctx, val_data)

            if val_acc > best_val_score:
                best_val_score = val_acc
                net.save_parameters(
                    '%s/%.4f-cifar-%s-%d-best.params' %
                    (save_dir, best_val_score, model_name, epoch))

            logging.info('[Epoch %d] train=%f val=%f loss=%f time: %f' %
                         (epoch, acc, val_acc, train_loss, time.time() - tic))

            if save_period and save_dir and (epoch + 1) % save_period == 0:
                net.save_parameters('%s/cifar10-%s-%d.params' %
                                    (save_dir, model_name, epoch))

        if save_period and save_dir:
            net.save_parameters('%s/cifar10-%s-%d.params' %
                                (save_dir, model_name, epochs - 1))

    if opt.mode == 'hybrid':
        net.hybridize()
    train(opt.num_epochs, context)
def train_cifar10(args, config, reporter):
    vars(args).update(config)
    np.random.seed(args.seed)
    random.seed(args.seed)
    mx.random.seed(args.seed)

    # Set Hyper-params
    batch_size = args.batch_size * max(args.num_gpus, 1)
    ctx = [mx.gpu(i)
           for i in range(args.num_gpus)] if args.num_gpus > 0 else [mx.cpu()]

    # Define DataLoader
    transform_train = transforms.Compose([
        gcv_transforms.RandomCrop(32, pad=4),
        transforms.RandomFlipLeftRight(),
        transforms.ToTensor(),
        transforms.Normalize([0.4914, 0.4822, 0.4465],
                             [0.2023, 0.1994, 0.2010])
    ])

    transform_test = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize([0.4914, 0.4822, 0.4465],
                             [0.2023, 0.1994, 0.2010])
    ])

    train_data = gluon.data.DataLoader(
        gluon.data.vision.CIFAR10(train=True).transform_first(transform_train),
        batch_size=batch_size,
        shuffle=True,
        last_batch="discard",
        num_workers=args.num_workers)

    test_data = gluon.data.DataLoader(
        gluon.data.vision.CIFAR10(train=False).transform_first(transform_test),
        batch_size=batch_size,
        shuffle=False,
        num_workers=args.num_workers)

    # Load model architecture and Initialize the net with pretrained model
    finetune_net = get_model(args.model, pretrained=True)
    with finetune_net.name_scope():
        finetune_net.fc = nn.Dense(args.classes)
    finetune_net.fc.initialize(init.Xavier(), ctx=ctx)
    finetune_net.collect_params().reset_ctx(ctx)
    finetune_net.hybridize()

    # Define trainer
    trainer = gluon.Trainer(finetune_net.collect_params(), "sgd", {
        "learning_rate": args.lr,
        "momentum": args.momentum,
        "wd": args.wd
    })
    L = gluon.loss.SoftmaxCrossEntropyLoss()
    metric = mx.metric.Accuracy()

    def train(epoch):
        for i, batch in enumerate(train_data):
            data = gluon.utils.split_and_load(batch[0],
                                              ctx_list=ctx,
                                              batch_axis=0,
                                              even_split=False)
            label = gluon.utils.split_and_load(batch[1],
                                               ctx_list=ctx,
                                               batch_axis=0,
                                               even_split=False)
            with ag.record():
                outputs = [finetune_net(X) for X in data]
                loss = [L(yhat, y) for yhat, y in zip(outputs, label)]
            for l in loss:
                l.backward()

            trainer.step(batch_size)
        mx.nd.waitall()

    def test():
        test_loss = 0
        for i, batch in enumerate(test_data):
            data = gluon.utils.split_and_load(batch[0],
                                              ctx_list=ctx,
                                              batch_axis=0,
                                              even_split=False)
            label = gluon.utils.split_and_load(batch[1],
                                               ctx_list=ctx,
                                               batch_axis=0,
                                               even_split=False)
            outputs = [finetune_net(X) for X in data]
            loss = [L(yhat, y) for yhat, y in zip(outputs, label)]

            test_loss += sum(l.mean().asscalar() for l in loss) / len(loss)
            metric.update(label, outputs)

        _, test_acc = metric.get()
        test_loss /= len(test_data)
        reporter(mean_loss=test_loss, mean_accuracy=test_acc)

    for epoch in range(1, args.epochs + 1):
        train(epoch)
        test()