def start_client(configs, addr, port):
    client_sanas = SANAS(configs=configs,
                         server_addr=(addr, port),
                         save_checkpoint=None,
                         is_server=False)
    for _ in range(2):
        arch = client_sanas.next_archs()[0]
        time.sleep(1)
        client_sanas.reward(0.1)
    def test_nas(self):

        factory = SearchSpaceFactory()
        config0 = {'input_size': 224, 'output_size': 7, 'block_num': 5}
        config1 = {'input_size': 7, 'output_size': 1, 'block_num': 2}
        configs = [('MobileNetV2Space', config0), ('ResNetSpace', config1)]

        space = factory.get_search_space([('MobileNetV2Space', config0)])
        origin_arch = space.token2arch()[0]

        main_program = fluid.Program()
        s_program = fluid.Program()
        with fluid.program_guard(main_program, s_program):
            input = fluid.data(name="input",
                               shape=[None, 3, 224, 224],
                               dtype="float32")
            origin_arch(input)
        base_flops = flops(main_program)

        search_steps = 3
        sa_nas = SANAS(configs,
                       search_steps=search_steps,
                       server_addr=("", 0),
                       is_server=True)

        for i in range(search_steps):
            archs = sa_nas.next_archs()
            main_program = fluid.Program()
            s_program = fluid.Program()
            with fluid.program_guard(main_program, s_program):
                input = fluid.data(name="input",
                                   shape=[None, 3, 224, 224],
                                   dtype="float32")
                archs[0](input)
            sa_nas.reward(1)
            self.assertTrue(flops(main_program) < base_flops)
Exemple #3
0
def main():
    env = os.environ
    FLAGS.dist = 'PADDLE_TRAINER_ID' in env and 'PADDLE_TRAINERS_NUM' in env
    if FLAGS.dist:
        trainer_id = int(env['PADDLE_TRAINER_ID'])
        import random
        local_seed = (99 + trainer_id)
        random.seed(local_seed)
        np.random.seed(local_seed)

    cfg = load_config(FLAGS.config)
    if 'architecture' in cfg:
        main_arch = cfg.architecture
    else:
        raise ValueError("'architecture' not specified in config file.")

    merge_config(FLAGS.opt)

    if 'log_iter' not in cfg:
        cfg.log_iter = 20

    # check if set use_gpu=True in paddlepaddle cpu version
    check_gpu(cfg.use_gpu)
    # check if paddlepaddle version is satisfied
    check_version()

    if cfg.use_gpu:
        devices_num = fluid.core.get_cuda_device_count()
    else:
        devices_num = int(os.environ.get('CPU_NUM', 1))

    if 'FLAGS_selected_gpus' in env:
        device_id = int(env['FLAGS_selected_gpus'])
    else:
        device_id = 0
    place = fluid.CUDAPlace(device_id) if cfg.use_gpu else fluid.CPUPlace()
    exe = fluid.Executor(place)

    lr_builder = create('LearningRate')
    optim_builder = create('OptimizerBuilder')

    # add NAS
    config = ([(cfg.search_space)])
    server_address = (cfg.server_ip, cfg.server_port)
    load_checkpoint = FLAGS.resume_checkpoint if FLAGS.resume_checkpoint else None
    sa_nas = SANAS(config,
                   server_addr=server_address,
                   init_temperature=cfg.init_temperature,
                   reduce_rate=cfg.reduce_rate,
                   search_steps=cfg.search_steps,
                   save_checkpoint=cfg.save_dir,
                   load_checkpoint=load_checkpoint,
                   is_server=cfg.is_server)
    start_iter = 0
    train_reader = create_reader(cfg.TrainReader,
                                 (cfg.max_iters - start_iter) * devices_num,
                                 cfg)
    eval_reader = create_reader(cfg.EvalReader)

    constraint = create('Constraint')
    for step in range(cfg.search_steps):
        logger.info('----->>> search step: {} <<<------'.format(step))
        archs = sa_nas.next_archs()[0]

        # build program
        startup_prog = fluid.Program()
        train_prog = fluid.Program()
        with fluid.program_guard(train_prog, startup_prog):
            with fluid.unique_name.guard():
                model = create(main_arch)
                if FLAGS.fp16:
                    assert (getattr(model.backbone, 'norm_type', None)
                            != 'affine_channel'), \
                        '--fp16 currently does not support affine channel, ' \
                        ' please modify backbone settings to use batch norm'

                with mixed_precision_context(FLAGS.loss_scale,
                                             FLAGS.fp16) as ctx:
                    inputs_def = cfg['TrainReader']['inputs_def']
                    feed_vars, train_loader = model.build_inputs(**inputs_def)
                    train_fetches = archs(feed_vars, 'train', cfg)
                    loss = train_fetches['loss']
                    if FLAGS.fp16:
                        loss *= ctx.get_loss_scale_var()
                    lr = lr_builder()
                    optimizer = optim_builder(lr)
                    optimizer.minimize(loss)
                    if FLAGS.fp16:
                        loss /= ctx.get_loss_scale_var()

        current_constraint = constraint.compute_constraint(train_prog)
        logger.info('current steps: {}, constraint {}'.format(
            step, current_constraint))

        if (constraint.max_constraint != None
                and current_constraint > constraint.max_constraint) or (
                    constraint.min_constraint != None
                    and current_constraint < constraint.min_constraint):
            continue

        # parse train fetches
        train_keys, train_values, _ = parse_fetches(train_fetches)
        train_values.append(lr)

        if FLAGS.eval:
            eval_prog = fluid.Program()
            with fluid.program_guard(eval_prog, startup_prog):
                with fluid.unique_name.guard():
                    model = create(main_arch)
                    inputs_def = cfg['EvalReader']['inputs_def']
                    feed_vars, eval_loader = model.build_inputs(**inputs_def)
                    fetches = archs(feed_vars, 'eval', cfg)
            eval_prog = eval_prog.clone(True)

            eval_loader.set_sample_list_generator(eval_reader, place)
            extra_keys = ['im_id', 'im_shape', 'gt_bbox']
            eval_keys, eval_values, eval_cls = parse_fetches(
                fetches, eval_prog, extra_keys)

        # compile program for multi-devices
        build_strategy = fluid.BuildStrategy()
        build_strategy.fuse_all_optimizer_ops = False
        build_strategy.fuse_elewise_add_act_ops = True

        exec_strategy = fluid.ExecutionStrategy()
        # iteration number when CompiledProgram tries to drop local execution scopes.
        # Set it to be 1 to save memory usages, so that unused variables in
        # local execution scopes can be deleted after each iteration.
        exec_strategy.num_iteration_per_drop_scope = 1
        if FLAGS.dist:
            dist_utils.prepare_for_multi_process(exe, build_strategy,
                                                 startup_prog, train_prog)
            exec_strategy.num_threads = 1

        exe.run(startup_prog)
        compiled_train_prog = fluid.CompiledProgram(
            train_prog).with_data_parallel(loss_name=loss.name,
                                           build_strategy=build_strategy,
                                           exec_strategy=exec_strategy)
        if FLAGS.eval:
            compiled_eval_prog = fluid.compiler.CompiledProgram(eval_prog)

        train_loader.set_sample_list_generator(train_reader, place)

        train_stats = TrainingStats(cfg.log_smooth_window, train_keys)
        train_loader.start()
        end_time = time.time()

        cfg_name = os.path.basename(FLAGS.config).split('.')[0]
        save_dir = os.path.join(cfg.save_dir, cfg_name)
        time_stat = deque(maxlen=cfg.log_smooth_window)
        ap = 0
        for it in range(start_iter, cfg.max_iters):
            start_time = end_time
            end_time = time.time()
            time_stat.append(end_time - start_time)
            time_cost = np.mean(time_stat)
            eta_sec = (cfg.max_iters - it) * time_cost
            eta = str(datetime.timedelta(seconds=int(eta_sec)))
            outs = exe.run(compiled_train_prog, fetch_list=train_values)
            stats = {
                k: np.array(v).mean()
                for k, v in zip(train_keys, outs[:-1])
            }

            train_stats.update(stats)
            logs = train_stats.log()
            if it % cfg.log_iter == 0 and (not FLAGS.dist or trainer_id == 0):
                strs = 'iter: {}, lr: {:.6f}, {}, time: {:.3f}, eta: {}'.format(
                    it, np.mean(outs[-1]), logs, time_cost, eta)
                logger.info(strs)

            if (it > 0 and it == cfg.max_iters - 1) and (not FLAGS.dist
                                                         or trainer_id == 0):
                save_name = str(
                    it) if it != cfg.max_iters - 1 else "model_final"
                checkpoint.save(exe, train_prog,
                                os.path.join(save_dir, save_name))
                if FLAGS.eval:
                    # evaluation
                    results = eval_run(exe, compiled_eval_prog, eval_loader,
                                       eval_keys, eval_values, eval_cls)
                    ap = calculate_ap_py(results)

        train_loader.reset()
        eval_loader.reset()
        logger.info('rewards: ap is {}'.format(ap))
        sa_nas.reward(float(ap))
    current_best_tokens = sa_nas.current_info()['best_tokens']
    logger.info("All steps end, the best BlazeFace-NAS structure  is: ")
    sa_nas.tokens2arch(current_best_tokens)
def search(config, args, image_size, is_server=True):
    if is_server:
        ### start a server and a client
        sa_nas = SANAS(
            config,
            server_addr=(args.server_address, args.port),
            search_steps=args.search_steps,
            is_server=True)
    else:
        ### start a client
        sa_nas = SANAS(
            config,
            server_addr=(args.server_address, args.port),
            init_temperature=init_temperature,
            is_server=False)

    image_shape = [3, image_size, image_size]
    for step in range(args.search_steps):
        archs = sa_nas.next_archs()[0]

        train_program = fluid.Program()
        test_program = fluid.Program()
        startup_program = fluid.Program()
        train_fetch_list, train_loader = build_program(
            train_program,
            startup_program,
            image_shape,
            archs,
            args,
            is_train=True)

        current_params = count_parameters_in_MB(
            train_program.global_block().all_parameters(), 'cifar10')
        _logger.info('step: {}, current_params: {}M'.format(step,
                                                            current_params))
        if current_params > float(3.77):
            continue

        test_fetch_list, test_loader = build_program(
            test_program,
            startup_program,
            image_shape,
            archs,
            args,
            is_train=False)
        test_program = test_program.clone(for_test=True)

        place = fluid.CUDAPlace(0) if args.use_gpu else fluid.CPUPlace()
        exe = fluid.Executor(place)
        exe.run(startup_program)

        train_reader = reader.train_valid(
            batch_size=args.batch_size, is_train=True, is_shuffle=True)
        test_reader = reader.train_valid(
            batch_size=args.batch_size, is_train=False, is_shuffle=False)

        train_loader.set_batch_generator(train_reader, places=place)
        test_loader.set_batch_generator(test_reader, places=place)

        build_strategy = fluid.BuildStrategy()
        train_compiled_program = fluid.CompiledProgram(
            train_program).with_data_parallel(
                loss_name=train_fetch_list[0].name,
                build_strategy=build_strategy)

        valid_top1_list = []
        for epoch_id in range(args.retain_epoch):
            train_top1 = train(train_compiled_program, exe, epoch_id,
                               train_loader, train_fetch_list, args)
            _logger.info("TRAIN: step: {}, Epoch {}, train_acc {:.6f}".format(
                step, epoch_id, train_top1))
            valid_top1 = valid(test_program, exe, epoch_id, test_loader,
                               test_fetch_list, args)
            _logger.info("TEST: Epoch {}, valid_acc {:.6f}".format(epoch_id,
                                                                   valid_top1))
            valid_top1_list.append(valid_top1)
        sa_nas.reward(float(valid_top1_list[-1] + valid_top1_list[-2]) / 2)
def search_mobilenetv2_block(config, args, image_size):
    image_shape = [3, image_size, image_size]
    if args.is_server:
        sa_nas = SANAS(config,
                       server_addr=(args.server_address, args.port),
                       search_steps=args.search_steps,
                       is_server=True)
    else:
        sa_nas = SANAS(config,
                       server_addr=(args.server_address, args.port),
                       search_steps=args.search_steps,
                       is_server=False)

    for step in range(args.search_steps):
        archs = sa_nas.next_archs()[0]

        train_program = fluid.Program()
        test_program = fluid.Program()
        startup_program = fluid.Program()
        with fluid.program_guard(train_program, startup_program):
            train_loader, data, label = create_data_loader(image_shape)
            data = conv_bn_layer(input=data,
                                 num_filters=32,
                                 filter_size=3,
                                 stride=2,
                                 padding='SAME',
                                 act='relu6',
                                 name='mobilenetv2_conv1')
            data = archs(data)[0]
            data = conv_bn_layer(input=data,
                                 num_filters=1280,
                                 filter_size=1,
                                 stride=1,
                                 padding='SAME',
                                 act='relu6',
                                 name='mobilenetv2_last_conv')
            data = fluid.layers.pool2d(input=data,
                                       pool_size=7,
                                       pool_stride=1,
                                       pool_type='avg',
                                       global_pooling=True,
                                       name='mobilenetv2_last_pool')
            output = fluid.layers.fc(
                input=data,
                size=args.class_dim,
                param_attr=ParamAttr(name='mobilenetv2_fc_weights'),
                bias_attr=ParamAttr(name='mobilenetv2_fc_offset'))

            softmax_out = fluid.layers.softmax(input=output, use_cudnn=False)
            cost = fluid.layers.cross_entropy(input=softmax_out, label=label)
            avg_cost = fluid.layers.mean(cost)
            acc_top1 = fluid.layers.accuracy(input=softmax_out,
                                             label=label,
                                             k=1)
            acc_top5 = fluid.layers.accuracy(input=softmax_out,
                                             label=label,
                                             k=5)
            test_program = train_program.clone(for_test=True)

            optimizer = fluid.optimizer.Momentum(
                learning_rate=0.1,
                momentum=0.9,
                regularization=fluid.regularizer.L2Decay(1e-4))
            optimizer.minimize(avg_cost)

        current_flops = flops(train_program)
        print('step: {}, current_flops: {}'.format(step, current_flops))
        if current_flops > int(321208544):
            continue

        place = fluid.CUDAPlace(0) if args.use_gpu else fluid.CPUPlace()
        exe = fluid.Executor(place)
        exe.run(startup_program)

        if args.data == 'cifar10':
            train_reader = paddle.fluid.io.batch(paddle.reader.shuffle(
                paddle.dataset.cifar.train10(cycle=False), buf_size=1024),
                                                 batch_size=args.batch_size,
                                                 drop_last=True)

            test_reader = paddle.fluid.io.batch(
                paddle.dataset.cifar.test10(cycle=False),
                batch_size=args.batch_size,
                drop_last=False)
        elif args.data == 'imagenet':
            train_reader = paddle.fluid.io.batch(imagenet_reader.train(),
                                                 batch_size=args.batch_size,
                                                 drop_last=True)
            test_reader = paddle.fluid.io.batch(imagenet_reader.val(),
                                                batch_size=args.batch_size,
                                                drop_last=False)

        test_loader, _, _ = create_data_loader(image_shape)
        train_loader.set_sample_list_generator(
            train_reader,
            places=fluid.cuda_places() if args.use_gpu else fluid.cpu_places())
        test_loader.set_sample_list_generator(test_reader, places=place)

        build_strategy = fluid.BuildStrategy()
        train_compiled_program = fluid.CompiledProgram(
            train_program).with_data_parallel(loss_name=avg_cost.name,
                                              build_strategy=build_strategy)
        for epoch_id in range(args.retain_epoch):
            for batch_id, data in enumerate(train_loader()):
                fetches = [avg_cost.name]
                s_time = time.time()
                outs = exe.run(train_compiled_program,
                               feed=data,
                               fetch_list=fetches)[0]
                batch_time = time.time() - s_time
                if batch_id % 10 == 0:
                    _logger.info(
                        'TRAIN: steps: {}, epoch: {}, batch: {}, cost: {}, batch_time: {}ms'
                        .format(step, epoch_id, batch_id, outs[0], batch_time))

        reward = []
        for batch_id, data in enumerate(test_loader()):
            test_fetches = [avg_cost.name, acc_top1.name, acc_top5.name]
            batch_reward = exe.run(test_program,
                                   feed=data,
                                   fetch_list=test_fetches)
            reward_avg = np.mean(np.array(batch_reward), axis=1)
            reward.append(reward_avg)

            _logger.info(
                'TEST: step: {}, batch: {}, avg_cost: {}, acc_top1: {}, acc_top5: {}'
                .format(step, batch_id, batch_reward[0], batch_reward[1],
                        batch_reward[2]))

        finally_reward = np.mean(np.array(reward), axis=0)
        _logger.info(
            'FINAL TEST: avg_cost: {}, acc_top1: {}, acc_top5: {}'.format(
                finally_reward[0], finally_reward[1], finally_reward[2]))

        sa_nas.reward(float(finally_reward[1]))
def search_mobilenetv2(config, args, image_size, is_server=True):
    if is_server:
        ### start a server and a client
        sa_nas = SANAS(config,
                       server_addr=(args.server_address, args.port),
                       search_steps=args.search_steps,
                       is_server=True)
    else:
        ### start a client
        sa_nas = SANAS(config,
                       server_addr=(args.server_address, args.port),
                       search_steps=args.search_steps,
                       is_server=False)

    image_shape = [3, image_size, image_size]
    for step in range(args.search_steps):
        archs = sa_nas.next_archs()[0]

        train_program = fluid.Program()
        test_program = fluid.Program()
        startup_program = fluid.Program()
        train_loader, avg_cost, acc_top1, acc_top5 = build_program(
            train_program, startup_program, image_shape, archs, args)

        current_flops = flops(train_program)
        print('step: {}, current_flops: {}'.format(step, current_flops))
        if current_flops > int(321208544):
            continue

        test_loader, test_avg_cost, test_acc_top1, test_acc_top5 = build_program(
            test_program,
            startup_program,
            image_shape,
            archs,
            args,
            is_test=True)
        test_program = test_program.clone(for_test=True)

        place = fluid.CUDAPlace(0) if args.use_gpu else fluid.CPUPlace()
        exe = fluid.Executor(place)
        exe.run(startup_program)

        if args.data == 'cifar10':
            train_reader = paddle.batch(paddle.reader.shuffle(
                paddle.dataset.cifar.train10(cycle=False), buf_size=1024),
                                        batch_size=args.batch_size,
                                        drop_last=True)

            test_reader = paddle.batch(
                paddle.dataset.cifar.test10(cycle=False),
                batch_size=args.batch_size,
                drop_last=False)
        elif args.data == 'imagenet':
            train_reader = paddle.batch(imagenet_reader.train(),
                                        batch_size=args.batch_size,
                                        drop_last=True)
            test_reader = paddle.batch(imagenet_reader.val(),
                                       batch_size=args.batch_size,
                                       drop_last=False)

        train_loader.set_sample_list_generator(
            train_reader,
            places=fluid.cuda_places() if args.use_gpu else fluid.cpu_places())
        test_loader.set_sample_list_generator(test_reader, places=place)

        build_strategy = fluid.BuildStrategy()
        train_compiled_program = fluid.CompiledProgram(
            train_program).with_data_parallel(loss_name=avg_cost.name,
                                              build_strategy=build_strategy)
        for epoch_id in range(args.retain_epoch):
            for batch_id, data in enumerate(train_loader()):
                fetches = [avg_cost.name]
                s_time = time.time()
                outs = exe.run(train_compiled_program,
                               feed=data,
                               fetch_list=fetches)[0]
                batch_time = time.time() - s_time
                if batch_id % 10 == 0:
                    _logger.info(
                        'TRAIN: steps: {}, epoch: {}, batch: {}, cost: {}, batch_time: {}ms'
                        .format(step, epoch_id, batch_id, outs[0], batch_time))

        reward = []
        for batch_id, data in enumerate(test_loader()):
            test_fetches = [
                test_avg_cost.name, test_acc_top1.name, test_acc_top5.name
            ]
            batch_reward = exe.run(test_program,
                                   feed=data,
                                   fetch_list=test_fetches)
            reward_avg = np.mean(np.array(batch_reward), axis=1)
            reward.append(reward_avg)

            _logger.info(
                'TEST: step: {}, batch: {}, avg_cost: {}, acc_top1: {}, acc_top5: {}'
                .format(step, batch_id, batch_reward[0], batch_reward[1],
                        batch_reward[2]))

        finally_reward = np.mean(np.array(reward), axis=0)
        _logger.info(
            'FINAL TEST: avg_cost: {}, acc_top1: {}, acc_top5: {}'.format(
                finally_reward[0], finally_reward[1], finally_reward[2]))

        sa_nas.reward(float(finally_reward[1]))
Exemple #7
0
class TestSANAS(unittest.TestCase):
    def setUp(self):
        self.init_test_case()
        port = np.random.randint(8337, 8773)
        self.sanas = SANAS(configs=self.configs,
                           server_addr=("", port),
                           save_checkpoint=None)

    def init_test_case(self):
        self.configs = [('MobileNetV2BlockSpace', {'block_mask': [0]})]
        self.filter_num = np.array([
            3, 4, 8, 12, 16, 24, 32, 48, 64, 80, 96, 128, 144, 160, 192, 224,
            256, 320, 384, 512
        ])
        self.k_size = np.array([3, 5])
        self.multiply = np.array([1, 2, 3, 4, 5, 6])
        self.repeat = np.array([1, 2, 3, 4, 5, 6])

    def check_chnum_convnum(self, program):
        current_tokens = self.sanas.current_info()['current_tokens']
        channel_exp = self.multiply[current_tokens[0]]
        filter_num = self.filter_num[current_tokens[1]]
        repeat_num = self.repeat[current_tokens[2]]

        conv_list, ch_pro = compute_op_num(program)
        ### assert conv number
        self.assertTrue((repeat_num * 3) == len(
            conv_list
        ), "the number of conv is NOT match, the number compute from token: {}, actual conv number: {}"
                        .format(repeat_num * 3, len(conv_list)))

        ### assert number of channels
        ch_token = []
        init_ch_num = 32
        for i in range(repeat_num):
            ch_token.append(init_ch_num * channel_exp)
            ch_token.append(init_ch_num * channel_exp)
            ch_token.append(filter_num)
            init_ch_num = filter_num

        self.assertTrue(
            str(ch_token) == str(ch_pro),
            "channel num is WRONG, channel num from token is {}, channel num come fom program is {}"
            .format(str(ch_token), str(ch_pro)))

    def test_all_function(self):
        ### unittest for next_archs
        next_program = fluid.Program()
        startup_program = fluid.Program()
        token2arch_program = fluid.Program()

        with fluid.program_guard(next_program, startup_program):
            inputs = fluid.data(name='input',
                                shape=[None, 3, 32, 32],
                                dtype='float32')
            archs = self.sanas.next_archs()
            for arch in archs:
                output = arch(inputs)
                inputs = output
        self.check_chnum_convnum(next_program)

        ### unittest for reward
        self.assertTrue(self.sanas.reward(float(1.0)), "reward is False")

        ### uniitest for tokens2arch
        with fluid.program_guard(token2arch_program, startup_program):
            inputs = fluid.data(name='input',
                                shape=[None, 3, 32, 32],
                                dtype='float32')
            arch = self.sanas.tokens2arch(
                self.sanas.current_info()['current_tokens'])
            for arch in archs:
                output = arch(inputs)
                inputs = output
        self.check_chnum_convnum(token2arch_program)

        ### unittest for current_info
        current_info = self.sanas.current_info()
        self.assertTrue(
            isinstance(current_info, dict),
            "the type of current info must be dict, but now is {}".format(
                type(current_info)))
def search_mobilenetv2(config, args, image_size, is_server=True):
    image_shape = [3, image_size, image_size]
    if args.data == 'cifar10':
        transform = T.Compose([T.Transpose(), T.Normalize([127.5], [127.5])])
        train_dataset = paddle.vision.datasets.Cifar10(mode='train',
                                                       transform=transform,
                                                       backend='cv2')
        val_dataset = paddle.vision.datasets.Cifar10(mode='test',
                                                     transform=transform,
                                                     backend='cv2')

    elif args.data == 'imagenet':
        train_dataset = imagenet_reader.ImageNetDataset(mode='train')
        val_dataset = imagenet_reader.ImageNetDataset(mode='val')

    places = static.cuda_places() if args.use_gpu else static.cpu_places()
    place = places[0]
    if is_server:
        ### start a server and a client
        sa_nas = SANAS(config,
                       server_addr=(args.server_address, args.port),
                       search_steps=args.search_steps,
                       is_server=True)
    else:
        ### start a client
        sa_nas = SANAS(config,
                       server_addr=(args.server_address, args.port),
                       search_steps=args.search_steps,
                       is_server=False)

    for step in range(args.search_steps):
        archs = sa_nas.next_archs()[0]

        train_program = static.Program()
        test_program = static.Program()
        startup_program = static.Program()
        train_loader, avg_cost, acc_top1, acc_top5 = build_program(
            train_program, startup_program, image_shape, train_dataset, archs,
            args, places)

        current_flops = flops(train_program)
        print('step: {}, current_flops: {}'.format(step, current_flops))
        if current_flops > int(321208544):
            continue

        test_loader, test_avg_cost, test_acc_top1, test_acc_top5 = build_program(
            test_program,
            startup_program,
            image_shape,
            val_dataset,
            archs,
            args,
            place,
            is_test=True)
        test_program = test_program.clone(for_test=True)

        exe = static.Executor(place)
        exe.run(startup_program)

        build_strategy = static.BuildStrategy()
        train_compiled_program = static.CompiledProgram(
            train_program).with_data_parallel(loss_name=avg_cost.name,
                                              build_strategy=build_strategy)
        for epoch_id in range(args.retain_epoch):
            for batch_id, data in enumerate(train_loader()):
                fetches = [avg_cost.name]
                s_time = time.time()
                outs = exe.run(train_compiled_program,
                               feed=data,
                               fetch_list=fetches)[0]
                batch_time = time.time() - s_time
                if batch_id % 10 == 0:
                    _logger.info(
                        'TRAIN: steps: {}, epoch: {}, batch: {}, cost: {}, batch_time: {}ms'
                        .format(step, epoch_id, batch_id, outs[0], batch_time))

        reward = []
        for batch_id, data in enumerate(test_loader()):
            test_fetches = [
                test_avg_cost.name, test_acc_top1.name, test_acc_top5.name
            ]
            batch_reward = exe.run(test_program,
                                   feed=data,
                                   fetch_list=test_fetches)
            reward_avg = np.mean(np.array(batch_reward), axis=1)
            reward.append(reward_avg)

            _logger.info(
                'TEST: step: {}, batch: {}, avg_cost: {}, acc_top1: {}, acc_top5: {}'
                .format(step, batch_id, batch_reward[0], batch_reward[1],
                        batch_reward[2]))

        finally_reward = np.mean(np.array(reward), axis=0)
        _logger.info(
            'FINAL TEST: avg_cost: {}, acc_top1: {}, acc_top5: {}'.format(
                finally_reward[0], finally_reward[1], finally_reward[2]))

        sa_nas.reward(float(finally_reward[1]))
def search_mobilenetv2_block(config, args, image_size):
    image_shape = [3, image_size, image_size]
    transform = T.Compose([T.Transpose(), T.Normalize([127.5], [127.5])])
    if args.data == 'cifar10':
        train_dataset = paddle.vision.datasets.Cifar10(mode='train',
                                                       transform=transform,
                                                       backend='cv2')
        val_dataset = paddle.vision.datasets.Cifar10(mode='test',
                                                     transform=transform,
                                                     backend='cv2')

    elif args.data == 'imagenet':
        train_dataset = imagenet_reader.ImageNetDataset(mode='train')
        val_dataset = imagenet_reader.ImageNetDataset(mode='val')

    places = static.cuda_places() if args.use_gpu else static.cpu_places()
    place = places[0]
    if args.is_server:
        sa_nas = SANAS(config,
                       server_addr=(args.server_address, args.port),
                       search_steps=args.search_steps,
                       is_server=True)
    else:
        sa_nas = SANAS(config,
                       server_addr=(args.server_address, args.port),
                       search_steps=args.search_steps,
                       is_server=False)

    for step in range(args.search_steps):
        archs = sa_nas.next_archs()[0]

        train_program = static.Program()
        test_program = static.Program()
        startup_program = static.Program()
        with static.program_guard(train_program, startup_program):
            data_shape = [None] + image_shape
            data = static.data(name='data', shape=data_shape, dtype='float32')
            label = static.data(name='label', shape=[None, 1], dtype='int64')
            if args.data == 'cifar10':
                paddle.assign(paddle.reshape(label, [-1, 1]), label)
            train_loader = paddle.io.DataLoader(train_dataset,
                                                places=places,
                                                feed_list=[data, label],
                                                drop_last=True,
                                                batch_size=args.batch_size,
                                                return_list=False,
                                                shuffle=True,
                                                use_shared_memory=True,
                                                num_workers=4)
            val_loader = paddle.io.DataLoader(val_dataset,
                                              places=place,
                                              feed_list=[data, label],
                                              drop_last=False,
                                              batch_size=args.batch_size,
                                              return_list=False,
                                              shuffle=False)
            data = conv_bn_layer(input=data,
                                 num_filters=32,
                                 filter_size=3,
                                 stride=2,
                                 padding='SAME',
                                 act='relu6',
                                 name='mobilenetv2_conv1')
            data = archs(data)[0]
            data = conv_bn_layer(input=data,
                                 num_filters=1280,
                                 filter_size=1,
                                 stride=1,
                                 padding='SAME',
                                 act='relu6',
                                 name='mobilenetv2_last_conv')
            data = F.adaptive_avg_pool2d(data,
                                         output_size=[1, 1],
                                         name='mobilenetv2_last_pool')
            output = static.nn.fc(
                x=data,
                size=args.class_dim,
                weight_attr=ParamAttr(name='mobilenetv2_fc_weights'),
                bias_attr=ParamAttr(name='mobilenetv2_fc_offset'))

            softmax_out = F.softmax(output)
            cost = F.cross_entropy(softmax_out, label=label)
            avg_cost = paddle.mean(cost)
            acc_top1 = paddle.metric.accuracy(input=softmax_out,
                                              label=label,
                                              k=1)
            acc_top5 = paddle.metric.accuracy(input=softmax_out,
                                              label=label,
                                              k=5)
            test_program = train_program.clone(for_test=True)

            optimizer = paddle.optimizer.Momentum(
                learning_rate=0.1,
                momentum=0.9,
                weight_decay=paddle.regularizer.L2Decay(1e-4))
            optimizer.minimize(avg_cost)

        current_flops = flops(train_program)
        print('step: {}, current_flops: {}'.format(step, current_flops))
        if current_flops > int(321208544):
            continue

        exe = static.Executor(place)
        exe.run(startup_program)

        build_strategy = static.BuildStrategy()
        train_compiled_program = static.CompiledProgram(
            train_program).with_data_parallel(loss_name=avg_cost.name,
                                              build_strategy=build_strategy)
        for epoch_id in range(args.retain_epoch):
            for batch_id, data in enumerate(train_loader()):
                fetches = [avg_cost.name]
                s_time = time.time()
                outs = exe.run(train_compiled_program,
                               feed=data,
                               fetch_list=fetches)[0]
                batch_time = time.time() - s_time
                if batch_id % 10 == 0:
                    _logger.info(
                        'TRAIN: steps: {}, epoch: {}, batch: {}, cost: {}, batch_time: {}ms'
                        .format(step, epoch_id, batch_id, outs[0], batch_time))

        reward = []
        for batch_id, data in enumerate(val_loader()):
            test_fetches = [avg_cost.name, acc_top1.name, acc_top5.name]
            batch_reward = exe.run(test_program,
                                   feed=data,
                                   fetch_list=test_fetches)
            reward_avg = np.mean(np.array(batch_reward), axis=1)
            reward.append(reward_avg)

            _logger.info(
                'TEST: step: {}, batch: {}, avg_cost: {}, acc_top1: {}, acc_top5: {}'
                .format(step, batch_id, batch_reward[0], batch_reward[1],
                        batch_reward[2]))

        finally_reward = np.mean(np.array(reward), axis=0)
        _logger.info(
            'FINAL TEST: avg_cost: {}, acc_top1: {}, acc_top5: {}'.format(
                finally_reward[0], finally_reward[1], finally_reward[2]))

        sa_nas.reward(float(finally_reward[1]))