def __init__(self, *args):
        lr, lr_mode, num_epochs, num_batches, lr_decay_epoch, \
        lr_decay, lr_decay_period, warmup_epochs, warmup_lr= args
        self._num_batches = num_batches
        self._lr_decay = lr_decay
        self._lr_decay_period = lr_decay_period
        self._warmup_epochs = warmup_epochs
        self._warmup_lr = warmup_lr
        self._num_epochs = num_epochs
        self._lr = lr
        self._lr_mode = lr_mode
        if lr_decay_period > 0:
            lr_decay_epoch = list(
                range(lr_decay_period, num_epochs, lr_decay_period))
        else:
            lr_decay_epoch = [int(i) for i in lr_decay_epoch.split(',')]
        self._lr_decay_epoch = [e - warmup_epochs for e in lr_decay_epoch]

        self._lr_scheduler = LRSequential([
            LRScheduler('linear',
                        base_lr=self._warmup_lr,
                        target_lr=lr,
                        nepochs=warmup_epochs,
                        iters_per_epoch=self._num_batches),
            LRScheduler(lr_mode,
                        base_lr=lr,
                        target_lr=0,
                        nepochs=num_epochs - warmup_epochs,
                        iters_per_epoch=self._num_batches,
                        step_epoch=self._lr_decay_epoch,
                        step_factor=lr_decay,
                        power=2)
        ])
Esempio n. 2
0
 def setup_trainer(self, params, train_dataset=None):
     """ Set up optimizer needed for training.
         Network must first be initialized before this.
     """
     optimizer_opts = {'learning_rate': params['learning_rate'], 'wd': params['weight_decay'], 'clip_gradient': params['clip_gradient']}
     if 'lr_scheduler' in params and params['lr_scheduler'] is not None:
         if train_dataset is None:
             raise ValueError("train_dataset cannot be None when lr_scheduler is specified.")
         base_lr = params.get('base_lr', 1e-6)
         target_lr = params.get('target_lr', 1.0)
         warmup_epochs = params.get('warmup_epochs', 10)
         lr_decay = params.get('lr_decay', 0.1)
         lr_mode = params['lr_scheduler']
         num_batches = train_dataset.num_examples // params['batch_size']
         lr_decay_epoch = [max(warmup_epochs, int(params['num_epochs']/3)), max(warmup_epochs+1, int(params['num_epochs']/2)),
                           max(warmup_epochs+2, int(2*params['num_epochs']/3))]
         lr_scheduler = LRSequential([
             LRScheduler('linear', base_lr=base_lr, target_lr=target_lr, nepochs=warmup_epochs, iters_per_epoch=num_batches),
             LRScheduler(lr_mode, base_lr=target_lr, target_lr=base_lr, nepochs=params['num_epochs'] - warmup_epochs,
                         iters_per_epoch=num_batches, step_epoch=lr_decay_epoch, step_factor=lr_decay, power=2)
         ])
         optimizer_opts['lr_scheduler'] = lr_scheduler
     if params['optimizer'] == 'sgd':
         if 'momentum' in params:
             optimizer_opts['momentum'] = params['momentum']
         optimizer = gluon.Trainer(self.model.collect_params(), 'sgd', optimizer_opts)
     elif params['optimizer'] == 'adam':  # TODO: Can we try AdamW?
         optimizer = gluon.Trainer(self.model.collect_params(), 'adam', optimizer_opts)
     else:
         raise ValueError("Unknown optimizer specified: %s" % params['optimizer'])
     return optimizer
Esempio n. 3
0
def getOptimizer(net, size, args):
    lr_scheduler = LRSequential([
        LRScheduler('linear',
                    base_lr=0,
                    target_lr=args.lr,
                    nepochs=warmup_epochs,
                    iters_per_epoch=size),
        LRScheduler(mode='poly',
                    base_lr=args.lr,
                    nepochs=args.epochs - warmup_epochs,
                    iters_per_epoch=size,
                    power=0.9)
    ])
    kvstore = 'device'
    kv = mx.kv.create(kvstore)

    weight_decay = 1e-4
    momentum = 0.9
    optimizer_params = {
        'lr_scheduler': lr_scheduler,
        'wd': weight_decay,
        'momentum': momentum,
        'learning_rate': args.lr
    }
    optimizer = gluon.Trainer(net.module.collect_params(),
                              optimizer_name,
                              optimizer_params,
                              kvstore=kv)
    return optimizer
Esempio n. 4
0
def get_lr_scheduler(args):
    if args.lr_decay_period > 0:
        lr_decay_epoch = list(range(args.lr_decay_period, args.epochs, args.lr_decay_period))
    else:
        lr_decay_epoch = [int(i) for i in args.lr_decay_epoch.split(',')]
    num_batches = args.num_samples // args.batch_size
    if args.meta_arch == 'yolo3':
        return LRSequential([
            LRScheduler('linear', base_lr=0, target_lr=args.lr,
                        nepochs=args.warmup_epochs, iters_per_epoch=num_batches),
            LRScheduler(args.lr_mode, base_lr=args.lr, nepochs=args.epochs - args.warmup_epochs,
                        iters_per_epoch=num_batches, step_epoch=lr_decay_epoch,
                        step_factor=args.lr_decay, power=2)])
    elif args.meta_arch == 'faster_rcnn':
        return LRSequential([
            LRScheduler('linear', base_lr=args.lr * args.warmup_factor, target_lr=args.lr,
                        niters=args.warmup_iters, iters_per_epoch=num_batches),
            LRScheduler(args.lr_mode, base_lr=args.lr, nepochs=args.epochs,
                        iters_per_epoch=num_batches, step_epoch=lr_decay_epoch,
                        step_factor=args.lr_decay, power=2)])
Esempio n. 5
0
def test_composed_method():
    N = 1000
    constant = LRScheduler('constant', base_lr=0, target_lr=1, niters=N)
    linear = LRScheduler('linear', base_lr=1, target_lr=2, niters=N)
    cosine = LRScheduler('cosine', base_lr=3, target_lr=1, niters=N)
    poly = LRScheduler('poly', base_lr=1, target_lr=0, niters=N, power=2)
    # components with niters=0 will be ignored
    null_cosine = LRScheduler('cosine', base_lr=3, target_lr=1, niters=0)
    null_poly = LRScheduler('cosine', base_lr=3, target_lr=1, niters=0)
    step = LRScheduler('step',
                       base_lr=1,
                       target_lr=0,
                       niters=N,
                       step_iter=[100, 500],
                       step_factor=0.1)
    arr = LRSequential(
        [constant, null_cosine, linear, cosine, null_poly, poly, step])
    # constant
    for i in range(N):
        compare(arr, i, 0)
    # linear
    for i in range(N, 2 * N):
        expect_linear = 2 + (1 - 2) * (1 - (i - N) / (N - 1))
        compare(arr, i, expect_linear)
    # cosine
    for i in range(2 * N, 3 * N):
        expect_cosine = 1 + (3 - 1) * ((1 + cos(pi * (i - 2 * N) /
                                                (N - 1))) / 2)
        compare(arr, i, expect_cosine)
    # poly
    for i in range(3 * N, 4 * N):
        expect_poly = 0 + (1 - 0) * (pow(1 - (i - 3 * N) / (N - 1), 2))
        compare(arr, i, expect_poly)
    for i in range(4 * N, 5 * N):
        if i - 4 * N < 100:
            expect_step = 1
        elif i - 4 * N < 500:
            expect_step = 0.1
        else:
            expect_step = 0.01
        compare(arr, i, expect_step)
    # out-of-bound index
    compare(arr, 10 * N, 0.01)
    compare(arr, -1, 0)
Esempio n. 6
0
def _get_lr_scheduler(opt):
    lr_factor = opt.lr_factor
    lr_steps = [int(i) for i in opt.lr_steps.split(',')]
    lr_steps = [e - opt.warmup_epochs for e in lr_steps]
    num_batches = get_num_examples(opt.dataset) // batch_size

    lr_scheduler = LRSequential([
        LRScheduler('linear',
                    base_lr=0,
                    target_lr=opt.lr,
                    nepochs=opt.warmup_epochs,
                    iters_per_epoch=num_batches),
        LRScheduler(opt.lr_mode,
                    base_lr=opt.lr,
                    target_lr=0,
                    nepochs=opt.epochs - opt.warmup_epochs,
                    iters_per_epoch=num_batches,
                    step_epoch=lr_steps,
                    step_factor=lr_factor,
                    power=2)
    ])
    return lr_scheduler
Esempio n. 7
0
def main():
    opt = parse_args()

    filehandler = logging.FileHandler(opt.logging_file)
    streamhandler = logging.StreamHandler()

    logger = logging.getLogger('')
    logger.setLevel(logging.INFO)
    logger.addHandler(filehandler)
    logger.addHandler(streamhandler)

    logger.info(opt)

    batch_size = opt.batch_size
    classes = 1000
    num_training_samples = 1281167

    num_gpus = opt.num_gpus
    batch_size *= max(1, num_gpus)
    context = [mx.gpu(i)
               for i in range(num_gpus)] if num_gpus > 0 else [mx.cpu()]
    num_workers = opt.num_workers

    lr_decay = opt.lr_decay
    lr_decay_period = opt.lr_decay_period
    if opt.lr_decay_period > 0:
        lr_decay_epoch = list(
            range(lr_decay_period, opt.num_epochs, lr_decay_period))
    else:
        lr_decay_epoch = [int(i) for i in opt.lr_decay_epoch.split(',')]
    lr_decay_epoch = [e - opt.warmup_epochs for e in lr_decay_epoch]
    num_batches = num_training_samples // batch_size

    lr_scheduler = LRSequential([
        LRScheduler('linear',
                    base_lr=0,
                    target_lr=opt.lr,
                    nepochs=opt.warmup_epochs,
                    iters_per_epoch=num_batches),
        LRScheduler(opt.lr_mode,
                    base_lr=opt.lr,
                    target_lr=0,
                    nepochs=opt.num_epochs - opt.warmup_epochs,
                    iters_per_epoch=num_batches,
                    step_epoch=lr_decay_epoch,
                    step_factor=lr_decay,
                    power=2)
    ])

    optimizer = 'nag'
    optimizer_params = {
        'wd': opt.wd,
        'momentum': opt.momentum,
        'lr_scheduler': lr_scheduler
    }
    if opt.dtype != 'float32':
        optimizer_params['multi_precision'] = True

    model_name = opt.model
    if 'lite' in model_name:
        net, input_size = get_efficientnet_lite(model_name,
                                                num_classes=classes)
    else:
        net, input_size = get_efficientnet(model_name, num_classes=classes)
    assert input_size == opt.input_size
    net.cast(opt.dtype)
    if opt.resume_params is not '':
        net.load_parameters(opt.resume_params, ctx=context)

    # Two functions for reading data from record file or raw images
    def get_data_rec(rec_train, rec_train_idx, rec_val, rec_val_idx,
                     batch_size, num_workers):
        rec_train = os.path.expanduser(rec_train)
        rec_train_idx = os.path.expanduser(rec_train_idx)
        rec_val = os.path.expanduser(rec_val)
        rec_val_idx = os.path.expanduser(rec_val_idx)
        jitter_param = 0.4
        lighting_param = 0.1
        input_size = opt.input_size
        crop_ratio = opt.crop_ratio if opt.crop_ratio > 0 else 0.875
        resize = int(math.ceil(input_size / crop_ratio))
        mean_rgb = [123.68, 116.779, 103.939]
        std_rgb = [58.393, 57.12, 57.375]

        def batch_fn(batch, ctx):
            data = gluon.utils.split_and_load(batch.data[0],
                                              ctx_list=ctx,
                                              batch_axis=0)
            label = gluon.utils.split_and_load(batch.label[0],
                                               ctx_list=ctx,
                                               batch_axis=0)
            return data, label

        train_data = mx.io.ImageRecordIter(
            path_imgrec=rec_train,
            path_imgidx=rec_train_idx,
            preprocess_threads=num_workers,
            shuffle=True,
            batch_size=batch_size,
            data_shape=(3, input_size, input_size),
            mean_r=mean_rgb[0],
            mean_g=mean_rgb[1],
            mean_b=mean_rgb[2],
            std_r=std_rgb[0],
            std_g=std_rgb[1],
            std_b=std_rgb[2],
            rand_mirror=True,
            random_resized_crop=True,
            max_aspect_ratio=4. / 3.,
            min_aspect_ratio=3. / 4.,
            max_random_area=1,
            min_random_area=0.08,
            brightness=jitter_param,
            saturation=jitter_param,
            contrast=jitter_param,
            pca_noise=lighting_param,
        )
        val_data = mx.io.ImageRecordIter(
            path_imgrec=rec_val,
            path_imgidx=rec_val_idx,
            preprocess_threads=num_workers,
            shuffle=False,
            batch_size=batch_size,
            resize=resize,
            data_shape=(3, input_size, input_size),
            mean_r=mean_rgb[0],
            mean_g=mean_rgb[1],
            mean_b=mean_rgb[2],
            std_r=std_rgb[0],
            std_g=std_rgb[1],
            std_b=std_rgb[2],
        )
        return train_data, val_data, batch_fn

    def get_data_loader(data_dir, batch_size, num_workers):
        normalize = transforms.Normalize([0.485, 0.456, 0.406],
                                         [0.229, 0.224, 0.225])
        jitter_param = 0.4
        lighting_param = 0.1
        input_size = opt.input_size
        crop_ratio = opt.crop_ratio if opt.crop_ratio > 0 else 0.875
        resize = int(math.ceil(input_size / crop_ratio))

        def batch_fn(batch, ctx):
            data = gluon.utils.split_and_load(batch[0],
                                              ctx_list=ctx,
                                              batch_axis=0)
            label = gluon.utils.split_and_load(batch[1],
                                               ctx_list=ctx,
                                               batch_axis=0)
            return data, label

        transform_train = transforms.Compose([
            transforms.RandomResizedCrop(input_size),
            transforms.RandomFlipLeftRight(),
            transforms.RandomColorJitter(brightness=jitter_param,
                                         contrast=jitter_param,
                                         saturation=jitter_param),
            transforms.RandomLighting(lighting_param),
            transforms.ToTensor(), normalize
        ])
        transform_test = transforms.Compose([
            transforms.Resize(resize, keep_ratio=True),
            transforms.CenterCrop(input_size),
            transforms.ToTensor(), normalize
        ])

        train_data = gluon.data.DataLoader(imagenet.classification.ImageNet(
            data_dir, train=True).transform_first(transform_train),
                                           batch_size=batch_size,
                                           shuffle=True,
                                           last_batch='discard',
                                           num_workers=num_workers)
        val_data = gluon.data.DataLoader(imagenet.classification.ImageNet(
            data_dir, train=False).transform_first(transform_test),
                                         batch_size=batch_size,
                                         shuffle=False,
                                         num_workers=num_workers)

        return train_data, val_data, batch_fn

    if opt.use_rec:
        train_data, val_data, batch_fn = get_data_rec(opt.rec_train,
                                                      opt.rec_train_idx,
                                                      opt.rec_val,
                                                      opt.rec_val_idx,
                                                      batch_size, num_workers)
    else:
        train_data, val_data, batch_fn = get_data_loader(
            opt.data_dir, batch_size, num_workers)

    train_metric = mx.metric.Accuracy()
    acc_top1 = mx.metric.Accuracy()
    acc_top5 = mx.metric.TopKAccuracy(5)

    save_frequency = opt.save_frequency
    if opt.save_dir and save_frequency:
        save_dir = opt.save_dir
        makedirs(save_dir)
    else:
        save_dir = ''
        save_frequency = 0

    def test(ctx, val_data):
        if opt.use_rec:
            val_data.reset()
        acc_top1.reset()
        acc_top5.reset()
        for i, batch in enumerate(val_data):
            data, label = batch_fn(batch, ctx)
            outputs = [net(X.astype(opt.dtype, copy=False)) for X in data]
            acc_top1.update(label, outputs)
            acc_top5.update(label, outputs)

        _, top1 = acc_top1.get()
        _, top5 = acc_top5.get()
        return (1 - top1, 1 - top5)

    def train(ctx):
        if isinstance(ctx, mx.Context):
            ctx = [ctx]
        if opt.resume_params is '':
            net.initialize(mx.init.MSRAPrelu(), ctx=ctx)

        if opt.no_wd:
            for k, v in net.collect_params('.*beta|.*gamma|.*bias').items():
                v.wd_mult = 0.0

        trainer = gluon.Trainer(net.collect_params(), optimizer,
                                optimizer_params)
        if opt.resume_states is not '':
            trainer.load_states(opt.resume_states)

        L = gluon.loss.SoftmaxCrossEntropyLoss(sparse_label=True)

        best_val_score = 1

        for epoch in range(opt.resume_epoch, opt.num_epochs):
            tic = time.time()
            if opt.use_rec:
                train_data.reset()
            train_metric.reset()
            btic = time.time()

            for i, batch in enumerate(train_data):
                data, label = batch_fn(batch, ctx)
                with ag.record():
                    outputs = [
                        net(X.astype(opt.dtype, copy=False)) for X in data
                    ]
                    loss = [
                        L(yhat, y.astype(opt.dtype, copy=False))
                        for yhat, y in zip(outputs, label)
                    ]
                for l in loss:
                    l.backward()
                trainer.step(batch_size)

                train_metric.update(label, outputs)

                if opt.log_interval and not (i + 1) % opt.log_interval:
                    train_metric_name, train_metric_score = train_metric.get()
                    logger.info(
                        'Epoch[%d] Batch [%d]\tSpeed: %f samples/sec\t%s=%f\tlr=%f'
                        % (epoch, i, batch_size * opt.log_interval /
                           (time.time() - btic), train_metric_name,
                           train_metric_score, trainer.learning_rate))
                    btic = time.time()

            train_metric_name, train_metric_score = train_metric.get()
            throughput = int(batch_size * i / (time.time() - tic))

            err_top1_val, err_top5_val = test(ctx, val_data)

            logger.info('[Epoch %d] training: %s=%f' %
                        (epoch, train_metric_name, train_metric_score))
            logger.info('[Epoch %d] speed: %d samples/sec\ttime cost: %f' %
                        (epoch, throughput, time.time() - tic))
            logger.info('[Epoch %d] validation: err-top1=%f err-top5=%f' %
                        (epoch, err_top1_val, err_top5_val))

            if err_top1_val < best_val_score:
                best_val_score = err_top1_val
                net.save_parameters(
                    '%s/%.4f-imagenet-%s-%d-best.params' %
                    (save_dir, best_val_score, model_name, epoch))
                trainer.save_states(
                    '%s/%.4f-imagenet-%s-%d-best.states' %
                    (save_dir, best_val_score, model_name, epoch))

            if save_frequency and save_dir and (epoch +
                                                1) % save_frequency == 0:
                net.save_parameters('%s/imagenet-%s-%d.params' %
                                    (save_dir, model_name, epoch))
                trainer.save_states('%s/imagenet-%s-%d.states' %
                                    (save_dir, model_name, epoch))

        if save_frequency and save_dir:
            net.save_parameters('%s/imagenet-%s-%d.params' %
                                (save_dir, model_name, opt.num_epochs - 1))
            trainer.save_states('%s/imagenet-%s-%d.states' %
                                (save_dir, model_name, opt.num_epochs - 1))

    if opt.mode == 'hybrid':
        net.hybridize(static_alloc=True, static_shape=True)

    train(context)
Esempio n. 8
0
def main():
    opt = parse_args()

    batch_size = opt.batch_size
    if opt.dataset == 'cifar10':
        classes = 10
    elif opt.dataset == 'cifar100':
        classes = 100
    else:
        raise ValueError('Unknown Dataset')

    if len(mx.test_utils.list_gpus()) == 0:
        context = [mx.cpu()]
    else:
        context = [mx.gpu(int(i)) for i in opt.gpus.split(',') if i.strip()]
        context = context if context else [mx.cpu()]
    print("context: ", context)
    num_gpus = len(context)
    batch_size *= max(1, num_gpus)
    num_workers = opt.num_workers

    lr_decay = opt.lr_decay
    lr_decay_period = opt.lr_decay_period
    if opt.lr_decay_period > 0:
        lr_decay_epoch = list(range(lr_decay_period, opt.num_epochs, lr_decay_period))
    else:
        # lr_decay_epoch = [int(i) for i in opt.lr_decay_epoch.split(',')]
        lr_decay_epoch = [int(i) for i in opt.lr_decay_epoch.split(',')] + [np.inf]
    lr_decay_epoch = [e - opt.warmup_epochs for e in lr_decay_epoch]
    num_batches = 50000 // batch_size

    lr_scheduler = LRSequential([
        LRScheduler('linear', base_lr=0, target_lr=opt.lr,
                    nepochs=opt.warmup_epochs, iters_per_epoch=num_batches),
        LRScheduler('cosine', base_lr=opt.lr, target_lr=0,
                    nepochs=opt.num_epochs - opt.warmup_epochs,
                    iters_per_epoch=num_batches,
                    step_epoch=lr_decay_epoch,
                    step_factor=lr_decay, power=2)
    ])

    optimizer = 'nag'
    if opt.cosine:
        optimizer_params = {'wd': opt.wd, 'momentum': opt.momentum, 'lr_scheduler': lr_scheduler}
    else:
        optimizer_params = {'learning_rate': opt.lr, 'wd': opt.wd, 'momentum': opt.momentum}

    layers = [opt.blocks] * 3
    channels = [x * opt.channel_times for x in [16, 16, 32, 64]]
    start_layer = opt.start_layer
    askc_type = opt.askc_type

    if channels[0] == 64:
        cardinality = 32
    elif channels[0] == 128:
        cardinality = 64
    bottleneck_width = 4

    print("model: ", opt.model)
    print("askc_type: ", opt.askc_type)
    print("layers: ", layers)
    print("channels: ", channels)
    print("start_layer: ", start_layer)
    print("classes: ", classes)
    print("deep_stem: ", opt.deep_stem)

    model_prefix = opt.dataset + '-' + askc_type
    model_suffix = '-c-' + str(opt.channel_times) + '-s-' + str(opt.start_layer)
    if opt.model == 'resnet':
        net = CIFARAFFResNet(askc_type=askc_type, start_layer=start_layer, layers=layers,
                             channels=channels, classes=classes, deep_stem=opt.deep_stem)
        model_name = model_prefix + '-resnet-' + str(sum(layers) * 2 + 2) + model_suffix
    elif opt.model == 'resnext':
        net = CIFARAFFResNeXt(askc_type=askc_type, start_layer=start_layer, layers=layers,
                              channels=channels, cardinality=cardinality,
                              bottleneck_width=bottleneck_width, classes=classes,
                              deep_stem=opt.deep_stem, use_se=False)
        model_name = model_prefix + '-resneXt-' + str(sum(layers) * 3 + 2) + '-' + \
                     str(cardinality) + 'x' + str(bottleneck_width) + model_suffix
    else:
        raise ValueError('Unknown opt.model')

    if opt.resume_from:
        net.load_parameters(opt.resume_from, ctx=context, ignore_extra=True)

    save_period = opt.save_period
    if opt.save_dir and save_period:
        save_dir = opt.save_dir
        makedirs(save_dir)
    else:
        save_dir = ''
        save_period = 0

    plot_name = opt.save_plot_dir

    logging_handlers = [logging.StreamHandler()]
    if opt.logging_dir:
        logging_dir = opt.logging_dir
        makedirs(logging_dir)
        logging_handlers.append(logging.FileHandler('%s/train_%s.log' %
                                                    (logging_dir, model_name)))

    logging.basicConfig(level=logging.INFO, handlers=logging_handlers)
    logging.info(opt)

    transform_train = []
    if opt.auto_aug:
        print('Using AutoAugment')
        from autogluon.utils.augment import AugmentationBlock, autoaug_cifar10_policies
        transform_train.append(AugmentationBlock(autoaug_cifar10_policies()))

    transform_train.extend([
        gcv_transforms.RandomCrop(32, pad=4),
        transforms.RandomFlipLeftRight(),
        transforms.ToTensor(),
        transforms.Normalize([0.4914, 0.4822, 0.4465], [0.2023, 0.1994, 0.2010])
    ])
    transform_train = transforms.Compose(transform_train)

    transform_test = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize([0.4914, 0.4822, 0.4465], [0.2023, 0.1994, 0.2010])
    ])

    def label_transform(label, classes):
        ind = label.astype('int')
        res = nd.zeros((ind.shape[0], classes), ctx=label.context)
        res[nd.arange(ind.shape[0], ctx=label.context), ind] = 1
        return res

    def mixup_transform(label, classes, lam=1, eta=0.0):
        if isinstance(label, nd.NDArray):
            label = [label]
        res = []
        for l in label:
            y1 = l.one_hot(classes, on_value=1 - eta + eta/classes, off_value=eta/classes)
            y2 = l[::-1].one_hot(classes, on_value=1 - eta + eta/classes, off_value=eta/classes)
            res.append(lam*y1 + (1-lam)*y2)
        return res

    def smooth(label, classes, eta=0.1):
        if isinstance(label, nd.NDArray):
            label = [label]
        smoothed = []
        for l in label:
            res = l.one_hot(classes, on_value=1 - eta + eta/classes, off_value = eta/classes)
            smoothed.append(res)
        return smoothed

    def test(ctx, val_data):
        metric = mx.metric.Accuracy()
        for i, batch in enumerate(val_data):
            data = gluon.utils.split_and_load(batch[0], ctx_list=ctx, batch_axis=0)
            label = gluon.utils.split_and_load(batch[1], ctx_list=ctx, batch_axis=0)
            outputs = [net(X) for X in data]
            metric.update(label, outputs)
        return metric.get()

    def train(epochs, ctx):
        if isinstance(ctx, mx.Context):
            ctx = [ctx]
        net.initialize(mx.init.Xavier(), ctx=ctx)

        if opt.summary:
            summary(net, mx.nd.zeros((1, 3, 32, 32), ctx=ctx[0]))
            sys.exit()

        if opt.dataset == 'cifar10':
            train_data = gluon.data.DataLoader(
                gluon.data.vision.CIFAR10(train=True).transform_first(transform_train),
                batch_size=batch_size, shuffle=True, last_batch='discard', num_workers=num_workers)
            val_data = gluon.data.DataLoader(
                gluon.data.vision.CIFAR10(train=False).transform_first(transform_test),
                batch_size=batch_size, shuffle=False, num_workers=num_workers)
        elif opt.dataset == 'cifar100':
            train_data = gluon.data.DataLoader(
                gluon.data.vision.CIFAR100(train=True).transform_first(transform_train),
                batch_size=batch_size, shuffle=True, last_batch='discard', num_workers=num_workers)
            val_data = gluon.data.DataLoader(
                gluon.data.vision.CIFAR100(train=False).transform_first(transform_test),
                batch_size=batch_size, shuffle=False, num_workers=num_workers)
        else:
            raise ValueError('Unknown Dataset')

        if opt.no_wd and opt.cosine:
            for k, v in net.collect_params('.*beta|.*gamma|.*bias').items():
                v.wd_mult = 0.0

        trainer = gluon.Trainer(net.collect_params(), optimizer, optimizer_params)

        if opt.label_smoothing or opt.mixup:
            sparse_label_loss = False
        else:
            sparse_label_loss = True

        metric = mx.metric.Accuracy()
        train_metric = mx.metric.RMSE()
        loss_fn = gluon.loss.SoftmaxCrossEntropyLoss(sparse_label=sparse_label_loss)
        train_history = TrainingHistory(['training-error', 'validation-error'])

        iteration = 0
        lr_decay_count = 0

        best_val_score = 0

        for epoch in range(epochs):
            tic = time.time()
            train_metric.reset()
            metric.reset()
            train_loss = 0
            num_batch = len(train_data)

            if not opt.cosine:
                if epoch == lr_decay_epoch[lr_decay_count]:
                    trainer.set_learning_rate(trainer.learning_rate * lr_decay)
                    lr_decay_count += 1

            for i, batch in enumerate(train_data):
                data_1 = gluon.utils.split_and_load(batch[0], ctx_list=ctx, batch_axis=0)
                label_1 = gluon.utils.split_and_load(batch[1], ctx_list=ctx, batch_axis=0)

                if opt.mixup:
                    lam = np.random.beta(opt.mixup_alpha, opt.mixup_alpha)
                    if (epoch >= epochs - opt.mixup_off_epoch) or not opt.mixup:
                        lam = 1

                    data = [lam * X + (1 - lam) * X[::-1] for X in data_1]

                    if opt.label_smoothing:
                        eta = 0.1
                    else:
                        eta = 0.0
                    label = mixup_transform(label_1, classes, lam, eta)

                elif opt.label_smoothing:
                    hard_label = label_1
                    label = smooth(label_1, classes)

                with ag.record():
                    output = [net(X) for X in data]
                    loss = [loss_fn(yhat, y) for yhat, y in zip(output, label)]
                for l in loss:
                    l.backward()
                trainer.step(batch_size)
                train_loss += sum([l.sum().asscalar() for l in loss])

                if opt.mixup:
                    output_softmax = [nd.SoftmaxActivation(out) for out in output]
                    train_metric.update(label, output_softmax)
                else:
                    if opt.label_smoothing:
                        train_metric.update(hard_label, output)
                    else:
                        train_metric.update(label, output)

                name, acc = train_metric.get()
                iteration += 1

            train_loss /= batch_size * num_batch
            name, acc = train_metric.get()
            name, val_acc = test(ctx, val_data)
            train_history.update([acc, 1 - val_acc])
            train_history.plot(save_path='%s/%s_history.png' % (plot_name, model_name))

            if val_acc > best_val_score:
                best_val_score = val_acc
                net.save_parameters('%s/%.4f-%s-best.params' %
                                    (save_dir, best_val_score, model_name))

            name, val_acc = test(ctx, val_data)
            logging.info('[Epoch %d] train=%f val=%f loss=%f lr: %f time: %f' %
                         (epoch, acc, val_acc, train_loss, trainer.learning_rate,
                          time.time() - tic))

        host_name = socket.gethostname()
        with open(opt.dataset + '_' + host_name + '_GPU_' + opt.gpus + '_best_Acc.log', 'a') as f:
            f.write('best Acc: {:.4f}\n'.format(best_val_score))
        print("best_val_score: ", best_val_score)

    if not opt.summary:
        if opt.mode == 'hybrid':
            net.hybridize()
    train(opt.num_epochs, context)
Esempio n. 9
0
def main():
    def batch_func(batch, ctx):
        data = gluon.utils.split_and_load(batch.data[0], ctx_list=ctx, batch_axis=0)
        label = gluon.utils.split_and_load(batch.label[0], ctx_list=ctx, batch_axis=0)
        return data, label

    opt = parse_args()

    filehandler = logging.FileHandler(opt.logging_file)
    streamhandler = logging.StreamHandler()

    logger = logging.getLogger('')
    logger.setLevel(logging.INFO)
    logger.addHandler(filehandler)
    logger.addHandler(streamhandler)
    if opt.data_backend[0:5] == 'dali-':
        opt = parse_args(logger)   # Adding DALI parameters

    if opt.data_backend == 'dali-gpu' and opt.num_gpus == 0:
        stream = os.popen('nvidia-smi -L | wc -l')
        opt.num_gpus = int(stream.read())
        logger.info("When '--data-backend' is equal to 'dali-gpu', then `--num-gpus` should NOT be 0\n" \
                    "For now '--num-gpus' will be set to the number of GPUs installed: %d" % opt.num_gpus)

    logger.info(opt)

    classes = 1000
    num_gpus = opt.num_gpus
    batch_size = opt.batch_size * max(1, num_gpus)
    context = [mx.gpu(i) for i in range(num_gpus)] if num_gpus > 0 else [mx.cpu()]
    opt.gpus = [i for i in range(num_gpus)]

    lr_decay = opt.lr_decay
    lr_decay_period = opt.lr_decay_period
    if opt.lr_decay_period > 0:
        lr_decay_epoch = list(range(lr_decay_period, opt.num_epochs, lr_decay_period))
    else:
        lr_decay_epoch = [int(i) for i in opt.lr_decay_epoch.split(',')]
    lr_decay_epoch = [e - opt.warmup_epochs for e in lr_decay_epoch]
    num_batches = opt.num_training_samples // batch_size

    lr_scheduler = LRSequential([
        LRScheduler('linear', base_lr=0, target_lr=opt.lr,
                    nepochs=opt.warmup_epochs, iters_per_epoch=num_batches),
        LRScheduler(opt.lr_mode, base_lr=opt.lr, target_lr=0,
                    nepochs=opt.num_epochs - opt.warmup_epochs,
                    iters_per_epoch=num_batches,
                    step_epoch=lr_decay_epoch,
                    step_factor=lr_decay, power=2)
    ])

    model_name = opt.model

    kwargs = {'ctx': context, 'pretrained': opt.use_pretrained, 'classes': classes}
    if opt.use_gn:
        kwargs['norm_layer'] = gcv.nn.GroupNorm
    if model_name.startswith('vgg'):
        kwargs['batch_norm'] = opt.batch_norm
    elif model_name.startswith('resnext'):
        kwargs['use_se'] = opt.use_se

    if opt.last_gamma:
        kwargs['last_gamma'] = True

    optimizer = 'nag'
    optimizer_params = {'wd': opt.wd, 'momentum': opt.momentum, 'lr_scheduler': lr_scheduler}
    if opt.dtype != 'float32':
        optimizer_params['multi_precision'] = True

    net = get_model(model_name, **kwargs)
    net.cast(opt.dtype)
    if opt.resume_params != '':
        net.load_parameters(opt.resume_params, ctx = context)

    # teacher model for distillation training
    distillation = opt.teacher is not None and opt.hard_weight < 1.0
    if distillation:
        teacher = get_model(opt.teacher, pretrained=True, classes=classes, ctx=context)
        teacher.cast(opt.dtype)

    # Two functions for reading data from record file or raw images
    def get_data_rec(args):
        rec_train = os.path.expanduser(args.rec_train)
        rec_train_idx = os.path.expanduser(args.rec_train_idx)
        rec_val = os.path.expanduser(args.rec_val)
        rec_val_idx = os.path.expanduser(args.rec_val_idx)
        num_gpus = args.num_gpus
        batch_size = args.batch_size * max(1, num_gpus)

        jitter_param = 0.4
        lighting_param = 0.1
        input_size = opt.input_size
        crop_ratio = opt.crop_ratio if opt.crop_ratio > 0 else 0.875
        resize = int(math.ceil(input_size / crop_ratio))

        train_data = mx.io.ImageRecordIter(
            path_imgrec         = rec_train,
            path_imgidx         = rec_train_idx,
            preprocess_threads  = args.num_workers,
            shuffle             = True,
            batch_size          = batch_size,
            data_shape          = (3, input_size, input_size),
            mean_r              = args.rgb_mean[0],
            mean_g              = args.rgb_mean[1],
            mean_b              = args.rgb_mean[2],
            std_r               = args.rgb_std[0],
            std_g               = args.rgb_std[1],
            std_b               = args.rgb_std[2],
            rand_mirror         = True,
            random_resized_crop = True,
            max_aspect_ratio    = 4. / 3.,
            min_aspect_ratio    = 3. / 4.,
            max_random_area     = 1,
            min_random_area     = 0.08,
            brightness          = jitter_param,
            saturation          = jitter_param,
            contrast            = jitter_param,
            pca_noise           = lighting_param,
        )
        val_data = mx.io.ImageRecordIter(
            path_imgrec         = rec_val,
            path_imgidx         = rec_val_idx,
            preprocess_threads  = args.num_workers,
            shuffle             = False,
            batch_size          = batch_size,

            resize              = resize,
            data_shape          = (3, input_size, input_size),
            mean_r              = args.rgb_mean[0],
            mean_g              = args.rgb_mean[1],
            mean_b              = args.rgb_mean[2],
            std_r               = args.rgb_std[0],
            std_g               = args.rgb_std[1],
            std_b               = args.rgb_std[2],
        )
        return train_data, val_data, batch_func

    def get_data_rec_transfomed(args):
        data_dir = args.data_dir
        num_workers = args.num_workers
        batch_size = args.batch_size * max(1, args.num_gpus)
        normalize = transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        jitter_param = 0.4
        lighting_param = 0.1
        input_size = opt.input_size
        crop_ratio = opt.crop_ratio if opt.crop_ratio > 0 else 0.875
        resize = int(math.ceil(input_size / crop_ratio))

        def batch_fn(batch, ctx):
            data = gluon.utils.split_and_load(batch[0], ctx_list=ctx, batch_axis=0)
            label = gluon.utils.split_and_load(batch[1], ctx_list=ctx, batch_axis=0)
            return data, label

        transform_train = transforms.Compose([
            transforms.RandomResizedCrop(input_size),
            transforms.RandomFlipLeftRight(),
            transforms.RandomColorJitter(brightness=jitter_param, contrast=jitter_param,
                                        saturation=jitter_param),
            transforms.RandomLighting(lighting_param),
            transforms.ToTensor(),
            normalize
        ])
        transform_test = transforms.Compose([
            transforms.Resize(resize, keep_ratio=True),
            transforms.CenterCrop(input_size),
            transforms.ToTensor(),
            normalize
        ])

        train_data = gluon.data.DataLoader(
            imagenet.classification.ImageNet(data_dir, train=True).transform_first(transform_train),
            batch_size=batch_size, shuffle=True, last_batch='discard', num_workers=num_workers)
        val_data = gluon.data.DataLoader(
            imagenet.classification.ImageNet(data_dir, train=False).transform_first(transform_test),
            batch_size=batch_size, shuffle=False, num_workers=num_workers)

        return train_data, val_data, batch_fn

    def get_data_loader(args):
        if args.data_backend == 'mxnet':
            return get_data_rec if args.use_rec else get_data_rec_transfomed

        # Check if DALI is installed:
        if args.data_backend[0:5] == 'dali-':
            import dali
        if args.data_backend == 'dali-gpu':
            return (lambda *args, **kwargs: dali.get_rec_iter(*args, **kwargs, batch_fn=batch_func, dali_cpu=False))
        if args.data_backend == 'dali-cpu':
            return (lambda *args, **kwargs: dali.get_rec_iter(*args, **kwargs, batch_fn=batch_func, dali_cpu=True))
        raise ValueError('Wrong data backend')


    train_data, val_data, batch_fn = get_data_loader(opt)(opt)
    train_metric = mx.metric.RMSE() if opt.mixup else mx.metric.Accuracy()
    acc_top1 = mx.metric.Accuracy()
    acc_top5 = mx.metric.TopKAccuracy(5)

    save_frequency = opt.save_frequency
    if opt.save_dir and save_frequency:
        save_dir = opt.save_dir
        makedirs(save_dir)
    else:
        save_dir = ''
        save_frequency = 0

    def mixup_transform(label, classes, lam=1, eta=0.0):
        if isinstance(label, nd.NDArray):
            label = [label]
        res = []
        for l in label:
            y1 = l.one_hot(classes, on_value = 1 - eta + eta/classes, off_value = eta/classes)
            y2 = l[::-1].one_hot(classes, on_value = 1 - eta + eta/classes, off_value = eta/classes)
            res.append(lam*y1 + (1-lam)*y2)
        return res

    def smooth(label, classes, eta=0.1):
        if isinstance(label, nd.NDArray):
            label = [label]
        smoothed = []
        for l in label:
            res = l.one_hot(classes, on_value = 1 - eta + eta/classes, off_value = eta/classes)
            smoothed.append(res)
        return smoothed

    def test(ctx, val_data, val_batch):
        if opt.use_rec:
            val_data.reset()
        acc_top1.reset()
        acc_top5.reset()
        for i, batch in enumerate(val_data):
            for j in range(val_batch):
                data, label = batch_fn(batch[j], ctx) if type(batch) == list else batch_fn(batch, ctx)
                outputs = [net(X.astype(opt.dtype, copy=False)) for X in data]
                acc_top1.update(label, outputs)
                acc_top5.update(label, outputs)

        _, top1 = acc_top1.get()
        _, top5 = acc_top5.get()
        return (1-top1, 1-top5)

    def train(ctx):
        if isinstance(ctx, mx.Context):
            ctx = [ctx]
        if opt.resume_params == '':
            net.initialize(mx.init.MSRAPrelu(), ctx=ctx)

        if opt.no_wd:
            for k, v in net.collect_params('.*beta|.*gamma|.*bias').items():
                v.wd_mult = 0.0

        trainer = gluon.Trainer(net.collect_params(), optimizer, optimizer_params)
        if opt.resume_states != '':
            trainer.load_states(opt.resume_states)

        sparse_label_loss = not (opt.label_smoothing or opt.mixup)
        if distillation:
            L = gcv.loss.DistillationSoftmaxCrossEntropyLoss(temperature=opt.temperature,
                                                                 hard_weight=opt.hard_weight,
                                                                 sparse_label=sparse_label_loss)
        else:
            L = gluon.loss.SoftmaxCrossEntropyLoss(sparse_label=sparse_label_loss)

        best_val_score = 1
        eta = 0.1 if opt.label_smoothing else 0.0
        start_time = time.time()
        val_batch = len(ctx) if opt.data_backend != 'mxnet' else 1
        for epoch in range(opt.resume_epoch, opt.num_epochs):
            tic = time.time()
            if opt.use_rec:
                train_data.reset()
            train_metric.reset()
            btic = time.time()

            for i, batch in enumerate(train_data):
                for j in range(val_batch):
                    data, label = batch_fn(batch[j], ctx) if type(batch) == list else batch_fn(batch, ctx)
                    if opt.mixup:
                        lam = np.random.beta(opt.mixup_alpha, opt.mixup_alpha)
                        if epoch >= opt.num_epochs - opt.mixup_off_epoch:
                            lam = 1
                        data = [lam*X + (1-lam)*X[::-1] for X in data]
                        label = mixup_transform(label, classes, lam, eta)

                    elif opt.label_smoothing:
                        hard_label = label
                        label = smooth(label, classes)

                    if distillation:
                        teacher_prob = [nd.softmax(teacher(X.astype(opt.dtype, copy=False)) / opt.temperature) \
                                        for X in data]

                    with ag.record():
                        outputs = [net(X.astype(opt.dtype, copy=False)) for X in data]
                        if distillation:
                            loss = [L(yhat.astype('float32', copy=False),
                                      y.astype('float32', copy=False),
                                      p.astype('float32', copy=False)) for yhat, y, p in zip(outputs, label, teacher_prob)]
                        else:
                            loss = [L(yhat, y.astype(opt.dtype, copy=False)) for yhat, y in zip(outputs, label)]
                    for l in loss:
                        l.backward()
                    trainer.step(batch_size)

                    if opt.mixup:
                        output_softmax = [nd.SoftmaxActivation(out.astype('float32', copy=False)) for out in outputs]
                        train_metric.update(label, output_softmax)
                    else:
                        if opt.label_smoothing:
                            train_metric.update(hard_label, outputs)
                        else:
                            train_metric.update(label, outputs)

                if opt.log_interval and not (i+1)%opt.log_interval:
                    train_metric_name, train_metric_score = train_metric.get()
                    logger.info('Epoch[%d] Batch [%d]\tSpeed: %f samples/sec\t%s=%f\tlr=%f'%(
                                epoch, i+1, batch_size*opt.log_interval*val_batch/(time.time()-btic),
                                train_metric_name, train_metric_score, trainer.learning_rate))
                    btic = time.time()

            train_metric_name, train_metric_score = train_metric.get()
            if opt.log_interval and i % opt.log_interval:
                # We did NOT report the speed on the last iteration of the loop. Let's do it now
                logger.info('Epoch[%d] Batch [%d]\tSpeed: %f samples/sec\t%s=%f\tlr=%f'%(
                            epoch, i, batch_size*(i%opt.log_interval)*val_batch/(time.time()-btic),
                            train_metric_name, train_metric_score, trainer.learning_rate))

            epoch_time = time.time() - tic
            throughput = int(batch_size * i * val_batch / epoch_time)
            err_top1_val, err_top5_val = test(ctx, val_data, val_batch)

            logger.info('[Epoch %d] training: %s=%f'%(epoch, train_metric_name, train_metric_score))
            logger.info('[Epoch %d] speed: %d samples/sec\ttime cost: %f'%(epoch, throughput, epoch_time))
            logger.info('[Epoch %d] validation: err-top1=%f err-top5=%f'%(epoch, err_top1_val, err_top5_val))

            if err_top1_val < best_val_score:
                best_val_score = err_top1_val
                net.save_parameters('%s/%.4f-imagenet-%s-%d-best.params'%(save_dir, best_val_score, model_name, epoch))
                trainer.save_states('%s/%.4f-imagenet-%s-%d-best.states'%(save_dir, best_val_score, model_name, epoch))

            if save_frequency and save_dir and (epoch + 1) % save_frequency == 0:
                net.save_parameters('%s/imagenet-%s-%d.params'%(save_dir, model_name, epoch))
                trainer.save_states('%s/imagenet-%s-%d.states'%(save_dir, model_name, epoch))

        if save_frequency and save_dir:
            net.save_parameters('%s/imagenet-%s-%d.params'%(save_dir, model_name, opt.num_epochs-1))
            trainer.save_states('%s/imagenet-%s-%d.states'%(save_dir, model_name, opt.num_epochs-1))

        logger.info('Training time for %d epochs: %f sec.'%(opt.num_epochs, time.time() - start_time))

    if opt.mode == 'hybrid':
        net.hybridize(static_alloc=True, static_shape=True)
        if distillation:
            teacher.hybridize(static_alloc=True, static_shape=True)
    train(context)
print('Load %d training samples.' % len(train_dataset))
train_data = gluon.data.DataLoader(train_dataset, batch_size=batch_size,
                                   shuffle=True, num_workers=num_workers)

################################################################
# Optimizer, Loss and Metric
# --------------------------

lr_decay = 0.1
warmup_epoch = 34
total_epoch = 196
num_batches = len(train_data)
lr_scheduler = LRSequential([
    LRScheduler('linear', base_lr=0.01, target_lr=0.1,
                nepochs=warmup_epoch, iters_per_epoch=num_batches),
    LRScheduler('cosine', base_lr=0.1, target_lr=0,
                nepochs=total_epoch - warmup_epoch,
                iters_per_epoch=num_batches,
                step_factor=lr_decay, power=2)
])

# Stochastic gradient descent
optimizer = 'sgd'
# Set parameters
optimizer_params = {'learning_rate': 0.01, 'wd': 0.0001, 'momentum': 0.9}
optimizer_params['lr_scheduler'] = lr_scheduler

# Define our trainer for net
trainer = gluon.Trainer(net.collect_params(), optimizer, optimizer_params)

################################################################
# In order to optimize our model, we need a loss function.
Esempio n. 11
0
def main():
    opt = parse_args()

    filehandler = logging.FileHandler(opt.logging_file)
    streamhandler = logging.StreamHandler()

    logger = logging.getLogger('')
    logger.setLevel(logging.INFO)
    logger.addHandler(filehandler)
    logger.addHandler(streamhandler)

    logger.info(opt)

    batch_size = opt.batch_size
    classes = 1000
    num_training_samples = 1281167

    num_gpus = opt.num_gpus
    batch_size *= max(1, num_gpus)
    context = [mx.gpu(i) for i in range(num_gpus)] if num_gpus > 0 else [mx.cpu()]
    num_workers = opt.num_workers

    # epoch_start_cs controls that before this epoch, use all channels, while, after this epoch, use channel selection.
    if opt.epoch_start_cs != -1:
        opt.use_all_channels = True

    lr_decay = opt.lr_decay
    lr_decay_period = opt.lr_decay_period
    if opt.lr_decay_period > 0:
        lr_decay_epoch = list(range(lr_decay_period, opt.num_epochs, lr_decay_period))
    else:
        lr_decay_epoch = [int(i) for i in opt.lr_decay_epoch.split(',')]
    lr_decay_epoch = [e - opt.warmup_epochs for e in lr_decay_epoch]
    num_batches = num_training_samples // batch_size // opt.reduced_dataset_scale

    lr_scheduler = LRSequential([
        LRScheduler('linear', base_lr=0, target_lr=opt.lr,
                    nepochs=opt.warmup_epochs, iters_per_epoch=num_batches),
        LRScheduler(opt.lr_mode, base_lr=opt.lr, target_lr=0,
                    nepochs=opt.num_epochs - opt.warmup_epochs,
                    iters_per_epoch=num_batches,
                    step_epoch=lr_decay_epoch,
                    step_factor=lr_decay, power=2)
    ])

    model_name = opt.model

    kwargs = {'ctx': context, 'pretrained': opt.use_pretrained, 'classes': classes}
    if opt.use_gn:
        from gluoncv.nn import GroupNorm
        kwargs['norm_layer'] = GroupNorm
    if model_name.startswith('vgg'):
        kwargs['batch_norm'] = opt.batch_norm
    elif model_name.startswith('resnext'):
        kwargs['use_se'] = opt.use_se

    if opt.last_gamma:
        kwargs['last_gamma'] = True

    optimizer = 'nag'
    optimizer_params = {'wd': opt.wd, 'momentum': opt.momentum, 'lr_scheduler': lr_scheduler}
    if opt.dtype != 'float32':
        optimizer_params['multi_precision'] = True

    if model_name == 'ShuffleNas_fixArch':
        architecture = parse_str_list(opt.block_choices)
        scale_ids = parse_str_list(opt.channel_choices)
        net = get_shufflenas_oneshot(architecture=architecture, n_class=classes, scale_ids=scale_ids, use_se=opt.use_se,
                                     last_conv_after_pooling=opt.last_conv_after_pooling,
                                     channels_layout=opt.channels_layout)
    elif model_name == 'ShuffleNas':
        net = get_shufflenas_oneshot(n_class=classes, use_all_blocks=opt.use_all_blocks, use_se=opt.use_se,
                                     last_conv_after_pooling=opt.last_conv_after_pooling,
                                     channels_layout=opt.channels_layout)
    else:
        net = get_model(model_name, **kwargs)

    net.cast(opt.dtype)
    if opt.resume_params is not '':
        net.load_parameters(opt.resume_params, ctx=context)

    # teacher model for distillation training
    if opt.teacher is not None and opt.hard_weight < 1.0:
        teacher_name = opt.teacher
        teacher = get_model(teacher_name, pretrained=True, classes=classes, ctx=context)
        teacher.cast(opt.dtype)
        distillation = True
    else:
        distillation = False

    # Two functions for reading data from record file or raw images
    def get_data_rec(rec_train, rec_train_idx, rec_val, rec_val_idx, batch_size, num_workers):
        rec_train = os.path.expanduser(rec_train)
        rec_train_idx = os.path.expanduser(rec_train_idx)
        rec_val = os.path.expanduser(rec_val)
        rec_val_idx = os.path.expanduser(rec_val_idx)
        jitter_param = 0.4
        lighting_param = 0.1
        input_size = opt.input_size
        crop_ratio = opt.crop_ratio if opt.crop_ratio > 0 else 0.875
        resize = int(math.ceil(input_size / crop_ratio))
        mean_rgb = [123.68, 116.779, 103.939]
        std_rgb = [58.393, 57.12, 57.375]

        def batch_fn(batch, ctx):
            data = gluon.utils.split_and_load(batch.data[0], ctx_list=ctx, batch_axis=0)
            label = gluon.utils.split_and_load(batch.label[0], ctx_list=ctx, batch_axis=0)
            return data, label

        train_data = mx.io.ImageRecordIter(
            path_imgrec         = rec_train,
            path_imgidx         = rec_train_idx,
            preprocess_threads  = num_workers,
            shuffle             = True,
            batch_size          = batch_size,

            data_shape          = (3, input_size, input_size),
            mean_r              = mean_rgb[0],
            mean_g              = mean_rgb[1],
            mean_b              = mean_rgb[2],
            std_r               = std_rgb[0],
            std_g               = std_rgb[1],
            std_b               = std_rgb[2],
            rand_mirror         = True,
            random_resized_crop = True,
            max_aspect_ratio    = 4. / 3.,
            min_aspect_ratio    = 3. / 4.,
            max_random_area     = 1,
            min_random_area     = 0.08,
            brightness          = jitter_param,
            saturation          = jitter_param,
            contrast            = jitter_param,
            pca_noise           = lighting_param,
        )
        val_data = mx.io.ImageRecordIter(
            path_imgrec         = rec_val,
            path_imgidx         = rec_val_idx,
            preprocess_threads  = num_workers,
            shuffle             = False,
            batch_size          = batch_size,

            resize              = resize,
            data_shape          = (3, input_size, input_size),
            mean_r              = mean_rgb[0],
            mean_g              = mean_rgb[1],
            mean_b              = mean_rgb[2],
            std_r               = std_rgb[0],
            std_g               = std_rgb[1],
            std_b               = std_rgb[2],
        )
        return train_data, val_data, batch_fn

    def get_data_loader(data_dir, batch_size, num_workers):
        normalize = transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        jitter_param = 0.4
        lighting_param = 0.1
        input_size = opt.input_size
        crop_ratio = opt.crop_ratio if opt.crop_ratio > 0 else 0.875
        resize = int(math.ceil(input_size / crop_ratio))

        def batch_fn(batch, ctx):
            data = gluon.utils.split_and_load(batch[0], ctx_list=ctx, batch_axis=0)
            label = gluon.utils.split_and_load(batch[1], ctx_list=ctx, batch_axis=0)
            return data, label

        transform_train = transforms.Compose([
            transforms.RandomResizedCrop(input_size),
            transforms.RandomFlipLeftRight(),
            transforms.RandomColorJitter(brightness=jitter_param, contrast=jitter_param,
                                        saturation=jitter_param),
            transforms.RandomLighting(lighting_param),
            transforms.ToTensor(),
            normalize
        ])
        transform_test = transforms.Compose([
            transforms.Resize(resize, keep_ratio=True),
            transforms.CenterCrop(input_size),
            transforms.ToTensor(),
            normalize
        ])

        train_data = gluon.data.DataLoader(
            imagenet.classification.ImageNet(data_dir, train=True).transform_first(transform_train),
            batch_size=batch_size, shuffle=True, last_batch='discard', num_workers=num_workers)
        val_data = gluon.data.DataLoader(
            imagenet.classification.ImageNet(data_dir, train=False).transform_first(transform_test),
            batch_size=batch_size, shuffle=False, num_workers=num_workers)

        return train_data, val_data, batch_fn

    if opt.use_rec:
        train_data, val_data, batch_fn = get_data_rec(opt.rec_train, opt.rec_train_idx,
                                                    opt.rec_val, opt.rec_val_idx,
                                                    batch_size, num_workers)
    else:
        train_data, val_data, batch_fn = get_data_loader(opt.data_dir, batch_size, num_workers)

    if opt.mixup:
        train_metric = mx.metric.RMSE()
    else:
        train_metric = mx.metric.Accuracy()
    acc_top1 = mx.metric.Accuracy()
    acc_top5 = mx.metric.TopKAccuracy(5)

    save_frequency = opt.save_frequency
    if opt.save_dir and save_frequency:
        save_dir = opt.save_dir
        makedirs(save_dir)
    else:
        save_dir = ''
        save_frequency = 0

    def mixup_transform(label, classes, lam=1, eta=0.0):
        if isinstance(label, nd.NDArray):
            label = [label]
        res = []
        for l in label:
            y1 = l.one_hot(classes, on_value = 1 - eta + eta/classes, off_value = eta/classes)
            y2 = l[::-1].one_hot(classes, on_value = 1 - eta + eta/classes, off_value = eta/classes)
            res.append(lam*y1 + (1-lam)*y2)
        return res

    def smooth(label, classes, eta=0.1):
        if isinstance(label, nd.NDArray):
            label = [label]
        smoothed = []
        for l in label:
            res = l.one_hot(classes, on_value = 1 - eta + eta/classes, off_value = eta/classes)
            smoothed.append(res)
        return smoothed

    def make_divisible(x, divisible_by=8):
        return int(np.ceil(x * 1. / divisible_by) * divisible_by)

    def test(ctx, val_data, epoch):
        if opt.use_rec:
            val_data.reset()
        acc_top1.reset()
        acc_top5.reset()
        for i, batch in enumerate(val_data):
            data, label = batch_fn(batch, ctx)
            if model_name == 'ShuffleNas':
                # For evaluating validation accuracy, random select block and channels.
                block_choices = net.random_block_choices(select_predefined_block=False, dtype=opt.dtype)
                if opt.cs_warm_up:
                    # TODO: edit in the issue, readme and medium article that all channels need to be selected in
                    #       testing.
                    full_channel_mask, channel_choices = net.random_channel_mask(select_all_channels=True,
                                                                                 epoch_after_cs=epoch - opt.epoch_start_cs,
                                                                                 dtype=opt.dtype,
                                                                                 ignore_first_two_cs=opt.ignore_first_two_cs)

                else:
                    full_channel_mask, _ = net.random_channel_mask(select_all_channels=True,
                                                                   dtype=opt.dtype,
                                                                   ignore_first_two_cs=opt.ignore_first_two_cs)

                # TODO: remove debug code
                stage_repeats = net.stage_repeats
                stage_out_channels = net.stage_out_channels
                candidate_scales = net.candidate_scales
                channel_scales = []
                for c in range(len(channel_choices)):
                    # scale_ids = [6, 5, 3, 5, 2, 6, 3, 4, 2, 5, 7, 5, 4, 6, 7, 4, 4, 5, 4, 3]
                    channel_scales.append(candidate_scales[channel_choices[c]])
                channels = [stage_out_channels[0]] * stage_repeats[0] + \
                           [stage_out_channels[1]] * stage_repeats[1] + \
                           [stage_out_channels[2]] * stage_repeats[2] + \
                           [stage_out_channels[3]] * stage_repeats[3]
                channels = [make_divisible(channel // 2 * channel_scales[j]) for (j, channel) in
                            enumerate(channels)]
                if i == 0:
                    print("Val batch channel choices: {}".format(channel_choices))
                    print("Val batch channels: {}".format(channels))
                # TODO: end of debug code

                outputs = [net(X.astype(opt.dtype, copy=False), block_choices, full_channel_mask) for X in data]
            else:
                outputs = [net(X.astype(opt.dtype, copy=False)) for X in data]
            acc_top1.update(label, outputs)
            acc_top5.update(label, outputs)

        _, top1 = acc_top1.get()
        _, top5 = acc_top5.get()
        return 1-top1, 1-top5

    def train(ctx):
        if isinstance(ctx, mx.Context):
            ctx = [ctx]
        if opt.resume_params is '':
            if 'ShuffleNas' in model_name:
                net._initialize(ctx=ctx)
            else:
                net.initialize(mx.init.MSRAPrelu(), ctx=ctx)

        if opt.no_wd:
            for k, v in net.collect_params('.*beta|.*gamma|.*bias').items():
                v.wd_mult = 0.0

        trainer = gluon.Trainer(net.collect_params(), optimizer, optimizer_params)
        if opt.resume_states is not '':
            trainer.load_states(opt.resume_states)

        if opt.label_smoothing or opt.mixup:
            sparse_label_loss = False
        else:
            sparse_label_loss = True
        if distillation:
            L = gcv.loss.DistillationSoftmaxCrossEntropyLoss(temperature=opt.temperature,
                                                                 hard_weight=opt.hard_weight,
                                                                 sparse_label=sparse_label_loss)
        else:
            L = gluon.loss.SoftmaxCrossEntropyLoss(sparse_label=sparse_label_loss)

        best_val_score = 1

        def train_epoch():
            btic = time.time()
            for i, batch in enumerate(train_data):
                if i == num_batches:
                    return i
                data, label = batch_fn(batch, ctx)

                if opt.mixup:
                    lam = np.random.beta(opt.mixup_alpha, opt.mixup_alpha)
                    if epoch >= opt.num_epochs - opt.mixup_off_epoch:
                        lam = 1
                    data = [lam*X + (1-lam)*X[::-1] for X in data]

                    if opt.label_smoothing:
                        eta = 0.1
                    else:
                        eta = 0.0
                    label = mixup_transform(label, classes, lam, eta)

                elif opt.label_smoothing:
                    hard_label = label
                    label = smooth(label, classes)

                if distillation:
                    teacher_prob = [nd.softmax(teacher(X.astype(opt.dtype, copy=False)) / opt.temperature) \
                                    for X in data]

                with ag.record():
                    if model_name == 'ShuffleNas':
                        block_choices = net.random_block_choices(select_predefined_block=False, dtype=opt.dtype)
                        if opt.cs_warm_up:
                            full_channel_mask, channel_choices = net.random_channel_mask(select_all_channels=opt.use_all_channels,
                                                                           epoch_after_cs=epoch - opt.epoch_start_cs,
                                                                           dtype=opt.dtype,
                                                                           ignore_first_two_cs=opt.ignore_first_two_cs)
                        else:
                            full_channel_mask, channel_choices = net.random_channel_mask(select_all_channels=opt.use_all_channels,
                                                                           dtype=opt.dtype,
                                                                           ignore_first_two_cs=opt.ignore_first_two_cs)
                        # TODO: remove debug code
                        stage_repeats = net.stage_repeats
                        stage_out_channels = net.stage_out_channels
                        candidate_scales = net.candidate_scales
                        channel_scales = []
                        for c in range(len(channel_choices)):
                            # scale_ids = [6, 5, 3, 5, 2, 6, 3, 4, 2, 5, 7, 5, 4, 6, 7, 4, 4, 5, 4, 3]
                            channel_scales.append(candidate_scales[channel_choices[c]])
                        channels = [stage_out_channels[0]] * stage_repeats[0] + \
                                   [stage_out_channels[1]] * stage_repeats[1] + \
                                   [stage_out_channels[2]] * stage_repeats[2] + \
                                   [stage_out_channels[3]] * stage_repeats[3]
                        channels = [make_divisible(channel / 2 * channel_scales[j]) for (j, channel)
                                    in enumerate(channels)]
                        if i < 5:
                            print("Train batch channel choices: {}".format(channel_choices))
                            print("Train batch channels: {}".format(channels))
                        # TODO: end of debug code

                        outputs = [net(X.astype(opt.dtype, copy=False), block_choices, full_channel_mask) for X in data]
                    else:
                        outputs = [net(X.astype(opt.dtype, copy=False)) for X in data]
                    if distillation:
                        loss = [L(yhat.astype('float32', copy=False),
                                  y.astype('float32', copy=False),
                                  p.astype('float32', copy=False)) for yhat, y, p in zip(outputs, label, teacher_prob)]
                    else:
                        loss = [L(yhat, y.astype(opt.dtype, copy=False)) for yhat, y in zip(outputs, label)]
                for l in loss:
                    l.backward()
                trainer.step(batch_size, ignore_stale_grad=True)

                if opt.mixup:
                    output_softmax = [nd.SoftmaxActivation(out.astype('float32', copy=False)) \
                                    for out in outputs]
                    train_metric.update(label, output_softmax)
                else:
                    if opt.label_smoothing:
                        train_metric.update(hard_label, outputs)
                    else:
                        train_metric.update(label, outputs)

                if opt.log_interval and not (i+1)%opt.log_interval:
                    train_metric_name, train_metric_score = train_metric.get()
                    logger.info('Epoch[%d] Batch [%d]\tSpeed: %f samples/sec\t%s=%f\tlr=%f'%(
                                epoch, i, batch_size*opt.log_interval/(time.time()-btic),
                                train_metric_name, train_metric_score, trainer.learning_rate))
                    btic = time.time()
            return i

        for epoch in range(opt.resume_epoch, opt.num_epochs):
            if epoch >= opt.epoch_start_cs:
                opt.use_all_channels = False
            tic = time.time()
            if opt.use_rec:
                train_data.reset()
            train_metric.reset()

            i = train_epoch()
            train_metric_name, train_metric_score = train_metric.get()
            throughput = int(batch_size * i / (time.time() - tic))

            err_top1_val, err_top5_val = test(ctx, val_data, epoch)

            logger.info('[Epoch %d] training: %s=%f'%(epoch, train_metric_name, train_metric_score))
            logger.info('[Epoch %d] speed: %d samples/sec\ttime cost: %f'%(epoch, throughput, time.time()-tic))
            logger.info('[Epoch %d] validation: err-top1=%f err-top5=%f'%(epoch, err_top1_val, err_top5_val))

            if err_top1_val < best_val_score:
                best_val_score = err_top1_val
                net.save_parameters('%s/%.4f-imagenet-%s-%d-best.params'%(save_dir, best_val_score, model_name, epoch))
                trainer.save_states('%s/%.4f-imagenet-%s-%d-best.states'%(save_dir, best_val_score, model_name, epoch))

            if save_frequency and save_dir and (epoch + 1) % save_frequency == 0:
                net.save_parameters('%s/imagenet-%s-%d.params'%(save_dir, model_name, epoch))
                trainer.save_states('%s/imagenet-%s-%d.states'%(save_dir, model_name, epoch))

        if save_frequency and save_dir:
            net.save_parameters('%s/imagenet-%s-%d.params'%(save_dir, model_name, opt.num_epochs-1))
            trainer.save_states('%s/imagenet-%s-%d.states'%(save_dir, model_name, opt.num_epochs-1))

    if opt.mode == 'hybrid':
        net.hybridize(static_alloc=True, static_shape=True)
        if distillation:
            teacher.hybridize(static_alloc=True, static_shape=True)
    print(net)
    train(context)
Esempio n. 12
0
def train(net, train_data, val_data, eval_metric, ctx, args):
    """Training pipeline"""
    net.collect_params().reset_ctx(ctx)
    if args.no_wd:
        for k, v in net.collect_params('.*beta|.*gamma|.*bias').items():
            v.wd_mult = 0.0

    if args.label_smooth:
        net._target_generator._label_smooth = True

    if args.lr_decay_period > 0:
        lr_decay_epoch = list(
            range(args.lr_decay_period, args.epochs, args.lr_decay_period))
    else:
        lr_decay_epoch = [int(i) for i in args.lr_decay_epoch.split(',')]
    lr_decay_epoch = [e - args.warmup_epochs for e in lr_decay_epoch]
    num_batches = args.num_samples // args.batch_size
    lr_scheduler = LRSequential([
        LRScheduler('linear',
                    base_lr=0,
                    target_lr=args.lr,
                    nepochs=args.warmup_epochs,
                    iters_per_epoch=num_batches),
        LRScheduler(args.lr_mode,
                    base_lr=args.lr,
                    nepochs=args.epochs - args.warmup_epochs,
                    iters_per_epoch=num_batches,
                    step_epoch=lr_decay_epoch,
                    step_factor=args.lr_decay,
                    power=2),
    ])

    if args.horovod:
        hvd.broadcast_parameters(net.collect_params(), root_rank=0)
        trainer = hvd.DistributedTrainer(net.collect_params(), 'sgd', {
            'wd': args.wd,
            'momentum': args.momentum,
            'lr_scheduler': lr_scheduler
        })
    else:
        trainer = gluon.Trainer(
            net.collect_params(),
            'sgd', {
                'wd': args.wd,
                'momentum': args.momentum,
                'lr_scheduler': lr_scheduler
            },
            kvstore='local',
            update_on_kvstore=(False if args.amp else None))

    if args.amp:
        amp.init_trainer(trainer)

    # targets
    sigmoid_ce = gluon.loss.SigmoidBinaryCrossEntropyLoss(from_sigmoid=False)
    l1_loss = gluon.loss.L1Loss()

    # metrics
    obj_metrics = mx.metric.Loss('ObjLoss')
    center_metrics = mx.metric.Loss('BoxCenterLoss')
    scale_metrics = mx.metric.Loss('BoxScaleLoss')
    cls_metrics = mx.metric.Loss('ClassLoss')

    # set up logger
    logging.basicConfig()
    logger = logging.getLogger()
    logger.setLevel(logging.INFO)
    log_file_path = args.save_prefix + '_train.log'
    log_dir = os.path.dirname(log_file_path)
    if log_dir and not os.path.exists(log_dir):
        os.makedirs(log_dir)
    fh = logging.FileHandler(log_file_path)
    logger.addHandler(fh)
    logger.info(args)
    logger.info('Start training from [Epoch {}]'.format(args.start_epoch))
    best_map = [0]
    for epoch in range(args.start_epoch, args.epochs):
        if args.mixup:
            # TODO(zhreshold): more elegant way to control mixup during runtime
            try:
                train_data._dataset.set_mixup(np.random.beta, 1.5, 1.5)
            except AttributeError:
                train_data._dataset._data.set_mixup(np.random.beta, 1.5, 1.5)
            if epoch >= args.epochs - args.no_mixup_epochs:
                try:
                    train_data._dataset.set_mixup(None)
                except AttributeError:
                    train_data._dataset._data.set_mixup(None)

        tic = time.time()
        btic = time.time()
        mx.nd.waitall()
        net.hybridize()
        for i, batch in enumerate(train_data):
            data = gluon.utils.split_and_load(batch[0],
                                              ctx_list=ctx,
                                              batch_axis=0)
            # objectness, center_targets, scale_targets, weights, class_targets
            fixed_targets = [
                gluon.utils.split_and_load(batch[it],
                                           ctx_list=ctx,
                                           batch_axis=0) for it in range(1, 6)
            ]
            gt_boxes = gluon.utils.split_and_load(batch[6],
                                                  ctx_list=ctx,
                                                  batch_axis=0)
            sum_losses = []
            obj_losses = []
            center_losses = []
            scale_losses = []
            cls_losses = []
            with autograd.record():
                for ix, x in enumerate(data):
                    obj_loss, center_loss, scale_loss, cls_loss = net(
                        x, gt_boxes[ix], *[ft[ix] for ft in fixed_targets])
                    sum_losses.append(obj_loss + center_loss + scale_loss +
                                      cls_loss)
                    obj_losses.append(obj_loss)
                    center_losses.append(center_loss)
                    scale_losses.append(scale_loss)
                    cls_losses.append(cls_loss)
                if args.amp:
                    with amp.scale_loss(sum_losses, trainer) as scaled_loss:
                        autograd.backward(scaled_loss)
                else:
                    autograd.backward(sum_losses)
            trainer.step(batch_size)
            if (not args.horovod or hvd.rank() == 0):
                obj_metrics.update(0, obj_losses)
                center_metrics.update(0, center_losses)
                scale_metrics.update(0, scale_losses)
                cls_metrics.update(0, cls_losses)
                if args.log_interval and not (i + 1) % args.log_interval:
                    name1, loss1 = obj_metrics.get()
                    name2, loss2 = center_metrics.get()
                    name3, loss3 = scale_metrics.get()
                    name4, loss4 = cls_metrics.get()
                    logger.info(
                        '[Epoch {}][Batch {}], LR: {:.2E}, Speed: {:.3f} samples/sec, {}={:.3f}, {}={:.3f}, {}={:.3f}, {}={:.3f}'
                        .format(epoch, i, trainer.learning_rate,
                                args.batch_size / (time.time() - btic), name1,
                                loss1, name2, loss2, name3, loss3, name4,
                                loss4))
                btic = time.time()

        if (not args.horovod or hvd.rank() == 0):
            name1, loss1 = obj_metrics.get()
            name2, loss2 = center_metrics.get()
            name3, loss3 = scale_metrics.get()
            name4, loss4 = cls_metrics.get()
            logger.info(
                '[Epoch {}] Training cost: {:.3f}, {}={:.3f}, {}={:.3f}, {}={:.3f}, {}={:.3f}'
                .format(epoch, (time.time() - tic), name1, loss1, name2, loss2,
                        name3, loss3, name4, loss4))
            if not (epoch + 1) % args.val_interval:
                # consider reduce the frequency of validation to save time
                map_name, mean_ap = validate(net, val_data, ctx, eval_metric)
                val_msg = '\n'.join(
                    ['{}={}'.format(k, v) for k, v in zip(map_name, mean_ap)])
                logger.info('[Epoch {}] Validation: \n{}'.format(
                    epoch, val_msg))
                current_map = float(mean_ap[-1])
            else:
                current_map = 0.
            save_params(net, best_map, current_map, epoch, args.save_interval,
                        args.save_prefix)
def main():
    opt = parse_args()

    hvd.init()

    logging_file = 'train_imagenet_%s.log' % (opt.trainer)

    filehandler = logging.FileHandler(logging_file)
    streamhandler = logging.StreamHandler()

    logger = logging.getLogger('')
    logger.setLevel(logging.INFO)
    logger.addHandler(filehandler)
    logger.addHandler(streamhandler)

    if hvd.rank() == 0:
        logger.info(opt)

    batch_size = opt.batch_size
    classes = 1000
    num_training_samples = 1281167

    context = [mx.gpu(hvd.local_rank())]
    num_workers = opt.num_workers

    optimizer = opt.optimizer

    warmup_epochs = opt.warmup_epochs

    lr_decay = opt.lr_decay
    lr_decay_period = opt.lr_decay_period
    if opt.lr_decay_period > 0:
        lr_decay_epoch = list(range(lr_decay_period, opt.num_epochs, lr_decay_period))
    else:
        lr_decay_epoch = [int(i) for i in opt.lr_decay_epoch.split(',')]
    lr_decay_epoch = [e - warmup_epochs for e in lr_decay_epoch]
    num_batches = num_training_samples // (batch_size * hvd.size())

    lr_scheduler = LRSequential([
        LRScheduler('linear', base_lr=opt.warmup_lr, target_lr=opt.lr,
                    nepochs=warmup_epochs, iters_per_epoch=num_batches),
        LRScheduler(opt.lr_mode, base_lr=opt.lr, target_lr=0,
                    nepochs=opt.num_epochs - warmup_epochs,
                    iters_per_epoch=num_batches,
                    step_epoch=lr_decay_epoch,
                    step_factor=lr_decay, power=2)
    ])


    model_name = opt.model

    kwargs = {'ctx': context, 'pretrained': opt.use_pretrained, 'classes': classes}
    if opt.use_gn:
        kwargs['norm_layer'] = gcv.nn.GroupNorm
    if model_name.startswith('vgg'):
        kwargs['batch_norm'] = opt.batch_norm
    elif model_name.startswith('resnext'):
        kwargs['use_se'] = opt.use_se

    if opt.last_gamma:
        kwargs['last_gamma'] = True

    optimizer_params = {'wd': opt.wd, 'momentum': opt.momentum, 'lr_scheduler': lr_scheduler}
    if opt.dtype != 'float32':
        optimizer_params['multi_precision'] = True

    net = get_model(model_name, **kwargs)
    net.cast(opt.dtype)
    if opt.resume_params != '':
        net.load_parameters(opt.resume_params, ctx = context)

    # teacher model for distillation training
    if opt.teacher is not None and opt.hard_weight < 1.0:
        teacher_name = opt.teacher
        teacher = get_model(teacher_name, pretrained=True, classes=classes, ctx=context)
        teacher.cast(opt.dtype)
        distillation = True
    else:
        distillation = False

    # Two functions for reading data from record file or raw images
    def get_data_rec(rec_train, rec_train_idx, rec_val, rec_val_idx, batch_size, num_workers):
        rec_train = os.path.expanduser(rec_train)
        rec_train_idx = os.path.expanduser(rec_train_idx)
        rec_val = os.path.expanduser(rec_val)
        rec_val_idx = os.path.expanduser(rec_val_idx)
        jitter_param = 0.4
        lighting_param = 0.1
        input_size = opt.input_size
        crop_ratio = opt.crop_ratio if opt.crop_ratio > 0 else 0.875
        resize = int(math.ceil(input_size / crop_ratio))
        mean_rgb = [123.68, 116.779, 103.939]
        std_rgb = [58.393, 57.12, 57.375]

        def batch_fn(batch, ctx):
            data = gluon.utils.split_and_load(batch.data[0], ctx_list=ctx, batch_axis=0)
            label = gluon.utils.split_and_load(batch.label[0], ctx_list=ctx, batch_axis=0)
            return data, label

        train_data = mx.io.ImageRecordIter(
            path_imgrec         = rec_train,
            path_imgidx         = rec_train_idx,
            preprocess_threads  = num_workers,
            shuffle             = True,
            batch_size          = batch_size,
            round_batch         = False,
            data_shape          = (3, input_size, input_size),
            mean_r              = mean_rgb[0],
            mean_g              = mean_rgb[1],
            mean_b              = mean_rgb[2],
            std_r               = std_rgb[0],
            std_g               = std_rgb[1],
            std_b               = std_rgb[2],
            rand_mirror         = True,
            random_resized_crop = True,
            max_aspect_ratio    = 4. / 3.,
            min_aspect_ratio    = 3. / 4.,
            max_random_area     = 1,
            min_random_area     = 0.08,
            brightness          = jitter_param,
            saturation          = jitter_param,
            contrast            = jitter_param,
            pca_noise           = lighting_param,
            num_parts           = hvd.size(),
            part_index          = hvd.rank(),
        )
        val_data = mx.io.ImageRecordIter(
            path_imgrec         = rec_val,
            path_imgidx         = rec_val_idx,
            preprocess_threads  = num_workers,
            shuffle             = False,
            batch_size          = batch_size,

            resize              = resize,
            data_shape          = (3, input_size, input_size),
            mean_r              = mean_rgb[0],
            mean_g              = mean_rgb[1],
            mean_b              = mean_rgb[2],
            std_r               = std_rgb[0],
            std_g               = std_rgb[1],
            std_b               = std_rgb[2],
        )
        return train_data, val_data, batch_fn

    def get_data_loader(data_dir, batch_size, num_workers):
        normalize = transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        jitter_param = 0.4
        lighting_param = 0.1
        input_size = opt.input_size
        crop_ratio = opt.crop_ratio if opt.crop_ratio > 0 else 0.875
        resize = int(math.ceil(input_size / crop_ratio))

        def batch_fn(batch, ctx):
            data = gluon.utils.split_and_load(batch[0], ctx_list=ctx, batch_axis=0)
            label = gluon.utils.split_and_load(batch[1], ctx_list=ctx, batch_axis=0)
            return data, label

        transform_train = transforms.Compose([
            transforms.RandomResizedCrop(input_size),
            transforms.RandomFlipLeftRight(),
            transforms.RandomColorJitter(brightness=jitter_param, contrast=jitter_param,
                                        saturation=jitter_param),
            transforms.RandomLighting(lighting_param),
            transforms.ToTensor(),
            normalize
        ])
        transform_test = transforms.Compose([
            transforms.Resize(resize, keep_ratio=True),
            transforms.CenterCrop(input_size),
            transforms.ToTensor(),
            normalize
        ])

        train_data = gluon.data.DataLoader(
            imagenet.classification.ImageNet(data_dir, train=True).transform_first(transform_train),
            batch_size=batch_size, shuffle=True, last_batch='discard', num_workers=num_workers)
        val_data = gluon.data.DataLoader(
            imagenet.classification.ImageNet(data_dir, train=False).transform_first(transform_test),
            batch_size=batch_size, shuffle=False, num_workers=num_workers)

        return train_data, val_data, batch_fn

    if opt.use_rec:
        train_data, val_data, batch_fn = get_data_rec(opt.rec_train, opt.rec_train_idx,
                                                    opt.rec_val, opt.rec_val_idx,
                                                    batch_size, num_workers)
    else:
        train_data, val_data, batch_fn = get_data_loader(opt.data_dir, batch_size, num_workers)

    if opt.mixup:
        train_metric = mx.metric.RMSE()
    else:
        train_metric = mx.metric.Accuracy()
    acc_top1 = mx.metric.Accuracy()
    acc_top5 = mx.metric.TopKAccuracy(5)

    save_frequency = opt.save_frequency
    if opt.save_dir and save_frequency:
        save_dir = opt.save_dir
        makedirs(save_dir)
    else:
        save_dir = ''
        save_frequency = 0

    def mixup_transform(label, classes, lam=1, eta=0.0):
        if isinstance(label, nd.NDArray):
            label = [label]
        res = []
        for l in label:
            y1 = l.one_hot(classes, on_value = 1 - eta + eta/classes, off_value = eta/classes)
            y2 = l[::-1].one_hot(classes, on_value = 1 - eta + eta/classes, off_value = eta/classes)
            res.append(lam*y1 + (1-lam)*y2)
        return res

    def smooth(label, classes, eta=0.1):
        if isinstance(label, nd.NDArray):
            label = [label]
        smoothed = []
        for l in label:
            res = l.one_hot(classes, on_value = 1 - eta + eta/classes, off_value = eta/classes)
            smoothed.append(res)
        return smoothed

    def test(ctx, val_data, val=True):
        if opt.use_rec:
            if val:
                val_data.reset()
            else:
                train_data.reset()
        acc_top1.reset()
        acc_top5.reset()
        for i, batch in enumerate(val_data):
            data, label = batch_fn(batch, ctx)
            outputs = [net(X.astype(opt.dtype, copy=False)) for X in data]
            acc_top1.update(label, outputs)
            acc_top5.update(label, outputs)

        _, top1 = acc_top1.get()
        _, top5 = acc_top5.get()
        return (1-top1, 1-top5)

    def train(ctx):
        if isinstance(ctx, mx.Context):
            ctx = [ctx]
        if opt.resume_params == '':
            net.initialize(mx.init.MSRAPrelu(), ctx=ctx)

        if opt.no_wd:
            for k, v in net.collect_params('.*beta|.*gamma|.*bias').items():
                v.wd_mult = 0.0

        hvd.broadcast_parameters(net.collect_params(), root_rank=0)

        # trainer = gluon.Trainer(net.collect_params(), optimizer, optimizer_params)

        # trainer = hvd.DistributedTrainer(
        #     net.collect_params(),  
        #     optimizer,
        #     optimizer_params)

        if opt.trainer == 'sgd':
            trainer = SGDTrainer(
                net.collect_params(),  
                optimizer, optimizer_params)
        elif opt.trainer == 'efsgd':
            trainer = EFSGDTrainerV1(
                net.collect_params(),  
                'EFSGDV1', optimizer_params, 
                input_sparse_ratio=1./opt.input_sparse_1, 
                output_sparse_ratio=1./opt.output_sparse_1, 
                layer_sparse_ratio=1./opt.layer_sparse_1)
        elif opt.trainer == 'qsparselocalsgd':
            trainer = QSparseLocalSGDTrainerV1(
                net.collect_params(),  
                optimizer, optimizer_params, 
                input_sparse_ratio=1./opt.input_sparse_1, 
                output_sparse_ratio=1./opt.output_sparse_1, 
                layer_sparse_ratio=1./opt.layer_sparse_1,
                local_sgd_interval=opt.local_sgd_interval)
        elif opt.trainer == 'ersgd':
            trainer = ERSGDTrainerV2(
                net.collect_params(),  
                optimizer, optimizer_params, 
                input_sparse_ratio=1./opt.input_sparse_1, 
                output_sparse_ratio=1./opt.output_sparse_1, 
                layer_sparse_ratio=1./opt.layer_sparse_1)
        elif opt.trainer == 'partiallocalsgd':
            trainer = PartialLocalSGDTrainerV1(
                net.collect_params(),  
                optimizer, optimizer_params, 
                input_sparse_ratio=1./opt.input_sparse_1, 
                output_sparse_ratio=1./opt.output_sparse_1, 
                layer_sparse_ratio=1./opt.layer_sparse_1,
                local_sgd_interval=opt.local_sgd_interval)
        elif opt.trainer == 'ersgd2':
            trainer = ERSGD2TrainerV2(
                net.collect_params(),  
                optimizer, optimizer_params, 
                input_sparse_ratio_1=1./opt.input_sparse_1, 
                output_sparse_ratio_1=1./opt.output_sparse_1, 
                layer_sparse_ratio_1=1./opt.layer_sparse_1,
                input_sparse_ratio_2=1./opt.input_sparse_2, 
                output_sparse_ratio_2=1./opt.output_sparse_2, 
                layer_sparse_ratio_2=1./opt.layer_sparse_2,
                local_sgd_interval=opt.local_sgd_interval)
        else:
            trainer = SGDTrainer(
                net.collect_params(),  
                optimizer, optimizer_params)

        if opt.resume_states != '':
            trainer.load_states(opt.resume_states)

        if opt.label_smoothing or opt.mixup:
            sparse_label_loss = False
        else:
            sparse_label_loss = True
        if distillation:
            L = gcv.loss.DistillationSoftmaxCrossEntropyLoss(temperature=opt.temperature,
                                                                 hard_weight=opt.hard_weight,
                                                                 sparse_label=sparse_label_loss)
        else:
            L = gluon.loss.SoftmaxCrossEntropyLoss(sparse_label=sparse_label_loss)

        best_val_score = 1

        for epoch in range(opt.resume_epoch, opt.num_epochs):
            tic = time.time()
            if opt.use_rec:
                train_data.reset()
            # train_metric.reset()
            train_loss = 0
            btic = time.time()

            # test speed
            if opt.test_speed > 0:
                n_repeats = opt.test_speed
            elif opt.test_speed == 0:
                n_repeats = 1
            else:
                n_repeats = 0

            for i, batch in enumerate(train_data):
                
                # test speed
                if n_repeats == 0 and not (i+1)%opt.log_interval:
                    print('[Epoch %d] # batch: %d'%(epoch, i))
                    continue

                data, label = batch_fn(batch, ctx)

                for j in range(n_repeats):

                    if opt.mixup:
                        lam = np.random.beta(opt.mixup_alpha, opt.mixup_alpha)
                        if epoch >= opt.num_epochs - opt.mixup_off_epoch:
                            lam = 1
                        data = [lam*X + (1-lam)*X[::-1] for X in data]

                        if opt.label_smoothing:
                            eta = 0.1
                        else:
                            eta = 0.0
                        label = mixup_transform(label, classes, lam, eta)

                    elif opt.label_smoothing:
                        hard_label = label
                        label = smooth(label, classes)

                    if distillation:
                        teacher_prob = [nd.softmax(teacher(X.astype(opt.dtype, copy=False)) / opt.temperature) \
                                        for X in data]

                    with ag.record():
                        outputs = [net(X.astype(opt.dtype, copy=False)) for X in data]
                        if distillation:
                            loss = [L(yhat.astype('float32', copy=False),
                                    y.astype('float32', copy=False),
                                    p.astype('float32', copy=False)) for yhat, y, p in zip(outputs, label, teacher_prob)]
                        else:
                            loss = [L(yhat, y.astype(opt.dtype, copy=False)) for yhat, y in zip(outputs, label)]
                    for l in loss:
                        l.backward()
                    trainer.step(batch_size)

                    # if opt.mixup:
                    #     output_softmax = [nd.SoftmaxActivation(out.astype('float32', copy=False)) \
                    #                     for out in outputs]
                    #     train_metric.update(label, output_softmax)
                    # else:
                    #     if opt.label_smoothing:
                    #         train_metric.update(hard_label, outputs)
                    #     else:
                    #         train_metric.update(label, outputs)

                    step_loss = sum([l.sum().asscalar() for l in loss])

                    train_loss += step_loss

                    if opt.log_interval and not (i+j+1)%opt.log_interval:
                        # train_metric_name, train_metric_score = train_metric.get()
                        if hvd.rank() == 0:
                            # logger.info('Epoch[%d] Batch[%d] Speed: %f samples/sec %s=%f lr=%f comm=%f'%(
                            #             epoch, i, batch_size*hvd.size()*opt.log_interval/(time.time()-btic),
                            #             train_metric_name, train_metric_score, trainer.learning_rate, trainer._comm_counter/1e6))
                            # print('Epoch[%d] Batch[%d] Speed: %f samples/sec %s=%f lr=%f comm=%f'%(
                            #             epoch, i, batch_size*hvd.size()*opt.log_interval/(time.time()-btic),
                            #             train_metric_name, train_metric_score, trainer.learning_rate, trainer._comm_counter/1e6))
                            print('Epoch[%d] Batch[%d] Speed: %f samples/sec %s=%f lr=%f comm=%f'%(
                                        epoch, i, batch_size*hvd.size()*opt.log_interval/(time.time()-btic),
                                        'loss', step_loss/batch_size, trainer.learning_rate, trainer._comm_counter/1e6))
                        btic = time.time()

            mx.nd.waitall()
            toc = time.time()

            if n_repeats == 0:
                allreduce_array_nd = mx.nd.array([i])
                hvd.allreduce_(allreduce_array_nd, name='allreduce_array', average=True)
                mx.nd.waitall()
                print('[Epoch %d] # total batch: %d'%(epoch, i))
                continue

            train_metric_name, train_metric_score = train_metric.get()
            throughput = int(batch_size * i /(toc - tic) * hvd.size())

            train_loss /= (batch_size * i)

            if opt.trainer == 'ersgd' or opt.trainer == 'qsparselocalsgd' or opt.trainer == 'ersgd2' or opt.trainer == 'partiallocalsgd':
                allreduce_for_val = True
            else:
                allreduce_for_val = False

            if allreduce_for_val:
                trainer.pre_test()
            # err_train_tic = time.time()
            # err_top1_train, err_top5_train = test(ctx, train_data, val=False)
            err_val_tic = time.time()
            err_top1_val, err_top5_val = test(ctx, val_data, val=True)
            err_val_toc = time.time()
            if allreduce_for_val:
                trainer.post_test()

            mx.nd.waitall()

            # allreduce the results
            allreduce_array_nd = mx.nd.array([train_loss, err_top1_val, err_top5_val])
            hvd.allreduce_(allreduce_array_nd, name='allreduce_array', average=True)
            allreduce_array_np = allreduce_array_nd.asnumpy()
            train_loss = np.asscalar(allreduce_array_np[0])
            err_top1_val = np.asscalar(allreduce_array_np[1])
            err_top5_val = np.asscalar(allreduce_array_np[2])

            if hvd.rank() == 0:
                # logger.info('[Epoch %d] training: %s=%f'%(epoch, train_metric_name, train_metric_score))
                logger.info('[Epoch %d] training: loss=%f'%(epoch, train_loss))
                logger.info('[Epoch %d] speed: %d samples/sec training-time: %f comm: %f'%(epoch, throughput, toc-tic, trainer._comm_counter/1e6))
                logger.info('[Epoch %d] validation: err-top1=%f err-top5=%f err-time=%f'%(epoch, err_top1_val, err_top5_val, err_val_toc - err_val_tic))
                trainer._comm_counter = 0

            if err_top1_val < best_val_score:
                best_val_score = err_top1_val
                # if hvd.local_rank() == 0:
                #     net.save_parameters('%s/%.4f-imagenet-%s-%d-best.params'%(save_dir, best_val_score, model_name, epoch))
                #     trainer.save_states('%s/%.4f-imagenet-%s-%d-best.states'%(save_dir, best_val_score, model_name, epoch))

            if save_frequency and save_dir and (epoch + 1) % save_frequency == 0:
                if hvd.local_rank() == 0:
                    net.save_parameters('%s/imagenet-%s-%d.params'%(save_dir, model_name, epoch))
                    trainer.save_states('%s/imagenet-%s-%d.states'%(save_dir, model_name, epoch))

        # if save_frequency and save_dir:
        #     if hvd.local_rank() == 0:
        #         net.save_parameters('%s/imagenet-%s-%d.params'%(save_dir, model_name, opt.num_epochs-1))
        #         trainer.save_states('%s/imagenet-%s-%d.states'%(save_dir, model_name, opt.num_epochs-1))


    if opt.mode == 'hybrid':
        net.hybridize(static_alloc=True, static_shape=True)
        if distillation:
            teacher.hybridize(static_alloc=True, static_shape=True)
    train(context)
Esempio n. 14
0
    def __init__(self, args):
        self.args = args

        filehandler = logging.FileHandler(args.logging_file)
        streamhandler = logging.StreamHandler()

        self.logger = logging.getLogger('')
        self.logger.setLevel(logging.INFO)
        self.logger.addHandler(filehandler)
        self.logger.addHandler(streamhandler)

        self.logger.info(args)


        # image transform
        input_transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize([.485, .456, .406], [.229, .224, .225]),  # Default mean and std
        ])

        ################################# dataset and dataloader #################################
        if platform.system() == "Darwin":
            data_root = os.path.join('~', 'Nutstore Files', 'Dataset')  # Mac
        elif platform.system() == "Linux":
            data_root = os.path.join('~', 'datasets')  # Laplace or HPC
            if args.colab:
                data_root = '/content/datasets'  # Colab
        else:
            raise ValueError('Notice Dataset Path')

        data_kwargs = {'base_size': args.base_size, 'transform': input_transform,
                       'crop_size': args.crop_size, 'root': data_root,
                       'base_dir': args.dataset}
        trainset = IceContrast(split=args.train_split, mode='train',   **data_kwargs)
        valset = IceContrast(split=args.val_split, mode='testval', **data_kwargs)
        self.train_data = gluon.data.DataLoader(trainset, args.batch_size, shuffle=True,
                                                last_batch='rollover', num_workers=args.workers)
        self.eval_data = gluon.data.DataLoader(valset, args.test_batch_size,
                                               last_batch='rollover', num_workers=args.workers)

        layers = [args.blocks] * 3
        channels = [x * args.channel_times for x in [8, 16, 32, 64]]
        if args.model == 'ResNetFPN':
            model = ASKCResNetFPN(layers=layers, channels=channels, fuse_mode=args.fuse_mode,
                                  tiny=args.tiny, classes=trainset.NUM_CLASS)
        elif args.model == 'ResUNet':
            model = ASKCResUNet(layers=layers, channels=channels, fuse_mode=args.fuse_mode,
                                tiny=args.tiny, classes=trainset.NUM_CLASS)
        print("layers: ", layers)
        print("channels: ", channels)
        print("fuse_mode: ", args.fuse_mode)
        print("tiny: ", args.tiny)
        print("classes: ", trainset.NUM_CLASS)

        if args.host == 'xxx':
            self.host_name = socket.gethostname()  # automatic
        else:
            self.host_name = args.host             # Puma needs to be specified
        self.save_prefix = '_'.join([args.model, args.fuse_mode, args.dataset, self.host_name,
                            'GPU', args.gpus])

        model.cast(args.dtype)
        # self.logger.info(model)

        # resume checkpoint if needed
        if args.resume is not None:
            if os.path.isfile(args.resume):
                model.load_parameters(args.resume, ctx=args.ctx)
            else:
                raise RuntimeError("=> no checkpoint found at '{}'".format(args.resume))
        else:
            model.initialize(init=init.MSRAPrelu(), ctx=args.ctx, force_reinit=True)
            print("Model Initializing")
            print("args.ctx: ", args.ctx)

        self.net = model
        if args.summary:
            summary(self.net, mx.nd.zeros((1, 3, args.crop_size, args.crop_size), ctx=args.ctx[0]))
            sys.exit()

        # create criterion
        self.criterion = SoftIoULoss()

        # optimizer and lr scheduling
        self.lr_scheduler = LRSequential([
                LRScheduler('linear', base_lr=0, target_lr=args.lr,
                            nepochs=args.warmup_epochs, iters_per_epoch=len(self.train_data)),
                LRScheduler(mode='poly', base_lr=args.lr,
                            nepochs=args.epochs-args.warmup_epochs,
                            iters_per_epoch=len(self.train_data),
                            power=0.9)
            ])
        kv = mx.kv.create(args.kvstore)

        if args.optimizer == 'sgd':
            optimizer_params = {'lr_scheduler': self.lr_scheduler,
                                'wd': args.weight_decay,
                                'momentum': args.momentum,
                                'learning_rate': args.lr}
        elif args.optimizer == 'adam':
            optimizer_params = {'lr_scheduler': self.lr_scheduler,
                                'wd': args.weight_decay,
                                'learning_rate': args.lr}
        elif args.optimizer == 'adagrad':
            optimizer_params = {
                'wd': args.weight_decay,
                'learning_rate': args.lr
            }
        else:
            raise ValueError('Unsupported optimizer {} used'.format(args.optimizer))

        if args.dtype == 'float16':
            optimizer_params['multi_precision'] = True

        if args.no_wd:
            for k, v in self.net.collect_params('.*beta|.*gamma|.*bias').items():
                v.wd_mult = 0.0

        self.optimizer = gluon.Trainer(self.net.collect_params(), args.optimizer,
                                       optimizer_params, kvstore=kv)

        ################################# evaluation metrics #################################

        self.iou_metric = SigmoidMetric(1)
        self.nIoU_metric = SamplewiseSigmoidMetric(1, score_thresh=self.args.score_thresh)
        self.best_iou = 0
        self.best_nIoU = 0
        self.is_best = False
Esempio n. 15
0
def train(net, async_net, ctx, args):
    """Training pipeline"""
    net.collect_params().reset_ctx(ctx)
    if args.no_wd:
        for k, v in net.collect_params(".*beta|.*gamma|.*bias").items():
            v.wd_mult = 0.0

    if args.label_smooth:
        net._target_generator._label_smooth = True

    if args.lr_decay_period > 0:
        lr_decay_epoch = list(
            range(args.lr_decay_period, args.epochs, args.lr_decay_period))
    else:
        lr_decay_epoch = [int(i) for i in args.lr_decay_epoch.split(',')]

    lr_scheduler = LRSequential([
        LRScheduler("linear",
                    base_lr=0,
                    target_lr=args.lr,
                    nepochs=args.warmup_epochs,
                    iters_per_epoch=args.batch_size),
        LRScheduler(args.lr_mode,
                    base_lr=args.lr,
                    nepochs=args.epochs - args.warmup_epochs,
                    iters_per_epoch=args.batch_size,
                    step_epoch=lr_decay_epoch,
                    step_factor=args.lr_decay,
                    power=2),
    ])
    if (args.optimizer == "sgd"):
        trainer = gluon.Trainer(net.collect_params(),
                                args.optimizer, {
                                    "wd": args.wd,
                                    "momentum": args.momentum,
                                    "lr_scheduler": lr_scheduler
                                },
                                kvstore="local")
    elif (args.optimizer == "adam"):
        trainer = gluon.Trainer(net.collect_params(),
                                args.optimizer, {"lr_scheduler": lr_scheduler},
                                kvstore="local")
    else:
        trainer = gluon.Trainer(net.collect_params(),
                                args.optimizer,
                                kvstore="local")

    # targets
    #sigmoid_ce = gluon.loss.SigmoidBinaryCrossEntropyLoss(from_sigmoid=False)
    #l1_loss = gluon.loss.L1Loss()

    # Intermediate Metrics:
    train_metrics = (
        mx.metric.Loss("ObjLoss"),
        mx.metric.Loss("BoxCenterLoss"),
        mx.metric.Loss("BoxScaleLoss"),
        mx.metric.Loss("ClassLoss"),
        mx.metric.Loss("TotalLoss"),
    )
    train_metric_ixs = range(len(train_metrics))
    target_metric_ix = -1  # Train towards TotalLoss (the last one)

    # Evaluation Metrics:
    val_metric = VOC07MApMetric(iou_thresh=0.5)

    # Data transformations:
    train_dataset = gluon_pipe_mode.AugmentedManifestDetection(
        args.train,
        length=args.num_samples_train,
    )
    train_batchify_fn = batchify.Tuple(
        *([batchify.Stack() for _ in range(6)] +
          [batchify.Pad(axis=0, pad_val=-1) for _ in range(1)]))
    if args.no_random_shape:
        logger.debug("Creating train DataLoader without random transform")
        train_transforms = YOLO3DefaultTrainTransform(args.data_shape,
                                                      args.data_shape,
                                                      net=async_net,
                                                      mixup=args.mixup)
        train_dataloader = gluon.data.DataLoader(
            train_dataset.transform(train_transforms),
            batch_size=args.batch_size,
            batchify_fn=train_batchify_fn,
            last_batch="discard",
            num_workers=args.num_workers,
            shuffle=
            False,  # Note that shuffle *cannot* be used with AugmentedManifestDetection
        )
    else:
        logger.debug("Creating train DataLoader with random transform")
        train_transforms = [
            YOLO3DefaultTrainTransform(x * 32,
                                       x * 32,
                                       net=async_net,
                                       mixup=args.mixup)
            for x in range(10, 20)
        ]
        train_dataloader = RandomTransformDataLoader(
            train_transforms,
            train_dataset,
            interval=10,
            batch_size=args.batch_size,
            batchify_fn=train_batchify_fn,
            last_batch="discard",
            num_workers=args.num_workers,
            shuffle=
            False,  # Note that shuffle *cannot* be used with AugmentedManifestDetection
        )
    validation_dataset = None
    validation_dataloader = None
    if args.validation:
        validation_dataset = gluon_pipe_mode.AugmentedManifestDetection(
            args.validation,
            length=args.num_samples_validation,
        )
        validation_dataloader = gluon.data.DataLoader(
            validation_dataset.transform(
                YOLO3DefaultValTransform(args.data_shape, args.data_shape), ),
            args.batch_size,
            shuffle=False,
            batchify_fn=batchify.Tuple(batchify.Stack(),
                                       batchify.Pad(pad_val=-1)),
            last_batch="keep",
            num_workers=args.num_workers,
        )

    # Prepare the inference-time configuration for our model's setup:
    # (This will be saved alongside our network structure/params)
    inference_config = config.InferenceConfig(image_size=args.data_shape)

    logger.info(args)
    logger.info(f"Start training from [Epoch {args.start_epoch}]")
    prev_best_score = float("-inf")
    best_epoch = args.start_epoch
    logger.info("Sleeping for 3s in case training data file not yet ready")
    time.sleep(3)
    for epoch in range(args.start_epoch, args.start_epoch + args.epochs):
        #         if args.mixup:
        #             # TODO(zhreshold): more elegant way to control mixup during runtime
        #             try:
        #                 train_data._dataset.set_mixup(np.random.beta, 1.5, 1.5)
        #             except AttributeError:
        #                 train_data._dataset._data.set_mixup(np.random.beta, 1.5, 1.5)
        #             if epoch >= args.epochs - args.no_mixup_epochs:
        #                 try:
        #                     train_data._dataset.set_mixup(None)
        #                 except AttributeError:
        #                     train_data._dataset._data.set_mixup(None)

        tic = time.time()
        btic = time.time()
        mx.nd.waitall()
        net.hybridize()

        logger.debug(
            f"Input data dir contents: {os.listdir('/opt/ml/input/data/')}")
        for i, batch in enumerate(train_dataloader):
            logger.debug(f"Epoch {epoch}, minibatch {i}")

            batch_size = batch[0].shape[0]
            data = gluon.utils.split_and_load(batch[0],
                                              ctx_list=ctx,
                                              batch_axis=0,
                                              even_split=False)
            # objectness, center_targets, scale_targets, weights, class_targets
            fixed_targets = [
                gluon.utils.split_and_load(batch[it],
                                           ctx_list=ctx,
                                           batch_axis=0,
                                           even_split=False)
                for it in range(1, 6)
            ]
            gt_boxes = gluon.utils.split_and_load(batch[6],
                                                  ctx_list=ctx,
                                                  batch_axis=0,
                                                  even_split=False)
            loss_trackers = tuple([] for metric in train_metrics)
            with autograd.record():
                for ix, x in enumerate(data):
                    losses_raw = net(x, gt_boxes[ix],
                                     *[ft[ix] for ft in fixed_targets])
                    # net outputs: [obj_loss, center_loss, scale_loss, cls_loss]
                    # Each a mx.ndarray 1xbatch_size. This is the same order as our
                    # train_metrics, so we just need to add a total vector:
                    total_loss = sum(losses_raw)
                    losses = losses_raw + [total_loss]

                    # If any sample's total loss is non-finite, sum will be:
                    if not isfinite(sum(total_loss)):
                        logger.error(
                            f"[Epoch {epoch}][Minibatch {i}] got non-finite losses: {losses_raw}"
                        )
                        # TODO: Terminate training if losses or gradient go infinite?

                    for ix in train_metric_ixs:
                        loss_trackers[ix].append(losses[ix])

                autograd.backward(loss_trackers[target_metric_ix])
            trainer.step(batch_size)
            for ix in train_metric_ixs:
                train_metrics[ix].update(0, loss_trackers[ix])

            if args.log_interval and not (i + 1) % args.log_interval:
                train_metrics_current = map(lambda metric: metric.get(),
                                            train_metrics)
                metrics_msg = "; ".join([
                    f"{name}={val:.3f}" for name, val in train_metrics_current
                ])
                logger.info(
                    f"[Epoch {epoch}][Minibatch {i}] LR={trainer.learning_rate:.2E}; "
                    f"Speed={batch_size/(time.time()-btic):.3f} samples/sec; {metrics_msg};"
                )
            btic = time.time()

        train_metrics_current = map(lambda metric: metric.get(), train_metrics)
        metrics_msg = "; ".join(
            [f"{name}={val:.3f}" for name, val in train_metrics_current])
        logger.info(
            f"[Epoch {epoch}] TrainingCost={time.time()-tic:.3f}; {metrics_msg};"
        )

        if not (epoch + 1) % args.val_interval:
            logger.info(f"Validating [Epoch {epoch}]")

            metric_names, metric_values = validate(
                net, validation_dataloader, epoch, ctx,
                VOC07MApMetric(iou_thresh=0.5), args)
            if isinstance(metric_names, list):
                val_msg = "; ".join(
                    [f"{k}={v}" for k, v in zip(metric_names, metric_values)])
                current_score = float(metric_values[-1])
            else:
                val_msg = f"{metric_names}={metric_values}"
                current_score = metric_values
            logger.info(f"[Epoch {epoch}] Validation: {val_msg};")
        else:
            current_score = float("-inf")

        save_progress(
            net,
            inference_config,
            current_score,
            prev_best_score,
            args.model_dir,
            epoch,
            args.checkpoint_interval,
            args.checkpoint_dir,
        )
        if current_score > prev_best_score:
            prev_best_score = current_score
            best_epoch = epoch

        if (args.early_stopping and epoch >= args.early_stopping_min_epochs
                and (epoch - best_epoch) >= args.early_stopping_patience):
            logger.info(
                f"[Epoch {epoch}] No improvement since epoch {best_epoch}: Stopping early"
            )
            break
Esempio n. 16
0
def train(net, train_data, val_data, eval_metric, ctx, args):
    """Training pipeline"""
    net.collect_params().reset_ctx(ctx)
    if args.no_wd:
        for k, v in net.collect_params('.*beta|.*gamma|.*bias').items():
            v.wd_mult = 0.0

    if args.label_smooth:
        net._target_generator._label_smooth = True

    if args.lr_decay_period > 0:
        lr_decay_epoch = list(
            range(args.lr_decay_period, args.epochs, args.lr_decay_period))
    else:
        lr_decay_epoch = [int(i) for i in args.lr_decay_epoch.split(',')]
    lr_decay_epoch = [e - args.warmup_epochs for e in lr_decay_epoch]
    num_batches = args.num_samples // args.batch_size
    lr_scheduler = LRSequential([
        LRScheduler('linear',
                    base_lr=0,
                    target_lr=args.lr,
                    nepochs=args.warmup_epochs,
                    iters_per_epoch=num_batches),
        LRScheduler(args.lr_mode,
                    base_lr=args.lr,
                    nepochs=args.epochs - args.warmup_epochs,
                    iters_per_epoch=num_batches,
                    step_epoch=lr_decay_epoch,
                    step_factor=args.lr_decay,
                    power=2),
    ])

    if args.horovod:
        hvd.broadcast_parameters(net.collect_params(), root_rank=0)
        trainer = hvd.DistributedTrainer(net.collect_params(), 'sgd', {
            'wd': args.wd,
            'momentum': args.momentum,
            'lr_scheduler': lr_scheduler
        })
    else:
        trainer = gluon.Trainer(
            net.collect_params(),
            'sgd', {
                'wd': args.wd,
                'momentum': args.momentum,
                'lr_scheduler': lr_scheduler
            },
            kvstore='local',
            update_on_kvstore=(False if args.amp else None))

    if args.amp:
        amp.init_trainer(trainer)

    # targets
    sigmoid_ce = gluon.loss.SigmoidBinaryCrossEntropyLoss(from_sigmoid=False)
    l1_loss = gluon.loss.L1Loss()
    """
    # metrics
    obj_metrics = mx.metric.Loss('ObjLoss')
    center_metrics = mx.metric.Loss('BoxCenterLoss')
    scale_metrics = mx.metric.Loss('BoxScaleLoss')
    cls_metrics = mx.metric.Loss('ClassLoss')   

    # set up logger
    logging.basicConfig()
    logger = logging.getLogger()
    logger.setLevel(logging.INFO)
    log_file_path = args.save_prefix + '_train.log'
    log_dir = os.path.dirname(log_file_path)
    if log_dir and not os.path.exists(log_dir):
        os.makedirs(log_dir)
    fh = logging.FileHandler(log_file_path)
    logger.addHandler(fh)
    logger.info(args)
    logger.info('Start training from [Epoch {}]'.format(args.start_epoch))
    best_map = [0]
    for epoch in range(args.start_epoch, args.epochs):
        if args.mixup:
            # TODO(zhreshold): more elegant way to control mixup during runtime
            try:
                train_data._dataset.set_mixup(np.random.beta, 1.5, 1.5)
            except AttributeError:
                train_data._dataset._data.set_mixup(np.random.beta, 1.5, 1.5)
            if epoch >= args.epochs - args.no_mixup_epochs:
                try:
                    train_data._dataset.set_mixup(None)
                except AttributeError:
                    train_data._dataset._data.set_mixup(None)

        tic = time.time()
        btic = time.time()
        mx.nd.waitall()
        net.hybridize()
        for i, batch in enumerate(train_data):
            data = gluon.utils.split_and_load(batch[0], ctx_list=ctx, batch_axis=0)
            # objectness, center_targets, scale_targets, weights, class_targets
            fixed_targets = [gluon.utils.split_and_load(batch[it], ctx_list=ctx, batch_axis=0) for it in range(1, 6)]
            gt_boxes = gluon.utils.split_and_load(batch[6], ctx_list=ctx, batch_axis=0)
            sum_losses = []
            obj_losses = []
            center_losses = []
            scale_losses = []
            cls_losses = []
            with autograd.record():
                for ix, x in enumerate(data):
                    obj_loss, center_loss, scale_loss, cls_loss = net(x, gt_boxes[ix], *[ft[ix] for ft in fixed_targets])
                    sum_losses.append(obj_loss + center_loss + scale_loss + cls_loss)
                    obj_losses.append(obj_loss)
                    center_losses.append(center_loss)
                    scale_losses.append(scale_loss)
                    cls_losses.append(cls_loss)
                if args.amp:
                    with amp.scale_loss(sum_losses, trainer) as scaled_loss:
                        autograd.backward(scaled_loss)
                else:
                    autograd.backward(sum_losses)
            trainer.step(batch_size)
            if (not args.horovod or hvd.rank() == 0):
                obj_metrics.update(0, obj_losses)
                center_metrics.update(0, center_losses)
                scale_metrics.update(0, scale_losses)
                cls_metrics.update(0, cls_losses)
                if args.log_interval and not (i + 1) % args.log_interval:
                    name1, loss1 = obj_metrics.get()
                    name2, loss2 = center_metrics.get()
                    name3, loss3 = scale_metrics.get()
                    name4, loss4 = cls_metrics.get()
                    logger.info('[Epoch {}][Batch {}], LR: {:.2E}, Speed: {:.3f} samples/sec, {}={:.3f}, {}={:.3f}, {}={:.3f}, {}={:.3f}'.format(
                        epoch, i, trainer.learning_rate, args.batch_size/(time.time()-btic), name1, loss1, name2, loss2, name3, loss3, name4, loss4))
                btic = time.time()

        if (not args.horovod or hvd.rank() == 0):
            name1, loss1 = obj_metrics.get()
            name2, loss2 = center_metrics.get()
            name3, loss3 = scale_metrics.get()
            name4, loss4 = cls_metrics.get()
            logger.info('[Epoch {}] Training cost: {:.3f}, {}={:.3f}, {}={:.3f}, {}={:.3f}, {}={:.3f}'.format(
                epoch, (time.time()-tic), name1, loss1, name2, loss2, name3, loss3, name4, loss4))
            if not (epoch + 1) % args.val_interval:
                # consider reduce the frequency of validation to save time
                map_name, mean_ap = validate(net, val_data, ctx, eval_metric)
                val_msg = '\n'.join(['{}={}'.format(k, v) for k, v in zip(map_name, mean_ap)])
                logger.info('[Epoch {}] Validation: \n{}'.format(epoch, val_msg))
                current_map = float(mean_ap[-1])
            else:
                current_map = 0.
            save_params(net, best_map, current_map, epoch, args.save_interval, args.save_prefix)

if __name__ == '__main__':
    args = parse_args()

    if args.amp:
        amp.init()

    if args.horovod:
        if hvd is None:
            raise SystemExit("Horovod not found, please check if you installed it correctly.")
        hvd.init()

    # fix seed for mxnet, numpy and python builtin random generator.
    gutils.random.seed(args.seed)

    # training contexts
    # check gpu or cpu
    if args.horovod:
        ctx = [mx.gpu(hvd.local_rank())]
    else:
        ctx = [mx.gpu(int(i)) for i in args.gpus.split(',') if i.strip()]
        if ctx:
            print("Using GPU")
        else:
            ctx=mx.cpu()


    # network
    net_name = '_'.join(('yolo3', args.network, args.dataset))
    
    args.save_prefix += net_name
    # use sync bn if specified
    if args.syncbn and len(ctx) > 1:
        net = get_model(net_name, pretrained_base=True, norm_layer=gluon.contrib.nn.SyncBatchNorm,
                        norm_kwargs={'num_devices': len(ctx)})
        async_net = get_model(net_name, pretrained_base=True)  # used by cpu worker
    else:
        net = get_model(net_name, pretrained_base=True)
        async_net = net
    if args.resume.strip():
        net.load_parameters(args.resume.strip())
        async_net.load_parameters(args.resume.strip())
    else:
        with warnings.catch_warnings(record=True) as w:
            warnings.simplefilter("always")
            net.initialize()
            async_net.initialize()


    # training data
    batch_size = (args.batch_size // hvd.size()) if args.horovod else args.batch_size
    train_dataset, val_dataset, eval_metric = get_dataset(args.dataset, args)
    train_data, val_data = get_dataloader(
        async_net, train_dataset, val_dataset, args.data_shape, batch_size, args.num_workers, args)

    start_time = time.time()

    # training
    train(net, train_data, val_data, eval_metric, ctx, args)

    elapsed_time = time.time() - start_time

    print("time to train: ", elapsed_time)

    file_name = "yolo3_mobilenet1.0_voc_last.params"
    net.save_parameters(file_name)

    
    """

    net = net.getmodel("yolo3_mobilenet1.0_voc")
    net.load_parameters("yolo3_mobilenet1.0_voc_best.params")

    im_fname = mx.utils.download(
        'https://raw.githubusercontent.com/zhreshold/' +
        'mxnet-ssd/master/data/demo/dog.jpg',
        path='dog.jpg')
    x, img = mx.data.transforms.presets.yolo.load_test(im_fname, short=512)
    print('Shape of pre-processed image:', x.shape)
    class_IDs, scores, bounding_boxs = net(x)

    ax = mx.utils.viz.plot_bbox(img,
                                bounding_boxs[0],
                                scores[0],
                                class_IDs[0],
                                class_names=net.classes)
    plt.show()
Esempio n. 17
0
def main():
    opt = parse_args()

    filehandler = logging.FileHandler(opt.logging_file)
    streamhandler = logging.StreamHandler()

    logger = logging.getLogger('')
    logger.setLevel(logging.INFO)
    # logger.setLevel(logging.DEBUG)
    logger.addHandler(filehandler)
    logger.addHandler(streamhandler)

    logger.info(opt)

    batch_size = opt.batch_size
    classes = 1000
    num_training_samples = 1281167

    num_gpus = opt.num_gpus
    batch_size *= max(1, num_gpus)
    context = [mx.gpu(i)
               for i in range(num_gpus)] if num_gpus > 0 else [mx.cpu()]
    num_workers = opt.num_workers

    # epoch_start_cs controls that before this epoch, use all channels, while, after this epoch, use channel selection.
    if opt.epoch_start_cs != -1:
        opt.use_all_channels = True

    lr_decay = opt.lr_decay
    lr_decay_period = opt.lr_decay_period
    if opt.lr_decay_period > 0:
        lr_decay_epoch = list(
            range(lr_decay_period, opt.num_epochs, lr_decay_period))
    else:
        lr_decay_epoch = [int(i) for i in opt.lr_decay_epoch.split(',')]
    lr_decay_epoch = [e - opt.warmup_epochs for e in lr_decay_epoch]
    num_batches = num_training_samples // batch_size // opt.reduced_dataset_scale

    lr_scheduler = LRSequential([
        LRScheduler('linear',
                    base_lr=0,
                    target_lr=opt.lr,
                    nepochs=opt.warmup_epochs,
                    iters_per_epoch=num_batches),
        LRScheduler(opt.lr_mode,
                    base_lr=opt.lr,
                    target_lr=0,
                    nepochs=opt.num_epochs - opt.warmup_epochs,
                    iters_per_epoch=num_batches,
                    step_epoch=lr_decay_epoch,
                    step_factor=lr_decay,
                    power=2)
    ])

    model_name = opt.model

    kwargs = {
        'ctx': context,
        'pretrained': opt.use_pretrained,
        'classes': classes
    }
    if opt.use_gn:
        from gluoncv.nn import GroupNorm
        kwargs['norm_layer'] = GroupNorm
    if model_name.startswith('vgg'):
        kwargs['batch_norm'] = opt.batch_norm
    elif model_name.startswith('resnext'):
        kwargs['use_se'] = opt.use_se

    if opt.last_gamma:
        kwargs['last_gamma'] = True

    optimizer = 'nag'
    optimizer_params = {
        'wd': opt.wd,
        'momentum': opt.momentum,
        'lr_scheduler': lr_scheduler
    }
    if opt.dtype != 'float32':
        optimizer_params['multi_precision'] = True

    if model_name == 'ShuffleNas_fixArch':
        architecture = parse_str_list(opt.block_choices)
        scale_ids = parse_str_list(opt.channel_choices)
        net = get_shufflenas_oneshot(
            architecture=architecture,
            n_class=classes,
            scale_ids=scale_ids,
            use_se=opt.use_se,
            last_conv_after_pooling=opt.last_conv_after_pooling,
            channels_layout=opt.channels_layout)
    elif model_name == 'ShuffleNas':
        net = get_shufflenas_oneshot(
            n_class=classes,
            use_all_blocks=opt.use_all_blocks,
            use_se=opt.use_se,
            last_conv_after_pooling=opt.last_conv_after_pooling,
            channels_layout=opt.channels_layout)
    else:
        net = get_model(model_name, **kwargs)

    net.cast(opt.dtype)
    if opt.resume_params is not '':
        net.load_parameters(opt.resume_params, ctx=context)

    # teacher model for distillation training
    if opt.teacher is not None and opt.hard_weight < 1.0:
        teacher_name = opt.teacher
        teacher = get_model(teacher_name,
                            pretrained=True,
                            classes=classes,
                            ctx=context)
        teacher.cast(opt.dtype)
        distillation = True
    else:
        distillation = False

    # Two functions for reading data from record file or raw images
    def get_data_rec(rec_train, rec_train_idx, rec_val, rec_val_idx,
                     batch_size, num_workers):
        rec_train = os.path.expanduser(rec_train)
        rec_train_idx = os.path.expanduser(rec_train_idx)
        rec_val = os.path.expanduser(rec_val)
        rec_val_idx = os.path.expanduser(rec_val_idx)
        jitter_param = 0.4
        lighting_param = 0.1
        input_size = opt.input_size
        crop_ratio = opt.crop_ratio if opt.crop_ratio > 0 else 0.875
        resize = int(math.ceil(input_size / crop_ratio))
        mean_rgb = [123.68, 116.779, 103.939]
        std_rgb = [58.393, 57.12, 57.375]

        def batch_fn(batch, ctx):
            data = gluon.utils.split_and_load(batch.data[0],
                                              ctx_list=ctx,
                                              batch_axis=0)
            label = gluon.utils.split_and_load(batch.label[0],
                                               ctx_list=ctx,
                                               batch_axis=0)
            return data, label

        train_data = mx.io.ImageRecordIter(
            path_imgrec=rec_train,
            path_imgidx=rec_train_idx,
            preprocess_threads=num_workers,
            shuffle=True,
            batch_size=batch_size,
            data_shape=(3, input_size, input_size),
            mean_r=mean_rgb[0],
            mean_g=mean_rgb[1],
            mean_b=mean_rgb[2],
            std_r=std_rgb[0],
            std_g=std_rgb[1],
            std_b=std_rgb[2],
            rand_mirror=True,
            random_resized_crop=True,
            max_aspect_ratio=4. / 3.,
            min_aspect_ratio=3. / 4.,
            max_random_area=1,
            min_random_area=0.08,
            brightness=jitter_param,
            saturation=jitter_param,
            contrast=jitter_param,
            pca_noise=lighting_param,
        )
        val_data = mx.io.ImageRecordIter(
            path_imgrec=rec_val,
            path_imgidx=rec_val_idx,
            preprocess_threads=num_workers,
            shuffle=False,
            batch_size=batch_size,
            resize=resize,
            data_shape=(3, input_size, input_size),
            mean_r=mean_rgb[0],
            mean_g=mean_rgb[1],
            mean_b=mean_rgb[2],
            std_r=std_rgb[0],
            std_g=std_rgb[1],
            std_b=std_rgb[2],
        )
        return train_data, val_data, batch_fn

    def get_data_loader(data_dir, batch_size, num_workers):
        normalize = transforms.Normalize([0.485, 0.456, 0.406],
                                         [0.229, 0.224, 0.225])
        jitter_param = 0.4
        lighting_param = 0.1
        input_size = opt.input_size
        crop_ratio = opt.crop_ratio if opt.crop_ratio > 0 else 0.875
        resize = int(math.ceil(input_size / crop_ratio))

        def batch_fn(batch, ctx):
            data = gluon.utils.split_and_load(batch[0],
                                              ctx_list=ctx,
                                              batch_axis=0)
            label = gluon.utils.split_and_load(batch[1],
                                               ctx_list=ctx,
                                               batch_axis=0)
            return data, label

        transform_train = transforms.Compose([
            transforms.RandomResizedCrop(input_size),
            transforms.RandomFlipLeftRight(),
            transforms.RandomColorJitter(brightness=jitter_param,
                                         contrast=jitter_param,
                                         saturation=jitter_param),
            transforms.RandomLighting(lighting_param),
            transforms.ToTensor(), normalize
        ])
        transform_test = transforms.Compose([
            transforms.Resize(resize, keep_ratio=True),
            transforms.CenterCrop(input_size),
            transforms.ToTensor(), normalize
        ])

        train_data = gluon.data.DataLoader(imagenet.classification.ImageNet(
            data_dir, train=True).transform_first(transform_train),
                                           batch_size=batch_size,
                                           shuffle=True,
                                           last_batch='discard',
                                           num_workers=num_workers)
        val_data = gluon.data.DataLoader(imagenet.classification.ImageNet(
            data_dir, train=False).transform_first(transform_test),
                                         batch_size=batch_size,
                                         shuffle=False,
                                         num_workers=num_workers)

        return train_data, val_data, batch_fn

    if opt.use_rec:
        train_data, val_data, batch_fn = get_data_rec(opt.rec_train,
                                                      opt.rec_train_idx,
                                                      opt.rec_val,
                                                      opt.rec_val_idx,
                                                      batch_size, num_workers)
    else:
        train_data, val_data, batch_fn = get_data_loader(
            opt.data_dir, batch_size, num_workers)

    if opt.mixup:
        train_metric = mx.metric.RMSE()
    else:
        train_metric = mx.metric.Accuracy()
    acc_top1 = mx.metric.Accuracy()
    acc_top5 = mx.metric.TopKAccuracy(5)

    save_frequency = opt.save_frequency
    if opt.save_dir and save_frequency:
        save_dir = opt.save_dir
        makedirs(save_dir)
    else:
        save_dir = ''
        save_frequency = 0

    def mixup_transform(label, classes, lam=1, eta=0.0):
        if isinstance(label, nd.NDArray):
            label = [label]
        res = []
        for l in label:
            y1 = l.one_hot(classes,
                           on_value=1 - eta + eta / classes,
                           off_value=eta / classes)
            y2 = l[::-1].one_hot(classes,
                                 on_value=1 - eta + eta / classes,
                                 off_value=eta / classes)
            res.append(lam * y1 + (1 - lam) * y2)
        return res

    def smooth(label, classes, eta=0.1):
        if isinstance(label, nd.NDArray):
            label = [label]
        smoothed = []
        for l in label:
            res = l.one_hot(classes,
                            on_value=1 - eta + eta / classes,
                            off_value=eta / classes)
            smoothed.append(res)
        return smoothed

    def make_divisible(x, divisible_by=8):
        return int(np.ceil(x * 1. / divisible_by) * divisible_by)

    def set_nas_bn(net, inference_update_stat=False):
        if isinstance(net, NasBatchNorm):
            net.inference_update_stat = inference_update_stat
        elif len(net._children) != 0:
            for k, v in net._children.items():
                set_nas_bn(v, inference_update_stat=inference_update_stat)
        else:
            return

    def update_bn(net,
                  batch_fn,
                  train_data,
                  block_choices,
                  full_channel_mask,
                  ctx=[mx.cpu()],
                  dtype='float32',
                  batch_size=256,
                  update_bn_images=20000):
        train_data.reset()
        # Updating bn needs the model to be float32
        net.cast('float32')
        full_channel_masks = [
            full_channel_mask.as_in_context(ctx_i) for ctx_i in ctx
        ]
        set_nas_bn(net, inference_update_stat=True)
        for i, batch in enumerate(train_data):
            if (i + 1) * batch_size * len(ctx) >= update_bn_images:
                break
            data, _ = batch_fn(batch, ctx)
            _ = [
                net(X.astype('float32', copy=False),
                    block_choices.astype('float32', copy=False),
                    channel_mask.astype('float32', copy=False))
                for X, channel_mask in zip(data, full_channel_masks)
            ]
        set_nas_bn(net, inference_update_stat=False)
        net.cast(dtype)

    def test(ctx, val_data, epoch):
        if model_name == 'ShuffleNas':
            # For evaluating validation accuracy, random select block and channels and update bn stats
            block_choices = net.random_block_choices(
                select_predefined_block=False, dtype=opt.dtype)
            if opt.cs_warm_up:
                # TODO: edit in the issue, readme and medium article that
                #  bn stat needs to be updated before verifying val acc
                full_channel_mask, channel_choices = net.random_channel_mask(
                    select_all_channels=False,
                    epoch_after_cs=epoch - opt.epoch_start_cs,
                    dtype=opt.dtype,
                    ignore_first_two_cs=opt.ignore_first_two_cs)
            else:
                full_channel_mask, _ = net.random_channel_mask(
                    select_all_channels=False,
                    dtype=opt.dtype,
                    ignore_first_two_cs=opt.ignore_first_two_cs)
            update_bn(net,
                      batch_fn,
                      train_data,
                      block_choices,
                      full_channel_mask,
                      ctx,
                      dtype=opt.dtype,
                      batch_size=batch_size)
        else:
            block_choices, full_channel_mask = None, None

        if opt.use_rec:
            val_data.reset()
        acc_top1.reset()
        acc_top5.reset()
        for i, batch in enumerate(val_data):
            data, label = batch_fn(batch, ctx)
            if model_name == 'ShuffleNas':
                full_channel_masks = [
                    full_channel_mask.as_in_context(ctx_i) for ctx_i in ctx
                ]
                outputs = [
                    net(X.astype(opt.dtype, copy=False), block_choices,
                        channel_mask)
                    for X, channel_mask in zip(data, full_channel_masks)
                ]
            else:
                outputs = [net(X.astype(opt.dtype, copy=False)) for X in data]
            acc_top1.update(label, outputs)
            acc_top5.update(label, outputs)

        _, top1 = acc_top1.get()
        _, top5 = acc_top5.get()
        return 1 - top1, 1 - top5

    def train(ctx):
        if isinstance(ctx, mx.Context):
            ctx = [ctx]
        if opt.resume_params is '':
            if 'ShuffleNas' in model_name:
                net._initialize(ctx=ctx)
            else:
                net.initialize(mx.init.MSRAPrelu(), ctx=ctx)

        if opt.no_wd:
            for k, v in net.collect_params('.*beta|.*gamma|.*bias').items():
                v.wd_mult = 0.0

        trainer = gluon.Trainer(net.collect_params(), optimizer,
                                optimizer_params)
        if opt.resume_states is not '':
            trainer.load_states(opt.resume_states)

        if opt.label_smoothing or opt.mixup:
            sparse_label_loss = False
        else:
            sparse_label_loss = True
        if distillation:
            L = gcv.loss.DistillationSoftmaxCrossEntropyLoss(
                temperature=opt.temperature,
                hard_weight=opt.hard_weight,
                sparse_label=sparse_label_loss)
        else:
            L = gluon.loss.SoftmaxCrossEntropyLoss(
                sparse_label=sparse_label_loss)

        best_val_score = 1

        def train_epoch(pool=None,
                        pool_lock=None,
                        shared_finished_flag=None,
                        use_pool=False):
            btic = time.time()
            for i, batch in enumerate(train_data):
                if i == num_batches:
                    if use_pool:
                        shared_finished_flag.value = True
                    return
                data, label = batch_fn(batch, ctx)

                if opt.mixup:
                    lam = np.random.beta(opt.mixup_alpha, opt.mixup_alpha)
                    if epoch >= opt.num_epochs - opt.mixup_off_epoch:
                        lam = 1
                    data = [lam * X + (1 - lam) * X[::-1] for X in data]

                    if opt.label_smoothing:
                        eta = 0.1
                    else:
                        eta = 0.0
                    label = mixup_transform(label, classes, lam, eta)

                elif opt.label_smoothing:
                    hard_label = label
                    label = smooth(label, classes)

                if distillation:
                    teacher_prob = [nd.softmax(teacher(X.astype(opt.dtype, copy=False)) / opt.temperature) \
                                    for X in data]

                with ag.record():
                    if model_name == 'ShuffleNas' and use_pool:
                        cand = None
                        while cand is None:
                            if len(pool) > 0:
                                with pool_lock:
                                    cand = pool.pop()
                                    if i % opt.log_interval == 0:
                                        logger.debug('[Trainer] ' + '-' * 40)
                                        logger.debug(
                                            "[Trainer] Time: {}".format(
                                                time.time()))
                                        logger.debug(
                                            "[Trainer] Block choice: {}".
                                            format(cand['block_list']))
                                        logger.debug(
                                            "[Trainer] Channel choice: {}".
                                            format(cand['channel_list']))
                                        logger.debug(
                                            "[Trainer] Flop: {}M, param: {}M".
                                            format(cand['flops'],
                                                   cand['model_size']))
                            else:
                                time.sleep(1)

                        full_channel_masks = [
                            cand['channel'].as_in_context(ctx_i)
                            for ctx_i in ctx
                        ]
                        outputs = [
                            net(X.astype(opt.dtype, copy=False), cand['block'],
                                channel_mask) for X, channel_mask in zip(
                                    data, full_channel_masks)
                        ]
                    elif model_name == 'ShuffleNas':
                        block_choices = net.random_block_choices(
                            select_predefined_block=False, dtype=opt.dtype)
                        if opt.cs_warm_up:
                            full_channel_mask, channel_choices = net.random_channel_mask(
                                select_all_channels=opt.use_all_channels,
                                epoch_after_cs=epoch - opt.epoch_start_cs,
                                dtype=opt.dtype,
                                ignore_first_two_cs=opt.ignore_first_two_cs)
                        else:
                            full_channel_mask, channel_choices = net.random_channel_mask(
                                select_all_channels=opt.use_all_channels,
                                dtype=opt.dtype,
                                ignore_first_two_cs=opt.ignore_first_two_cs)

                        full_channel_masks = [
                            full_channel_mask.as_in_context(ctx_i)
                            for ctx_i in ctx
                        ]
                        outputs = [
                            net(X.astype(opt.dtype, copy=False), block_choices,
                                channel_mask) for X, channel_mask in zip(
                                    data, full_channel_masks)
                        ]
                    else:
                        outputs = [
                            net(X.astype(opt.dtype, copy=False)) for X in data
                        ]

                    if distillation:
                        loss = [
                            L(yhat.astype('float32', copy=False),
                              y.astype('float32', copy=False),
                              p.astype('float32', copy=False))
                            for yhat, y, p in zip(outputs, label, teacher_prob)
                        ]
                    else:
                        loss = [
                            L(yhat, y.astype(opt.dtype, copy=False))
                            for yhat, y in zip(outputs, label)
                        ]
                for l in loss:
                    l.backward()
                trainer.step(batch_size, ignore_stale_grad=True)

                if opt.mixup:
                    output_softmax = [nd.SoftmaxActivation(out.astype('float32', copy=False)) \
                                    for out in outputs]
                    train_metric.update(label, output_softmax)
                else:
                    if opt.label_smoothing:
                        train_metric.update(hard_label, outputs)
                    else:
                        train_metric.update(label, outputs)

                if opt.log_interval and not (i + 1) % opt.log_interval:
                    train_metric_name, train_metric_score = train_metric.get()
                    logger.info(
                        'Epoch[%d] Batch [%d]\tSpeed: %f samples/sec\t%s=%f\tlr=%f'
                        % (epoch, i, batch_size * opt.log_interval /
                           (time.time() - btic), train_metric_name,
                           train_metric_score, trainer.learning_rate))
                    btic = time.time()
            return

        def maintain_random_pool(pool,
                                 pool_lock,
                                 shared_finished_flag,
                                 upper_flops=sys.maxsize,
                                 upper_params=sys.maxsize,
                                 bottom_flops=0,
                                 bottom_params=0):
            lookup_table = None
            if opt.flop_param_method == 'lookup_table':
                lookup_table = load_lookup_table(opt.use_se,
                                                 opt.last_conv_after_pooling,
                                                 opt.channels_layout,
                                                 nas_root='./')
            while True:
                if shared_finished_flag.value:
                    break
                if len(pool) < 5:
                    candidate = dict()
                    block_choices, block_choices_list = net.random_block_choices(
                        select_predefined_block=False,
                        dtype=opt.dtype,
                        return_choice_list=True)
                    if opt.cs_warm_up:
                        full_channel_mask, channel_choices_list = net.random_channel_mask(
                            select_all_channels=opt.use_all_channels,
                            epoch_after_cs=epoch - opt.epoch_start_cs,
                            dtype=opt.dtype,
                            ignore_first_two_cs=opt.ignore_first_two_cs,
                        )
                    else:
                        full_channel_mask, channel_choices_list = net.random_channel_mask(
                            select_all_channels=opt.use_all_channels,
                            dtype=opt.dtype,
                            ignore_first_two_cs=opt.ignore_first_two_cs)

                    if opt.flop_param_method == 'symbol':
                        flops, model_size, _, _ = \
                            get_flop_param_forward(block_choices_list, channel_choices_list,
                                                 use_se=opt.use_se, last_conv_after_pooling=opt.last_conv_after_pooling,
                                                 channels_layout=opt.channels_layout)
                    elif opt.flop_param_method == 'lookup_table':
                        flops, model_size = get_flop_param_lookup(
                            block_choices_list, channel_choices_list,
                            lookup_table)
                    else:
                        raise ValueError(
                            'Unrecognized flop param calculation method: {}'.
                            format(opt.flop_param_method))

                    candidate['block'] = block_choices
                    candidate['channel'] = full_channel_mask
                    candidate['block_list'] = block_choices_list
                    candidate['channel_list'] = channel_choices_list
                    candidate['flops'] = flops
                    candidate['model_size'] = model_size

                    if flops > upper_flops or model_size > upper_params or \
                        flops < bottom_flops or model_size < bottom_params:
                        continue

                    with pool_lock:
                        pool.append(candidate)
                        logger.debug(
                            "[Maintainer] Add one good candidate. currently pool size: {}"
                            .format(len(pool)))

        if opt.train_constraint_method == 'evolution':

            evolve_maintainer = Maintainer(net,
                                           num_batches=num_batches,
                                           nas_root='./')
        else:
            evolve_maintainer = None
        manager = multiprocessing.Manager()
        cand_pool = manager.list()
        p_lock = manager.Lock()

        for epoch in range(opt.resume_epoch, opt.num_epochs):
            if epoch >= opt.epoch_start_cs:
                opt.use_all_channels = False
            tic = time.time()
            if opt.use_rec:
                train_data.reset()
            train_metric.reset()

            if opt.train_constraint_method is None:
                logger.debug("===== DEBUG ======\n"
                             "Train SuperNet with no constraint")
                train_epoch()
            else:
                upper_constraints = opt.train_upper_constraints.split('-')
                # opt.train_upper_constraints = 'flops-330-params-5.0'
                assert len(upper_constraints) == 4 and upper_constraints[0] == 'flops' \
                       and upper_constraints[2] == 'params'
                upper_flops = float(upper_constraints[1]) if float(
                    upper_constraints[1]) != 0 else sys.maxsize
                upper_params = float(upper_constraints[3]) if float(
                    upper_constraints[3]) != 0 else sys.maxsize

                bottom_constraints = opt.train_bottom_constraints.split('-')
                assert len(bottom_constraints) == 4 and bottom_constraints[0] == 'flops' \
                       and bottom_constraints[2] == 'params'
                bottom_flops = float(bottom_constraints[1]) if float(
                    bottom_constraints[1]) != 0 else 0
                bottom_params = float(bottom_constraints[3]) if float(
                    bottom_constraints[3]) != 0 else 0

                if opt.train_constraint_method == 'random':
                    finished = Value(c_bool, False)
                    logger.debug(
                        "===== DEBUG ======\n"
                        "Train SuperNet with Flops less than {}, greater than {}, "
                        "params less than {}, greater than {}\n Random sample."
                        .format(upper_flops, bottom_flops, upper_params,
                                bottom_params))
                    pool_process = multiprocessing.Process(
                        target=maintain_random_pool,
                        args=[
                            cand_pool, p_lock, finished, upper_flops,
                            upper_params, bottom_flops, bottom_params
                        ])
                    pool_process.start()
                    train_epoch(pool=cand_pool,
                                pool_lock=p_lock,
                                shared_finished_flag=finished,
                                use_pool=True)
                    pool_process.join()
                elif opt.train_constraint_method == 'evolution':
                    finished = Value(c_bool, False)
                    logger.debug(
                        "===== DEBUG ======\n"
                        "Train SuperNet with Flops less than {}, greater than {}, "
                        "params less than {}, greater than {}\n Using strolling evolution to sample."
                        .format(upper_flops, bottom_flops, upper_params,
                                bottom_params))
                    pool_process = multiprocessing.Process(
                        target=evolve_maintainer.maintain,
                        args=[cand_pool, p_lock, finished, opt.dtype, logger])
                    pool_process.start()
                    train_epoch(pool=cand_pool,
                                pool_lock=p_lock,
                                shared_finished_flag=finished,
                                use_pool=True)
                    pool_process.join()
                else:
                    raise ValueError(
                        "Unrecognized training constraint method: {}".format(
                            opt.train_constraint_method))

            train_metric_name, train_metric_score = train_metric.get()
            throughput = int(batch_size * num_batches / (time.time() - tic))

            err_top1_val, err_top5_val = test(ctx, val_data, epoch)

            logger.info('[Epoch %d] training: %s=%f' %
                        (epoch, train_metric_name, train_metric_score))
            logger.info('[Epoch %d] speed: %d samples/sec\ttime cost: %f' %
                        (epoch, throughput, time.time() - tic))
            logger.info('[Epoch %d] validation: err-top1=%f err-top5=%f' %
                        (epoch, err_top1_val, err_top5_val))

            if err_top1_val < best_val_score:
                best_val_score = err_top1_val
                net.save_parameters(
                    '%s/%.4f-imagenet-%s-%d-best.params' %
                    (save_dir, best_val_score, model_name, epoch))
                trainer.save_states(
                    '%s/%.4f-imagenet-%s-%d-best.states' %
                    (save_dir, best_val_score, model_name, epoch))

            if save_frequency and save_dir and (epoch +
                                                1) % save_frequency == 0:
                net.save_parameters('%s/imagenet-%s-%d.params' %
                                    (save_dir, model_name, epoch))
                trainer.save_states('%s/imagenet-%s-%d.states' %
                                    (save_dir, model_name, epoch))

        if save_frequency and save_dir:
            net.save_parameters('%s/imagenet-%s-%d.params' %
                                (save_dir, model_name, opt.num_epochs - 1))
            trainer.save_states('%s/imagenet-%s-%d.states' %
                                (save_dir, model_name, opt.num_epochs - 1))

    if opt.mode == 'hybrid':
        net.hybridize(static_alloc=True, static_shape=True)
        if distillation:
            teacher.hybridize(static_alloc=True, static_shape=True)
    print(net)
    train(context)
def main():
    opt = parse_args()
    makedirs(opt.log_dir)
    filehandler = logging.FileHandler(opt.log_dir + '/' + opt.logging_file)
    streamhandler = logging.StreamHandler()
    logger = logging.getLogger('')
    logger.setLevel(logging.INFO)
    logger.addHandler(filehandler)
    logger.addHandler(streamhandler)
    logger.info(opt)
    batch_size = opt.batch_size
    classes = 1000
    num_training_samples = 1281167
    num_gpus = opt.num_gpus
    batch_size *= max(1, num_gpus)
    context = [mx.gpu(i)
               for i in range(num_gpus)] if num_gpus > 0 else [mx.cpu()]
    num_workers = opt.num_workers
    lr_decay = opt.lr_decay
    lr_decay_period = opt.lr_decay_period
    if opt.lr_decay_period > 0:
        lr_decay_epoch = list(
            range(lr_decay_period, opt.num_epochs, lr_decay_period))
    else:
        lr_decay_epoch = [int(i) for i in opt.lr_decay_epoch.split(',')]
    lr_decay_epoch = [e - opt.warmup_epochs for e in lr_decay_epoch]
    num_batches = num_training_samples // batch_size

    lr_scheduler = LRSequential([
        LRScheduler('linear',
                    base_lr=0,
                    target_lr=opt.lr,
                    nepochs=opt.warmup_epochs,
                    iters_per_epoch=num_batches),
        LRScheduler(opt.lr_mode,
                    base_lr=opt.lr,
                    target_lr=0,
                    nepochs=opt.num_epochs - opt.warmup_epochs,
                    iters_per_epoch=num_batches,
                    step_epoch=lr_decay_epoch,
                    step_factor=lr_decay,
                    power=2)
    ])

    sw = SummaryWriter(logdir=opt.log_dir, flush_secs=5, verbose=False)
    optimizer = 'sgd'
    optimizer_params = {
        'wd': opt.wd,
        'momentum': opt.momentum,
        'lr_scheduler': lr_scheduler
    }
    if opt.dtype != 'float32':
        optimizer_params['multi_precision'] = True
    net = ShuffleNetV2_OneShot(input_size=224,
                               n_class=1000,
                               architecture=cand[0],
                               channels_idx=cand[1],
                               act_type='relu',
                               search=False)  # define a specific subnet

    #compute the flops and params of the subnet
    logger.info("The FLOPs of the subnet is {}".format(
        get_cand_flops_params(cand[0], cand[1])[0]))
    logger.info("The Params of the subnet is {}".format(
        get_cand_flops_params(cand[0], cand[1])[1]))

    net.cast(opt.dtype)
    net.hybridize()

    if opt.resume_params is not '':
        net.load_parameters(opt.resume_params, ctx=context)

    # teacher model for distillation training
    if opt.teacher is not None and opt.hard_weight < 1.0:
        teacher_name = opt.teacher
        teacher = get_model(teacher_name,
                            pretrained=True,
                            classes=classes,
                            ctx=context)
        teacher.cast(opt.dtype)
        distillation = True
    else:
        distillation = False

    # Two functions for reading data from record file or raw images
    def get_data_rec(rec_train, rec_train_idx, rec_val, rec_val_idx,
                     batch_size, num_workers, seed):
        rec_train = os.path.expanduser(rec_train)
        rec_train_idx = os.path.expanduser(rec_train_idx)
        rec_val = os.path.expanduser(rec_val)
        rec_val_idx = os.path.expanduser(rec_val_idx)
        jitter_param = 0.4
        lighting_param = 0.1
        input_size = opt.input_size
        crop_ratio = opt.crop_ratio if opt.crop_ratio > 0 else 0.875
        resize = int(math.ceil(input_size / crop_ratio))
        mean_rgb = [123.68, 116.779, 103.939]
        std_rgb = [58.393, 57.12, 57.375]

        def batch_fn(batch, ctx):
            data = gluon.utils.split_and_load(batch.data[0],
                                              ctx_list=ctx,
                                              batch_axis=0)
            label = gluon.utils.split_and_load(batch.label[0],
                                               ctx_list=ctx,
                                               batch_axis=0)
            return data, label

        train_data = mx.io.ImageRecordIter(
            path_imgrec=rec_train,
            path_imgidx=rec_train_idx,
            preprocess_threads=num_workers,
            shuffle=True,
            batch_size=batch_size,
            data_shape=(3, input_size, input_size),
            mean_r=mean_rgb[0],
            mean_g=mean_rgb[1],
            mean_b=mean_rgb[2],
            std_r=std_rgb[0],
            std_g=std_rgb[1],
            std_b=std_rgb[2],
            rand_mirror=True,
            random_resized_crop=True,
            max_aspect_ratio=4. / 3.,
            min_aspect_ratio=3. / 4.,
            max_random_area=1,
            min_random_area=0.08,
            brightness=jitter_param,
            saturation=jitter_param,
            contrast=jitter_param,
            pca_noise=lighting_param,
            seed=seed,
            seed_aug=seed,
            shuffle_chunk_seed=seed,
        )
        val_data = mx.io.ImageRecordIter(
            path_imgrec=rec_val,
            path_imgidx=rec_val_idx,
            preprocess_threads=num_workers,
            shuffle=False,
            batch_size=batch_size,
            resize=resize,
            data_shape=(3, input_size, input_size),
            mean_r=mean_rgb[0],
            mean_g=mean_rgb[1],
            mean_b=mean_rgb[2],
            std_r=std_rgb[0],
            std_g=std_rgb[1],
            std_b=std_rgb[2],
        )
        return train_data, val_data, batch_fn

    def get_data_loader(data_dir, batch_size, num_workers):
        normalize = transforms.Normalize([0.485, 0.456, 0.406],
                                         [0.229, 0.224, 0.225])
        jitter_param = 0.4
        lighting_param = 0.1
        input_size = opt.input_size
        crop_ratio = opt.crop_ratio if opt.crop_ratio > 0 else 0.875
        resize = int(math.ceil(input_size / crop_ratio))

        def batch_fn(batch, ctx):
            data = gluon.utils.split_and_load(batch[0],
                                              ctx_list=ctx,
                                              batch_axis=0)
            label = gluon.utils.split_and_load(batch[1],
                                               ctx_list=ctx,
                                               batch_axis=0)
            return data, label

        transform_train = transforms.Compose([
            transforms.RandomResizedCrop(input_size),
            transforms.RandomFlipLeftRight(),
            transforms.RandomColorJitter(brightness=jitter_param,
                                         contrast=jitter_param,
                                         saturation=jitter_param),
            transforms.RandomLighting(lighting_param),
            transforms.ToTensor(), normalize
        ])
        transform_test = transforms.Compose([
            transforms.Resize(resize, keep_ratio=True),
            transforms.CenterCrop(input_size),
            transforms.ToTensor(), normalize
        ])

        train_data = gluon.data.DataLoader(imagenet.classification.ImageNet(
            data_dir, train=True).transform_first(transform_train),
                                           batch_size=batch_size,
                                           shuffle=True,
                                           last_batch='discard',
                                           num_workers=num_workers)
        val_data = gluon.data.DataLoader(imagenet.classification.ImageNet(
            data_dir, train=False).transform_first(transform_test),
                                         batch_size=batch_size,
                                         shuffle=False,
                                         num_workers=num_workers)

        return train_data, val_data, batch_fn

    if opt.use_rec:
        if opt.use_dali:
            train_data = dali.get_data_rec((3, opt.input_size, opt.input_size),
                                           opt.crop_ratio,
                                           opt.rec_train,
                                           opt.rec_train_idx,
                                           opt.batch_size,
                                           num_workers=2,
                                           train=True,
                                           shuffle=True,
                                           backend='dali-gpu',
                                           gpu_ids=[0, 1],
                                           kv_store='nccl',
                                           dtype=opt.dtype,
                                           input_layout='NCHW')
            val_data = dali.get_data_rec((3, opt.input_size, opt.input_size),
                                         opt.crop_ratio,
                                         opt.rec_val,
                                         opt.rec_val_idx,
                                         opt.batch_size,
                                         num_workers=2,
                                         train=False,
                                         shuffle=False,
                                         backend='dali-gpu',
                                         gpu_ids=[0, 1],
                                         kv_store='nccl',
                                         dtype=opt.dtype,
                                         input_layout='NCHW')

            def batch_fn(batch, ctx):
                data = gluon.utils.split_and_load(batch[0],
                                                  ctx_list=ctx,
                                                  batch_axis=0)
                label = gluon.utils.split_and_load(batch[1],
                                                   ctx_list=ctx,
                                                   batch_axis=0)
                return data, label
        else:
            train_data, val_data, batch_fn = get_data_rec(
                opt.rec_train, opt.rec_train_idx, opt.rec_val, opt.rec_val_idx,
                batch_size, num_workers, opt.random_seed)
    else:
        train_data, val_data, batch_fn = get_data_loader(
            opt.data_dir, batch_size, num_workers)

    if opt.mixup:
        train_metric = mx.metric.RMSE()
    else:
        train_metric = mx.metric.Accuracy()
    acc_top1 = mx.metric.Accuracy()
    acc_top5 = mx.metric.TopKAccuracy(5)

    save_frequency = opt.save_frequency
    if opt.save_dir and save_frequency:
        save_dir = opt.save_dir
        makedirs(save_dir)
    else:
        save_dir = ''
        save_frequency = 0

    def mixup_transform(label, classes, lam=1, eta=0.0):
        if isinstance(label, nd.NDArray):
            label = [label]
        res = []
        for l in label:
            y1 = l.one_hot(classes,
                           on_value=1 - eta + eta / classes,
                           off_value=eta / classes)
            y2 = l[::-1].one_hot(classes,
                                 on_value=1 - eta + eta / classes,
                                 off_value=eta / classes)
            res.append(lam * y1 + (1 - lam) * y2)
        return res

    def smooth(label, classes, eta=0.1):
        if isinstance(label, nd.NDArray):
            label = [label]
        smoothed = []
        for l in label:
            res = l.one_hot(classes,
                            on_value=1 - eta + eta / classes,
                            off_value=eta / classes)
            smoothed.append(res)
        return smoothed

    def test(net, batch_fn, ctx, val_data):
        if opt.use_rec:
            val_data.reset()
        acc_top1.reset()
        acc_top5.reset()
        for i, batch in enumerate(val_data):
            data, label = batch_fn(batch, ctx)
            outputs = [
                net(X.astype(opt.dtype, copy=False),
                    cand.as_in_context(X.context),
                    channel_mask.as_in_context(X.context)) for X in data
            ]
            acc_top1.update(label, outputs)
            acc_top5.update(label, outputs)
        _, top1 = acc_top1.get()
        _, top5 = acc_top5.get()
        return (top1, top5)

    def train(ctx):
        if isinstance(ctx, mx.Context):
            ctx = [ctx]
        if opt.resume_params is '':
            net._initialize(ctx=ctx, force_reinit=True)
        if opt.no_wd:
            for k, v in net.collect_params('.*beta|.*gamma|.*bias').items():
                v.wd_mult = 0.0

        trainer = gluon.Trainer(net.collect_params(), optimizer,
                                optimizer_params)
        if opt.resume_states is not '':
            trainer.load_states(opt.resume_states)

        if opt.label_smoothing or opt.mixup:
            sparse_label_loss = False
        else:
            sparse_label_loss = True

        if distillation:
            L = gcv.loss.DistillationSoftmaxCrossEntropyLoss(
                temperature=opt.temperature,
                hard_weight=opt.hard_weight,
                sparse_label=sparse_label_loss)
        else:
            L = gluon.loss.SoftmaxCrossEntropyLoss(
                sparse_label=sparse_label_loss)

        best_val_score = 0
        iteration = 0

        for epoch in range(opt.resume_epoch, opt.num_epochs):
            tic = time.time()
            if opt.use_rec:
                train_data.reset()
            train_metric.reset()
            btic = time.time()

            for i, batch in enumerate(train_data):

                data, label = batch_fn(batch, ctx)
                if opt.mixup:
                    lam = np.random.beta(opt.mixup_alpha, opt.mixup_alpha)
                    if epoch >= opt.num_epochs - opt.mixup_off_epoch:
                        lam = 1
                    data = [lam * X + (1 - lam) * X[::-1] for X in data]

                    if opt.label_smoothing:
                        eta = 0.1
                    else:
                        eta = 0.0
                    label = mixup_transform(label, classes, lam, eta)

                elif opt.label_smoothing:
                    hard_label = label
                    label = smooth(label, classes)

                if distillation:
                    teacher_prob = [nd.softmax(teacher(X.astype(opt.dtype, copy=False)) / opt.temperature) \
                                    for X in data]

                with ag.record():
                    outputs = [
                        net(X.astype(opt.dtype, copy=False)) for X in data
                    ]
                    if distillation:
                        loss = [
                            L(yhat.astype('float32', copy=False),
                              y.astype('float32', copy=False),
                              p.astype('float32', copy=False))
                            for yhat, y, p in zip(outputs, label, teacher_prob)
                        ]
                    else:
                        loss = [
                            L(yhat, y.astype(opt.dtype, copy=False))
                            for yhat, y in zip(outputs, label)
                        ]
                for l in loss:
                    l.backward()
                sw.add_scalar(tag='train_loss',
                              value=sum([l.sum().asscalar()
                                         for l in loss]) / len(loss),
                              global_step=iteration)

                trainer.step(batch_size)

                if opt.mixup:
                    output_softmax = [nd.SoftmaxActivation(out.astype('float32', copy=False)) \
                                    for out in outputs]
                    train_metric.update(label, output_softmax)
                else:
                    if opt.label_smoothing:
                        train_metric.update(hard_label, outputs)
                    else:
                        train_metric.update(label, outputs)
                train_metric_name, train_metric_score = train_metric.get()
                sw.add_scalar(
                    tag='train_{}_curves'.format(train_metric_name),
                    value=('train_{}_value'.format(train_metric_name),
                           train_metric_score),
                    global_step=iteration)

                if opt.log_interval and not (i + 1) % opt.log_interval:
                    train_metric_name, train_metric_score = train_metric.get()
                    logger.info(
                        'Epoch[%d] Batch [%d]\tSpeed: %f samples/sec\t%s=%f\tlr=%f'
                        % (epoch, i, batch_size * opt.log_interval /
                           (time.time() - btic), train_metric_name,
                           train_metric_score, trainer.learning_rate))
                    btic = time.time()
                iteration += 1
            if epoch == 0:
                sw.add_graph(net)

            train_metric_name, train_metric_score = train_metric.get()
            throughput = int(batch_size * i / (time.time() - tic))

            top1_val_acc, top5_val_acc = test(net, batch_fn, ctx, val_data)
            sw.add_scalar(tag='val_acc_curves',
                          value=('valid_acc_value', top1_val_acc),
                          global_step=epoch)
            logger.info('Epoch [%d] training: %s=%f' %
                        (epoch, train_metric_name, train_metric_score))
            logger.info('Epoch [%d] speed: %d samples/sec\ttime cost: %f' %
                        (epoch, throughput, time.time() - tic))
            logger.info('Epoch [%d] validation: top1_acc=%f top5_acc=%f' %
                        (epoch, top1_val_acc, top5_val_acc))

            if top1_val_acc > best_val_score:
                best_val_score = top1_val_acc
                net.collect_params().save(
                    '%s/%.4f-subnet_imagenet-%d-best.params' %
                    (save_dir, best_val_score, epoch))
                trainer.save_states('%s/%.4f-subnet_imagenet-%d-best.states' %
                                    (save_dir, best_val_score, epoch))

            if save_frequency and save_dir and (epoch +
                                                1) % save_frequency == 0:
                net.collect_params().save('%s/subnet_imagenet-%d.params' %
                                          (save_dir, epoch))
                trainer.save_states('%s/subnet_imagenet-%d.states' %
                                    (save_dir, epoch))

        sw.close()
        if save_frequency and save_dir:
            net.collect_params().save('%s/subnet_imagenet-%d.params' %
                                      (save_dir, opt.num_epochs - 1))
            trainer.save_states('%s/subnet_imagenet-%d.states' %
                                (save_dir, opt.num_epochs - 1))

    net.hybridize(static_alloc=True, static_shape=True)
    if distillation:
        teacher.hybridize(static_alloc=True, static_shape=True)
    train(context)
Esempio n. 19
0
    def __init__(self, options, logger):
        # configuration setting
        self.opt = options
        self.logger = logger
        self.log_path = os.path.join(self.opt.log_dir, self.opt.model_zoo)

        # checking height and width are multiples of 32
        assert self.opt.height % 32 == 0, "'height' must be a multiple of 32"
        assert self.opt.width % 32 == 0, "'width' must be a multiple of 32"

        ################### model initialization ###################
        self.num_scales = len(self.opt.scales)
        self.num_input_frames = len(self.opt.frame_ids)
        self.num_pose_frames = 2 if self.opt.pose_model_input == "pairs" else self.num_input_frames

        assert self.opt.frame_ids[0] == 0, "frame_ids must start with 0"

        self.use_pose_net = not (self.opt.use_stereo and self.opt.frame_ids == [0])

        if self.opt.use_stereo:
            self.opt.frame_ids.append("s")

        # create network
        if self.opt.model_zoo is not None:
            self.model = get_model(self.opt.model_zoo, pretrained_base=self.opt.pretrained_base,
                                   scales=self.opt.scales, ctx=self.opt.ctx)
        else:
            assert "Must choose a model from model_zoo, " \
                   "please provide the model_zoo using --model_zoo"
        self.logger.info(self.model)

        # resume checkpoint if needed
        if self.opt.resume is not None:
            if os.path.isfile(self.opt.resume):
                logger.info('Resume model: %s' % self.opt.resume)
                self.model.load_parameters(self.opt.resume, ctx=self.opt.ctx)
            else:
                raise RuntimeError("=> no checkpoint found at '{}'".format(self.opt.resume))

        self.parameters_to_train = self.model.collect_params()

        if self.opt.hybridize:
            self.model.hybridize()

        ######################### dataloader #########################
        datasets_dict = {"kitti": KITTIRAWDataset,
                         "kitti_odom": KITTIOdomDataset}
        self.dataset = datasets_dict[self.opt.dataset]

        fpath = os.path.join(os.path.expanduser("~"), ".mxnet/datasets/kitti",
                             "splits", self.opt.split, "{}_files.txt")
        train_filenames = readlines(fpath.format("train"))
        val_filenames = readlines(fpath.format("val"))
        img_ext = '.png' if self.opt.png else '.jpg'

        num_train_samples = len(train_filenames)
        self.num_total_steps = num_train_samples // self.opt.batch_size * self.opt.num_epochs

        train_dataset = self.dataset(
            self.opt.data_path, train_filenames, self.opt.height, self.opt.width,
            self.opt.frame_ids, num_scales=4, is_train=True, img_ext=img_ext)
        self.train_loader = gluon.data.DataLoader(
            train_dataset, batch_size=self.opt.batch_size, shuffle=True,
            batchify_fn=dict_batchify_fn, num_workers=self.opt.num_workers,
            pin_memory=True, last_batch='discard')

        val_dataset = self.dataset(
            self.opt.data_path, val_filenames, self.opt.height, self.opt.width,
            self.opt.frame_ids, num_scales=4, is_train=False, img_ext=img_ext)
        self.val_loader = gluon.data.DataLoader(
            val_dataset, batch_size=self.opt.batch_size, shuffle=False,
            batchify_fn=dict_batchify_fn, num_workers=self.opt.num_workers,
            pin_memory=True, last_batch='discard')

        ################### optimization setting ###################
        self.lr_scheduler = LRSequential([
            LRScheduler('step', base_lr=self.opt.learning_rate,
                        nepochs=self.opt.num_epochs - self.opt.warmup_epochs,
                        iters_per_epoch=len(train_dataset),
                        step_epoch=[self.opt.scheduler_step_size - self.opt.warmup_epochs])
        ])
        optimizer_params = {'lr_scheduler': self.lr_scheduler,
                            'learning_rate': self.opt.learning_rate}

        self.optimizer = gluon.Trainer(self.parameters_to_train, 'adam', optimizer_params)

        print("Training model named:\n  ", self.opt.model_zoo)
        print("Models are saved to:\n  ", self.opt.log_dir)
        print("Training is using:\n  ", "CPU" if self.opt.ctx[0] is mx.cpu() else "GPU")

        ################### loss function ###################
        if not self.opt.no_ssim:
            self.ssim = SSIM()

        self.backproject_depth = {}
        self.project_3d = {}
        for scale in self.opt.scales:
            h = self.opt.height // (2 ** scale)
            w = self.opt.width // (2 ** scale)

            self.backproject_depth[scale] = BackprojectDepth(
                self.opt.batch_size, h, w, ctx=self.opt.ctx[0])
            self.project_3d[scale] = Project3D(self.opt.batch_size, h, w)

        ################### metrics ###################
        self.depth_metric_names = [
            "de/abs_rel", "de/sq_rel", "de/rms", "de/log_rms", "da/a1", "da/a2", "da/a3"]

        print("Using split:\n  ", self.opt.split)
        print("There are {:d} training items and {:d} validation items\n".format(
            len(train_dataset), len(val_dataset)))

        self.save_opts()

        # for save best model
        self.best_delta1 = 0
        self.best_model = self.model
Esempio n. 20
0
def train(net, train_data, val_data, eval_metric, ctx, args):
    net.collect_params().reset_ctx(ctx)
    if args.no_wd:
        for k, v in net.collect_params('.*beta|.*gamma|.*bias').items():
            v.wd_mult = 0.0

    if args.label_smooth:
        net._target_generator._label_smooth = True

    if args.lr_decay_period > 0:
        lr_decay_epoch = list(
            range(args.lr_decay_period, args.epochs, args.lr_decay_period))
    else:
        lr_decay_epoch = [int(i) for i in args.lr_decay_epoch.split(',')]

    lr_decay_epoch = [e - args.warmup_epochs for e in lr_decay_epoch]
    num_batches = args.num_samples // args.batch_size

    lr_scheduler = LRSequential([
        LRScheduler('linear',
                    base_lr=0,
                    target_lr=args.lr,
                    nepochs=args.warmup_epochs,
                    iters_per_epoch=num_batches),
        LRScheduler(args.mode,
                    base_lr=args.lr,
                    nepochs=args.epochs - args.warmup_epochs,
                    iters_per_epoch=num_batches,
                    step_epoch=lr_decay_epoch,
                    step_factor=args.lr_decay,
                    power=2)
    ])

    trainer = gluon.Trainer(net.collect_params(),
                            'sgd', {
                                'wd': args.wd,
                                'momentum': args.momentum,
                                'lr_scheduler': lr_scheduler
                            },
                            kvstore='local')

    obj_metrics = mx.metric.Loss("ObjLoss")
    center_metrics = mx.metric.Loss("BoxCenterLoss")
    scale_metrics = mx.metric.Loss("BoxScaleLoss")
    cls_metrics = mx.metric.Loss("ClassLoss")

    # logger set_up
    logging.basicConfig()
    logger = logging.getLogger()
    logger.setLevel(logging.INFO)
    log_file_path = args.save_prefix + '_train.log'
    log_dir = os.path.dirname(log_file_path)
    if log_dir and not os.path.exists(log_dir):
        os.makedirs(log_dir)
    fh = logging.FileHandler(log_file_path)
    logger.addHandler(fh)
    logger.info(args)
    logger.info("Start training from [Epoch {}]".format(args.start_epoch))
    best_map = [0]

    for epoch in range(args.start_epoch, args.epochs):
        if args.mixup:
            try:
                train_data._dataset.set_mixup(np.random.beta, 1.5, 1.5)
            except AttributeError:
                train_data._dataset._data.set_mixup(np.random.beta, 1.5, 1.5)
            if epoch >= args.epochs - args.no_mixup_epochs:
                try:
                    train_data._dataset.set_mixup(None)
                except AttributeError:
                    train_data._dataset._data.set_mixup(None)

        tic = time.time()
        btic = time.time()
        mx.nd.waitall()
        net.hybridize()

        for i, batch in enumerate(train_data):
            batch_size = batch[0].shape[0]
            data = gluon.utils.split_and_load(batch[0],
                                              ctx_list=ctx,
                                              batch_axis=0)
            fixed_targets = [
                gluon.utils.split_and_load(batch[it],
                                           ctx_list=ctx,
                                           batch_axis=0) for it in range(1, 6)
            ]
            # objectness, center_targets, scale_targets, weights, class_targets
            # objectness, center_targets, scale_targets 代表什么?
            gt_boxes = gluon.utils.split_and_load(batch[6],
                                                  ctx_list=ctx,
                                                  batch_axis=0)

            with autograd.record():
                obj_losses, center_losses, scale_losses, cls_losses = net(
                    data, gt_boxes, *[ft for ft in fixed_targets])
                sum_losses = obj_losses + center_losses + scale_losses + cls_losses
                sum_losses.backward()

            trainer.step(batch_size)
            obj_losses.update(0, obj_losses)
            center_metrics.update(0, center_losses)
            scale_metrics.update(0, scale_losses)
            cls_metrics.update(0, cls_losses)

            if args.log_interval and (i + 1) % args.log_interval == 0:
                name1, loss1 = obj_metrics.get()
                name2, loss2 = center_metrics.get()
                name3, loss3 = scale_metrics.get()
                name4, loss4 = scale_metrics.get()
                logger.info(
                    '[Epoch {}][Batch {}], LR: {:.2E}, Speed: {:.3f} samples/sec, {}={:.3f}, {}={:.3f}, '
                    '{}={:.3f}, {}={:.3f}'.format(
                        epoch, i, trainer.learning_rate,
                        batch_size / (time.time() - btic), name1, loss1, name2,
                        loss2, name3, loss3, name4, loss4))
            btic = time.time()

        name1, loss1 = obj_metrics.get()
        name2, loss2 = center_metrics.get()
        name3, loss3 = scale_metrics.get()
        name4, loss4 = scale_metrics.get()
        logger.info(
            '[Epoch {}] Training cost: {:.3f}, {}={:.3f}, {}={:.3f}, {}={:.3f}, {}={:.3f}'
            .format(epoch, (time.time() - tic), name1, loss1, name2, loss2,
                    name3, loss3, name4, loss4))

        if ((epoch + 1) % args.val_interval
                == 0) or (args.save_interval
                          and epoch % args.save_interval == 0):
            map_name, mean_map = validate(net, val_data, ctx, eval_metric)
            val_msg = "\n".join('{}={}'.format(k, v)
                                for k, v in zip(map_name, mean_map))
            logger.info('[Epoch {}] Validation: \n{}'.format(epoch, val_msg))
            current_map = float(mean_map[-1])
        else:
            current_map = 0.
        save_parameters(net, best_map, current_map, epoch, args.save_interval,
                        args.save_prefix)
Esempio n. 21
0
def main():
    opt = parse_args()

    filehandler = logging.FileHandler(opt.logging_file)
    streamhandler = logging.StreamHandler()

    logger = logging.getLogger('')
    logger.setLevel(logging.INFO)
    logger.addHandler(filehandler)
    logger.addHandler(streamhandler)

    logger.info(opt)

    batch_size = opt.batch_size
    classes = 1000
    num_training_samples = 1281167

    num_gpus = opt.num_gpus
    batch_size *= max(1, num_gpus)
    context = [mx.gpu(i) for i in range(num_gpus)] if num_gpus > 0 else [mx.cpu()]
    num_workers = opt.num_workers

    lr_decay = opt.lr_decay
    lr_decay_period = opt.lr_decay_period
    if opt.lr_decay_period > 0:
        lr_decay_epoch = list(range(lr_decay_period, opt.num_epochs, lr_decay_period))
    else:
        lr_decay_epoch = [int(i) for i in opt.lr_decay_epoch.split(',')]
    lr_decay_epoch = [e - opt.warmup_epochs for e in lr_decay_epoch]
    num_batches = num_training_samples // batch_size

    lr_scheduler = LRSequential([
        LRScheduler('linear', base_lr=0, target_lr=opt.lr,
                    nepochs=opt.warmup_epochs, iters_per_epoch=num_batches),
        LRScheduler(opt.lr_mode, base_lr=opt.lr, target_lr=0,
                    nepochs=opt.num_epochs - opt.warmup_epochs,
                    iters_per_epoch=num_batches,
                    step_epoch=lr_decay_epoch,
                    step_factor=lr_decay, power=2)
    ])

    model_name = 'ResNet50_v1b_' + opt.act_type

    kwargs = {'ctx': context, 'pretrained': opt.use_pretrained, 'classes': classes}
    if opt.use_gn:
        kwargs['norm_layer'] = gcv.nn.GroupNorm
    if model_name.startswith('vgg'):
        kwargs['batch_norm'] = opt.batch_norm
    elif model_name.startswith('resnext'):
        kwargs['use_se'] = opt.use_se

    if opt.last_gamma:
        kwargs['last_gamma'] = True

    optimizer = 'nag'
    optimizer_params = {'wd': opt.wd, 'momentum': opt.momentum, 'lr_scheduler': lr_scheduler}
    if opt.dtype != 'float32':
        optimizer_params['multi_precision'] = True

    ############### Our Own ################

    layers = [3, 4, 6, 3]
    act_type = opt.act_type  # 'ChaATAC', 'SeqATAC'
    r = opt.r
    act_layers = opt.act_layers
    net = ResNet50_v1bATAC(act_type=act_type, r=r, act_layers=act_layers, layers=layers, **kwargs)

    print('act_type: ', act_type)
    print('r: ', r)
    print('act_layers: ', act_layers)
    print('layers: ', layers)

    net.cast(opt.dtype)
    if opt.resume_params != '':
        net.load_parameters(opt.resume_params, ctx = context)

    # teacher model for distillation training
    if opt.teacher is not None and opt.hard_weight < 1.0:
        teacher_name = opt.teacher
        teacher = get_model(teacher_name, pretrained=True, classes=classes, ctx=context)
        teacher.cast(opt.dtype)
        distillation = True
    else:
        distillation = False

    # Two functions for reading data from record file or raw images
    def get_data_rec(rec_train, rec_train_idx, rec_val, rec_val_idx, batch_size, num_workers):
        rec_train = os.path.expanduser(rec_train)
        rec_train_idx = os.path.expanduser(rec_train_idx)
        rec_val = os.path.expanduser(rec_val)
        rec_val_idx = os.path.expanduser(rec_val_idx)
        jitter_param = 0.4
        lighting_param = 0.1
        input_size = opt.input_size
        crop_ratio = opt.crop_ratio if opt.crop_ratio > 0 else 0.875
        resize = int(math.ceil(input_size / crop_ratio))
        mean_rgb = [123.68, 116.779, 103.939]
        std_rgb = [58.393, 57.12, 57.375]

        def batch_fn(batch, ctx):
            data = gluon.utils.split_and_load(batch.data[0], ctx_list=ctx, batch_axis=0)
            label = gluon.utils.split_and_load(batch.label[0], ctx_list=ctx, batch_axis=0)
            return data, label

        train_data = mx.io.ImageRecordIter(
            path_imgrec         = rec_train,
            path_imgidx         = rec_train_idx,
            preprocess_threads  = num_workers,
            shuffle             = True,
            batch_size          = batch_size,

            data_shape          = (3, input_size, input_size),
            mean_r              = mean_rgb[0],
            mean_g              = mean_rgb[1],
            mean_b              = mean_rgb[2],
            std_r               = std_rgb[0],
            std_g               = std_rgb[1],
            std_b               = std_rgb[2],
            rand_mirror         = True,
            random_resized_crop = True,
            max_aspect_ratio    = 4. / 3.,
            min_aspect_ratio    = 3. / 4.,
            max_random_area     = 1,
            min_random_area     = 0.08,
            brightness          = jitter_param,
            saturation          = jitter_param,
            contrast            = jitter_param,
            pca_noise           = lighting_param,
        )
        val_data = mx.io.ImageRecordIter(
            path_imgrec         = rec_val,
            path_imgidx         = rec_val_idx,
            preprocess_threads  = num_workers,
            shuffle             = False,
            batch_size          = batch_size,

            resize              = resize,
            data_shape          = (3, input_size, input_size),
            mean_r              = mean_rgb[0],
            mean_g              = mean_rgb[1],
            mean_b              = mean_rgb[2],
            std_r               = std_rgb[0],
            std_g               = std_rgb[1],
            std_b               = std_rgb[2],
        )
        return train_data, val_data, batch_fn

    def get_data_loader(data_dir, batch_size, num_workers):
        normalize = transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        jitter_param = 0.4
        lighting_param = 0.1
        input_size = opt.input_size
        crop_ratio = opt.crop_ratio if opt.crop_ratio > 0 else 0.875
        resize = int(math.ceil(input_size / crop_ratio))

        def batch_fn(batch, ctx):
            data = gluon.utils.split_and_load(batch[0], ctx_list=ctx, batch_axis=0)
            label = gluon.utils.split_and_load(batch[1], ctx_list=ctx, batch_axis=0)
            return data, label

        transform_train = transforms.Compose([
            transforms.RandomResizedCrop(input_size),
            transforms.RandomFlipLeftRight(),
            transforms.RandomColorJitter(brightness=jitter_param, contrast=jitter_param,
                                        saturation=jitter_param),
            transforms.RandomLighting(lighting_param),
            transforms.ToTensor(),
            normalize
        ])
        transform_test = transforms.Compose([
            transforms.Resize(resize, keep_ratio=True),
            transforms.CenterCrop(input_size),
            transforms.ToTensor(),
            normalize
        ])

        train_data = gluon.data.DataLoader(
            imagenet.classification.ImageNet(data_dir, train=True).transform_first(transform_train),
            batch_size=batch_size, shuffle=True, last_batch='discard', num_workers=num_workers)
        val_data = gluon.data.DataLoader(
            imagenet.classification.ImageNet(data_dir, train=False).transform_first(transform_test),
            batch_size=batch_size, shuffle=False, num_workers=num_workers)

        return train_data, val_data, batch_fn

    if opt.use_rec:
        train_data, val_data, batch_fn = get_data_rec(opt.rec_train, opt.rec_train_idx,
                                                    opt.rec_val, opt.rec_val_idx,
                                                    batch_size, num_workers)
    else:
        train_data, val_data, batch_fn = get_data_loader(opt.data_dir, batch_size, num_workers)

    if opt.mixup:
        train_metric = mx.metric.RMSE()
    else:
        train_metric = mx.metric.Accuracy()
    acc_top1 = mx.metric.Accuracy()
    acc_top5 = mx.metric.TopKAccuracy(5)

    save_frequency = opt.save_frequency
    if opt.save_dir and save_frequency:
        save_dir = opt.save_dir
        makedirs(save_dir)
    else:
        save_dir = ''
        save_frequency = 0

    def mixup_transform(label, classes, lam=1, eta=0.0):
        if isinstance(label, nd.NDArray):
            label = [label]
        res = []
        for l in label:
            y1 = l.one_hot(classes, on_value = 1 - eta + eta/classes, off_value = eta/classes)
            y2 = l[::-1].one_hot(classes, on_value = 1 - eta + eta/classes, off_value = eta/classes)
            res.append(lam*y1 + (1-lam)*y2)
        return res

    def smooth(label, classes, eta=0.1):
        if isinstance(label, nd.NDArray):
            label = [label]
        smoothed = []
        for l in label:
            res = l.one_hot(classes, on_value = 1 - eta + eta/classes, off_value = eta/classes)
            smoothed.append(res)
        return smoothed

    def test(ctx, val_data):
        if opt.use_rec:
            val_data.reset()
        acc_top1.reset()
        acc_top5.reset()
        for i, batch in enumerate(val_data):
            data, label = batch_fn(batch, ctx)
            outputs = [net(X.astype(opt.dtype, copy=False)) for X in data]
            acc_top1.update(label, outputs)
            acc_top5.update(label, outputs)

        _, top1 = acc_top1.get()
        _, top5 = acc_top5.get()
        return (1-top1, 1-top5)

    def train(ctx):
        if isinstance(ctx, mx.Context):
            ctx = [ctx]
        if opt.resume_params == '':
            net.initialize(mx.init.MSRAPrelu(), ctx=ctx)

        if opt.summary:
            # net.summary(mx.nd.zeros((1, 3, opt.input_size, opt.input_size), ctx=ctx[0]))
            summary(net, mx.nd.zeros((1, 3, opt.input_size, opt.input_size), ctx=ctx[0]))
            sys.exit()

        if opt.no_wd:
            for k, v in net.collect_params('.*beta|.*gamma|.*bias').items():
                v.wd_mult = 0.0

        trainer = gluon.Trainer(net.collect_params(), optimizer, optimizer_params)
        if opt.resume_states != '':
            trainer.load_states(opt.resume_states)

        if opt.label_smoothing or opt.mixup:
            sparse_label_loss = False
        else:
            sparse_label_loss = True
        if distillation:
            L = gcv.loss.DistillationSoftmaxCrossEntropyLoss(temperature=opt.temperature,
                                                                 hard_weight=opt.hard_weight,
                                                                 sparse_label=sparse_label_loss)
        else:
            L = gluon.loss.SoftmaxCrossEntropyLoss(sparse_label=sparse_label_loss)

        best_val_score = 1

        for epoch in range(opt.resume_epoch, opt.num_epochs):
            tic = time.time()
            if opt.use_rec:
                train_data.reset()
            train_metric.reset()
            btic = time.time()

            for i, batch in enumerate(train_data):
                data, label = batch_fn(batch, ctx)

                if opt.mixup:
                    lam = np.random.beta(opt.mixup_alpha, opt.mixup_alpha)
                    if epoch >= opt.num_epochs - opt.mixup_off_epoch:
                        lam = 1
                    data = [lam*X + (1-lam)*X[::-1] for X in data]

                    if opt.label_smoothing:
                        eta = 0.1
                    else:
                        eta = 0.0
                    label = mixup_transform(label, classes, lam, eta)

                elif opt.label_smoothing:
                    hard_label = label
                    label = smooth(label, classes)

                if distillation:
                    teacher_prob = [nd.softmax(teacher(X.astype(opt.dtype, copy=False)) / opt.temperature) \
                                    for X in data]

                with ag.record():
                    outputs = [net(X.astype(opt.dtype, copy=False)) for X in data]
                    if distillation:
                        loss = [L(yhat.astype('float32', copy=False),
                                  y.astype('float32', copy=False),
                                  p.astype('float32', copy=False)) for yhat, y, p in zip(outputs, label, teacher_prob)]
                    else:
                        loss = [L(yhat, y.astype(opt.dtype, copy=False)) for yhat, y in zip(outputs, label)]
                for l in loss:
                    l.backward()
                trainer.step(batch_size)

                if opt.mixup:
                    output_softmax = [nd.SoftmaxActivation(out.astype('float32', copy=False)) \
                                    for out in outputs]
                    train_metric.update(label, output_softmax)
                else:
                    if opt.label_smoothing:
                        train_metric.update(hard_label, outputs)
                    else:
                        train_metric.update(label, outputs)

                if opt.log_interval and not (i+1)%opt.log_interval:
                    train_metric_name, train_metric_score = train_metric.get()
                    logger.info('Epoch[%d] Batch [%d]\tSpeed: %f samples/sec\t%s=%f\tlr=%f'%(
                                epoch, i, batch_size*opt.log_interval/(time.time()-btic),
                                train_metric_name, train_metric_score, trainer.learning_rate))
                    btic = time.time()

            train_metric_name, train_metric_score = train_metric.get()
            throughput = int(batch_size * i /(time.time() - tic))

            err_top1_val, err_top5_val = test(ctx, val_data)

            logger.info('[Epoch %d] training: %s=%f'%(epoch, train_metric_name, train_metric_score))
            logger.info('[Epoch %d] speed: %d samples/sec\ttime cost: %f'%(epoch, throughput, time.time()-tic))
            logger.info('[Epoch %d] validation: err-top1=%f err-top5=%f'%(epoch, err_top1_val, err_top5_val))

            if err_top1_val < best_val_score:
                best_val_score = err_top1_val
                net.save_parameters('%s/%.4f-imagenet-%s-%d-best.params'%(save_dir, best_val_score, model_name, epoch))
                trainer.save_states('%s/%.4f-imagenet-%s-%d-best.states'%(save_dir, best_val_score, model_name, epoch))

            if save_frequency and save_dir and (epoch + 1) % save_frequency == 0:
                net.save_parameters('%s/imagenet-%s-%d.params'%(save_dir, model_name, epoch))
                trainer.save_states('%s/imagenet-%s-%d.states'%(save_dir, model_name, epoch))

        if save_frequency and save_dir:
            net.save_parameters('%s/imagenet-%s-%d.params'%(save_dir, model_name, opt.num_epochs-1))
            trainer.save_states('%s/imagenet-%s-%d.states'%(save_dir, model_name, opt.num_epochs-1))


    if opt.mode == 'hybrid':
        net.hybridize(static_alloc=True, static_shape=True)
        if distillation:
            teacher.hybridize(static_alloc=True, static_shape=True)
    train(context)
def train(net, train_data, val_data, eval_metric, ctx, args):

    import gluoncv as gcv

    gcv.utils.check_version("0.6.0")
    from gluoncv import data as gdata
    from gluoncv import utils as gutils
    from gluoncv.data.batchify import Pad, Stack, Tuple
    from gluoncv.data.dataloader import RandomTransformDataLoader
    from gluoncv.data.transforms.presets.yolo import (
        YOLO3DefaultTrainTransform,
        YOLO3DefaultValTransform,
    )
    from gluoncv.model_zoo import get_model
    from gluoncv.utils import LRScheduler, LRSequential
    from gluoncv.utils.metrics.coco_detection import COCODetectionMetric
    from gluoncv.utils.metrics.voc_detection import VOC07MApMetric

    """Training pipeline"""
    net.collect_params().reset_ctx(ctx)
    if args.no_wd:
        for k, v in net.collect_params(".*beta|.*gamma|.*bias").items():
            v.wd_mult = 0.0

    if args.label_smooth:
        net._target_generator._label_smooth = True

    if args.lr_decay_period > 0:
        lr_decay_epoch = list(range(args.lr_decay_period, args.epochs, args.lr_decay_period))
    else:
        lr_decay_epoch = [int(i) for i in args.lr_decay_epoch.split(",")]
    lr_decay_epoch = [e - args.warmup_epochs for e in lr_decay_epoch]
    num_batches = args.num_samples // args.batch_size
    lr_scheduler = LRSequential(
        [
            LRScheduler(
                "linear",
                base_lr=0,
                target_lr=args.lr,
                nepochs=args.warmup_epochs,
                iters_per_epoch=num_batches,
            ),
            LRScheduler(
                args.lr_mode,
                base_lr=args.lr,
                nepochs=args.epochs - args.warmup_epochs,
                iters_per_epoch=num_batches,
                step_epoch=lr_decay_epoch,
                step_factor=args.lr_decay,
                power=2,
            ),
        ]
    )

    if args.horovod:
        hvd.broadcast_parameters(net.collect_params(), root_rank=0)
        trainer = hvd.DistributedTrainer(
            net.collect_params(),
            "sgd",
            {"wd": args.wd, "momentum": args.momentum, "lr_scheduler": lr_scheduler},
        )
    else:
        trainer = gluon.Trainer(
            net.collect_params(),
            "sgd",
            {"wd": args.wd, "momentum": args.momentum, "lr_scheduler": lr_scheduler},
            kvstore="local",
            update_on_kvstore=(False if args.amp else None),
        )

    if args.amp:
        amp.init_trainer(trainer)

    # targets
    sigmoid_ce = gluon.loss.SigmoidBinaryCrossEntropyLoss(from_sigmoid=False)
    l1_loss = gluon.loss.L1Loss()

    # metrics
    obj_metrics = mx.metric.Loss("ObjLoss")
    center_metrics = mx.metric.Loss("BoxCenterLoss")
    scale_metrics = mx.metric.Loss("BoxScaleLoss")
    cls_metrics = mx.metric.Loss("ClassLoss")

    # set up logger
    logging.basicConfig()
    logger = logging.getLogger()
    logger.setLevel(logging.INFO)
    log_file_path = args.save_prefix + "_train.log"
    log_dir = os.path.dirname(log_file_path)
    if log_dir and not os.path.exists(log_dir):
        os.makedirs(log_dir)
    fh = logging.FileHandler(log_file_path)
    logger.addHandler(fh)
    logger.info(args)
    logger.info("Start training from [Epoch {}]".format(args.start_epoch))
    best_map = [0]
    for epoch in range(args.start_epoch, args.num_epochs):
        if args.mixup:
            # TODO(zhreshold): more elegant way to control mixup during runtime
            try:
                train_data._dataset.set_mixup(np.random.beta, 1.5, 1.5)
            except AttributeError:
                train_data._dataset._data.set_mixup(np.random.beta, 1.5, 1.5)
            if epoch >= args.num_epochs - args.no_mixup_epochs:
                try:
                    train_data._dataset.set_mixup(None)
                except AttributeError:
                    train_data._dataset._data.set_mixup(None)

        tic = time.time()
        btic = time.time()
        mx.nd.waitall()
        net.hybridize()
        for i, batch in enumerate(train_data):
            data = gluon.utils.split_and_load(batch[0], ctx_list=ctx, batch_axis=0)
            # objectness, center_targets, scale_targets, weights, class_targets
            fixed_targets = [
                gluon.utils.split_and_load(batch[it], ctx_list=ctx, batch_axis=0)
                for it in range(1, 6)
            ]
            gt_boxes = gluon.utils.split_and_load(batch[6], ctx_list=ctx, batch_axis=0)
            sum_losses = []
            obj_losses = []
            center_losses = []
            scale_losses = []
            cls_losses = []
            with autograd.record():
                for ix, x in enumerate(data):
                    obj_loss, center_loss, scale_loss, cls_loss = net(
                        x, gt_boxes[ix], *[ft[ix] for ft in fixed_targets]
                    )
                    sum_losses.append(obj_loss + center_loss + scale_loss + cls_loss)
                    obj_losses.append(obj_loss)
                    center_losses.append(center_loss)
                    scale_losses.append(scale_loss)
                    cls_losses.append(cls_loss)
                if args.amp:
                    with amp.scale_loss(sum_losses, trainer) as scaled_loss:
                        autograd.backward(scaled_loss)
                else:
                    autograd.backward(sum_losses)
            trainer.step(batch_size)
            if not args.horovod or hvd.rank() == 0:
                obj_metrics.update(0, obj_losses)
                center_metrics.update(0, center_losses)
                scale_metrics.update(0, scale_losses)
                cls_metrics.update(0, cls_losses)
                if args.log_interval and not (i + 1) % args.log_interval:
                    name1, loss1 = obj_metrics.get()
                    name2, loss2 = center_metrics.get()
                    name3, loss3 = scale_metrics.get()
                    name4, loss4 = cls_metrics.get()
                    logger.info(
                        "[Epoch {}][Batch {}], LR: {:.2E}, Speed: {:.3f} samples/sec, {}={:.3f}, {}={:.3f}, {}={:.3f}, {}={:.3f}".format(
                            epoch,
                            i,
                            trainer.learning_rate,
                            args.batch_size / (time.time() - btic),
                            name1,
                            loss1,
                            name2,
                            loss2,
                            name3,
                            loss3,
                            name4,
                            loss4,
                        )
                    )
                btic = time.time()

        if not args.horovod or hvd.rank() == 0:
            name1, loss1 = obj_metrics.get()
            name2, loss2 = center_metrics.get()
            name3, loss3 = scale_metrics.get()
            name4, loss4 = cls_metrics.get()
            logger.info(
                "[Epoch {}] Training cost: {:.3f}, {}={:.3f}, {}={:.3f}, {}={:.3f}, {}={:.3f}".format(
                    epoch,
                    (time.time() - tic),
                    name1,
                    loss1,
                    name2,
                    loss2,
                    name3,
                    loss3,
                    name4,
                    loss4,
                )
            )
            if not (epoch + 1) % args.val_interval:
                # consider reduce the frequency of validation to save time
                map_name, mean_ap = validate(net, val_data, ctx, eval_metric)
                val_msg = "\n".join(["{}={}".format(k, v) for k, v in zip(map_name, mean_ap)])
                logger.info("[Epoch {}] Validation: \n{}".format(epoch, val_msg))
                current_map = float(mean_ap[-1])
            else:
                current_map = 0.0
            save_params(net, best_map, current_map, epoch, args.save_interval, args.save_prefix)

    # save model
    net.set_nms(nms_thresh=0.45, nms_topk=400, post_nms=100)
    net(mx.nd.ones((1, 3, args.data_shape, args.data_shape), ctx=ctx[0]))
    net.export("%s/model" % os.environ["SM_MODEL_DIR"])
Esempio n. 23
0
    def __init__(self, args, logger):
        self.args = args
        self.logger = logger

        # image transform
        input_transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize([.485, .456, .406], [.229, .224, .225]),
        ])

        # dataset and dataloader
        data_kwargs = {
            'transform': input_transform,
            'base_size': args.base_size,
            'crop_size': args.crop_size
        }
        trainset = get_segmentation_dataset(args.dataset,
                                            split=args.train_split,
                                            mode='train',
                                            **data_kwargs)
        valset = get_segmentation_dataset(args.dataset,
                                          split='val',
                                          mode='val',
                                          **data_kwargs)
        self.train_data = gluon.data.DataLoader(trainset,
                                                args.batch_size,
                                                shuffle=True,
                                                last_batch='rollover',
                                                num_workers=args.workers)
        self.eval_data = gluon.data.DataLoader(valset,
                                               args.test_batch_size,
                                               last_batch='rollover',
                                               num_workers=args.workers)

        # create network
        if args.model_zoo is not None:
            model = get_model(args.model_zoo,
                              norm_layer=args.norm_layer,
                              norm_kwargs=args.norm_kwargs,
                              aux=args.aux,
                              base_size=args.base_size,
                              crop_size=args.crop_size,
                              pretrained=args.pretrained)
        else:
            model = get_segmentation_model(model=args.model,
                                           dataset=args.dataset,
                                           backbone=args.backbone,
                                           norm_layer=args.norm_layer,
                                           norm_kwargs=args.norm_kwargs,
                                           aux=args.aux,
                                           base_size=args.base_size,
                                           crop_size=args.crop_size)
        # for resnest use only
        from gluoncv.nn.dropblock import set_drop_prob
        from functools import partial
        apply_drop_prob = partial(set_drop_prob, 0.0)
        model.apply(apply_drop_prob)

        model.cast(args.dtype)
        logger.info(model)

        self.net = DataParallelModel(model, args.ctx, args.syncbn)
        self.evaluator = DataParallelModel(SegEvalModel(model), args.ctx)
        # resume checkpoint if needed
        if args.resume is not None:
            if os.path.isfile(args.resume):
                model.load_parameters(args.resume, ctx=args.ctx)
            else:
                raise RuntimeError("=> no checkpoint found at '{}'".format(
                    args.resume))

        # create criterion
        if 'icnet' in args.model:
            criterion = ICNetLoss(crop_size=args.crop_size)
        else:
            criterion = MixSoftmaxCrossEntropyLoss(args.aux,
                                                   aux_weight=args.aux_weight)
        self.criterion = DataParallelCriterion(criterion, args.ctx,
                                               args.syncbn)

        # optimizer and lr scheduling
        self.lr_scheduler = LRSequential([
            LRScheduler('linear',
                        base_lr=0,
                        target_lr=args.lr,
                        nepochs=args.warmup_epochs,
                        iters_per_epoch=len(self.train_data)),
            LRScheduler(mode='poly',
                        base_lr=args.lr,
                        nepochs=args.epochs - args.warmup_epochs,
                        iters_per_epoch=len(self.train_data),
                        power=0.9)
        ])
        kv = mx.kv.create(args.kvstore)

        if args.optimizer == 'sgd':
            optimizer_params = {
                'lr_scheduler': self.lr_scheduler,
                'wd': args.weight_decay,
                'momentum': args.momentum,
                'learning_rate': args.lr
            }
        elif args.optimizer == 'adam':
            optimizer_params = {
                'lr_scheduler': self.lr_scheduler,
                'wd': args.weight_decay,
                'learning_rate': args.lr
            }
        else:
            raise ValueError('Unsupported optimizer {} used'.format(
                args.optimizer))

        if args.dtype == 'float16':
            optimizer_params['multi_precision'] = True

        if args.no_wd:
            for k, v in self.net.module.collect_params(
                    '.*beta|.*gamma|.*bias').items():
                v.wd_mult = 0.0

        self.optimizer = gluon.Trainer(self.net.module.collect_params(),
                                       args.optimizer,
                                       optimizer_params,
                                       kvstore=kv)
        # evaluation metrics
        self.metric = gluoncv.utils.metrics.SegmentationMetric(
            trainset.num_class)
Esempio n. 24
0
def train(net, train_samples, train_data, val_data, eval_metric, ctx, args):
    """Training pipline"""
    net.collect_params().reset_ctx(ctx)
    # training_patterns = '.*vgg'
    # net.collect_params(training_patterns).setattr('lr_mult', 0.1)

    num_batches = train_samples // args.batch_size

    if args.start_epoch == 0:
        lr_scheduler = LRSequential([
            LRScheduler('linear', base_lr=0, target_lr=args.lr,
                        nepochs=args.warmup_epochs, iters_per_epoch=num_batches),
            LRScheduler('cosine', base_lr=args.lr, target_lr=0,
                        nepochs=args.epochs - args.warmup_epochs
                        , iters_per_epoch=num_batches)])
    else:
        offset = args.start_epoch
        lr_scheduler = LRSequential([
            LRScheduler('cosine', base_lr=args.lr, target_lr=0,
                        nepochs=args.epochs - offset
                        , iters_per_epoch=num_batches)
        ])

    opt_params = {'learning_rate': args.lr, 'momentum': args.momentum, 'wd': args.wd,
                  'lr_scheduler': lr_scheduler}

    trainer = gluon.Trainer(
        net.collect_params(),
        'nag',
        opt_params)
    # lr decay policy
    # lr_decay = float(args.lr_decay)
    # lr_steps = sorted([float(ls) for ls in args.lr_decay_epoch.split(',') if ls.strip()])
    # lr_warmup = float(args.lr_warmup)

    face_mbox_loss = gcv.loss.SSDMultiBoxLoss(rho=1.0, lambd=0.5)
    head_mbox_loss = gcv.loss.SSDMultiBoxLoss(rho=1.0, lambd=0.5)
    body_mbox_loss = gcv.loss.SSDMultiBoxLoss(rho=1.0, lambd=0.5)
    face_ce_metric = mx.metric.Loss('FaceCrossEntropy')
    face_smoothl1_metric = mx.metric.Loss('FaceSmoothL1')
    head_ce_metric = mx.metric.Loss('HeadCrossEntropy')
    head_smoothl1_metric = mx.metric.Loss('HeadSmoothL1')
    body_ce_metric = mx.metric.Loss('BodyCrossEntropy')
    body_smoothl1_metric = mx.metric.Loss('BodySmoothL1')
    metrics = [face_ce_metric, face_smoothl1_metric,
               head_ce_metric, head_smoothl1_metric,
               body_ce_metric, body_smoothl1_metric]

    # set up loger
    logger = logging.getLogger()  # formatter = logging.Formatter('[%(asctime)s] %(message)s')
    logger.setLevel(logging.INFO)
    log_file_path = args.save_prefix + '_train.log'
    log_dir = os.path.dirname(log_file_path)
    if log_dir and not os.path.exists(log_dir):
        os.makedirs(log_dir)
    if os.path.exists(log_file_path) and args.start_epoch == 0:
        os.remove(log_file_path)
    fh = logging.FileHandler(log_file_path)

    logger.addHandler(fh)
    logger.info(args)
    logger.info('Start training from [Epoch {}]'.format(args.start_epoch))
    best_map = [0]

    total_batch = 0
    base_lr = trainer.learning_rate
    # names, maps = validate(net, val_data, ctx, eval_metric)

    for epoch in range(args.start_epoch, args.epochs + 1):
        # while lr_steps and epoch >= lr_steps[0]:
        #     new_lr = trainer.learning_rate * lr_decay
        #     lr_steps.pop(0)
        #     trainer.set_learning_rate(new_lr)
        #     logger.info("[Epoch {}] Set learning rate to {}".format(epoch, new_lr))
        # every epoch learning rate decay
        # if args.start_epoch != 0 or total_batch >= lr_warmup:
        #     new_lr = trainer.learning_rate * lr_decay
        #     # lr_steps.pop(0)
        #     trainer.set_learning_rate(new_lr)
        #     logger.info("[Epoch {}] Set learning rate to {}".format(epoch, new_lr))

        for m in metrics:
            m.reset()
        tic = time.time()
        btic = time.time()

        for i, batch in tqdm.tqdm(enumerate(train_data)):

            # if args.start_epoch == 0 and total_batch <= lr_warmup:
            #     # adjust based on real percentage
            #     new_lr = base_lr * get_lr_at_iter((total_batch + 1) / lr_warmup)
            #     if new_lr != trainer.learning_rate:
            #         if i % args.log_interval == 0:
            #             logger.info(
            #                 '[Epoch {} Iteration {}] Set learning rate to {}'.format(epoch, total_batch, new_lr))
            #         trainer.set_learning_rate(new_lr)
            total_batch += 1
            batch_size = batch[0].shape[0]
            data = gluon.utils.split_and_load(batch[0], ctx_list=ctx, batch_axis=0)
            face_cls_target = gluon.utils.split_and_load(batch[1], ctx_list=ctx, batch_axis=0)
            head_cls_target = gluon.utils.split_and_load(batch[2], ctx_list=ctx, batch_axis=0)
            body_cls_target = gluon.utils.split_and_load(batch[3], ctx_list=ctx, batch_axis=0)

            face_box_target = gluon.utils.split_and_load(batch[4], ctx_list=ctx, batch_axis=0)
            head_box_target = gluon.utils.split_and_load(batch[5], ctx_list=ctx, batch_axis=0)
            body_box_target = gluon.utils.split_and_load(batch[6], ctx_list=ctx, batch_axis=0)

            with autograd.record():
                face_cls_preds = []
                face_box_preds = []
                head_cls_preds = []
                head_box_preds = []
                body_cls_preds = []
                body_box_preds = []

                for x in data:
                    face_cls_predict, face_box_predict, _, \
                    head_cls_predict, head_box_predict, _, \
                    body_cls_predict, body_box_predict, _ = net(x)
                    face_cls_preds.append(face_cls_predict)
                    face_box_preds.append(face_box_predict)
                    head_cls_preds.append(head_cls_predict)
                    head_box_preds.append(head_box_predict)
                    body_cls_preds.append(body_cls_predict)
                    body_box_preds.append(body_box_predict)
                # calculate the loss
                face_sum_loss, face_cls_loss, face_box_loss = face_mbox_loss(
                    face_cls_preds, face_box_preds, face_cls_target, face_box_target)

                head_sum_loss, head_cls_loss, head_box_loss = head_mbox_loss(
                    head_cls_preds, head_box_preds, head_cls_target, head_box_target)

                body_sum_loss, body_cls_loss, body_box_loss = body_mbox_loss(
                    body_cls_preds, body_box_preds, body_cls_target, body_box_target)

                # use 1:0.5:0.2 to backward loss
                # totalloss = [face_sum_loss,head_sum_loss,body_sum_loss]
                totalloss = [f + 0.5 * h + 0.1 * b for f, h, b in zip(face_sum_loss, head_sum_loss, body_sum_loss)]
                # totalloss = face_sum_loss+head_sum_loss
                # autograd.backward(totalloss)

                autograd.backward(totalloss)
            # since we have already normalized the loss, we don't want to normalize
            # by batch-size anymore

            trainer.step(1)
            # logger training info
            face_ce_metric.update(0, [l * batch_size for l in face_cls_loss])
            face_smoothl1_metric.update(0, [l * batch_size for l in face_box_loss])
            head_ce_metric.update(0, [l * batch_size for l in head_cls_loss])
            head_smoothl1_metric.update(0, [l * batch_size for l in head_box_loss])
            body_ce_metric.update(0, [l * batch_size for l in body_cls_loss])
            body_smoothl1_metric.update(0, [l * batch_size for l in body_box_loss])

            # save_params(net, logger, [0], 0, None, total_batch, args.save_interval, args.save_prefix)

            if args.log_interval and not (i + 1) % args.log_interval:
                info = ','.join(['{}={:.4f}'.format(*metric.get()) for metric in metrics])
                # print(info)
                logger.info('[Epoch {}][Batch {}], Speed: {:.3f} samples/sec , lr: {:.5f} ,{:s}'.format(
                    epoch, i, batch_size / (time.time() - btic), trainer.learning_rate, info))
            btic = time.time()
        info = ','.join(['{}={:.4f}'.format(*metric.get()) for metric in metrics])
        logger.info('[Epoch {}] lr: {:.5f} Training cost: {:s}'.format(epoch, trainer.learning_rate, info))

        if args.val_interval and not (epoch + 1) % args.val_interval:
            # consider reduce the frequency of validation to save time
            vtic = time.time()
            names, maps = validate(net, val_data, ctx, eval_metric)
            val_msg = '\n'.join(['{:7}MAP = {}'.format(k, v) for k, v in zip(names, maps)])
            logger.info('[Epoch {}] Validation: {:.3f}\n{}'.format(epoch, (time.time() - vtic), val_msg))
            current_map = sum(maps) / len(maps)
            save_params(net, logger,best_map, current_map, maps, epoch, args.save_interval, args.save_prefix)
        else:
            current_map = 0.
            maps = None
        save_params(net, logger, best_map, current_map, maps, epoch, args.save_interval, args.save_prefix)
Esempio n. 25
0
# Learning rate decay factor
lr_decay = opt.lr_decay
# Epochs where learning rate decays
lr_decay_epoch = opt.lr_decay_epoch

# Stochastic gradient descent
optimizer = 'sgd'
# Set parameters
optimizer_params = {'learning_rate': opt.lr, 'wd': opt.wd, 'momentum': opt.momentum}

num_batches = len(train_data)
lr_scheduler = LRSequential([
    LRScheduler('linear', base_lr=opt.warmup_lr, target_lr=opt.lr,
                nepochs=opt.warmup_epochs, iters_per_epoch=num_batches),
    LRScheduler(opt.lr_mode, base_lr=opt.lr, target_lr=0,
                nepochs=opt.epochs - opt.warmup_epochs,
                iters_per_epoch=num_batches,
                step_epoch=lr_decay_epoch,
                step_factor=lr_decay, power=2)
])
optimizer_params['lr_scheduler'] = lr_scheduler

if opt.partial_bn:
    train_patterns = None
    if 'inceptionv3' in opt.model:
        train_patterns = '.*weight|.*bias|inception30_batchnorm0_gamma|inception30_batchnorm0_beta|inception30_batchnorm0_running_mean|inception30_batchnorm0_running_var'
    else:
        logger.info('Current model does not support partial batch normalization.')

    trainer = gluon.Trainer(net.collect_params(train_patterns), optimizer, optimizer_params, update_on_kvstore=False)
elif opt.partial_bn == False and opt.use_train_patterns == True:
Esempio n. 26
0
def train(net, train_data, val_data, eval_metric, ctx, args):
    """Training pipeline"""
    net.collect_params().reset_ctx(ctx)
    if args.lr_decay_period > 0:
        lr_decay_epoch = list(
            range(args.lr_decay_period, args.epochs, args.lr_decay_period))
    else:
        lr_decay_epoch = [int(i) for i in args.lr_decay_epoch.split(',')]
    lr_decay_epoch = [e - args.warmup_epochs for e in lr_decay_epoch]
    num_batches = args.num_samples // args.batch_size
    lr_scheduler = LRSequential([
        LRScheduler('linear',
                    base_lr=0,
                    target_lr=args.lr,
                    nepochs=args.warmup_epochs,
                    iters_per_epoch=num_batches),
        LRScheduler(args.lr_mode,
                    base_lr=args.lr,
                    nepochs=args.epochs - args.warmup_epochs,
                    iters_per_epoch=num_batches,
                    step_epoch=lr_decay_epoch,
                    step_factor=args.lr_decay,
                    power=2),
    ])
    trainer = gluon.Trainer(net.collect_params(),
                            'sgd', {
                                'lr_scheduler': lr_scheduler,
                                'wd': args.wd,
                                'momentum': args.momentum
                            },
                            update_on_kvstore=(False if args.amp else None))

    if args.amp:
        amp.init_trainer(trainer)

    print("train_efficientdet.py-148 train classes=", classes, len(classes))
    cls_box_loss = EfficientDetLoss(len(classes) + 1, rho=0.1, lambd=50.0)
    ce_metric = mx.metric.Loss('FocalLoss')
    smoothl1_metric = mx.metric.Loss('SmoothL1')

    # set up logger
    logging.basicConfig()
    logger = logging.getLogger()
    logger.setLevel(logging.INFO)
    log_file_path = args.save_prefix + '_train.log'
    log_dir = os.path.dirname(log_file_path)
    if log_dir and not os.path.exists(log_dir):
        os.makedirs(log_dir)
    fh = logging.FileHandler(log_file_path)
    logger.addHandler(fh)
    logger.info(args)
    logger.info('Start training from [Epoch {}]'.format(args.start_epoch))
    best_map = [0]

    for epoch in range(args.start_epoch + 1, args.epochs + 1):
        logger.info("[Epoch {}] Set learning rate to {}".format(
            epoch, trainer.learning_rate))
        ce_metric.reset()
        smoothl1_metric.reset()
        tic = time.time()
        btic = time.time()
        net.hybridize()
        for i, batch in enumerate(train_data):
            data = gluon.utils.split_and_load(batch[0],
                                              ctx_list=ctx,
                                              batch_axis=0)
            cls_targets = gluon.utils.split_and_load(batch[1],
                                                     ctx_list=ctx,
                                                     batch_axis=0)
            box_targets = gluon.utils.split_and_load(batch[2],
                                                     ctx_list=ctx,
                                                     batch_axis=0)

            with autograd.record():
                cls_preds = []
                box_preds = []
                for x in data:
                    cls_pred, box_pred, _ = net(x)
                    cls_preds.append(cls_pred)
                    box_preds.append(box_pred)
                sum_loss, cls_loss, box_loss = cls_box_loss(
                    cls_preds, box_preds, cls_targets, box_targets)
                if args.amp:
                    with amp.scale_loss(sum_loss, trainer) as scaled_loss:
                        autograd.backward(scaled_loss)
                else:
                    autograd.backward(sum_loss)
            # since we have already normalized the loss, we don't want to normalize
            # by batch-size anymore
            trainer.step(1)

            local_batch_size = int(args.batch_size)
            ce_metric.update(0, [l * local_batch_size for l in cls_loss])
            smoothl1_metric.update(0, [l * local_batch_size for l in box_loss])
            if args.log_interval and not (i + 1) % args.log_interval:
                name1, loss1 = ce_metric.get()
                name2, loss2 = smoothl1_metric.get()
                logger.info(
                    '[Epoch {}][Batch {}], Speed: {:.3f} samples/sec, {}={:.3f}, {}={:.3f}'
                    .format(epoch, i, args.batch_size / (time.time() - btic),
                            name1, loss1, name2, loss2))
            btic = time.time()

        name1, loss1 = ce_metric.get()
        name2, loss2 = smoothl1_metric.get()
        logger.info(
            '[Epoch {}] Training cost: {:.3f}, {}={:.3f}, {}={:.3f}'.format(
                epoch, (time.time() - tic), name1, loss1, name2, loss2))
        if (epoch % args.val_interval
                == 0) or (args.save_interval
                          and epoch % args.save_interval == 0):
            # consider reduce the frequency of validation to save time
            map_name, mean_ap = validate(net, val_data, ctx, eval_metric)
            val_msg = '\n'.join(
                ['{}={}'.format(k, v) for k, v in zip(map_name, mean_ap)])
            logger.info('[Epoch {}] Validation: \n{}'.format(epoch, val_msg))
            current_map = float(mean_ap[-1])
        else:
            current_map = 0.
        save_params(net, best_map, current_map, epoch, args.save_interval,
                    args.save_prefix)
Esempio n. 27
0
def main():
    opt = parse_args()

    makedirs(opt.save_dir)

    filehandler = logging.FileHandler(
        os.path.join(opt.save_dir, opt.logging_file))
    streamhandler = logging.StreamHandler()
    logger = logging.getLogger('')
    logger.setLevel(logging.INFO)
    logger.addHandler(filehandler)
    logger.addHandler(streamhandler)
    logger.info(opt)

    sw = SummaryWriter(logdir=opt.save_dir, flush_secs=5, verbose=False)

    if opt.kvstore is not None:
        kv = mx.kvstore.create(opt.kvstore)
        logger.info(
            'Distributed training with %d workers and current rank is %d' %
            (kv.num_workers, kv.rank))
    if opt.use_amp:
        amp.init()

    batch_size = opt.batch_size
    classes = opt.num_classes

    num_gpus = opt.num_gpus
    batch_size *= max(1, num_gpus)
    logger.info('Total batch size is set to %d on %d GPUs' %
                (batch_size, num_gpus))
    context = [mx.gpu(i)
               for i in range(num_gpus)] if num_gpus > 0 else [mx.cpu()]
    num_workers = opt.num_workers

    lr_decay = opt.lr_decay
    lr_decay_period = opt.lr_decay_period
    if opt.lr_decay_period > 0:
        lr_decay_epoch = list(
            range(lr_decay_period, opt.num_epochs, lr_decay_period))
    else:
        lr_decay_epoch = [int(i) for i in opt.lr_decay_epoch.split(',')]
    lr_decay_epoch = [e - opt.warmup_epochs for e in lr_decay_epoch]

    optimizer = 'sgd'
    if opt.clip_grad > 0:
        optimizer_params = {
            'learning_rate': opt.lr,
            'wd': opt.wd,
            'momentum': opt.momentum,
            'clip_gradient': opt.clip_grad
        }
    else:
        optimizer_params = {
            'learning_rate': opt.lr,
            'wd': opt.wd,
            'momentum': opt.momentum
        }

    if opt.dtype != 'float32':
        optimizer_params['multi_precision'] = True

    model_name = opt.model
    net = get_model(name=model_name,
                    nclass=classes,
                    pretrained=opt.use_pretrained,
                    use_tsn=opt.use_tsn,
                    num_segments=opt.num_segments,
                    partial_bn=opt.partial_bn)
    net.cast(opt.dtype)
    net.collect_params().reset_ctx(context)
    logger.info(net)

    if opt.resume_params is not '':
        net.load_parameters(opt.resume_params, ctx=context)

    if opt.kvstore is not None:
        train_data, val_data, batch_fn = get_data_loader(
            opt, batch_size, num_workers, logger, kv)
    else:
        train_data, val_data, batch_fn = get_data_loader(
            opt, batch_size, num_workers, logger)

    num_batches = len(train_data)
    lr_scheduler = LRSequential([
        LRScheduler('linear',
                    base_lr=0,
                    target_lr=opt.lr,
                    nepochs=opt.warmup_epochs,
                    iters_per_epoch=num_batches),
        LRScheduler(opt.lr_mode,
                    base_lr=opt.lr,
                    target_lr=0,
                    nepochs=opt.num_epochs - opt.warmup_epochs,
                    iters_per_epoch=num_batches,
                    step_epoch=lr_decay_epoch,
                    step_factor=lr_decay,
                    power=2)
    ])
    optimizer_params['lr_scheduler'] = lr_scheduler

    train_metric = mx.metric.Accuracy()
    acc_top1 = mx.metric.Accuracy()
    acc_top5 = mx.metric.TopKAccuracy(5)

    def test(ctx, val_data, kvstore=None):
        acc_top1.reset()
        acc_top5.reset()
        L = gluon.loss.SoftmaxCrossEntropyLoss()
        num_test_iter = len(val_data)
        val_loss_epoch = 0
        for i, batch in enumerate(val_data):
            data, label = batch_fn(batch, ctx)
            outputs = []
            for _, X in enumerate(data):
                X = X.reshape((-1, ) + X.shape[2:])
                pred = net(X.astype(opt.dtype, copy=False))
                outputs.append(pred)

            loss = [
                L(yhat, y.astype(opt.dtype, copy=False))
                for yhat, y in zip(outputs, label)
            ]

            acc_top1.update(label, outputs)
            acc_top5.update(label, outputs)

            val_loss_epoch += sum([l.mean().asscalar()
                                   for l in loss]) / len(loss)

            if opt.log_interval and not (i + 1) % opt.log_interval:
                logger.info('Batch [%04d]/[%04d]: evaluated' %
                            (i, num_test_iter))

        _, top1 = acc_top1.get()
        _, top5 = acc_top5.get()
        val_loss = val_loss_epoch / num_test_iter

        if kvstore is not None:
            top1_nd = nd.zeros(1)
            top5_nd = nd.zeros(1)
            val_loss_nd = nd.zeros(1)
            kvstore.push(111111, nd.array(np.array([top1])))
            kvstore.pull(111111, out=top1_nd)
            kvstore.push(555555, nd.array(np.array([top5])))
            kvstore.pull(555555, out=top5_nd)
            kvstore.push(999999, nd.array(np.array([val_loss])))
            kvstore.pull(999999, out=val_loss_nd)
            top1 = top1_nd.asnumpy() / kvstore.num_workers
            top5 = top5_nd.asnumpy() / kvstore.num_workers
            val_loss = val_loss_nd.asnumpy() / kvstore.num_workers

        return (top1, top5, val_loss)

    def train(ctx):
        if isinstance(ctx, mx.Context):
            ctx = [ctx]

        if opt.no_wd:
            for k, v in net.collect_params('.*beta|.*gamma|.*bias').items():
                v.wd_mult = 0.0

        if opt.partial_bn:
            train_patterns = None
            if 'inceptionv3' in opt.model:
                train_patterns = '.*weight|.*bias|inception30_batchnorm0_gamma|inception30_batchnorm0_beta|inception30_batchnorm0_running_mean|inception30_batchnorm0_running_var'
            else:
                logger.info(
                    'Current model does not support partial batch normalization.'
                )

            if opt.kvstore is not None:
                trainer = gluon.Trainer(net.collect_params(train_patterns),
                                        optimizer,
                                        optimizer_params,
                                        kvstore=kv,
                                        update_on_kvstore=False)
            else:
                trainer = gluon.Trainer(net.collect_params(train_patterns),
                                        optimizer,
                                        optimizer_params,
                                        update_on_kvstore=False)
        else:
            if opt.kvstore is not None:
                trainer = gluon.Trainer(net.collect_params(),
                                        optimizer,
                                        optimizer_params,
                                        kvstore=kv,
                                        update_on_kvstore=False)
            else:
                trainer = gluon.Trainer(net.collect_params(),
                                        optimizer,
                                        optimizer_params,
                                        update_on_kvstore=False)

        if opt.accumulate > 1:
            params = [
                p for p in net.collect_params().values()
                if p.grad_req != 'null'
            ]
            for p in params:
                p.grad_req = 'add'

        if opt.resume_states is not '':
            trainer.load_states(opt.resume_states)

        if opt.use_amp:
            amp.init_trainer(trainer)

        L = gluon.loss.SoftmaxCrossEntropyLoss()

        best_val_score = 0
        lr_decay_count = 0

        for epoch in range(opt.resume_epoch, opt.num_epochs):
            tic = time.time()
            train_metric.reset()
            btic = time.time()
            num_train_iter = len(train_data)
            train_loss_epoch = 0
            train_loss_iter = 0

            for i, batch in enumerate(train_data):
                data, label = batch_fn(batch, ctx)

                with ag.record():
                    outputs = []
                    for _, X in enumerate(data):
                        X = X.reshape((-1, ) + X.shape[2:])
                        pred = net(X.astype(opt.dtype, copy=False))
                        outputs.append(pred)
                    loss = [
                        L(yhat, y.astype(opt.dtype, copy=False))
                        for yhat, y in zip(outputs, label)
                    ]

                    if opt.use_amp:
                        with amp.scale_loss(loss, trainer) as scaled_loss:
                            ag.backward(scaled_loss)
                    else:
                        ag.backward(loss)

                if opt.accumulate > 1 and (i + 1) % opt.accumulate == 0:
                    if opt.kvstore is not None:
                        trainer.step(batch_size * kv.num_workers *
                                     opt.accumulate)
                    else:
                        trainer.step(batch_size * opt.accumulate)
                        net.collect_params().zero_grad()
                else:
                    if opt.kvstore is not None:
                        trainer.step(batch_size * kv.num_workers)
                    else:
                        trainer.step(batch_size)

                train_metric.update(label, outputs)
                train_loss_iter = sum([l.mean().asscalar()
                                       for l in loss]) / len(loss)
                train_loss_epoch += train_loss_iter

                train_metric_name, train_metric_score = train_metric.get()
                sw.add_scalar(tag='train_acc_top1_iter',
                              value=train_metric_score * 100,
                              global_step=epoch * num_train_iter + i)
                sw.add_scalar(tag='train_loss_iter',
                              value=train_loss_iter,
                              global_step=epoch * num_train_iter + i)
                sw.add_scalar(tag='learning_rate_iter',
                              value=trainer.learning_rate,
                              global_step=epoch * num_train_iter + i)

                if opt.log_interval and not (i + 1) % opt.log_interval:
                    logger.info(
                        'Epoch[%03d] Batch [%04d]/[%04d]\tSpeed: %f samples/sec\t %s=%f\t loss=%f\t lr=%f'
                        % (epoch, i, num_train_iter,
                           batch_size * opt.log_interval /
                           (time.time() - btic), train_metric_name,
                           train_metric_score * 100, train_loss_epoch /
                           (i + 1), trainer.learning_rate))
                    btic = time.time()

            train_metric_name, train_metric_score = train_metric.get()
            throughput = int(batch_size * i / (time.time() - tic))
            mx.ndarray.waitall()

            if opt.kvstore is not None and epoch == opt.resume_epoch:
                kv.init(111111, nd.zeros(1))
                kv.init(555555, nd.zeros(1))
                kv.init(999999, nd.zeros(1))

            if opt.kvstore is not None:
                acc_top1_val, acc_top5_val, loss_val = test(ctx, val_data, kv)
            else:
                acc_top1_val, acc_top5_val, loss_val = test(ctx, val_data)

            logger.info('[Epoch %03d] training: %s=%f\t loss=%f' %
                        (epoch, train_metric_name, train_metric_score * 100,
                         train_loss_epoch / num_train_iter))
            logger.info('[Epoch %03d] speed: %d samples/sec\ttime cost: %f' %
                        (epoch, throughput, time.time() - tic))
            logger.info(
                '[Epoch %03d] validation: acc-top1=%f acc-top5=%f loss=%f' %
                (epoch, acc_top1_val * 100, acc_top5_val * 100, loss_val))

            sw.add_scalar(tag='train_loss_epoch',
                          value=train_loss_epoch / num_train_iter,
                          global_step=epoch)
            sw.add_scalar(tag='val_loss_epoch',
                          value=loss_val,
                          global_step=epoch)
            sw.add_scalar(tag='val_acc_top1_epoch',
                          value=acc_top1_val * 100,
                          global_step=epoch)

            if acc_top1_val > best_val_score:
                best_val_score = acc_top1_val
                net.save_parameters('%s/%.4f-%s-%s-%03d-best.params' %
                                    (opt.save_dir, best_val_score, opt.dataset,
                                     model_name, epoch))
                trainer.save_states('%s/%.4f-%s-%s-%03d-best.states' %
                                    (opt.save_dir, best_val_score, opt.dataset,
                                     model_name, epoch))
            else:
                if opt.save_frequency and opt.save_dir and (
                        epoch + 1) % opt.save_frequency == 0:
                    net.save_parameters(
                        '%s/%s-%s-%03d.params' %
                        (opt.save_dir, opt.dataset, model_name, epoch))
                    trainer.save_states(
                        '%s/%s-%s-%03d.states' %
                        (opt.save_dir, opt.dataset, model_name, epoch))

        # save the last model
        net.save_parameters(
            '%s/%s-%s-%03d.params' %
            (opt.save_dir, opt.dataset, model_name, opt.num_epochs - 1))
        trainer.save_states(
            '%s/%s-%s-%03d.states' %
            (opt.save_dir, opt.dataset, model_name, opt.num_epochs - 1))

    if opt.mode == 'hybrid':
        net.hybridize(static_alloc=True, static_shape=True)

    train(context)
    sw.close()
Esempio n. 28
0
def train(net, train_data, train_dataset, val_data, eval_metric, ctx, save_prefix, start_epoch, num_samples):
    """Training pipeline"""
    net.collect_params().reset_ctx(ctx)
    if FLAGS.no_wd:
        for k, v in net.collect_params('.*beta|.*gamma|.*bias').items():
            v.wd_mult = 0.0

    if FLAGS.label_smooth:
        net._target_generator._label_smooth = True

    if FLAGS.lr_decay_period > 0:
        lr_decay_epoch = list(range(FLAGS.lr_decay_period, FLAGS.epochs, FLAGS.lr_decay_period))
    else:
        lr_decay_epoch = FLAGS.lr_decay_epoch

    # for handling reloading from past epoch
    lr_decay_epoch_tmp = list()
    for e in lr_decay_epoch:
        if int(e) <= start_epoch:
            FLAGS.lr = FLAGS.lr * FLAGS.lr_decay
        else:
            lr_decay_epoch_tmp.append(int(e) - start_epoch - FLAGS.warmup_epochs)
    lr_decay_epoch = lr_decay_epoch_tmp

    num_batches = num_samples // FLAGS.batch_size
    lr_scheduler = LRSequential([
        LRScheduler('linear', base_lr=0, target_lr=FLAGS.lr,
                    nepochs=FLAGS.warmup_epochs, iters_per_epoch=num_batches),
        LRScheduler(FLAGS.lr_mode, base_lr=FLAGS.lr,
                    nepochs=FLAGS.epochs - FLAGS.warmup_epochs - start_epoch,
                    iters_per_epoch=num_batches,
                    step_epoch=lr_decay_epoch,
                    step_factor=FLAGS.lr_decay, power=2),
    ])

    trainer = gluon.Trainer(
        net.collect_params(), 'sgd',
        {'wd': FLAGS.wd, 'momentum': FLAGS.momentum, 'lr_scheduler': lr_scheduler},
        kvstore='local')

    # targets
    sigmoid_ce = gluon.loss.SigmoidBinaryCrossEntropyLoss(from_sigmoid=False)
    l1_loss = gluon.loss.L1Loss()

    # metrics
    obj_metrics = mx.metric.Loss('ObjLoss')
    center_metrics = mx.metric.Loss('BoxCenterLoss')
    scale_metrics = mx.metric.Loss('BoxScaleLoss')
    cls_metrics = mx.metric.Loss('ClassLoss')

    # set up logger
    logging.basicConfig()
    logger = logging.getLogger()
    logger.setLevel(logging.INFO)
    log_file_path = save_prefix + '_train.log'
    log_dir = os.path.dirname(log_file_path)
    if log_dir and not os.path.exists(log_dir):
        os.makedirs(log_dir)
    fh = logging.FileHandler(log_file_path)
    logger.addHandler(fh)
    # logger.info(FLAGS)

    # set up tensorboard summary writer
    tb_sw = SummaryWriter(log_dir=os.path.join(log_dir, 'tb'), comment=FLAGS.save_prefix)

    # Check if wanting to resume
    logger.info('Start training from [Epoch {}]'.format(start_epoch))
    if FLAGS.resume.strip() and os.path.exists(save_prefix+'_best_map.log'):
        with open(save_prefix+'_best_map.log', 'r') as f:
            lines = [line.split()[1] for line in f.readlines()]
            best_map = [float(lines[-1])]
    else:
        best_map = [0]

    # Training loop
    num_batches = int(len(train_dataset)/FLAGS.batch_size)
    for epoch in range(start_epoch, FLAGS.epochs+1):

        st = time.time()
        if FLAGS.mixup:
            # TODO(zhreshold): more elegant way to control mixup during runtime
            try:
                train_data._dataset.set_mixup(np.random.beta, 1.5, 1.5)
            except AttributeError:
                train_data._dataset._data.set_mixup(np.random.beta, 1.5, 1.5)
            if epoch >= FLAGS.epochs - FLAGS.no_mixup_epochs:
                try:
                    train_data._dataset.set_mixup(None)
                except AttributeError:
                    train_data._dataset._data.set_mixup(None)

        tic = time.time()
        btic = time.time()
        if not FLAGS.nd_only:
            net.hybridize()
        for i, batch in enumerate(train_data):
            batch_size = batch[0].shape[0]

            if FLAGS.max_epoch_time > 0 and (time.time()-st)/60 > FLAGS.max_epoch_time:
                logger.info('Max epoch time of %d minutes reached after completing %d%% of epoch. '
                            'Moving on to next epoch' % (FLAGS.max_epoch_time, int(100*(i/num_batches))))
                break

            if FLAGS.features_dir is not None:
                f1 = gluon.utils.split_and_load(batch[0], ctx_list=ctx, batch_axis=0)
                f2 = gluon.utils.split_and_load(batch[1], ctx_list=ctx, batch_axis=0)
                f3 = gluon.utils.split_and_load(batch[2], ctx_list=ctx, batch_axis=0)
                # objectness, center_targets, scale_targets, weights, class_targets
                fixed_targets = [gluon.utils.split_and_load(batch[it], ctx_list=ctx, batch_axis=0) for it in range(3, 8)]
                gt_boxes = gluon.utils.split_and_load(batch[8], ctx_list=ctx, batch_axis=0)
            else:
                data = gluon.utils.split_and_load(batch[0], ctx_list=ctx, batch_axis=0)
                # objectness, center_targets, scale_targets, weights, class_targets
                fixed_targets = [gluon.utils.split_and_load(batch[it], ctx_list=ctx, batch_axis=0) for it in range(1, 6)]
                gt_boxes = gluon.utils.split_and_load(batch[6], ctx_list=ctx, batch_axis=0)
            sum_losses = []
            obj_losses = []
            center_losses = []
            scale_losses = []
            cls_losses = []
            if FLAGS.features_dir is not None:
                with autograd.record():
                    for ix, (x1, x2, x3) in enumerate(zip(f1, f2, f3)):
                        obj_loss, center_loss, scale_loss, cls_loss = net(x1, x2, x3, gt_boxes[ix], *[ft[ix] for ft in fixed_targets])
                        sum_losses.append(obj_loss + center_loss + scale_loss + cls_loss)
                        obj_losses.append(obj_loss)
                        center_losses.append(center_loss)
                        scale_losses.append(scale_loss)
                        cls_losses.append(cls_loss)
                    autograd.backward(sum_losses)
            else:
                with autograd.record():
                    for ix, x in enumerate(data):
                        obj_loss, center_loss, scale_loss, cls_loss = net(x, gt_boxes[ix], *[ft[ix] for ft in fixed_targets])
                        sum_losses.append(obj_loss + center_loss + scale_loss + cls_loss)
                        obj_losses.append(obj_loss)
                        center_losses.append(center_loss)
                        scale_losses.append(scale_loss)
                        cls_losses.append(cls_loss)
                    autograd.backward(sum_losses)

            if FLAGS.motion_stream is None:
                trainer.step(batch_size)
            else:
                trainer.step(batch_size, ignore_stale_grad=True)  # we don't use all layers of each stream
            obj_metrics.update(0, obj_losses)
            center_metrics.update(0, center_losses)
            scale_metrics.update(0, scale_losses)
            cls_metrics.update(0, cls_losses)

            if FLAGS.log_interval and not (i + 1) % FLAGS.log_interval:
                name1, loss1 = obj_metrics.get()
                name2, loss2 = center_metrics.get()
                name3, loss3 = scale_metrics.get()
                name4, loss4 = cls_metrics.get()
                logger.info('[Epoch {}][Batch {}/{}], LR: {:.2E}, Speed: {:.3f} samples/sec, {}={:.3f}, {}={:.3f}, '
                            '{}={:.3f}, {}={:.3f}'.format(epoch, i, num_batches, trainer.learning_rate,
                                                          batch_size/(time.time()-btic),
                                                          name1, loss1, name2, loss2, name3, loss3, name4, loss4))
                tb_sw.add_scalar(tag='Training_' + name1, scalar_value=loss1, global_step=(epoch * len(train_data) + i))
                tb_sw.add_scalar(tag='Training_' + name2, scalar_value=loss2, global_step=(epoch * len(train_data) + i))
                tb_sw.add_scalar(tag='Training_' + name3, scalar_value=loss3, global_step=(epoch * len(train_data) + i))
                tb_sw.add_scalar(tag='Training_' + name4, scalar_value=loss4, global_step=(epoch * len(train_data) + i))
            btic = time.time()

        name1, loss1 = obj_metrics.get()
        name2, loss2 = center_metrics.get()
        name3, loss3 = scale_metrics.get()
        name4, loss4 = cls_metrics.get()
        logger.info('[Epoch {}] Training cost: {:.3f}, {}={:.3f}, {}={:.3f}, {}={:.3f}, {}={:.3f}'.format(
            epoch, (time.time()-tic), name1, loss1, name2, loss2, name3, loss3, name4, loss4))
        if not (epoch + 1) % FLAGS.val_interval:
            # consider reduce the frequency of validation to save time

            logger.info('End Epoch {}: # samples: {}, seconds: {}, samples/sec: {:.2f}'.format(
                epoch, len(train_data)*batch_size, time.time() - st, (len(train_data)*batch_size)/(time.time() - st)))
            st = time.time()
            map_name, mean_ap = validate(net, val_data, ctx, eval_metric)
            logger.info('End Val: # samples: {}, seconds: {}, samples/sec: {:.2f}'.format(
                len(val_data)*batch_size, time.time() - st, (len(val_data) * batch_size)/(time.time() - st)))

            val_msg = '\n'.join(['{}={}'.format(k, v) for k, v in zip(map_name, mean_ap)])
            tb_sw.add_scalar(tag='Validation_mAP', scalar_value=float(mean_ap[-1]),
                             global_step=(epoch * len(train_data) + i))
            logger.info('[Epoch {}] Validation: \n{}'.format(epoch, val_msg))
            current_map = float(mean_ap[-1])
        else:
            current_map = 0.
        save_params(net, best_map, current_map, epoch, FLAGS.save_interval, save_prefix)
Esempio n. 29
0
def main():
    opt = parse_args()

    bps.init()

    gpu_name = subprocess.check_output(
        ['nvidia-smi', '--query-gpu=gpu_name', '--format=csv'])
    gpu_name = gpu_name.decode('utf8').split('\n')[-2]
    gpu_name = '-'.join(gpu_name.split())
    filename = "cifar100-%d-%s-%s.log" % (bps.size(), gpu_name,
                                          opt.logging_file)
    filehandler = logging.FileHandler(filename)
    streamhandler = logging.StreamHandler()

    logger = logging.getLogger('')
    logger.setLevel(logging.INFO)
    logger.addHandler(filehandler)
    logger.addHandler(streamhandler)

    logger.info(opt)

    batch_size = opt.batch_size
    classes = 100

    num_gpus = opt.num_gpus
    # batch_size *= max(1, num_gpus)
    context = mx.gpu(bps.local_rank()) if num_gpus > 0 else mx.cpu(
        bps.local_rank())
    num_workers = opt.num_workers
    nworker = bps.size()
    rank = bps.rank()

    lr_decay = opt.lr_decay
    lr_decay_epoch = [int(i) for i in opt.lr_decay_epoch.split(',')]

    num_batches = 50000 // (opt.batch_size * nworker)
    lr_scheduler = LRSequential([
        LRScheduler('linear',
                    base_lr=opt.warmup_lr,
                    target_lr=opt.lr * nworker / bps.local_size(),
                    nepochs=opt.warmup_epochs,
                    iters_per_epoch=num_batches),
        LRScheduler('step',
                    base_lr=opt.lr * nworker / bps.local_size(),
                    target_lr=0,
                    nepochs=opt.num_epochs - opt.warmup_epochs,
                    iters_per_epoch=num_batches,
                    step_epoch=lr_decay_epoch,
                    step_factor=lr_decay,
                    power=2)
    ])

    num_batches = 50000 // (opt.batch_size * nworker)
    lr_scheduler = LRSequential([
        LRScheduler('linear',
                    base_lr=opt.warmup_lr,
                    target_lr=opt.lr * nworker / bps.local_size(),
                    nepochs=opt.warmup_epochs,
                    iters_per_epoch=num_batches),
        LRScheduler('step',
                    base_lr=opt.lr * nworker / bps.local_size(),
                    target_lr=0,
                    nepochs=opt.num_epochs - opt.warmup_epochs,
                    iters_per_epoch=num_batches,
                    step_epoch=lr_decay_epoch,
                    step_factor=lr_decay,
                    power=2)
    ])

    model_name = opt.model
    if model_name.startswith('cifar_wideresnet'):
        kwargs = {'classes': classes, 'drop_rate': opt.drop_rate}
    else:
        kwargs = {'classes': classes}
    net = get_model(model_name, **kwargs)
    if opt.resume_from:
        net.load_parameters(opt.resume_from, ctx=context)

    if opt.compressor:
        optimizer = 'sgd'
    else:
        optimizer = 'nag'

    save_period = opt.save_period
    if opt.save_dir and save_period:
        save_dir = opt.save_dir
        makedirs(save_dir)
    else:
        save_dir = ''
        save_period = 0

    # from https://github.com/weiaicunzai/pytorch-cifar/blob/master/conf/global_settings.py
    CIFAR100_TRAIN_MEAN = [
        0.5070751592371323, 0.48654887331495095, 0.4409178433670343
    ]
    CIFAR100_TRAIN_STD = [
        0.2673342858792401, 0.2564384629170883, 0.27615047132568404
    ]

    transform_train = transforms.Compose([
        gcv_transforms.RandomCrop(32, pad=4),
        transforms.RandomFlipLeftRight(),
        transforms.ToTensor(),
        transforms.Normalize(CIFAR100_TRAIN_MEAN, CIFAR100_TRAIN_STD)
    ])

    transform_test = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(CIFAR100_TRAIN_MEAN, CIFAR100_TRAIN_STD)
    ])

    def test(ctx, val_data):
        metric = mx.metric.Accuracy()
        for i, batch in enumerate(val_data):
            data = gluon.utils.split_and_load(batch[0],
                                              ctx_list=ctx,
                                              batch_axis=0)
            label = gluon.utils.split_and_load(batch[1],
                                               ctx_list=ctx,
                                               batch_axis=0)
            outputs = [net(X) for X in data]
            metric.update(label, outputs)
        return metric.get()

    def train(epochs, ctx):
        if isinstance(ctx, mx.Context):
            ctx = [ctx]
        net.initialize(mx.init.Xavier(), ctx=ctx)

        train_data = gluon.data.DataLoader(gluon.data.vision.CIFAR100(
            train=True).shard(nworker, rank).transform_first(transform_train),
                                           batch_size=batch_size,
                                           shuffle=True,
                                           last_batch='discard',
                                           num_workers=num_workers)

        val_data = gluon.data.DataLoader(gluon.data.vision.CIFAR100(
            train=False).shard(nworker, rank).transform_first(transform_test),
                                         batch_size=batch_size,
                                         shuffle=False,
                                         num_workers=num_workers)

        params = net.collect_params()

        compression_params = {
            "compressor": opt.compressor,
            "ef": opt.ef,
            "momentum": opt.compress_momentum,
            "scaling": opt.onebit_scaling,
            "k": opt.k,
            "fp16": opt.fp16_pushpull
        }

        optimizer_params = {
            'lr_scheduler': lr_scheduler,
            'wd': opt.wd,
            'momentum': opt.momentum
        }

        trainer = bps.DistributedTrainer(params,
                                         optimizer,
                                         optimizer_params,
                                         compression_params=compression_params)
        metric = mx.metric.Accuracy()
        train_metric = mx.metric.Accuracy()
        loss_fn = gluon.loss.SoftmaxCrossEntropyLoss()

        iteration = 0
        best_val_score = 0
        bps.byteps_declare_tensor("acc")
        for epoch in range(epochs):
            tic = time.time()
            train_metric.reset()
            metric.reset()
            train_loss = 0
            num_batch = len(train_data)

            for i, batch in enumerate(train_data):
                data = gluon.utils.split_and_load(batch[0],
                                                  ctx_list=ctx,
                                                  batch_axis=0)
                label = gluon.utils.split_and_load(batch[1],
                                                   ctx_list=ctx,
                                                   batch_axis=0)

                with ag.record():
                    output = [net(X) for X in data]
                    loss = [loss_fn(yhat, y) for yhat, y in zip(output, label)]
                for l in loss:
                    l.backward()
                trainer.step(batch_size)
                train_loss += sum([l.sum().asscalar() for l in loss])

                train_metric.update(label, output)
                name, train_acc = train_metric.get()
                iteration += 1

            train_loss /= batch_size * num_batch
            name, train_acc = train_metric.get()
            throughput = int(batch_size * nworker * i / (time.time() - tic))

            logger.info(
                '[Epoch %d] speed: %d samples/sec\ttime cost: %f lr=%f' %
                (epoch, throughput, time.time() - tic, trainer.learning_rate))

            name, val_acc = test(ctx, val_data)
            acc = mx.nd.array([train_acc, val_acc], ctx=ctx[0])
            bps.byteps_push_pull(acc, name="acc", is_average=False)
            acc /= bps.size()
            train_acc, val_acc = acc[0].asscalar(), acc[1].asscalar()
            if bps.rank() == 0:
                logger.info('[Epoch %d] training: %s=%f' %
                            (epoch, name, train_acc))
                logger.info('[Epoch %d] validation: %s=%f' %
                            (epoch, name, val_acc))

            if val_acc > best_val_score:
                best_val_score = val_acc
                net.save_parameters(
                    '%s/%.4f-cifar-%s-%d-best.params' %
                    (save_dir, best_val_score, model_name, epoch))

            if save_period and save_dir and (epoch + 1) % save_period == 0:
                net.save_parameters('%s/cifar100-%s-%d.params' %
                                    (save_dir, model_name, epoch))

        if save_period and save_dir:
            net.save_parameters('%s/cifar100-%s-%d.params' %
                                (save_dir, model_name, epochs - 1))

    if opt.mode == 'hybrid':
        net.hybridize()
    train(opt.num_epochs, context)
Esempio n. 30
0
def main():
    opt = parse_args()

    logger = Logger(join(opt.save_dir, "log.txt"))

    logger.info(opt)

    batch_size = opt.batch_size
    classes = 1000
    num_training_samples = 1281167

    num_gpus = opt.num_gpus
    batch_size *= max(1, num_gpus)
    context = [mx.gpu(i)
               for i in range(num_gpus)] if num_gpus > 0 else [mx.cpu()]
    num_workers = opt.num_workers

    lr_decay = opt.lr_decay
    lr_decay_period = opt.lr_decay_period
    if opt.lr_decay_period > 0:
        lr_decay_epoch = list(
            range(lr_decay_period, opt.num_epochs, lr_decay_period))
    else:
        lr_decay_epoch = [int(i) for i in opt.lr_decay_epoch.split(',')]
    lr_decay_epoch = [e - opt.warmup_epochs for e in lr_decay_epoch]
    num_batches = num_training_samples // batch_size

    lr_scheduler = LRSequential([
        LRScheduler('linear',
                    base_lr=0,
                    target_lr=opt.lr,
                    nepochs=opt.warmup_epochs,
                    iters_per_epoch=num_batches),
        LRScheduler(opt.lr_mode,
                    base_lr=opt.lr,
                    target_lr=0,
                    nepochs=opt.num_epochs - opt.warmup_epochs,
                    iters_per_epoch=num_batches,
                    step_epoch=lr_decay_epoch,
                    step_factor=lr_decay,
                    power=2)
    ])

    model_name = opt.model

    kwargs = {
        'ctx': context,
        'pretrained': opt.use_pretrained,
        'classes': classes
    }
    if opt.use_gn:
        from gluoncv.nn import GroupNorm
        kwargs['norm_layer'] = GroupNorm
    if model_name.startswith('vgg'):
        kwargs['batch_norm'] = opt.batch_norm
    elif model_name.startswith('resnext'):
        kwargs['use_se'] = opt.use_se

    if opt.last_gamma:
        kwargs['last_gamma'] = True

    optimizer = 'nag'
    optimizer_params = {
        'wd': opt.wd,
        'momentum': opt.momentum,
        'lr_scheduler': lr_scheduler
    }
    if opt.dtype != 'float32':
        optimizer_params['multi_precision'] = True

    net = get_model(model_name, **kwargs)
    net.cast(opt.dtype)
    if opt.resume_params is not '':
        net.load_parameters(opt.resume_params, ctx=context)

    # Two functions for reading data from record file or raw images
    def get_data_rec(rec_train, rec_train_idx, rec_val, rec_val_idx,
                     batch_size, num_workers):
        rec_train = os.path.expanduser(rec_train)
        rec_train_idx = os.path.expanduser(rec_train_idx)
        rec_val = os.path.expanduser(rec_val)
        rec_val_idx = os.path.expanduser(rec_val_idx)
        jitter_param = 0.4
        lighting_param = 0.1
        input_size = opt.input_size
        crop_ratio = opt.crop_ratio if opt.crop_ratio > 0 else 0.875
        resize = int(math.ceil(input_size / crop_ratio))
        mean_rgb = [123.68, 116.779, 103.939]
        std_rgb = [58.393, 57.12, 57.375]

        def batch_fn(batch, ctx):
            data = gluon.utils.split_and_load(batch.data[0],
                                              ctx_list=ctx,
                                              batch_axis=0)
            label = gluon.utils.split_and_load(batch.label[0],
                                               ctx_list=ctx,
                                               batch_axis=0)
            return data, label

        train_data = mx.io.ImageRecordIter(
            path_imgrec=rec_train,
            path_imgidx=rec_train_idx,
            preprocess_threads=num_workers,
            shuffle=True,
            batch_size=batch_size,
            data_shape=(3, input_size, input_size),
            mean_r=mean_rgb[0],
            mean_g=mean_rgb[1],
            mean_b=mean_rgb[2],
            std_r=std_rgb[0],
            std_g=std_rgb[1],
            std_b=std_rgb[2],
            rand_mirror=True,
            random_resized_crop=True,
            max_aspect_ratio=4. / 3.,
            min_aspect_ratio=3. / 4.,
            max_random_area=1,
            min_random_area=0.08,
            brightness=jitter_param,
            saturation=jitter_param,
            contrast=jitter_param,
            pca_noise=lighting_param,
        )
        val_data = mx.io.ImageRecordIter(
            path_imgrec=rec_val,
            path_imgidx=rec_val_idx,
            preprocess_threads=num_workers,
            shuffle=False,
            batch_size=batch_size,
            resize=resize,
            data_shape=(3, input_size, input_size),
            mean_r=mean_rgb[0],
            mean_g=mean_rgb[1],
            mean_b=mean_rgb[2],
            std_r=std_rgb[0],
            std_g=std_rgb[1],
            std_b=std_rgb[2],
        )
        return train_data, val_data, batch_fn

    train_data, val_data, batch_fn = get_data_rec(
        join(opt.rec_dir, "train.rec"), join(opt.rec_dir, "train.idx"),
        join(opt.rec_dir, "val.rec"), join(opt.rec_dir, "val.idx"), batch_size,
        num_workers)

    train_metric = mx.metric.Accuracy()
    acc_top1 = mx.metric.Accuracy()
    acc_top5 = mx.metric.TopKAccuracy(5)

    save_frequency = opt.save_frequency
    if opt.save_dir and save_frequency:
        save_dir = opt.save_dir
        makedirs(save_dir)
    else:
        save_dir = ''
        save_frequency = 0

    def test(ctx, val_data):
        val_data.reset()
        acc_top1.reset()
        acc_top5.reset()
        for i, batch in enumerate(val_data):
            data, label = batch_fn(batch, ctx)
            outputs = [net(X.astype(opt.dtype, copy=False)) for X in data]
            acc_top1.update(label, outputs)
            acc_top5.update(label, outputs)

        _, top1 = acc_top1.get()
        _, top5 = acc_top5.get()
        return (1 - top1, 1 - top5)

    def train(ctx):
        if isinstance(ctx, mx.Context):
            ctx = [ctx]
        if opt.resume_params is '':
            net.initialize(mx.init.MSRAPrelu(), ctx=ctx)

        if opt.no_wd:
            for k, v in net.collect_params('.*beta|.*gamma|.*bias').items():
                v.wd_mult = 0.0

        trainer = gluon.Trainer(net.collect_params(), optimizer,
                                optimizer_params)
        if opt.resume_states is not '':
            trainer.load_states(opt.resume_states)

        L = gluon.loss.SoftmaxCrossEntropyLoss(sparse_label=True)

        best_val_score = 1

        err_top1_val, err_top5_val = test(ctx, val_data)
        logger.info("Initial testing. Top1 %f, top5 %f" %
                    (1 - err_top1_val, 1 - err_top5_val))

        for epoch in range(opt.resume_epoch, opt.num_epochs):
            tic = time.time()
            train_data.reset()
            train_metric.reset()
            btic = time.time()

            for i, batch in enumerate(train_data):
                data, label = batch_fn(batch, ctx)

                with ag.record():
                    outputs = [
                        net(X.astype(opt.dtype, copy=False)) for X in data
                    ]
                    loss = [
                        L(yhat, y.astype(opt.dtype, copy=False))
                        for yhat, y in zip(outputs, label)
                    ]

                for l in loss:
                    l.backward()

                trainer.step(batch_size)
                train_metric.update(label, outputs)

                if opt.log_interval and not (i + 1) % opt.log_interval:
                    train_metric_name, train_metric_score = train_metric.get()
                    logger.info(
                        'Epoch[%d] Batch [%d]\tSpeed: %f samples/sec\t%s=%f\tlr=%f'
                        % (epoch, i, batch_size * opt.log_interval /
                           (time.time() - btic), train_metric_name,
                           train_metric_score, trainer.learning_rate))
                    btic = time.time()

            train_metric_name, train_metric_score = train_metric.get()
            throughput = int(batch_size * i / (time.time() - tic))

            err_top1_val, err_top5_val = test(ctx, val_data)

            logger.info('[Epoch %d] training: %s=%f' %
                        (epoch, train_metric_name, train_metric_score))
            logger.info('[Epoch %d] speed: %d samples/sec\ttime cost: %f' %
                        (epoch, throughput, time.time() - tic))
            logger.info('[Epoch %d] validation: err-top1=%f err-top5=%f' %
                        (epoch, err_top1_val, err_top5_val))

            if err_top1_val < best_val_score:
                best_val_score = err_top1_val
                net.save_parameters(
                    '%s/%.4f-imagenet-%s-%d-best.params' %
                    (save_dir, best_val_score, model_name, epoch))
                trainer.save_states(
                    '%s/%.4f-imagenet-%s-%d-best.states' %
                    (save_dir, best_val_score, model_name, epoch))

            if save_frequency and save_dir and (epoch +
                                                1) % save_frequency == 0:
                net.save_parameters('%s/imagenet-%s-%d.params' %
                                    (save_dir, model_name, epoch))
                trainer.save_states('%s/imagenet-%s-%d.states' %
                                    (save_dir, model_name, epoch))

        if save_frequency and save_dir:
            net.save_parameters('%s/imagenet-%s-%d.params' %
                                (save_dir, model_name, opt.num_epochs - 1))
            trainer.save_states('%s/imagenet-%s-%d.states' %
                                (save_dir, model_name, opt.num_epochs - 1))

    if opt.mode == 'hybrid':
        net.hybridize(static_alloc=True, static_shape=True)

    train(context)