Пример #1
0
def coach_adc(model, dataset, arch, optimizer_data, validate_fn, save_checkpoint_fn, train_fn):
    # task_parameters = TaskParameters(framework_type="tensorflow",
    #                                  experiment_path="./experiments/test")
    # extra_params = {'save_checkpoint_secs': None,
    #                 'render': True}
    # task_parameters.__dict__.update(extra_params)
    task_parameters = TaskParameters(experiment_path=logger.get_experiment_path('adc'))
    conv_cnt = count_conv_layer(model)

    # Create a dictionary of parameters that Coach will handover to CNNEnvironment
    # Once it creates it.
    services = distiller.utils.MutableNamedTuple({
                'validate_fn': validate_fn,
                'save_checkpoint_fn': save_checkpoint_fn,
                'train_fn': train_fn})

    app_args = distiller.utils.MutableNamedTuple({
                'dataset': dataset,
                'arch': arch,
                'optimizer_data': optimizer_data})
    if True:
        amc_cfg = distiller.utils.MutableNamedTuple({
                #'action_range': (0.20, 0.95),
                'action_range': (0.20, 0.80),
                'onehot_encoding': False,
                'normalize_obs': True,
                'desired_reduction': None,
                'reward_fn': lambda top1, top5, vloss, total_macs: -1 * (1-top1/100) * math.log(total_macs),
                'conv_cnt': conv_cnt,
                'max_reward': -1000})
    else:
        amc_cfg = distiller.utils.MutableNamedTuple({
                'action_range': (0.10, 0.95),
                'onehot_encoding': False,
                'normalize_obs': True,
                'desired_reduction': 1.5e8,
                'reward_fn': lambda top1, top5, vloss, total_macs: top1/100,
                #'reward_fn': lambda top1, total_macs: min(top1/100, 0.75),
                'conv_cnt': conv_cnt,
                'max_reward': -1000})

    # These parameters are passed to the Distiller environment
    graph_manager.env_params.additional_simulator_parameters = {'model': model,
                                                                'app_args': app_args,
                                                                'amc_cfg': amc_cfg,
                                                                'services': services}
    exploration_noise = 0.5
    exploitation_decay = 0.996
    steps_per_episode = conv_cnt
    agent_params.exploration.noise_percentage_schedule = PieceWiseSchedule([
        (ConstantSchedule(exploration_noise), EnvironmentSteps(100*steps_per_episode)),
        (ExponentialSchedule(exploration_noise, 0, exploitation_decay), EnvironmentSteps(300*steps_per_episode))])
    graph_manager.create_graph(task_parameters)
    graph_manager.improve()
Пример #2
0
    def get_config_args(self,
                        parser: argparse.ArgumentParser) -> argparse.Namespace:
        """
        Returns a Namespace object with all the user-specified configuration options needed to launch.
        This implementation uses argparse to take arguments from the CLI, but this can be over-ridden by
        another method that gets its configuration from elsewhere.  An equivalent method however must
        return an identically structured Namespace object, which conforms to the structure defined by
        get_argument_parser.

        This method parses the arguments that the user entered, does some basic validation, and
        modification of user-specified values in short form to be more explicit.

        :param parser: a parser object which implicitly defines the format of the Namespace that
                       is expected to be returned.
        :return: the parsed arguments as a Namespace
        """
        args = parser.parse_args()

        if args.nocolor:
            screen.set_use_colors(False)

        # if no arg is given
        if len(sys.argv) == 1:
            parser.print_help()
            sys.exit(1)

        # list available presets
        if args.list:
            self.display_all_presets_and_exit()

        # Read args from config file for distributed Coach.
        if args.distributed_coach and args.distributed_coach_run_type == RunType.ORCHESTRATOR:
            coach_config = ConfigParser({
                'image': '',
                'memory_backend': 'redispubsub',
                'data_store': 's3',
                's3_end_point': 's3.amazonaws.com',
                's3_bucket_name': '',
                's3_creds_file': ''
            })
            try:
                coach_config.read(args.distributed_coach_config_path)
                args.image = coach_config.get('coach', 'image')
                args.memory_backend = coach_config.get('coach',
                                                       'memory_backend')
                args.data_store = coach_config.get('coach', 'data_store')
                if args.data_store == 's3':
                    args.s3_end_point = coach_config.get(
                        'coach', 's3_end_point')
                    args.s3_bucket_name = coach_config.get(
                        'coach', 's3_bucket_name')
                    args.s3_creds_file = coach_config.get(
                        'coach', 's3_creds_file')
            except Error as e:
                screen.error(
                    "Error when reading distributed Coach config file: {}".
                    format(e))

            if args.image == '':
                screen.error("Image cannot be empty.")

            data_store_choices = ['s3', 'nfs']
            if args.data_store not in data_store_choices:
                screen.warning("{} data store is unsupported.".format(
                    args.data_store))
                screen.error(
                    "Supported data stores are {}.".format(data_store_choices))

            memory_backend_choices = ['redispubsub']
            if args.memory_backend not in memory_backend_choices:
                screen.warning("{} memory backend is not supported.".format(
                    args.memory_backend))
                screen.error("Supported memory backends are {}.".format(
                    memory_backend_choices))

            if args.data_store == 's3':
                if args.s3_bucket_name == '':
                    screen.error("S3 bucket name cannot be empty.")
                if args.s3_creds_file == '':
                    args.s3_creds_file = None

        if args.play and args.distributed_coach:
            screen.error("Playing is not supported in distributed Coach.")

        # replace a short preset name with the full path
        if args.preset is not None:
            args.preset = self.expand_preset(args.preset)

        # validate the checkpoints args
        if args.checkpoint_restore_dir is not None and not os.path.exists(
                args.checkpoint_restore_dir):
            screen.error(
                "The requested checkpoint folder to load from does not exist.")

        # validate the checkpoints args
        if args.checkpoint_restore_file is not None and not glob(
                args.checkpoint_restore_file + '*'):
            screen.error(
                "The requested checkpoint file to load from does not exist.")

        # no preset was given. check if the user requested to play some environment on its own
        if args.preset is None and args.play and not args.environment_type:
            screen.error(
                'When no preset is given for Coach to run, and the user requests human control over '
                'the environment, the user is expected to input the desired environment_type and level.'
                '\nAt least one of these parameters was not given.')
        elif args.preset and args.play:
            screen.error(
                "Both the --preset and the --play flags were set. These flags can not be used together. "
                "For human control, please use the --play flag together with the environment type flag (-et)"
            )
        elif args.preset is None and not args.play:
            screen.error(
                "Please choose a preset using the -p flag or use the --play flag together with choosing an "
                "environment type (-et) in order to play the game.")

        # get experiment name and path
        args.experiment_name = logger.get_experiment_name(args.experiment_name)
        args.experiment_path = logger.get_experiment_path(args.experiment_name)

        if args.play and args.num_workers > 1:
            screen.warning(
                "Playing the game as a human is only available with a single worker. "
                "The number of workers will be reduced to 1")
            args.num_workers = 1

        args.framework = Frameworks[args.framework.lower()]

        # checkpoints
        args.checkpoint_save_dir = os.path.join(
            args.experiment_path,
            'checkpoint') if args.checkpoint_save_secs is not None else None

        if args.export_onnx_graph and not args.checkpoint_save_secs:
            screen.warning(
                "Exporting ONNX graphs requires setting the --checkpoint_save_secs flag. "
                "The --export_onnx_graph will have no effect.")

        return args
Пример #3
0
def parse_arguments(parser: argparse.ArgumentParser) -> argparse.Namespace:
    """
    Parse the arguments that the user entered
    :param parser: the argparse command line parser
    :return: the parsed arguments
    """
    args = parser.parse_args()

    # if no arg is given
    if len(sys.argv) == 1:
        parser.print_help()
        exit(0)

    # list available presets
    preset_names = list_all_presets()
    if args.list:
        screen.log_title("Available Presets:")
        for preset in sorted(preset_names):
            print(preset)
        sys.exit(0)

    # replace a short preset name with the full path
    if args.preset is not None:
        if args.preset.lower() in [p.lower() for p in preset_names]:
            args.preset = "{}.py:graph_manager".format(
                os.path.join(get_base_dir(), 'presets', args.preset))
        else:
            args.preset = "{}".format(args.preset)
            # if a graph manager variable was not specified, try the default of :graph_manager
            if len(args.preset.split(":")) == 1:
                args.preset += ":graph_manager"

        # verify that the preset exists
        preset_path = args.preset.split(":")[0]
        if not os.path.exists(preset_path):
            screen.error("The given preset ({}) cannot be found.".format(
                args.preset))

        # verify that the preset can be instantiated
        try:
            short_dynamic_import(args.preset, ignore_module_case=True)
        except TypeError as e:
            traceback.print_exc()
            screen.error('Internal Error: ' + str(e) +
                         "\n\nThe given preset ({}) cannot be instantiated.".
                         format(args.preset))

    # validate the checkpoints args
    if args.checkpoint_restore_dir is not None and not os.path.exists(
            args.checkpoint_restore_dir):
        screen.error(
            "The requested checkpoint folder to load from does not exist.")

    # no preset was given. check if the user requested to play some environment on its own
    if args.preset is None and args.play:
        if args.environment_type:
            args.agent_type = 'Human'
        else:
            screen.error(
                'When no preset is given for Coach to run, and the user requests human control over '
                'the environment, the user is expected to input the desired environment_type and level.'
                '\nAt least one of these parameters was not given.')
    elif args.preset and args.play:
        screen.error(
            "Both the --preset and the --play flags were set. These flags can not be used together. "
            "For human control, please use the --play flag together with the environment type flag (-et)"
        )
    elif args.preset is None and not args.play:
        screen.error(
            "Please choose a preset using the -p flag or use the --play flag together with choosing an "
            "environment type (-et) in order to play the game.")

    # get experiment name and path
    args.experiment_name = logger.get_experiment_name(args.experiment_name)
    args.experiment_path = logger.get_experiment_path(args.experiment_name)

    if args.play and args.num_workers > 1:
        screen.warning(
            "Playing the game as a human is only available with a single worker. "
            "The number of workers will be reduced to 1")
        args.num_workers = 1

    args.framework = Frameworks[args.framework.lower()]

    # checkpoints
    args.save_checkpoint_dir = os.path.join(
        args.experiment_path,
        'checkpoint') if args.save_checkpoint_secs is not None else None

    return args