예제 #1
0
def convert():
    architecture = [0, 0, 0, 1, 0, 0, 1, 0, 3, 2, 0, 1, 2, 2, 1, 2, 0, 0, 2, 0]
    scale_ids = [8, 7, 6, 8, 5, 7, 3, 4, 2, 4, 2, 3, 4, 5, 6, 6, 3, 3, 4, 6]
    net = get_shufflenas_oneshot(architecture=architecture, scale_ids=scale_ids, use_se=True,
                                 last_conv_after_pooling=True,  shuffle_by_conv=True)

    # load params
    net._initialize(force_reinit=True, dtype='float32')
    net.cast('float16')
    net.load_parameters('../models/oneshot-s+model-0000.params', allow_missing=True)
    net.cast('float32')

    # save both gluon model and symbols
    test_data = nd.ones([5, 3, 224, 224], dtype='float32')
    _ = net(test_data)
    net.summary(test_data)
    net.hybridize()

    if not os.path.exists('./symbols'):
        os.makedirs('./symbols')
    if not os.path.exists('./params'):
        os.makedirs('./params')
    net.cast('float16')
    net.load_parameters('../models/oneshot-s+model-0000.params', allow_missing=True)
    net.cast('float32')
    net.hybridize()
    net(test_data)
    net.save_parameters('./params/ShuffleNas_fixArch_shuffleByConv-0000.params')
    net.export("./symbols/ShuffleNas_fixArch_shuffleByConv", epoch=0)
    flops, model_size = get_flops(symbol_path='./symbols/ShuffleNas_fixArch_shuffleByConv-symbol.json')
    print("Last conv after pooling: {}, use se: {}".format(True, True))
    print("FLOPS: {}M, # parameters: {}M".format(flops, model_size))
def convert_from_shufflenas(architecture,
                            scale_ids,
                            image_shape,
                            model_name='ShuffleNas_fixArch',
                            use_se=True,
                            last_conv_after_pooling=True,
                            logger=None,
                            pretrained=True):
    '''
    architecture = [0, 0, 0, 0, 0, 0, 1, 1, 2, 0, 1, 1, 0, 0, 1, 2, 2, 0, 2, 0]
    scale_ids = [8, 6, 5, 7, 6, 7, 3, 4, 2, 4, 2, 3, 4, 3, 6, 7, 5, 3, 4, 6]
    '''
    dir_path = os.path.dirname(os.path.realpath(__file__))
    if logger is not None:
        logger.info(
            'Converting model from Gluon ShuffleNas. with blocks {}, channels {}'
            .format(architecture, scale_ids))
    net = get_shufflenas_oneshot(
        architecture=architecture,
        scale_ids=scale_ids,
        use_se=use_se,
        last_conv_after_pooling=last_conv_after_pooling)
    if pretrained:
        net.cast('float16')
        net.load_parameters('../models/oneshot-s+model-0000.params')
        net.cast('float32')
    else:
        net.initialize(mx.init.MSRAPrelu())
        net(nd.ones((1, 3, 224, 224)))
    net.hybridize()
    x = mx.sym.var('data')
    y = net(x)
    y = mx.sym.SoftmaxOutput(data=y, name='softmax')
    symnet = mx.symbol.load_json(y.tojson())
    params = net.collect_params()
    args = {}
    auxs = {}
    for param in params.values():
        v = param._reduce()
        k = param.name
        if 'running' in k:
            auxs[k] = v
        else:
            args[k] = v
    mod = mx.mod.Module(symbol=symnet,
                        context=mx.cpu(),
                        label_names=['softmax_label'])
    mod.bind(for_training=False,
             data_shapes=[
                 ('data',
                  (1, ) + tuple([int(i) for i in image_shape.split(',')]))
             ])
    mod.set_params(arg_params=args, aux_params=auxs)
    dst_dir = os.path.join(dir_path, 'model')
    prefix = os.path.join(dir_path, 'model', model_name)
    if not os.path.isdir(dst_dir):
        os.mkdir(dst_dir)
    mod.save_checkpoint(prefix, 0)
    return prefix
예제 #3
0
def main(
        num_gpus=4,
        supernet_params='./params/ShuffleNasOneshot-imagenet-supernet.params'):
    context = [mx.gpu(i)
               for i in range(num_gpus)] if num_gpus > 0 else [mx.cpu()]
    net = get_shufflenas_oneshot()
    net.load_parameters(supernet_params, ctx=context)
    print(net)
    search_supernet(net, search_iters=10, bn_iters=1, num_gpus=num_gpus)
def clip_weights():
    architecture = parse_str_list(args.block_choices)
    scale_ids = parse_str_list(args.channel_choices)
    net = get_shufflenas_oneshot(
        architecture=architecture,
        n_class=1000,
        scale_ids=scale_ids,
        last_conv_after_pooling=args.last_conv_after_pooling,
        use_se=args.use_se,
        channels_layout=args.channels_layout)

    net.cast(args.dtype)
    net.load_parameters(args.param_file)
    param_dict = net.collect_params()
    modified_count = 0

    assert args.clip_to == -1 or args.clip_from == -1
    for param_name in param_dict:
        if 'running' in param_name:
            continue
        param = param_dict[param_name].data().asnumpy()
        param_bak = copy.deepcopy(param)
        if args.clip_to != -1:
            mask = np.abs(param) < args.clip_to
            neg_mask = (param < 0) * mask
            local_count = np.sum(mask)
            if local_count == 0:
                continue
            modified_count += local_count
            param[mask] = args.clip_to
            param[neg_mask] *= -1
            print(param_name)
            print("before clipping")
            print(param_bak[mask])
            print("after clipping")
            print(param[mask])
        if args.clip_from != -1:
            mask = np.abs(param) < args.clip_from
            local_count = np.sum(mask)
            if local_count == 0:
                continue
            modified_count += local_count
            param[mask] = 0
            print(param_name)
            print("before clipping")
            print(param_bak[mask])
            print("after clipping")
            print(param[mask])
        param_dict[param_name].set_data(param)

    print("Totally modified {} weights.".format(modified_count))
    orig_file = args.param_file
    save_file = orig_file[:-7] + '-clip-to-{}.'.format(args.clip_to) + 'params' if args.clip_to != -1 else \
        orig_file[:-7] + '-clip-from-{}.'.format(args.clip_from) + 'params'
    print(save_file)
    net.save_parameters(save_file)
def get_flop_param_score(block_choices,
                         channel_choices,
                         comparison_model='SinglePathOneShot',
                         use_se=False,
                         last_conv_after_pooling=False,
                         channels_layout='OneShot'):
    """ Return the flops and num of params """
    # build fix_arch network and calculate flop
    fixarch_net = get_shufflenas_oneshot(
        block_choices,
        channel_choices,
        use_se=use_se,
        last_conv_after_pooling=last_conv_after_pooling,
        channels_layout=channels_layout)
    fixarch_net._initialize()
    if not os.path.exists('./symbols'):
        os.makedirs('./symbols')
    fixarch_net.hybridize()

    # calculate flops and num of params
    dummy_data = nd.ones([1, 3, 224, 224])
    fixarch_net(dummy_data)
    fixarch_net.export("./symbols/ShuffleNas_fixArch", epoch=1)

    flops, model_size = get_flops(
        symbol_path="./symbols/ShuffleNas_fixArch-symbol.json"
    )  # both in Millions

    # proves ShuffleNet series calculate == google paper's
    if comparison_model == 'MobileNetV3_large':
        flops_constraint = 217
        parameter_number_constraint = 5.4

    # proves MicroNet challenge doubles what google paper claimed
    elif comparison_model == 'MobileNetV2_1.4':
        flops_constraint = 585
        parameter_number_constraint = 6.9

    elif comparison_model == 'SinglePathOneShot':
        flops_constraint = 328
        parameter_number_constraint = 3.4

    # proves mine calculation == ShuffleNet series' == google paper's
    elif comparison_model == 'ShuffleNetV2+_medium':
        flops_constraint = 222
        parameter_number_constraint = 5.6

    else:
        raise ValueError(
            "Unrecognized comparison model: {}".format(comparison_model))

    flop_score = flops / flops_constraint
    model_size_score = model_size / parameter_number_constraint

    return flops, model_size, flop_score, model_size_score
예제 #6
0
def main():
    args = parse_args()
    net = get_shufflenas_oneshot(use_se=args.use_se, last_conv_after_pooling=args.last_conv_after_pooling,
                                 channels_layout=args.channels_layout)

    print(args)
    # 180 ~= 1280000 / (7 x 2 x 8)
    m = Maintainer(net, sample_counts=180, flops_cuts=7)

    ''' Plot all '''
    # flop_list, model_size_list = m.get_single_flops_params(2, find_max_param=True, max_flop=210, upper_model_size=6.0)
    # flop_list, model_size_list = m.get_all_flops_params()
    # plot(flop_list, model_size_list)

    ''' Plot step by step '''
    # plot_each_strolling_step(m)

    ''' Plot random choice '''
    # flop_list = []
    # model_size_list = []
    # count = 0
    # while count < 1120:
    #     _, block_choices = net.random_block_choices(return_choice_list=True)
    #     _, channel_choices = net.random_channel_mask()
    #     flops, model_size = get_flop_params(block_choices, channel_choices, m.lookup_table)
    #     # count += 1
    #     if flops < 190 or flops > 330 \
    #             or model_size < 2.8 or model_size > 5.0:
    #         continue
    #     flop_list.append(flops)
    #     model_size_list.append(model_size)
    #     count += 1
    # plot(flop_list, model_size_list, titile='Flops Param Distribution Random Selection',
    #      save_file='supernet_flops_params_dist_full_random.png', show=True)  # , x_max=600, y_max=6.5)

    ''' Check step '''
    manager = multiprocessing.Manager()
    cand_pool = manager.list()
    p_lock = manager.Lock()
    finished = Value(c_bool, False)
    m = Maintainer(net, sample_counts=1, flops_cuts=7)
    pool_process = multiprocessing.Process(target=m.step, args=[cand_pool, p_lock, finished, 'float32'])
    pool_process.start()
    step = 0
    while step < 112:
        if len(cand_pool) > 1:
            cand_pool.pop()
            step += 1
        else:
            time.sleep(0.5)
    time.sleep(5)
    finished.value = True
    print('sequential finished')
    pool_process.join()
예제 #7
0
def search_supernet(net, search_iters=2000, bn_iters=50000, num_gpus=0):
    # TODO: use a heapq here to store top-5 models
    train_data, val_data, batch_fn = get_data(num_gpus=num_gpus)
    best_acc = 0
    best_block_choices = None
    best_channel_choices = None
    context = [mx.gpu(i)
               for i in range(num_gpus)] if num_gpus > 0 else [mx.cpu()]
    acc_top1 = mx.metric.Accuracy()
    acc_top5 = mx.metric.TopKAccuracy(5)
    for _ in range(search_iters):
        block_choices = net.random_block_choices(select_predefined_block=False,
                                                 dtype='float32')
        full_channel_mask, channel_choices = net.random_channel_mask(
            select_all_channels=False, dtype='float32')
        # Update BN
        # for _ in range(bn_iters):
        #     data, _ = generate_random_data_label()
        #     net(data, block_choices, full_channel_mask)
        # Get validation accuracy
        val_acc = get_accuracy(net,
                               val_data,
                               batch_fn,
                               block_choices,
                               full_channel_mask,
                               acc_top1=acc_top1,
                               acc_top5=acc_top5,
                               ctx=context)
        if val_acc > best_acc:
            best_acc = val_acc
            best_block_choices = copy.deepcopy(block_choices)
            best_channel_choices = copy.deepcopy(channel_choices)
        # build fix_arch network and calculate flop
        fixarch_net = get_shufflenas_oneshot(block_choices.asnumpy(),
                                             channel_choices)
        fixarch_net.initialize()
        if not os.path.exists('./symbols'):
            os.makedirs('./symbols')
        fixarch_net.hybridize()
        dummy_data = nd.ones([1, 3, 224, 224])
        fixarch_net(dummy_data)
        fixarch_net.export("./symbols/ShuffleNas_fixArch", epoch=1)
        flops, model_size = get_flops()
        print('-' * 40)
        print("Val accuracy: {}".format(val_acc))
        print('flops: ', str(flops), ' MFLOPS')
        print('model size: ', str(model_size), ' MB')

    print('-' * 40)
    print("Best val accuracy: {}".format(best_acc))
    print("Best block choices: {}".format(best_block_choices.asnumpy()))
    print("Best channel choices: {}".format(best_channel_choices))
def merge_bn(
    param_file='../params_shufflenas_oneshot+_genetic/0.2448-imagenet-ShuffleNas_fixArch-357-best.params'
):
    architecture = [0, 0, 0, 1, 0, 0, 1, 0, 3, 2, 0, 1, 2, 2, 1, 2, 0, 0, 2, 0]
    scale_ids = [8, 7, 6, 8, 5, 7, 3, 4, 2, 4, 2, 3, 4, 5, 6, 6, 3, 3, 4, 6]
    net = get_shufflenas_oneshot(architecture=architecture,
                                 scale_ids=scale_ids,
                                 use_se=True,
                                 last_conv_after_pooling=True)
    net.cast('float16')
    net.load_parameters(param_file)
    param_dict = net.collect_params()
    nobn_net = get_noBN_shufflenas_oneshot(architecture=architecture,
                                           scale_ids=scale_ids,
                                           use_se=True,
                                           last_conv_after_pooling=True)
    nobn_net.initialize()
    nobn_param_dict = nobn_net.collect_params()

    merge_list = []
    param_list = list(param_dict.items())
    nobn_param_list = list(nobn_param_dict.keys())
    for i, key in enumerate(param_dict):
        if 'gamma' in key:
            merge_list.append({
                'conv_name': param_list[i - 1][0],
                'bn_name': key[:key.rfind('_')],
                'gamma': param_list[i][1].data(),
                'beta': param_list[i + 1][1].data(),
                'running_mean': param_list[i + 2][1].data(),
                'running_var': param_list[i + 3][1].data(),
            })
        if 'batchnorm' not in key:
            nobn_param_dict[key.replace('fix0', 'fix1')].set_data(
                param_dict[key].data())
            nobn_param_list.remove(key.replace('fix0', 'fix1'))

    for info in merge_list:
        new_w, new_b = merge(param_dict[info['conv_name']].data(),
                             info['gamma'], info['beta'], info['running_mean'],
                             info['running_var'])
        nobn_param_dict[info['conv_name'].replace('fix0',
                                                  'fix1')].set_data(new_w)
        nobn_param_dict[info['conv_name'][:-6].replace('fix0', 'fix1') +
                        'bias'].set_data(new_b)

    nobn_net.save_parameters('./ShuffleNas-fixArch-noBN.params')
예제 #9
0
def get_block(block_mode='conv', act_mode='relu', use_se=False):
    if block_mode == 'just-conv':
        net = gluon.nn.HybridSequential()
        net.add(
            nn.Conv2D(16, kernel_size=3, strides=2,
                      padding=1, use_bias=False, prefix='1st_conv_'),
            nn.BatchNorm(momentum=0.1),
            Activation(act_mode)
        )
        if use_se:
            net.add(SE(16))
        net.add(
            nn.Conv2D(32, in_channels=16, kernel_size=3, strides=2,
                      padding=1, use_bias=False, prefix='2nd_conv_'),
            nn.BatchNorm(momentum=0.1),
            Activation(act_mode)
        )
    elif block_mode == 'SNB':
        net = gluon.nn.HybridSequential()
        net.add(
            nn.Conv2D(16, kernel_size=3, strides=2,
                      padding=1, use_bias=False, prefix='1st_conv_'),
            nn.BatchNorm(momentum=0.1),
            Activation(act_mode)
        )
        if use_se:
            net.add(SE(16))
        net.add(
            ShuffleNetBlock(16, 32, 16, bn=nn.BatchNorm,
                            block_mode='ShuffleNetV2', ksize=3, stride=1,
                            use_se=use_se, act_name=act_mode)
        )
    elif block_mode == 'SNB-x':
        net = gluon.nn.HybridSequential()
        net.add(
            nn.Conv2D(16, kernel_size=3, strides=2,
                      padding=1, use_bias=False, prefix='1st_conv_'),
            nn.BatchNorm(momentum=0.1),
            Activation(act_mode)
        )
        if use_se:
            net.add(SE(16))
        net.add(
            ShuffleNetBlock(16, 32, 16, bn=nn.BatchNorm,
                            block_mode='ShuffleXception', ksize=3, stride=1,
                            use_se=use_se, act_name=act_mode)
        )
    elif block_mode == 'ShuffleNas_fixArch':
        architecture = [0, 0, 0, 0, 0, 0, 1, 1, 2, 0, 1, 1, 0, 0, 1, 2, 2, 0, 2, 0]
        scale_ids = [8, 6, 5, 7, 6, 7, 3, 4, 2, 4, 2, 3, 4, 3, 6, 7, 5, 3, 4, 6]
        net = get_shufflenas_oneshot(architecture=architecture, scale_ids=scale_ids,
                                     use_se=True, last_conv_after_pooling=True)
    else:
        raise ValueError("Unrecognized mode: {}".format(block_mode))

    if block_mode != 'ShuffleNas_fixArch':
        net.add(nn.GlobalAvgPool2D(),
                nn.Conv2D(10, in_channels=32, kernel_size=1, strides=1,
                          padding=0, use_bias=True),
                nn.Flatten()
                )
    else:
        net.output = nn.HybridSequential(prefix='output_')
        with net.output.name_scope():
            net.output.add(
                nn.Conv2D(10, in_channels=1024, kernel_size=1, strides=1,
                          padding=0, use_bias=True),
                nn.Flatten()
            )
    return net
예제 #10
0
def main():
    opt = parse_args()

    filehandler = logging.FileHandler(opt.logging_file)
    streamhandler = logging.StreamHandler()

    logger = logging.getLogger('')
    logger.setLevel(logging.INFO)
    logger.addHandler(filehandler)
    logger.addHandler(streamhandler)

    logger.info(opt)

    batch_size = opt.batch_size
    classes = 1000
    num_training_samples = 1281167

    num_gpus = opt.num_gpus
    batch_size *= max(1, num_gpus)
    context = [mx.gpu(i)
               for i in range(num_gpus)] if num_gpus > 0 else [mx.cpu()]
    num_workers = opt.num_workers

    lr_decay = opt.lr_decay
    lr_decay_period = opt.lr_decay_period
    if opt.lr_decay_period > 0:
        lr_decay_epoch = list(
            range(lr_decay_period, opt.num_epochs, lr_decay_period))
    else:
        lr_decay_epoch = [int(i) for i in opt.lr_decay_epoch.split(',')]
    lr_decay_epoch = [e - opt.warmup_epochs for e in lr_decay_epoch]
    num_batches = num_training_samples // batch_size

    lr_scheduler = LRSequential([
        LRScheduler('linear',
                    base_lr=0,
                    target_lr=opt.lr,
                    nepochs=opt.warmup_epochs,
                    iters_per_epoch=num_batches),
        LRScheduler(opt.lr_mode,
                    base_lr=opt.lr,
                    target_lr=0,
                    nepochs=opt.num_epochs - opt.warmup_epochs,
                    iters_per_epoch=num_batches,
                    step_epoch=lr_decay_epoch,
                    step_factor=lr_decay,
                    power=2)
    ])

    model_name = opt.model

    kwargs = {
        'ctx': context,
        'pretrained': opt.use_pretrained,
        'classes': classes
    }
    if opt.use_gn:
        from gluoncv.nn import GroupNorm
        kwargs['norm_layer'] = GroupNorm
    if model_name.startswith('vgg'):
        kwargs['batch_norm'] = opt.batch_norm
    elif model_name.startswith('resnext'):
        kwargs['use_se'] = opt.use_se

    if opt.last_gamma:
        kwargs['last_gamma'] = True

    optimizer = 'nag'
    optimizer_params = {
        'wd': opt.wd,
        'momentum': opt.momentum,
        'lr_scheduler': lr_scheduler
    }
    if opt.dtype != 'float32':
        optimizer_params['multi_precision'] = True

    if model_name == 'ShuffleNas_fixArch':
        architecture = [
            0, 0, 3, 1, 1, 1, 0, 0, 2, 0, 2, 1, 1, 0, 2, 0, 2, 1, 3, 2
        ]
        scale_ids = [
            6, 5, 3, 5, 2, 6, 3, 4, 2, 5, 7, 5, 4, 6, 7, 4, 4, 5, 4, 3
        ]
        net = get_shufflenas_oneshot(architecture, scale_ids)
    elif model_name == 'ShuffleNas':
        net = get_shufflenas_oneshot()
    else:
        net = get_model(model_name, **kwargs)

    net.cast(opt.dtype)
    if opt.resume_params is not '':
        net.load_parameters(opt.resume_params, ctx=context)

    # teacher model for distillation training
    if opt.teacher is not None and opt.hard_weight < 1.0:
        teacher_name = opt.teacher
        teacher = get_model(teacher_name,
                            pretrained=True,
                            classes=classes,
                            ctx=context)
        teacher.cast(opt.dtype)
        distillation = True
    else:
        distillation = False

    # Two functions for reading data from record file or raw images
    def get_data_rec(rec_train, rec_train_idx, rec_val, rec_val_idx,
                     batch_size, num_workers):
        rec_train = os.path.expanduser(rec_train)
        rec_train_idx = os.path.expanduser(rec_train_idx)
        rec_val = os.path.expanduser(rec_val)
        rec_val_idx = os.path.expanduser(rec_val_idx)
        jitter_param = 0.4
        lighting_param = 0.1
        input_size = opt.input_size
        crop_ratio = opt.crop_ratio if opt.crop_ratio > 0 else 0.875
        resize = int(math.ceil(input_size / crop_ratio))
        mean_rgb = [123.68, 116.779, 103.939]
        std_rgb = [58.393, 57.12, 57.375]

        def batch_fn(batch, ctx):
            data = gluon.utils.split_and_load(batch.data[0],
                                              ctx_list=ctx,
                                              batch_axis=0)
            label = gluon.utils.split_and_load(batch.label[0],
                                               ctx_list=ctx,
                                               batch_axis=0)
            return data, label

        train_data = mx.io.ImageRecordIter(
            path_imgrec=rec_train,
            path_imgidx=rec_train_idx,
            preprocess_threads=num_workers,
            shuffle=True,
            batch_size=batch_size,
            data_shape=(3, input_size, input_size),
            mean_r=mean_rgb[0],
            mean_g=mean_rgb[1],
            mean_b=mean_rgb[2],
            std_r=std_rgb[0],
            std_g=std_rgb[1],
            std_b=std_rgb[2],
            rand_mirror=True,
            random_resized_crop=True,
            max_aspect_ratio=4. / 3.,
            min_aspect_ratio=3. / 4.,
            max_random_area=1,
            min_random_area=0.08,
            brightness=jitter_param,
            saturation=jitter_param,
            contrast=jitter_param,
            pca_noise=lighting_param,
        )
        val_data = mx.io.ImageRecordIter(
            path_imgrec=rec_val,
            path_imgidx=rec_val_idx,
            preprocess_threads=num_workers,
            shuffle=False,
            batch_size=batch_size,
            resize=resize,
            data_shape=(3, input_size, input_size),
            mean_r=mean_rgb[0],
            mean_g=mean_rgb[1],
            mean_b=mean_rgb[2],
            std_r=std_rgb[0],
            std_g=std_rgb[1],
            std_b=std_rgb[2],
        )
        return train_data, val_data, batch_fn

    def get_data_loader(data_dir, batch_size, num_workers):
        normalize = transforms.Normalize([0.485, 0.456, 0.406],
                                         [0.229, 0.224, 0.225])
        jitter_param = 0.4
        lighting_param = 0.1
        input_size = opt.input_size
        crop_ratio = opt.crop_ratio if opt.crop_ratio > 0 else 0.875
        resize = int(math.ceil(input_size / crop_ratio))

        def batch_fn(batch, ctx):
            data = gluon.utils.split_and_load(batch[0],
                                              ctx_list=ctx,
                                              batch_axis=0)
            label = gluon.utils.split_and_load(batch[1],
                                               ctx_list=ctx,
                                               batch_axis=0)
            return data, label

        transform_train = transforms.Compose([
            transforms.RandomResizedCrop(input_size),
            transforms.RandomFlipLeftRight(),
            transforms.RandomColorJitter(brightness=jitter_param,
                                         contrast=jitter_param,
                                         saturation=jitter_param),
            transforms.RandomLighting(lighting_param),
            transforms.ToTensor(), normalize
        ])
        transform_test = transforms.Compose([
            transforms.Resize(resize, keep_ratio=True),
            transforms.CenterCrop(input_size),
            transforms.ToTensor(), normalize
        ])

        train_data = gluon.data.DataLoader(imagenet.classification.ImageNet(
            data_dir, train=True).transform_first(transform_train),
                                           batch_size=batch_size,
                                           shuffle=True,
                                           last_batch='discard',
                                           num_workers=num_workers)
        val_data = gluon.data.DataLoader(imagenet.classification.ImageNet(
            data_dir, train=False).transform_first(transform_test),
                                         batch_size=batch_size,
                                         shuffle=False,
                                         num_workers=num_workers)

        return train_data, val_data, batch_fn

    if opt.use_rec:
        train_data, val_data, batch_fn = get_data_rec(opt.rec_train,
                                                      opt.rec_train_idx,
                                                      opt.rec_val,
                                                      opt.rec_val_idx,
                                                      batch_size, num_workers)
    else:
        train_data, val_data, batch_fn = get_data_loader(
            opt.data_dir, batch_size, num_workers)

    if opt.mixup:
        train_metric = mx.metric.RMSE()
    else:
        train_metric = mx.metric.Accuracy()
    acc_top1 = mx.metric.Accuracy()
    acc_top5 = mx.metric.TopKAccuracy(5)

    save_frequency = opt.save_frequency
    if opt.save_dir and save_frequency:
        save_dir = opt.save_dir
        makedirs(save_dir)
    else:
        save_dir = ''
        save_frequency = 0

    def mixup_transform(label, classes, lam=1, eta=0.0):
        if isinstance(label, nd.NDArray):
            label = [label]
        res = []
        for l in label:
            y1 = l.one_hot(classes,
                           on_value=1 - eta + eta / classes,
                           off_value=eta / classes)
            y2 = l[::-1].one_hot(classes,
                                 on_value=1 - eta + eta / classes,
                                 off_value=eta / classes)
            res.append(lam * y1 + (1 - lam) * y2)
        return res

    def smooth(label, classes, eta=0.1):
        if isinstance(label, nd.NDArray):
            label = [label]
        smoothed = []
        for l in label:
            res = l.one_hot(classes,
                            on_value=1 - eta + eta / classes,
                            off_value=eta / classes)
            smoothed.append(res)
        return smoothed

    def test(ctx, val_data):
        if opt.use_rec:
            val_data.reset()
        acc_top1.reset()
        acc_top5.reset()
        for i, batch in enumerate(val_data):
            data, label = batch_fn(batch, ctx)
            if model_name == 'ShuffleNas':
                block_choices = net.random_block_choices(
                    select_predefined_block=False, dtype=opt.dtype)
                full_channel_mask, _ = net.random_channel_mask(
                    select_all_channels=False, dtype=opt.dtype)
                outputs = [
                    net(X.astype(opt.dtype, copy=False), block_choices,
                        full_channel_mask) for X in data
                ]
            else:
                outputs = [net(X.astype(opt.dtype, copy=False)) for X in data]
            acc_top1.update(label, outputs)
            acc_top5.update(label, outputs)

        _, top1 = acc_top1.get()
        _, top5 = acc_top5.get()
        return (1 - top1, 1 - top5)

    def train(ctx):
        if isinstance(ctx, mx.Context):
            ctx = [ctx]
        if opt.resume_params is '':
            if 'ShuffleNas' in model_name:
                net._initialize(ctx=ctx)
            else:
                net.initialize(mx.init.MSRAPrelu(), ctx=ctx)

        if opt.no_wd:
            for k, v in net.collect_params('.*beta|.*gamma|.*bias').items():
                v.wd_mult = 0.0

        trainer = gluon.Trainer(net.collect_params(), optimizer,
                                optimizer_params)
        if opt.resume_states is not '':
            trainer.load_states(opt.resume_states)

        if opt.label_smoothing or opt.mixup:
            sparse_label_loss = False
        else:
            sparse_label_loss = True
        if distillation:
            L = gcv.loss.DistillationSoftmaxCrossEntropyLoss(
                temperature=opt.temperature,
                hard_weight=opt.hard_weight,
                sparse_label=sparse_label_loss)
        else:
            L = gluon.loss.SoftmaxCrossEntropyLoss(
                sparse_label=sparse_label_loss)

        best_val_score = 1

        for epoch in range(opt.resume_epoch, opt.num_epochs):
            tic = time.time()
            if opt.use_rec:
                train_data.reset()
            train_metric.reset()
            btic = time.time()

            for i, batch in enumerate(train_data):
                data, label = batch_fn(batch, ctx)

                if opt.mixup:
                    lam = np.random.beta(opt.mixup_alpha, opt.mixup_alpha)
                    if epoch >= opt.num_epochs - opt.mixup_off_epoch:
                        lam = 1
                    data = [lam * X + (1 - lam) * X[::-1] for X in data]

                    if opt.label_smoothing:
                        eta = 0.1
                    else:
                        eta = 0.0
                    label = mixup_transform(label, classes, lam, eta)

                elif opt.label_smoothing:
                    hard_label = label
                    label = smooth(label, classes)

                if distillation:
                    teacher_prob = [nd.softmax(teacher(X.astype(opt.dtype, copy=False)) / opt.temperature) \
                                    for X in data]

                with ag.record():
                    if model_name == 'ShuffleNas':
                        block_choices = net.random_block_choices(
                            select_predefined_block=False, dtype=opt.dtype)
                        full_channel_mask, _ = net.random_channel_mask(
                            select_all_channels=False, dtype=opt.dtype)
                        outputs = [
                            net(X.astype(opt.dtype, copy=False), block_choices,
                                full_channel_mask) for X in data
                        ]
                    else:
                        outputs = [
                            net(X.astype(opt.dtype, copy=False)) for X in data
                        ]
                    if distillation:
                        loss = [
                            L(yhat.astype('float32', copy=False),
                              y.astype('float32', copy=False),
                              p.astype('float32', copy=False))
                            for yhat, y, p in zip(outputs, label, teacher_prob)
                        ]
                    else:
                        loss = [
                            L(yhat, y.astype(opt.dtype, copy=False))
                            for yhat, y in zip(outputs, label)
                        ]
                for l in loss:
                    l.backward()
                trainer.step(batch_size, ignore_stale_grad=True)

                if opt.mixup:
                    output_softmax = [nd.SoftmaxActivation(out.astype('float32', copy=False)) \
                                    for out in outputs]
                    train_metric.update(label, output_softmax)
                else:
                    if opt.label_smoothing:
                        train_metric.update(hard_label, outputs)
                    else:
                        train_metric.update(label, outputs)

                if opt.log_interval and not (i + 1) % opt.log_interval:
                    train_metric_name, train_metric_score = train_metric.get()
                    logger.info(
                        'Epoch[%d] Batch [%d]\tSpeed: %f samples/sec\t%s=%f\tlr=%f'
                        % (epoch, i, batch_size * opt.log_interval /
                           (time.time() - btic), train_metric_name,
                           train_metric_score, trainer.learning_rate))
                    btic = time.time()

            train_metric_name, train_metric_score = train_metric.get()
            throughput = int(batch_size * i / (time.time() - tic))

            err_top1_val, err_top5_val = test(ctx, val_data)

            logger.info('[Epoch %d] training: %s=%f' %
                        (epoch, train_metric_name, train_metric_score))
            logger.info('[Epoch %d] speed: %d samples/sec\ttime cost: %f' %
                        (epoch, throughput, time.time() - tic))
            logger.info('[Epoch %d] validation: err-top1=%f err-top5=%f' %
                        (epoch, err_top1_val, err_top5_val))

            if err_top1_val < best_val_score:
                best_val_score = err_top1_val
                net.save_parameters(
                    '%s/%.4f-imagenet-%s-%d-best.params' %
                    (save_dir, best_val_score, model_name, epoch))
                trainer.save_states(
                    '%s/%.4f-imagenet-%s-%d-best.states' %
                    (save_dir, best_val_score, model_name, epoch))

            if save_frequency and save_dir and (epoch +
                                                1) % save_frequency == 0:
                net.save_parameters('%s/imagenet-%s-%d.params' %
                                    (save_dir, model_name, epoch))
                trainer.save_states('%s/imagenet-%s-%d.states' %
                                    (save_dir, model_name, epoch))

        if save_frequency and save_dir:
            net.save_parameters('%s/imagenet-%s-%d.params' %
                                (save_dir, model_name, opt.num_epochs - 1))
            trainer.save_states('%s/imagenet-%s-%d.states' %
                                (save_dir, model_name, opt.num_epochs - 1))

    if opt.mode == 'hybrid':
        net.hybridize(static_alloc=True, static_shape=True)
        if distillation:
            teacher.hybridize(static_alloc=True, static_shape=True)
    train(context)
예제 #11
0
def main():
    context = [mx.gpu(i) for i in range(args.num_gpus)
               ] if args.num_gpus > 0 else [mx.cpu()]
    net = get_shufflenas_oneshot(
        use_se=args.use_se,
        last_conv_after_pooling=args.last_conv_after_pooling)
    net.cast(args.dtype)
    net.load_parameters(args.supernet_params, ctx=context)
    net.cast('float32')
    print(net)

    filehandler = logging.FileHandler('./search_supernet_{}.log'.format(
        args.comparison_model))
    streamhandler = logging.StreamHandler()

    logger = logging.getLogger('')
    logger.setLevel(logging.INFO)
    logger.addHandler(filehandler)
    logger.addHandler(streamhandler)

    logger.info(args)

    data_kwargs = {
        "rec_train": args.rec_train,
        "rec_train_idx": args.rec_train_idx,
        "rec_val": args.rec_val,
        "rec_val_idx": args.rec_val_idx,
        "input_size": args.input_size,
        "crop_ratio": args.crop_ratio,
        "num_workers": args.num_workers,
        "shuffle_train": args.shuffle_train
    }

    if args.search_mode == 'random':
        random_search(net,
                      dtype='float32',
                      logger=logger,
                      ctx=context,
                      search_iters=100,
                      comparison_model=args.comparison_model,
                      update_bn_images=args.update_bn_images,
                      batch_size=args.batch_size,
                      topk=args.topk,
                      **data_kwargs)
    elif args.search_mode == 'genetic':
        genetic_search(net,
                       dtype='float32',
                       logger=logger,
                       ctx=context,
                       search_iters=args.search_iters,
                       comparison_model=args.comparison_model,
                       update_bn_images=args.update_bn_images,
                       batch_size=args.batch_size,
                       topk=args.topk,
                       population_size=args.population_size,
                       retain_length=args.retain_length,
                       random_select=args.random_select,
                       mutate_chance=args.mutate_chance,
                       search_target=args.search_target,
                       **data_kwargs)
    else:
        raise ValueError("Unrecognized search mode: {}".format(
            args.search_mode))
예제 #12
0
def main():
    opt = parse_args()

    filehandler = logging.FileHandler(opt.logging_file)
    streamhandler = logging.StreamHandler()

    logger = logging.getLogger('')
    logger.setLevel(logging.INFO)
    # logger.setLevel(logging.DEBUG)
    logger.addHandler(filehandler)
    logger.addHandler(streamhandler)

    logger.info(opt)

    batch_size = opt.batch_size
    classes = 1000
    num_training_samples = 1281167

    num_gpus = opt.num_gpus
    batch_size *= max(1, num_gpus)
    context = [mx.gpu(i)
               for i in range(num_gpus)] if num_gpus > 0 else [mx.cpu()]
    num_workers = opt.num_workers

    # epoch_start_cs controls that before this epoch, use all channels, while, after this epoch, use channel selection.
    if opt.epoch_start_cs != -1:
        opt.use_all_channels = True

    lr_decay = opt.lr_decay
    lr_decay_period = opt.lr_decay_period
    if opt.lr_decay_period > 0:
        lr_decay_epoch = list(
            range(lr_decay_period, opt.num_epochs, lr_decay_period))
    else:
        lr_decay_epoch = [int(i) for i in opt.lr_decay_epoch.split(',')]
    lr_decay_epoch = [e - opt.warmup_epochs for e in lr_decay_epoch]
    num_batches = num_training_samples // batch_size // opt.reduced_dataset_scale

    lr_scheduler = LRSequential([
        LRScheduler('linear',
                    base_lr=0,
                    target_lr=opt.lr,
                    nepochs=opt.warmup_epochs,
                    iters_per_epoch=num_batches),
        LRScheduler(opt.lr_mode,
                    base_lr=opt.lr,
                    target_lr=0,
                    nepochs=opt.num_epochs - opt.warmup_epochs,
                    iters_per_epoch=num_batches,
                    step_epoch=lr_decay_epoch,
                    step_factor=lr_decay,
                    power=2)
    ])

    model_name = opt.model

    kwargs = {
        'ctx': context,
        'pretrained': opt.use_pretrained,
        'classes': classes
    }
    if opt.use_gn:
        from gluoncv.nn import GroupNorm
        kwargs['norm_layer'] = GroupNorm
    if model_name.startswith('vgg'):
        kwargs['batch_norm'] = opt.batch_norm
    elif model_name.startswith('resnext'):
        kwargs['use_se'] = opt.use_se

    if opt.last_gamma:
        kwargs['last_gamma'] = True

    optimizer = 'nag'
    optimizer_params = {
        'wd': opt.wd,
        'momentum': opt.momentum,
        'lr_scheduler': lr_scheduler
    }
    if opt.dtype != 'float32':
        optimizer_params['multi_precision'] = True

    if model_name == 'ShuffleNas_fixArch':
        architecture = parse_str_list(opt.block_choices)
        scale_ids = parse_str_list(opt.channel_choices)
        net = get_shufflenas_oneshot(
            architecture=architecture,
            n_class=classes,
            scale_ids=scale_ids,
            use_se=opt.use_se,
            last_conv_after_pooling=opt.last_conv_after_pooling,
            channels_layout=opt.channels_layout)
    elif model_name == 'ShuffleNas':
        net = get_shufflenas_oneshot(
            n_class=classes,
            use_all_blocks=opt.use_all_blocks,
            use_se=opt.use_se,
            last_conv_after_pooling=opt.last_conv_after_pooling,
            channels_layout=opt.channels_layout)
    else:
        net = get_model(model_name, **kwargs)

    net.cast(opt.dtype)
    if opt.resume_params is not '':
        net.load_parameters(opt.resume_params, ctx=context)

    # teacher model for distillation training
    if opt.teacher is not None and opt.hard_weight < 1.0:
        teacher_name = opt.teacher
        teacher = get_model(teacher_name,
                            pretrained=True,
                            classes=classes,
                            ctx=context)
        teacher.cast(opt.dtype)
        distillation = True
    else:
        distillation = False

    # Two functions for reading data from record file or raw images
    def get_data_rec(rec_train, rec_train_idx, rec_val, rec_val_idx,
                     batch_size, num_workers):
        rec_train = os.path.expanduser(rec_train)
        rec_train_idx = os.path.expanduser(rec_train_idx)
        rec_val = os.path.expanduser(rec_val)
        rec_val_idx = os.path.expanduser(rec_val_idx)
        jitter_param = 0.4
        lighting_param = 0.1
        input_size = opt.input_size
        crop_ratio = opt.crop_ratio if opt.crop_ratio > 0 else 0.875
        resize = int(math.ceil(input_size / crop_ratio))
        mean_rgb = [123.68, 116.779, 103.939]
        std_rgb = [58.393, 57.12, 57.375]

        def batch_fn(batch, ctx):
            data = gluon.utils.split_and_load(batch.data[0],
                                              ctx_list=ctx,
                                              batch_axis=0)
            label = gluon.utils.split_and_load(batch.label[0],
                                               ctx_list=ctx,
                                               batch_axis=0)
            return data, label

        train_data = mx.io.ImageRecordIter(
            path_imgrec=rec_train,
            path_imgidx=rec_train_idx,
            preprocess_threads=num_workers,
            shuffle=True,
            batch_size=batch_size,
            data_shape=(3, input_size, input_size),
            mean_r=mean_rgb[0],
            mean_g=mean_rgb[1],
            mean_b=mean_rgb[2],
            std_r=std_rgb[0],
            std_g=std_rgb[1],
            std_b=std_rgb[2],
            rand_mirror=True,
            random_resized_crop=True,
            max_aspect_ratio=4. / 3.,
            min_aspect_ratio=3. / 4.,
            max_random_area=1,
            min_random_area=0.08,
            brightness=jitter_param,
            saturation=jitter_param,
            contrast=jitter_param,
            pca_noise=lighting_param,
        )
        val_data = mx.io.ImageRecordIter(
            path_imgrec=rec_val,
            path_imgidx=rec_val_idx,
            preprocess_threads=num_workers,
            shuffle=False,
            batch_size=batch_size,
            resize=resize,
            data_shape=(3, input_size, input_size),
            mean_r=mean_rgb[0],
            mean_g=mean_rgb[1],
            mean_b=mean_rgb[2],
            std_r=std_rgb[0],
            std_g=std_rgb[1],
            std_b=std_rgb[2],
        )
        return train_data, val_data, batch_fn

    def get_data_loader(data_dir, batch_size, num_workers):
        normalize = transforms.Normalize([0.485, 0.456, 0.406],
                                         [0.229, 0.224, 0.225])
        jitter_param = 0.4
        lighting_param = 0.1
        input_size = opt.input_size
        crop_ratio = opt.crop_ratio if opt.crop_ratio > 0 else 0.875
        resize = int(math.ceil(input_size / crop_ratio))

        def batch_fn(batch, ctx):
            data = gluon.utils.split_and_load(batch[0],
                                              ctx_list=ctx,
                                              batch_axis=0)
            label = gluon.utils.split_and_load(batch[1],
                                               ctx_list=ctx,
                                               batch_axis=0)
            return data, label

        transform_train = transforms.Compose([
            transforms.RandomResizedCrop(input_size),
            transforms.RandomFlipLeftRight(),
            transforms.RandomColorJitter(brightness=jitter_param,
                                         contrast=jitter_param,
                                         saturation=jitter_param),
            transforms.RandomLighting(lighting_param),
            transforms.ToTensor(), normalize
        ])
        transform_test = transforms.Compose([
            transforms.Resize(resize, keep_ratio=True),
            transforms.CenterCrop(input_size),
            transforms.ToTensor(), normalize
        ])

        train_data = gluon.data.DataLoader(imagenet.classification.ImageNet(
            data_dir, train=True).transform_first(transform_train),
                                           batch_size=batch_size,
                                           shuffle=True,
                                           last_batch='discard',
                                           num_workers=num_workers)
        val_data = gluon.data.DataLoader(imagenet.classification.ImageNet(
            data_dir, train=False).transform_first(transform_test),
                                         batch_size=batch_size,
                                         shuffle=False,
                                         num_workers=num_workers)

        return train_data, val_data, batch_fn

    if opt.use_rec:
        train_data, val_data, batch_fn = get_data_rec(opt.rec_train,
                                                      opt.rec_train_idx,
                                                      opt.rec_val,
                                                      opt.rec_val_idx,
                                                      batch_size, num_workers)
    else:
        train_data, val_data, batch_fn = get_data_loader(
            opt.data_dir, batch_size, num_workers)

    if opt.mixup:
        train_metric = mx.metric.RMSE()
    else:
        train_metric = mx.metric.Accuracy()
    acc_top1 = mx.metric.Accuracy()
    acc_top5 = mx.metric.TopKAccuracy(5)

    save_frequency = opt.save_frequency
    if opt.save_dir and save_frequency:
        save_dir = opt.save_dir
        makedirs(save_dir)
    else:
        save_dir = ''
        save_frequency = 0

    def mixup_transform(label, classes, lam=1, eta=0.0):
        if isinstance(label, nd.NDArray):
            label = [label]
        res = []
        for l in label:
            y1 = l.one_hot(classes,
                           on_value=1 - eta + eta / classes,
                           off_value=eta / classes)
            y2 = l[::-1].one_hot(classes,
                                 on_value=1 - eta + eta / classes,
                                 off_value=eta / classes)
            res.append(lam * y1 + (1 - lam) * y2)
        return res

    def smooth(label, classes, eta=0.1):
        if isinstance(label, nd.NDArray):
            label = [label]
        smoothed = []
        for l in label:
            res = l.one_hot(classes,
                            on_value=1 - eta + eta / classes,
                            off_value=eta / classes)
            smoothed.append(res)
        return smoothed

    def make_divisible(x, divisible_by=8):
        return int(np.ceil(x * 1. / divisible_by) * divisible_by)

    def set_nas_bn(net, inference_update_stat=False):
        if isinstance(net, NasBatchNorm):
            net.inference_update_stat = inference_update_stat
        elif len(net._children) != 0:
            for k, v in net._children.items():
                set_nas_bn(v, inference_update_stat=inference_update_stat)
        else:
            return

    def update_bn(net,
                  batch_fn,
                  train_data,
                  block_choices,
                  full_channel_mask,
                  ctx=[mx.cpu()],
                  dtype='float32',
                  batch_size=256,
                  update_bn_images=20000):
        train_data.reset()
        # Updating bn needs the model to be float32
        net.cast('float32')
        full_channel_masks = [
            full_channel_mask.as_in_context(ctx_i) for ctx_i in ctx
        ]
        set_nas_bn(net, inference_update_stat=True)
        for i, batch in enumerate(train_data):
            if (i + 1) * batch_size * len(ctx) >= update_bn_images:
                break
            data, _ = batch_fn(batch, ctx)
            _ = [
                net(X.astype('float32', copy=False),
                    block_choices.astype('float32', copy=False),
                    channel_mask.astype('float32', copy=False))
                for X, channel_mask in zip(data, full_channel_masks)
            ]
        set_nas_bn(net, inference_update_stat=False)
        net.cast(dtype)

    def test(ctx, val_data, epoch):
        if model_name == 'ShuffleNas':
            # For evaluating validation accuracy, random select block and channels and update bn stats
            block_choices = net.random_block_choices(
                select_predefined_block=False, dtype=opt.dtype)
            if opt.cs_warm_up:
                # TODO: edit in the issue, readme and medium article that
                #  bn stat needs to be updated before verifying val acc
                full_channel_mask, channel_choices = net.random_channel_mask(
                    select_all_channels=False,
                    epoch_after_cs=epoch - opt.epoch_start_cs,
                    dtype=opt.dtype,
                    ignore_first_two_cs=opt.ignore_first_two_cs)
            else:
                full_channel_mask, _ = net.random_channel_mask(
                    select_all_channels=False,
                    dtype=opt.dtype,
                    ignore_first_two_cs=opt.ignore_first_two_cs)
            update_bn(net,
                      batch_fn,
                      train_data,
                      block_choices,
                      full_channel_mask,
                      ctx,
                      dtype=opt.dtype,
                      batch_size=batch_size)
        else:
            block_choices, full_channel_mask = None, None

        if opt.use_rec:
            val_data.reset()
        acc_top1.reset()
        acc_top5.reset()
        for i, batch in enumerate(val_data):
            data, label = batch_fn(batch, ctx)
            if model_name == 'ShuffleNas':
                full_channel_masks = [
                    full_channel_mask.as_in_context(ctx_i) for ctx_i in ctx
                ]
                outputs = [
                    net(X.astype(opt.dtype, copy=False), block_choices,
                        channel_mask)
                    for X, channel_mask in zip(data, full_channel_masks)
                ]
            else:
                outputs = [net(X.astype(opt.dtype, copy=False)) for X in data]
            acc_top1.update(label, outputs)
            acc_top5.update(label, outputs)

        _, top1 = acc_top1.get()
        _, top5 = acc_top5.get()
        return 1 - top1, 1 - top5

    def train(ctx):
        if isinstance(ctx, mx.Context):
            ctx = [ctx]
        if opt.resume_params is '':
            if 'ShuffleNas' in model_name:
                net._initialize(ctx=ctx)
            else:
                net.initialize(mx.init.MSRAPrelu(), ctx=ctx)

        if opt.no_wd:
            for k, v in net.collect_params('.*beta|.*gamma|.*bias').items():
                v.wd_mult = 0.0

        trainer = gluon.Trainer(net.collect_params(), optimizer,
                                optimizer_params)
        if opt.resume_states is not '':
            trainer.load_states(opt.resume_states)

        if opt.label_smoothing or opt.mixup:
            sparse_label_loss = False
        else:
            sparse_label_loss = True
        if distillation:
            L = gcv.loss.DistillationSoftmaxCrossEntropyLoss(
                temperature=opt.temperature,
                hard_weight=opt.hard_weight,
                sparse_label=sparse_label_loss)
        else:
            L = gluon.loss.SoftmaxCrossEntropyLoss(
                sparse_label=sparse_label_loss)

        best_val_score = 1

        def train_epoch(pool=None,
                        pool_lock=None,
                        shared_finished_flag=None,
                        use_pool=False):
            btic = time.time()
            for i, batch in enumerate(train_data):
                if i == num_batches:
                    if use_pool:
                        shared_finished_flag.value = True
                    return
                data, label = batch_fn(batch, ctx)

                if opt.mixup:
                    lam = np.random.beta(opt.mixup_alpha, opt.mixup_alpha)
                    if epoch >= opt.num_epochs - opt.mixup_off_epoch:
                        lam = 1
                    data = [lam * X + (1 - lam) * X[::-1] for X in data]

                    if opt.label_smoothing:
                        eta = 0.1
                    else:
                        eta = 0.0
                    label = mixup_transform(label, classes, lam, eta)

                elif opt.label_smoothing:
                    hard_label = label
                    label = smooth(label, classes)

                if distillation:
                    teacher_prob = [nd.softmax(teacher(X.astype(opt.dtype, copy=False)) / opt.temperature) \
                                    for X in data]

                with ag.record():
                    if model_name == 'ShuffleNas' and use_pool:
                        cand = None
                        while cand is None:
                            if len(pool) > 0:
                                with pool_lock:
                                    cand = pool.pop()
                                    if i % opt.log_interval == 0:
                                        logger.debug('[Trainer] ' + '-' * 40)
                                        logger.debug(
                                            "[Trainer] Time: {}".format(
                                                time.time()))
                                        logger.debug(
                                            "[Trainer] Block choice: {}".
                                            format(cand['block_list']))
                                        logger.debug(
                                            "[Trainer] Channel choice: {}".
                                            format(cand['channel_list']))
                                        logger.debug(
                                            "[Trainer] Flop: {}M, param: {}M".
                                            format(cand['flops'],
                                                   cand['model_size']))
                            else:
                                time.sleep(1)

                        full_channel_masks = [
                            cand['channel'].as_in_context(ctx_i)
                            for ctx_i in ctx
                        ]
                        outputs = [
                            net(X.astype(opt.dtype, copy=False), cand['block'],
                                channel_mask) for X, channel_mask in zip(
                                    data, full_channel_masks)
                        ]
                    elif model_name == 'ShuffleNas':
                        block_choices = net.random_block_choices(
                            select_predefined_block=False, dtype=opt.dtype)
                        if opt.cs_warm_up:
                            full_channel_mask, channel_choices = net.random_channel_mask(
                                select_all_channels=opt.use_all_channels,
                                epoch_after_cs=epoch - opt.epoch_start_cs,
                                dtype=opt.dtype,
                                ignore_first_two_cs=opt.ignore_first_two_cs)
                        else:
                            full_channel_mask, channel_choices = net.random_channel_mask(
                                select_all_channels=opt.use_all_channels,
                                dtype=opt.dtype,
                                ignore_first_two_cs=opt.ignore_first_two_cs)

                        full_channel_masks = [
                            full_channel_mask.as_in_context(ctx_i)
                            for ctx_i in ctx
                        ]
                        outputs = [
                            net(X.astype(opt.dtype, copy=False), block_choices,
                                channel_mask) for X, channel_mask in zip(
                                    data, full_channel_masks)
                        ]
                    else:
                        outputs = [
                            net(X.astype(opt.dtype, copy=False)) for X in data
                        ]

                    if distillation:
                        loss = [
                            L(yhat.astype('float32', copy=False),
                              y.astype('float32', copy=False),
                              p.astype('float32', copy=False))
                            for yhat, y, p in zip(outputs, label, teacher_prob)
                        ]
                    else:
                        loss = [
                            L(yhat, y.astype(opt.dtype, copy=False))
                            for yhat, y in zip(outputs, label)
                        ]
                for l in loss:
                    l.backward()
                trainer.step(batch_size, ignore_stale_grad=True)

                if opt.mixup:
                    output_softmax = [nd.SoftmaxActivation(out.astype('float32', copy=False)) \
                                    for out in outputs]
                    train_metric.update(label, output_softmax)
                else:
                    if opt.label_smoothing:
                        train_metric.update(hard_label, outputs)
                    else:
                        train_metric.update(label, outputs)

                if opt.log_interval and not (i + 1) % opt.log_interval:
                    train_metric_name, train_metric_score = train_metric.get()
                    logger.info(
                        'Epoch[%d] Batch [%d]\tSpeed: %f samples/sec\t%s=%f\tlr=%f'
                        % (epoch, i, batch_size * opt.log_interval /
                           (time.time() - btic), train_metric_name,
                           train_metric_score, trainer.learning_rate))
                    btic = time.time()
            return

        def maintain_random_pool(pool,
                                 pool_lock,
                                 shared_finished_flag,
                                 upper_flops=sys.maxsize,
                                 upper_params=sys.maxsize,
                                 bottom_flops=0,
                                 bottom_params=0):
            lookup_table = None
            if opt.flop_param_method == 'lookup_table':
                lookup_table = load_lookup_table(opt.use_se,
                                                 opt.last_conv_after_pooling,
                                                 opt.channels_layout,
                                                 nas_root='./')
            while True:
                if shared_finished_flag.value:
                    break
                if len(pool) < 5:
                    candidate = dict()
                    block_choices, block_choices_list = net.random_block_choices(
                        select_predefined_block=False,
                        dtype=opt.dtype,
                        return_choice_list=True)
                    if opt.cs_warm_up:
                        full_channel_mask, channel_choices_list = net.random_channel_mask(
                            select_all_channels=opt.use_all_channels,
                            epoch_after_cs=epoch - opt.epoch_start_cs,
                            dtype=opt.dtype,
                            ignore_first_two_cs=opt.ignore_first_two_cs,
                        )
                    else:
                        full_channel_mask, channel_choices_list = net.random_channel_mask(
                            select_all_channels=opt.use_all_channels,
                            dtype=opt.dtype,
                            ignore_first_two_cs=opt.ignore_first_two_cs)

                    if opt.flop_param_method == 'symbol':
                        flops, model_size, _, _ = \
                            get_flop_param_forward(block_choices_list, channel_choices_list,
                                                 use_se=opt.use_se, last_conv_after_pooling=opt.last_conv_after_pooling,
                                                 channels_layout=opt.channels_layout)
                    elif opt.flop_param_method == 'lookup_table':
                        flops, model_size = get_flop_param_lookup(
                            block_choices_list, channel_choices_list,
                            lookup_table)
                    else:
                        raise ValueError(
                            'Unrecognized flop param calculation method: {}'.
                            format(opt.flop_param_method))

                    candidate['block'] = block_choices
                    candidate['channel'] = full_channel_mask
                    candidate['block_list'] = block_choices_list
                    candidate['channel_list'] = channel_choices_list
                    candidate['flops'] = flops
                    candidate['model_size'] = model_size

                    if flops > upper_flops or model_size > upper_params or \
                        flops < bottom_flops or model_size < bottom_params:
                        continue

                    with pool_lock:
                        pool.append(candidate)
                        logger.debug(
                            "[Maintainer] Add one good candidate. currently pool size: {}"
                            .format(len(pool)))

        if opt.train_constraint_method == 'evolution':

            evolve_maintainer = Maintainer(net,
                                           num_batches=num_batches,
                                           nas_root='./')
        else:
            evolve_maintainer = None
        manager = multiprocessing.Manager()
        cand_pool = manager.list()
        p_lock = manager.Lock()

        for epoch in range(opt.resume_epoch, opt.num_epochs):
            if epoch >= opt.epoch_start_cs:
                opt.use_all_channels = False
            tic = time.time()
            if opt.use_rec:
                train_data.reset()
            train_metric.reset()

            if opt.train_constraint_method is None:
                logger.debug("===== DEBUG ======\n"
                             "Train SuperNet with no constraint")
                train_epoch()
            else:
                upper_constraints = opt.train_upper_constraints.split('-')
                # opt.train_upper_constraints = 'flops-330-params-5.0'
                assert len(upper_constraints) == 4 and upper_constraints[0] == 'flops' \
                       and upper_constraints[2] == 'params'
                upper_flops = float(upper_constraints[1]) if float(
                    upper_constraints[1]) != 0 else sys.maxsize
                upper_params = float(upper_constraints[3]) if float(
                    upper_constraints[3]) != 0 else sys.maxsize

                bottom_constraints = opt.train_bottom_constraints.split('-')
                assert len(bottom_constraints) == 4 and bottom_constraints[0] == 'flops' \
                       and bottom_constraints[2] == 'params'
                bottom_flops = float(bottom_constraints[1]) if float(
                    bottom_constraints[1]) != 0 else 0
                bottom_params = float(bottom_constraints[3]) if float(
                    bottom_constraints[3]) != 0 else 0

                if opt.train_constraint_method == 'random':
                    finished = Value(c_bool, False)
                    logger.debug(
                        "===== DEBUG ======\n"
                        "Train SuperNet with Flops less than {}, greater than {}, "
                        "params less than {}, greater than {}\n Random sample."
                        .format(upper_flops, bottom_flops, upper_params,
                                bottom_params))
                    pool_process = multiprocessing.Process(
                        target=maintain_random_pool,
                        args=[
                            cand_pool, p_lock, finished, upper_flops,
                            upper_params, bottom_flops, bottom_params
                        ])
                    pool_process.start()
                    train_epoch(pool=cand_pool,
                                pool_lock=p_lock,
                                shared_finished_flag=finished,
                                use_pool=True)
                    pool_process.join()
                elif opt.train_constraint_method == 'evolution':
                    finished = Value(c_bool, False)
                    logger.debug(
                        "===== DEBUG ======\n"
                        "Train SuperNet with Flops less than {}, greater than {}, "
                        "params less than {}, greater than {}\n Using strolling evolution to sample."
                        .format(upper_flops, bottom_flops, upper_params,
                                bottom_params))
                    pool_process = multiprocessing.Process(
                        target=evolve_maintainer.maintain,
                        args=[cand_pool, p_lock, finished, opt.dtype, logger])
                    pool_process.start()
                    train_epoch(pool=cand_pool,
                                pool_lock=p_lock,
                                shared_finished_flag=finished,
                                use_pool=True)
                    pool_process.join()
                else:
                    raise ValueError(
                        "Unrecognized training constraint method: {}".format(
                            opt.train_constraint_method))

            train_metric_name, train_metric_score = train_metric.get()
            throughput = int(batch_size * num_batches / (time.time() - tic))

            err_top1_val, err_top5_val = test(ctx, val_data, epoch)

            logger.info('[Epoch %d] training: %s=%f' %
                        (epoch, train_metric_name, train_metric_score))
            logger.info('[Epoch %d] speed: %d samples/sec\ttime cost: %f' %
                        (epoch, throughput, time.time() - tic))
            logger.info('[Epoch %d] validation: err-top1=%f err-top5=%f' %
                        (epoch, err_top1_val, err_top5_val))

            if err_top1_val < best_val_score:
                best_val_score = err_top1_val
                net.save_parameters(
                    '%s/%.4f-imagenet-%s-%d-best.params' %
                    (save_dir, best_val_score, model_name, epoch))
                trainer.save_states(
                    '%s/%.4f-imagenet-%s-%d-best.states' %
                    (save_dir, best_val_score, model_name, epoch))

            if save_frequency and save_dir and (epoch +
                                                1) % save_frequency == 0:
                net.save_parameters('%s/imagenet-%s-%d.params' %
                                    (save_dir, model_name, epoch))
                trainer.save_states('%s/imagenet-%s-%d.states' %
                                    (save_dir, model_name, epoch))

        if save_frequency and save_dir:
            net.save_parameters('%s/imagenet-%s-%d.params' %
                                (save_dir, model_name, opt.num_epochs - 1))
            trainer.save_states('%s/imagenet-%s-%d.states' %
                                (save_dir, model_name, opt.num_epochs - 1))

    if opt.mode == 'hybrid':
        net.hybridize(static_alloc=True, static_shape=True)
        if distillation:
            teacher.hybridize(static_alloc=True, static_shape=True)
    print(net)
    train(context)
def get_distribution():
    net = get_shufflenas_oneshot(
        use_se=args.use_se,
        last_conv_after_pooling=args.last_conv_after_pooling,
        channels_layout=args.channels_layout)
    if args.compare:
        net = get_shufflenas_oneshot()
    print(net)
    # TODO: find out why argparser does not work.
    args.compare = True
    print(args)
    flop_list = []
    param_list = []
    se_flop_list = []
    se_param_list = []
    pool = []
    with open('../models/lookup_table_OneShot.json', 'r') as fp:
        lookup_table = json.load(fp)
    with open('../models/lookup_table_se_lastConvAfterPooling_OneShot.json',
              'r') as fp:
        se_lookup_table = json.load(fp)

    for i in range(args.sample_count):
        candidate = dict()
        if not args.use_evolution or len(pool) < 10:
            _, block_choices = net.random_block_choices(
                select_predefined_block=False, return_choice_list=True)
            _, channel_choices = net.random_channel_mask(
                select_all_channels=False)

        elif len(pool) < 20:
            # randomly select parents from current pool
            mother = random.choice(pool)
            father = random.choice(pool)

            # make sure mother and father are different
            while father is mother:
                mother = random.choice(pool)

            # breed block choice
            block_choices = [0] * len(father['block'])
            for i in range(len(block_choices)):
                block_choices[i] = random.choice(
                    [mother['block'][i], father['block'][i]])
                # Mutation: randomly mutate some of the children.
                if random.random() < 0.3:
                    block_choices[i] = random.choice(PARAM_DICT['block'])

            # breed channel choice
            channel_choices = [0] * len(father['channel'])
            for i in range(len(channel_choices)):
                channel_choices[i] = random.choice(
                    [mother['channel'][i], father['channel'][i]])
                # Mutation: randomly mutate some of the children.
                if random.random() < 0.2:
                    channel_choices[i] = random.choice(PARAM_DICT['channel'])
            pool.pop(0)

        candidate['block'] = block_choices
        candidate['channel'] = channel_choices

        if args.compare:
            if FAST_LOOKUP:
                flops, model_size = lookup_flop_params(block_choices,
                                                       channel_choices,
                                                       lookup_table)
                se_flops, se_model_size = lookup_flop_params(
                    block_choices, channel_choices, se_lookup_table)
            else:
                flops, model_size, _, _ = \
                    get_flop_param_score(block_choices, channel_choices, use_se=False, last_conv_after_pooling=False,
                                         channels_layout=args.channels_layout)

                se_flops, se_model_size, _, _ = \
                    get_flop_param_score(block_choices, channel_choices, use_se=True, last_conv_after_pooling=True,
                                         channels_layout=args.channels_layout)

            flop_list.append(flops)
            param_list.append(model_size)
            se_flop_list.append(se_flops)
            se_param_list.append(se_model_size)
        else:
            flops, model_size, _, _ = \
                get_flop_param_score(block_choices, channel_choices, use_se=args.use_se,
                                     last_conv_after_pooling=args.last_conv_after_pooling,
                                     channels_layout=args.channels_layout)

            flop_list.append(flops)
            param_list.append(model_size)

        if flops > 300 or model_size > 4.5 or not args.use_evolution:
            continue
        pool.append(candidate)

    # plot
    if args.compare:
        plt.style.use("ggplot")
        fig, (axs1, axs2) = plt.subplots(1, 2, sharex=True, sharey=True)
        axs1.scatter(flop_list,
                     param_list,
                     alpha=0.8,
                     c='mediumaquamarine',
                     s=50,
                     label='subnet')
        axs1.set_title('Original SuperNet Distribution')
        axs1.legend(loc="lower right")
        axs2.scatter(se_flop_list,
                     se_param_list,
                     alpha=0.8,
                     c='mediumaquamarine',
                     s=50,
                     label='subnet')
        axs2.set_title('SE-SuperNet Distribution')
        axs2.legend(loc="lower right")
        for axs in [axs1, axs2]:
            axs.set(xlabel='Flops', ylabel='Params amount')
            # Hide x labels and tick labels for top plots and y ticks for right plots.
            axs.label_outer()

        plt.savefig('../images/supernet_flops_params_dist_compare.png')
        plt.show()
        plt.close()
    else:
        plt.style.use("ggplot")
        plt.figure()
        plt.scatter(flop_list,
                    param_list,
                    alpha=0.8,
                    c='mediumaquamarine',
                    s=50,
                    label='subnet')
        plt.title('Flops Param Distribution')
        plt.xlabel("Flops")
        plt.ylabel("Params amount")
        plt.legend(loc="lower right")
        plt.savefig('../images/supernet_flops_params_dist.png')
        plt.show()
        plt.close()