Exemple #1
0
    def get_config_args(self,
                        parser: argparse.ArgumentParser) -> argparse.Namespace:
        """Overrides the default CLI parsing.
        Sets the configuration parameters for what a SageMaker run should do.
        Note, this does not support the "play" mode.
        """
        # first, convert the parser to a Namespace object with all default values.
        empty_arg_list = []
        args, _ = parser.parse_known_args(args=empty_arg_list)
        parser = self.sagemaker_argparser()
        sage_args, unknown = parser.parse_known_args()

        # Now fill in the args that we care about.
        sagemaker_job_name = os.environ.get("sagemaker_job_name",
                                            "sagemaker-experiment")
        args.experiment_name = logger.get_experiment_name(sagemaker_job_name)

        # Override experiment_path used for outputs
        args.experiment_path = "/opt/ml/output/intermediate"
        rl_coach.logger.experiment_path = "/opt/ml/output/intermediate"  # for gifs

        args.checkpoint_save_dir = "/opt/ml/output/data/checkpoint"
        args.checkpoint_save_secs = 10  # should avoid hardcoding
        # onnx for deployment for mxnet (not tensorflow)
        save_model = sage_args.save_model == 1
        backend = os.getenv("COACH_BACKEND", "tensorflow")
        if save_model and backend == "mxnet":
            args.export_onnx_graph = True

        args.no_summary = True

        args.num_workers = sage_args.num_workers
        args.framework = Frameworks[backend]
        args.preset = sage_args.RLCOACH_PRESET
        # args.apply_stop_condition = True # uncomment for old coach behaviour

        self.hyperparameters = CoachConfigurationList()
        if len(unknown) % 2 == 1:
            raise ValueError(
                "Odd number of command-line arguments specified. Key without value."
            )

        for i in range(0, len(unknown), 2):
            name = unknown[i]
            if name.startswith("--"):
                name = name[2:]
            else:
                raise ValueError("Unknown command-line argument %s" % name)
            val = unknown[i + 1]
            self.map_hyperparameter(name, val)

        return args
    def get_config_args(self, parser: argparse.ArgumentParser) -> argparse.Namespace:
        """Overrides the default CLI parsing.
        Sets the configuration parameters for what a SageMaker run should do.
        Note, this does not support the "play" mode.
        """
        # first, convert the parser to a Namespace object with all default values.
        empty_arg_list = []
        args, _ = parser.parse_known_args(args=empty_arg_list)
        parser = self.sagemaker_argparser()
        sage_args, unknown = parser.parse_known_args()
        
        # Now fill in the args that we care about.
        sagemaker_job_name = os.environ.get("sagemaker_job_name", "sagemaker-experiment")
        args.experiment_name = logger.get_experiment_name(sagemaker_job_name)
        
        # Override experiment_path used for outputs
        args.experiment_path = '/opt/ml/output/intermediate'
        rl_coach.logger.experiment_path = '/opt/ml/output/intermediate' # for gifs

        args.checkpoint_save_dir = '/opt/ml/output/data/checkpoint'
        args.checkpoint_save_secs = 10 # should avoid hardcoding
        # onnx for deployment for mxnet (not tensorflow)
        save_model = (sage_args.save_model == 1)
        backend = os.getenv('COACH_BACKEND', 'tensorflow')
        if save_model and backend == "mxnet":
            args.export_onnx_graph = True

        args.no_summary = True

        args.num_workers = sage_args.num_workers
        args.framework = Frameworks[backend]
        args.preset = sage_args.RLCOACH_PRESET
        # args.apply_stop_condition = True # uncomment for old coach behaviour

        self.hyperparameters = CoachConfigurationList()
        if len(unknown) % 2 == 1:
            raise ValueError("Odd number of command-line arguments specified. Key without value.")

        for i in range(0, len(unknown), 2):
            name = unknown[i]
            if name.startswith("--"):
                name = name[2:]
            else:
                raise ValueError("Unknown command-line argument %s" % name)
            val = unknown[i+1]
            self.map_hyperparameter(name, val)

        return args
Exemple #3
0
    def get_config_args(self,
                        parser: argparse.ArgumentParser) -> argparse.Namespace:
        """
        Returns a Namespace object with all the user-specified configuration options needed to launch.
        This implementation uses argparse to take arguments from the CLI, but this can be over-ridden by
        another method that gets its configuration from elsewhere.  An equivalent method however must
        return an identically structured Namespace object, which conforms to the structure defined by
        get_argument_parser.

        This method parses the arguments that the user entered, does some basic validation, and
        modification of user-specified values in short form to be more explicit.

        :param parser: a parser object which implicitly defines the format of the Namespace that
                       is expected to be returned.
        :return: the parsed arguments as a Namespace
        """
        args = parser.parse_args()

        if args.nocolor:
            screen.set_use_colors(False)

        # if no arg is given
        if len(sys.argv) == 1:
            parser.print_help()
            sys.exit(1)

        # list available presets
        if args.list:
            self.display_all_presets_and_exit()

        # Read args from config file for distributed Coach.
        if args.distributed_coach and args.distributed_coach_run_type == RunType.ORCHESTRATOR:
            coach_config = ConfigParser({
                'image': '',
                'memory_backend': 'redispubsub',
                'data_store': 's3',
                's3_end_point': 's3.amazonaws.com',
                's3_bucket_name': '',
                's3_creds_file': ''
            })
            try:
                coach_config.read(args.distributed_coach_config_path)
                args.image = coach_config.get('coach', 'image')
                args.memory_backend = coach_config.get('coach',
                                                       'memory_backend')
                args.data_store = coach_config.get('coach', 'data_store')
                if args.data_store == 's3':
                    args.s3_end_point = coach_config.get(
                        'coach', 's3_end_point')
                    args.s3_bucket_name = coach_config.get(
                        'coach', 's3_bucket_name')
                    args.s3_creds_file = coach_config.get(
                        'coach', 's3_creds_file')
            except Error as e:
                screen.error(
                    "Error when reading distributed Coach config file: {}".
                    format(e))

            if args.image == '':
                screen.error("Image cannot be empty.")

            data_store_choices = ['s3', 'nfs']
            if args.data_store not in data_store_choices:
                screen.warning("{} data store is unsupported.".format(
                    args.data_store))
                screen.error(
                    "Supported data stores are {}.".format(data_store_choices))

            memory_backend_choices = ['redispubsub']
            if args.memory_backend not in memory_backend_choices:
                screen.warning("{} memory backend is not supported.".format(
                    args.memory_backend))
                screen.error("Supported memory backends are {}.".format(
                    memory_backend_choices))

            if args.data_store == 's3':
                if args.s3_bucket_name == '':
                    screen.error("S3 bucket name cannot be empty.")
                if args.s3_creds_file == '':
                    args.s3_creds_file = None

        if args.play and args.distributed_coach:
            screen.error("Playing is not supported in distributed Coach.")

        # replace a short preset name with the full path
        if args.preset is not None:
            args.preset = self.expand_preset(args.preset)

        # validate the checkpoints args
        if args.checkpoint_restore_dir is not None and not os.path.exists(
                args.checkpoint_restore_dir):
            screen.error(
                "The requested checkpoint folder to load from does not exist.")

        # validate the checkpoints args
        if args.checkpoint_restore_file is not None and not glob(
                args.checkpoint_restore_file + '*'):
            screen.error(
                "The requested checkpoint file to load from does not exist.")

        # no preset was given. check if the user requested to play some environment on its own
        if args.preset is None and args.play and not args.environment_type:
            screen.error(
                'When no preset is given for Coach to run, and the user requests human control over '
                'the environment, the user is expected to input the desired environment_type and level.'
                '\nAt least one of these parameters was not given.')
        elif args.preset and args.play:
            screen.error(
                "Both the --preset and the --play flags were set. These flags can not be used together. "
                "For human control, please use the --play flag together with the environment type flag (-et)"
            )
        elif args.preset is None and not args.play:
            screen.error(
                "Please choose a preset using the -p flag or use the --play flag together with choosing an "
                "environment type (-et) in order to play the game.")

        # get experiment name and path
        args.experiment_name = logger.get_experiment_name(args.experiment_name)
        args.experiment_path = logger.get_experiment_path(args.experiment_name)

        if args.play and args.num_workers > 1:
            screen.warning(
                "Playing the game as a human is only available with a single worker. "
                "The number of workers will be reduced to 1")
            args.num_workers = 1

        args.framework = Frameworks[args.framework.lower()]

        # checkpoints
        args.checkpoint_save_dir = os.path.join(
            args.experiment_path,
            'checkpoint') if args.checkpoint_save_secs is not None else None

        if args.export_onnx_graph and not args.checkpoint_save_secs:
            screen.warning(
                "Exporting ONNX graphs requires setting the --checkpoint_save_secs flag. "
                "The --export_onnx_graph will have no effect.")

        return args
Exemple #4
0
def parse_arguments(parser: argparse.ArgumentParser) -> argparse.Namespace:
    """
    Parse the arguments that the user entered
    :param parser: the argparse command line parser
    :return: the parsed arguments
    """
    args = parser.parse_args()

    # if no arg is given
    if len(sys.argv) == 1:
        parser.print_help()
        exit(0)

    # list available presets
    preset_names = list_all_presets()
    if args.list:
        screen.log_title("Available Presets:")
        for preset in sorted(preset_names):
            print(preset)
        sys.exit(0)

    # replace a short preset name with the full path
    if args.preset is not None:
        if args.preset.lower() in [p.lower() for p in preset_names]:
            args.preset = "{}.py:graph_manager".format(
                os.path.join(get_base_dir(), 'presets', args.preset))
        else:
            args.preset = "{}".format(args.preset)
            # if a graph manager variable was not specified, try the default of :graph_manager
            if len(args.preset.split(":")) == 1:
                args.preset += ":graph_manager"

        # verify that the preset exists
        preset_path = args.preset.split(":")[0]
        if not os.path.exists(preset_path):
            screen.error("The given preset ({}) cannot be found.".format(
                args.preset))

        # verify that the preset can be instantiated
        try:
            short_dynamic_import(args.preset, ignore_module_case=True)
        except TypeError as e:
            traceback.print_exc()
            screen.error('Internal Error: ' + str(e) +
                         "\n\nThe given preset ({}) cannot be instantiated.".
                         format(args.preset))

    # validate the checkpoints args
    if args.checkpoint_restore_dir is not None and not os.path.exists(
            args.checkpoint_restore_dir):
        screen.error(
            "The requested checkpoint folder to load from does not exist.")

    # no preset was given. check if the user requested to play some environment on its own
    if args.preset is None and args.play:
        if args.environment_type:
            args.agent_type = 'Human'
        else:
            screen.error(
                'When no preset is given for Coach to run, and the user requests human control over '
                'the environment, the user is expected to input the desired environment_type and level.'
                '\nAt least one of these parameters was not given.')
    elif args.preset and args.play:
        screen.error(
            "Both the --preset and the --play flags were set. These flags can not be used together. "
            "For human control, please use the --play flag together with the environment type flag (-et)"
        )
    elif args.preset is None and not args.play:
        screen.error(
            "Please choose a preset using the -p flag or use the --play flag together with choosing an "
            "environment type (-et) in order to play the game.")

    # get experiment name and path
    args.experiment_name = logger.get_experiment_name(args.experiment_name)
    args.experiment_path = logger.get_experiment_path(args.experiment_name)

    if args.play and args.num_workers > 1:
        screen.warning(
            "Playing the game as a human is only available with a single worker. "
            "The number of workers will be reduced to 1")
        args.num_workers = 1

    args.framework = Frameworks[args.framework.lower()]

    # checkpoints
    args.save_checkpoint_dir = os.path.join(
        args.experiment_path,
        'checkpoint') if args.save_checkpoint_secs is not None else None

    return args