def coach_adc(model, dataset, arch, optimizer_data, validate_fn, save_checkpoint_fn, train_fn): # task_parameters = TaskParameters(framework_type="tensorflow", # experiment_path="./experiments/test") # extra_params = {'save_checkpoint_secs': None, # 'render': True} # task_parameters.__dict__.update(extra_params) task_parameters = TaskParameters(experiment_path=logger.get_experiment_path('adc')) conv_cnt = count_conv_layer(model) # Create a dictionary of parameters that Coach will handover to CNNEnvironment # Once it creates it. services = distiller.utils.MutableNamedTuple({ 'validate_fn': validate_fn, 'save_checkpoint_fn': save_checkpoint_fn, 'train_fn': train_fn}) app_args = distiller.utils.MutableNamedTuple({ 'dataset': dataset, 'arch': arch, 'optimizer_data': optimizer_data}) if True: amc_cfg = distiller.utils.MutableNamedTuple({ #'action_range': (0.20, 0.95), 'action_range': (0.20, 0.80), 'onehot_encoding': False, 'normalize_obs': True, 'desired_reduction': None, 'reward_fn': lambda top1, top5, vloss, total_macs: -1 * (1-top1/100) * math.log(total_macs), 'conv_cnt': conv_cnt, 'max_reward': -1000}) else: amc_cfg = distiller.utils.MutableNamedTuple({ 'action_range': (0.10, 0.95), 'onehot_encoding': False, 'normalize_obs': True, 'desired_reduction': 1.5e8, 'reward_fn': lambda top1, top5, vloss, total_macs: top1/100, #'reward_fn': lambda top1, total_macs: min(top1/100, 0.75), 'conv_cnt': conv_cnt, 'max_reward': -1000}) # These parameters are passed to the Distiller environment graph_manager.env_params.additional_simulator_parameters = {'model': model, 'app_args': app_args, 'amc_cfg': amc_cfg, 'services': services} exploration_noise = 0.5 exploitation_decay = 0.996 steps_per_episode = conv_cnt agent_params.exploration.noise_percentage_schedule = PieceWiseSchedule([ (ConstantSchedule(exploration_noise), EnvironmentSteps(100*steps_per_episode)), (ExponentialSchedule(exploration_noise, 0, exploitation_decay), EnvironmentSteps(300*steps_per_episode))]) graph_manager.create_graph(task_parameters) graph_manager.improve()
def get_config_args(self, parser: argparse.ArgumentParser) -> argparse.Namespace: """ Returns a Namespace object with all the user-specified configuration options needed to launch. This implementation uses argparse to take arguments from the CLI, but this can be over-ridden by another method that gets its configuration from elsewhere. An equivalent method however must return an identically structured Namespace object, which conforms to the structure defined by get_argument_parser. This method parses the arguments that the user entered, does some basic validation, and modification of user-specified values in short form to be more explicit. :param parser: a parser object which implicitly defines the format of the Namespace that is expected to be returned. :return: the parsed arguments as a Namespace """ args = parser.parse_args() if args.nocolor: screen.set_use_colors(False) # if no arg is given if len(sys.argv) == 1: parser.print_help() sys.exit(1) # list available presets if args.list: self.display_all_presets_and_exit() # Read args from config file for distributed Coach. if args.distributed_coach and args.distributed_coach_run_type == RunType.ORCHESTRATOR: coach_config = ConfigParser({ 'image': '', 'memory_backend': 'redispubsub', 'data_store': 's3', 's3_end_point': 's3.amazonaws.com', 's3_bucket_name': '', 's3_creds_file': '' }) try: coach_config.read(args.distributed_coach_config_path) args.image = coach_config.get('coach', 'image') args.memory_backend = coach_config.get('coach', 'memory_backend') args.data_store = coach_config.get('coach', 'data_store') if args.data_store == 's3': args.s3_end_point = coach_config.get( 'coach', 's3_end_point') args.s3_bucket_name = coach_config.get( 'coach', 's3_bucket_name') args.s3_creds_file = coach_config.get( 'coach', 's3_creds_file') except Error as e: screen.error( "Error when reading distributed Coach config file: {}". format(e)) if args.image == '': screen.error("Image cannot be empty.") data_store_choices = ['s3', 'nfs'] if args.data_store not in data_store_choices: screen.warning("{} data store is unsupported.".format( args.data_store)) screen.error( "Supported data stores are {}.".format(data_store_choices)) memory_backend_choices = ['redispubsub'] if args.memory_backend not in memory_backend_choices: screen.warning("{} memory backend is not supported.".format( args.memory_backend)) screen.error("Supported memory backends are {}.".format( memory_backend_choices)) if args.data_store == 's3': if args.s3_bucket_name == '': screen.error("S3 bucket name cannot be empty.") if args.s3_creds_file == '': args.s3_creds_file = None if args.play and args.distributed_coach: screen.error("Playing is not supported in distributed Coach.") # replace a short preset name with the full path if args.preset is not None: args.preset = self.expand_preset(args.preset) # validate the checkpoints args if args.checkpoint_restore_dir is not None and not os.path.exists( args.checkpoint_restore_dir): screen.error( "The requested checkpoint folder to load from does not exist.") # validate the checkpoints args if args.checkpoint_restore_file is not None and not glob( args.checkpoint_restore_file + '*'): screen.error( "The requested checkpoint file to load from does not exist.") # no preset was given. check if the user requested to play some environment on its own if args.preset is None and args.play and not args.environment_type: screen.error( 'When no preset is given for Coach to run, and the user requests human control over ' 'the environment, the user is expected to input the desired environment_type and level.' '\nAt least one of these parameters was not given.') elif args.preset and args.play: screen.error( "Both the --preset and the --play flags were set. These flags can not be used together. " "For human control, please use the --play flag together with the environment type flag (-et)" ) elif args.preset is None and not args.play: screen.error( "Please choose a preset using the -p flag or use the --play flag together with choosing an " "environment type (-et) in order to play the game.") # get experiment name and path args.experiment_name = logger.get_experiment_name(args.experiment_name) args.experiment_path = logger.get_experiment_path(args.experiment_name) if args.play and args.num_workers > 1: screen.warning( "Playing the game as a human is only available with a single worker. " "The number of workers will be reduced to 1") args.num_workers = 1 args.framework = Frameworks[args.framework.lower()] # checkpoints args.checkpoint_save_dir = os.path.join( args.experiment_path, 'checkpoint') if args.checkpoint_save_secs is not None else None if args.export_onnx_graph and not args.checkpoint_save_secs: screen.warning( "Exporting ONNX graphs requires setting the --checkpoint_save_secs flag. " "The --export_onnx_graph will have no effect.") return args
def parse_arguments(parser: argparse.ArgumentParser) -> argparse.Namespace: """ Parse the arguments that the user entered :param parser: the argparse command line parser :return: the parsed arguments """ args = parser.parse_args() # if no arg is given if len(sys.argv) == 1: parser.print_help() exit(0) # list available presets preset_names = list_all_presets() if args.list: screen.log_title("Available Presets:") for preset in sorted(preset_names): print(preset) sys.exit(0) # replace a short preset name with the full path if args.preset is not None: if args.preset.lower() in [p.lower() for p in preset_names]: args.preset = "{}.py:graph_manager".format( os.path.join(get_base_dir(), 'presets', args.preset)) else: args.preset = "{}".format(args.preset) # if a graph manager variable was not specified, try the default of :graph_manager if len(args.preset.split(":")) == 1: args.preset += ":graph_manager" # verify that the preset exists preset_path = args.preset.split(":")[0] if not os.path.exists(preset_path): screen.error("The given preset ({}) cannot be found.".format( args.preset)) # verify that the preset can be instantiated try: short_dynamic_import(args.preset, ignore_module_case=True) except TypeError as e: traceback.print_exc() screen.error('Internal Error: ' + str(e) + "\n\nThe given preset ({}) cannot be instantiated.". format(args.preset)) # validate the checkpoints args if args.checkpoint_restore_dir is not None and not os.path.exists( args.checkpoint_restore_dir): screen.error( "The requested checkpoint folder to load from does not exist.") # no preset was given. check if the user requested to play some environment on its own if args.preset is None and args.play: if args.environment_type: args.agent_type = 'Human' else: screen.error( 'When no preset is given for Coach to run, and the user requests human control over ' 'the environment, the user is expected to input the desired environment_type and level.' '\nAt least one of these parameters was not given.') elif args.preset and args.play: screen.error( "Both the --preset and the --play flags were set. These flags can not be used together. " "For human control, please use the --play flag together with the environment type flag (-et)" ) elif args.preset is None and not args.play: screen.error( "Please choose a preset using the -p flag or use the --play flag together with choosing an " "environment type (-et) in order to play the game.") # get experiment name and path args.experiment_name = logger.get_experiment_name(args.experiment_name) args.experiment_path = logger.get_experiment_path(args.experiment_name) if args.play and args.num_workers > 1: screen.warning( "Playing the game as a human is only available with a single worker. " "The number of workers will be reduced to 1") args.num_workers = 1 args.framework = Frameworks[args.framework.lower()] # checkpoints args.save_checkpoint_dir = os.path.join( args.experiment_path, 'checkpoint') if args.save_checkpoint_secs is not None else None return args