Пример #1
0
def handle_distributed_coach_tasks(graph_manager, args):
    ckpt_inside_container = "/checkpoint"

    memory_backend_params = None
    if args.memory_backend_params:
        memory_backend_params = json.loads(args.memory_backend_params)
        memory_backend_params['run_type'] = str(args.distributed_coach_run_type)
        graph_manager.agent_params.memory.register_var('memory_backend_params', construct_memory_params(memory_backend_params))

    data_store_params = None
    if args.data_store_params:
        data_store_params = construct_data_store_params(json.loads(args.data_store_params))
        data_store_params.checkpoint_dir = ckpt_inside_container
        graph_manager.data_store_params = data_store_params

    if args.distributed_coach_run_type == RunType.TRAINER:
        training_worker(
            graph_manager=graph_manager,
            checkpoint_dir=ckpt_inside_container
        )

    if args.distributed_coach_run_type == RunType.ROLLOUT_WORKER:
        data_store = None
        if args.data_store_params:
            data_store = get_data_store(data_store_params)
            wait_for_checkpoint(checkpoint_dir=ckpt_inside_container, data_store=data_store)

        rollout_worker(
            graph_manager=graph_manager,
            checkpoint_dir=ckpt_inside_container,
            data_store=data_store,
            num_workers=args.num_workers
        )
Пример #2
0
def handle_distributed_coach_tasks(graph_manager, args, task_parameters):
    ckpt_inside_container = "/checkpoint"

    memory_backend_params = None
    if args.memory_backend_params:
        memory_backend_params = json.loads(args.memory_backend_params)
        memory_backend_params['run_type'] = str(
            args.distributed_coach_run_type)
        graph_manager.agent_params.memory.register_var(
            'memory_backend_params',
            construct_memory_params(memory_backend_params))

    data_store = None
    data_store_params = None
    if args.data_store_params:
        data_store_params = construct_data_store_params(
            json.loads(args.data_store_params))
        data_store_params.expt_dir = args.experiment_path
        data_store_params.checkpoint_dir = ckpt_inside_container
        graph_manager.data_store_params = data_store_params
        data_store = get_data_store(data_store_params)

    if args.distributed_coach_run_type == RunType.TRAINER:
        task_parameters.checkpoint_save_dir = ckpt_inside_container
        training_worker(graph_manager=graph_manager,
                        data_store=data_store,
                        task_parameters=task_parameters,
                        is_multi_node_test=args.is_multi_node_test)

    if args.distributed_coach_run_type == RunType.ROLLOUT_WORKER:
        rollout_worker(graph_manager=graph_manager,
                       data_store=data_store,
                       num_workers=args.num_workers,
                       task_parameters=task_parameters)
Пример #3
0
    def __init__(self, params: KubernetesParameters):
        """
        :param params: The Kubernetes parameters which are used for deploying the components in Coach. These parameters
        include namespace and kubeconfig.
        """

        super().__init__(params)
        self.params = params
        if self.params.kubeconfig:
            k8sconfig.load_kube_config()
        else:
            k8sconfig.load_incluster_config()

        if not self.params.namespace:
            _, current_context = k8sconfig.list_kube_config_contexts()
            self.params.namespace = current_context['context']['namespace']

        if os.environ.get('http_proxy'):
            k8sclient.Configuration._default.proxy = os.environ.get(
                'http_proxy')

        self.params.memory_backend_parameters.orchestrator_params = {
            'namespace': self.params.namespace
        }
        self.memory_backend = get_memory_backend(
            self.params.memory_backend_parameters)

        self.params.data_store_params.orchestrator_params = {
            'namespace': self.params.namespace
        }
        self.params.data_store_params.namespace = self.params.namespace
        self.data_store = get_data_store(self.params.data_store_params)

        if self.params.data_store_params.store_type == "s3":
            self.s3_access_key = None
            self.s3_secret_key = None
            if self.params.data_store_params.creds_file:
                s3config = ConfigParser()
                s3config.read(self.params.data_store_params.creds_file)
                try:
                    self.s3_access_key = s3config.get('default',
                                                      'aws_access_key_id')
                    self.s3_secret_key = s3config.get('default',
                                                      'aws_secret_access_key')
                except Error as e:
                    screen.print("Error when reading S3 credentials file: %s",
                                 e)
            else:
                self.s3_access_key = os.environ.get('ACCESS_KEY_ID')
                self.s3_secret_key = os.environ.get('SECRET_ACCESS_KEY')