def main(): env = os.environ FLAGS.dist = 'PADDLE_TRAINER_ID' in env and 'PADDLE_TRAINERS_NUM' in env if FLAGS.dist: trainer_id = int(env['PADDLE_TRAINER_ID']) import random local_seed = (99 + trainer_id) random.seed(local_seed) np.random.seed(local_seed) cfg = load_config(FLAGS.config) if 'architecture' in cfg: main_arch = cfg.architecture else: raise ValueError("'architecture' not specified in config file.") merge_config(FLAGS.opt) if 'log_iter' not in cfg: cfg.log_iter = 20 # check if set use_gpu=True in paddlepaddle cpu version check_gpu(cfg.use_gpu) # check if paddlepaddle version is satisfied check_version() if cfg.use_gpu: devices_num = fluid.core.get_cuda_device_count() else: devices_num = int(os.environ.get('CPU_NUM', 1)) if 'FLAGS_selected_gpus' in env: device_id = int(env['FLAGS_selected_gpus']) else: device_id = 0 place = fluid.CUDAPlace(device_id) if cfg.use_gpu else fluid.CPUPlace() exe = fluid.Executor(place) lr_builder = create('LearningRate') optim_builder = create('OptimizerBuilder') # add NAS config = ([(cfg.search_space)]) server_address = (cfg.server_ip, cfg.server_port) load_checkpoint = FLAGS.resume_checkpoint if FLAGS.resume_checkpoint else None sa_nas = SANAS(config, server_addr=server_address, init_temperature=cfg.init_temperature, reduce_rate=cfg.reduce_rate, search_steps=cfg.search_steps, save_checkpoint=cfg.save_dir, load_checkpoint=load_checkpoint, is_server=cfg.is_server) start_iter = 0 train_reader = create_reader(cfg.TrainReader, (cfg.max_iters - start_iter) * devices_num, cfg) eval_reader = create_reader(cfg.EvalReader) constraint = create('Constraint') for step in range(cfg.search_steps): logger.info('----->>> search step: {} <<<------'.format(step)) archs = sa_nas.next_archs()[0] # build program startup_prog = fluid.Program() train_prog = fluid.Program() with fluid.program_guard(train_prog, startup_prog): with fluid.unique_name.guard(): model = create(main_arch) if FLAGS.fp16: assert (getattr(model.backbone, 'norm_type', None) != 'affine_channel'), \ '--fp16 currently does not support affine channel, ' \ ' please modify backbone settings to use batch norm' with mixed_precision_context(FLAGS.loss_scale, FLAGS.fp16) as ctx: inputs_def = cfg['TrainReader']['inputs_def'] feed_vars, train_loader = model.build_inputs(**inputs_def) train_fetches = archs(feed_vars, 'train', cfg) loss = train_fetches['loss'] if FLAGS.fp16: loss *= ctx.get_loss_scale_var() lr = lr_builder() optimizer = optim_builder(lr) optimizer.minimize(loss) if FLAGS.fp16: loss /= ctx.get_loss_scale_var() current_constraint = constraint.compute_constraint(train_prog) logger.info('current steps: {}, constraint {}'.format( step, current_constraint)) if (constraint.max_constraint != None and current_constraint > constraint.max_constraint) or ( constraint.min_constraint != None and current_constraint < constraint.min_constraint): continue # parse train fetches train_keys, train_values, _ = parse_fetches(train_fetches) train_values.append(lr) if FLAGS.eval: eval_prog = fluid.Program() with fluid.program_guard(eval_prog, startup_prog): with fluid.unique_name.guard(): model = create(main_arch) inputs_def = cfg['EvalReader']['inputs_def'] feed_vars, eval_loader = model.build_inputs(**inputs_def) fetches = archs(feed_vars, 'eval', cfg) eval_prog = eval_prog.clone(True) eval_loader.set_sample_list_generator(eval_reader, place) extra_keys = ['im_id', 'im_shape', 'gt_bbox'] eval_keys, eval_values, eval_cls = parse_fetches( fetches, eval_prog, extra_keys) # compile program for multi-devices build_strategy = fluid.BuildStrategy() build_strategy.fuse_all_optimizer_ops = False build_strategy.fuse_elewise_add_act_ops = True exec_strategy = fluid.ExecutionStrategy() # iteration number when CompiledProgram tries to drop local execution scopes. # Set it to be 1 to save memory usages, so that unused variables in # local execution scopes can be deleted after each iteration. exec_strategy.num_iteration_per_drop_scope = 1 if FLAGS.dist: dist_utils.prepare_for_multi_process(exe, build_strategy, startup_prog, train_prog) exec_strategy.num_threads = 1 exe.run(startup_prog) compiled_train_prog = fluid.CompiledProgram( train_prog).with_data_parallel(loss_name=loss.name, build_strategy=build_strategy, exec_strategy=exec_strategy) if FLAGS.eval: compiled_eval_prog = fluid.compiler.CompiledProgram(eval_prog) train_loader.set_sample_list_generator(train_reader, place) train_stats = TrainingStats(cfg.log_smooth_window, train_keys) train_loader.start() end_time = time.time() cfg_name = os.path.basename(FLAGS.config).split('.')[0] save_dir = os.path.join(cfg.save_dir, cfg_name) time_stat = deque(maxlen=cfg.log_smooth_window) ap = 0 for it in range(start_iter, cfg.max_iters): start_time = end_time end_time = time.time() time_stat.append(end_time - start_time) time_cost = np.mean(time_stat) eta_sec = (cfg.max_iters - it) * time_cost eta = str(datetime.timedelta(seconds=int(eta_sec))) outs = exe.run(compiled_train_prog, fetch_list=train_values) stats = { k: np.array(v).mean() for k, v in zip(train_keys, outs[:-1]) } train_stats.update(stats) logs = train_stats.log() if it % cfg.log_iter == 0 and (not FLAGS.dist or trainer_id == 0): strs = 'iter: {}, lr: {:.6f}, {}, time: {:.3f}, eta: {}'.format( it, np.mean(outs[-1]), logs, time_cost, eta) logger.info(strs) if (it > 0 and it == cfg.max_iters - 1) and (not FLAGS.dist or trainer_id == 0): save_name = str( it) if it != cfg.max_iters - 1 else "model_final" checkpoint.save(exe, train_prog, os.path.join(save_dir, save_name)) if FLAGS.eval: # evaluation results = eval_run(exe, compiled_eval_prog, eval_loader, eval_keys, eval_values, eval_cls) ap = calculate_ap_py(results) train_loader.reset() eval_loader.reset() logger.info('rewards: ap is {}'.format(ap)) sa_nas.reward(float(ap)) current_best_tokens = sa_nas.current_info()['best_tokens'] logger.info("All steps end, the best BlazeFace-NAS structure is: ") sa_nas.tokens2arch(current_best_tokens)
class TestSANAS(unittest.TestCase): def setUp(self): self.init_test_case() port = np.random.randint(8337, 8773) self.sanas = SANAS(configs=self.configs, server_addr=("", port), save_checkpoint=None) def init_test_case(self): self.configs = [('MobileNetV2BlockSpace', {'block_mask': [0]})] self.filter_num = np.array([ 3, 4, 8, 12, 16, 24, 32, 48, 64, 80, 96, 128, 144, 160, 192, 224, 256, 320, 384, 512 ]) self.k_size = np.array([3, 5]) self.multiply = np.array([1, 2, 3, 4, 5, 6]) self.repeat = np.array([1, 2, 3, 4, 5, 6]) def check_chnum_convnum(self, program): current_tokens = self.sanas.current_info()['current_tokens'] channel_exp = self.multiply[current_tokens[0]] filter_num = self.filter_num[current_tokens[1]] repeat_num = self.repeat[current_tokens[2]] conv_list, ch_pro = compute_op_num(program) ### assert conv number self.assertTrue((repeat_num * 3) == len( conv_list ), "the number of conv is NOT match, the number compute from token: {}, actual conv number: {}" .format(repeat_num * 3, len(conv_list))) ### assert number of channels ch_token = [] init_ch_num = 32 for i in range(repeat_num): ch_token.append(init_ch_num * channel_exp) ch_token.append(init_ch_num * channel_exp) ch_token.append(filter_num) init_ch_num = filter_num self.assertTrue( str(ch_token) == str(ch_pro), "channel num is WRONG, channel num from token is {}, channel num come fom program is {}" .format(str(ch_token), str(ch_pro))) def test_all_function(self): ### unittest for next_archs next_program = fluid.Program() startup_program = fluid.Program() token2arch_program = fluid.Program() with fluid.program_guard(next_program, startup_program): inputs = fluid.data(name='input', shape=[None, 3, 32, 32], dtype='float32') archs = self.sanas.next_archs() for arch in archs: output = arch(inputs) inputs = output self.check_chnum_convnum(next_program) ### unittest for reward self.assertTrue(self.sanas.reward(float(1.0)), "reward is False") ### uniitest for tokens2arch with fluid.program_guard(token2arch_program, startup_program): inputs = fluid.data(name='input', shape=[None, 3, 32, 32], dtype='float32') arch = self.sanas.tokens2arch( self.sanas.current_info()['current_tokens']) for arch in archs: output = arch(inputs) inputs = output self.check_chnum_convnum(token2arch_program) ### unittest for current_info current_info = self.sanas.current_info() self.assertTrue( isinstance(current_info, dict), "the type of current info must be dict, but now is {}".format( type(current_info)))