def worker_logs(self, path='./logs'): """ :param path: Path to store the worker logs. """ worker_params = self.params.run_type_params.get( str(RunType.ROLLOUT_WORKER), None) if not worker_params: return api_client = k8sclient.CoreV1Api() pods = None try: pods = api_client.list_namespaced_pod( self.params.namespace, label_selector='app={}'.format( worker_params.orchestration_params['job_name'])) # pod = pods.items[0] except k8sclient.rest.ApiException as e: screen.print("Got exception: %s\n while reading pods", e) return if not pods or len(pods.items) == 0: return for pod in pods.items: Process(target=self._tail_log_file, args=(pod.metadata.name, api_client, self.params.namespace, path), daemon=True).start()
def undeploy(self): """ Undeploy all the components, such as trainer and rollout worker(s), Redis pub/sub and data store, when required. """ trainer_params = self.params.run_type_params.get( str(RunType.TRAINER), None) api_client = k8sclient.BatchV1Api() delete_options = k8sclient.V1DeleteOptions( propagation_policy="Foreground") if trainer_params: try: api_client.delete_namespaced_job( trainer_params.orchestration_params['job_name'], self.params.namespace, delete_options) except k8sclient.rest.ApiException as e: screen.print("Got exception: %s\n while deleting trainer", e) worker_params = self.params.run_type_params.get( str(RunType.ROLLOUT_WORKER), None) if worker_params: try: api_client.delete_namespaced_job( worker_params.orchestration_params['job_name'], self.params.namespace, delete_options) except k8sclient.rest.ApiException as e: screen.print("Got exception: %s\n while deleting workers", e) self.memory_backend.undeploy() self.data_store.undeploy()
def undeploy(self): """ Undeploy the Redis Pub/Sub service in an orchestrator. """ from kubernetes import client if self.params.deployed: return from kubernetes import client api_client = client.AppsV1Api() delete_options = client.V1DeleteOptions() try: api_client.delete_namespaced_deployment( self.redis_server_name, self.params.orchestrator_params['namespace'], delete_options) except client.rest.ApiException as e: screen.print("Got exception: %s\n while deleting redis-server", e) api_client = client.CoreV1Api() try: api_client.delete_namespaced_service( self.redis_service_name, self.params.orchestrator_params['namespace'], delete_options) except client.rest.ApiException as e: screen.print("Got exception: %s\n while deleting redis-server", e)
def trainer_logs(self): """ Get the logs from trainer. """ trainer_params = self.params.run_type_params.get( str(RunType.TRAINER), None) if not trainer_params: return api_client = k8sclient.CoreV1Api() pod = None try: pods = api_client.list_namespaced_pod( self.params.namespace, label_selector='app={}'.format( trainer_params.orchestration_params['job_name'])) pod = pods.items[0] except k8sclient.rest.ApiException as e: screen.print("Got exception: %s\n while reading pods", e) return if not pod: return return self.tail_log(pod.metadata.name, api_client)
def _save_to_store(self, checkpoint_dir): """ save_to_store() uploads the policy checkpoint, gifs and videos to the S3 data store. It reads the checkpoint state files and uploads only the latest checkpoint files to S3. It is used by the trainer in Coach when used in the distributed mode. """ try: # remove lock file if it exists self.mc.remove_object(self.params.bucket_name, SyncFiles.LOCKFILE.value) # Acquire lock self.mc.put_object(self.params.bucket_name, SyncFiles.LOCKFILE.value, io.BytesIO(b''), 0) state_file = CheckpointStateFile(os.path.abspath(checkpoint_dir)) if state_file.exists(): ckpt_state = state_file.read() checkpoint_file = None for root, dirs, files in os.walk(checkpoint_dir): for filename in files: if filename == CheckpointStateFile.checkpoint_state_filename: checkpoint_file = (root, filename) continue if filename.startswith(ckpt_state.name): abs_name = os.path.abspath(os.path.join(root, filename)) rel_name = os.path.relpath(abs_name, checkpoint_dir) self.mc.fput_object(self.params.bucket_name, rel_name, abs_name) abs_name = os.path.abspath(os.path.join(checkpoint_file[0], checkpoint_file[1])) rel_name = os.path.relpath(abs_name, checkpoint_dir) self.mc.fput_object(self.params.bucket_name, rel_name, abs_name) # upload Finished if present if os.path.exists(os.path.join(checkpoint_dir, SyncFiles.FINISHED.value)): self.mc.put_object(self.params.bucket_name, SyncFiles.FINISHED.value, io.BytesIO(b''), 0) # upload Ready if present if os.path.exists(os.path.join(checkpoint_dir, SyncFiles.TRAINER_READY.value)): self.mc.put_object(self.params.bucket_name, SyncFiles.TRAINER_READY.value, io.BytesIO(b''), 0) # release lock self.mc.remove_object(self.params.bucket_name, SyncFiles.LOCKFILE.value) if self.params.expt_dir and os.path.exists(self.params.expt_dir): for filename in os.listdir(self.params.expt_dir): if filename.endswith((".csv", ".json")): self.mc.fput_object(self.params.bucket_name, filename, os.path.join(self.params.expt_dir, filename)) if self.params.expt_dir and os.path.exists(os.path.join(self.params.expt_dir, 'videos')): for filename in os.listdir(os.path.join(self.params.expt_dir, 'videos')): self.mc.fput_object(self.params.bucket_name, filename, os.path.join(self.params.expt_dir, 'videos', filename)) if self.params.expt_dir and os.path.exists(os.path.join(self.params.expt_dir, 'gifs')): for filename in os.listdir(os.path.join(self.params.expt_dir, 'gifs')): self.mc.fput_object(self.params.bucket_name, filename, os.path.join(self.params.expt_dir, 'gifs', filename)) except S3Error as e: screen.print("Got exception: %s\n while saving to S3", e)
def load_from_store(self): """ load_from_store() downloads a new checkpoint from the S3 data store when it is not available locally. It is used by the rollout workers when using Coach in distributed mode. """ try: state_file = CheckpointStateFile(os.path.abspath(self.params.checkpoint_dir)) # wait until lock is removed while True: objects = self.mc.list_objects_v2(self.params.bucket_name, SyncFiles.LOCKFILE.value) if next(objects, None) is None: try: # fetch checkpoint state file from S3 self.mc.fget_object(self.params.bucket_name, state_file.filename, state_file.path) except Exception as e: continue break time.sleep(10) # Check if there's a finished file objects = self.mc.list_objects_v2(self.params.bucket_name, SyncFiles.FINISHED.value) if next(objects, None) is not None: try: self.mc.fget_object( self.params.bucket_name, SyncFiles.FINISHED.value, os.path.abspath(os.path.join(self.params.checkpoint_dir, SyncFiles.FINISHED.value)) ) except Exception as e: pass # Check if there's a ready file objects = self.mc.list_objects_v2(self.params.bucket_name, SyncFiles.TRAINER_READY.value) if next(objects, None) is not None: try: self.mc.fget_object( self.params.bucket_name, SyncFiles.TRAINER_READY.value, os.path.abspath(os.path.join(self.params.checkpoint_dir, SyncFiles.TRAINER_READY.value)) ) except Exception as e: pass checkpoint_state = state_file.read() if checkpoint_state is not None: objects = self.mc.list_objects_v2(self.params.bucket_name, prefix=checkpoint_state.name, recursive=True) for obj in objects: filename = os.path.abspath(os.path.join(self.params.checkpoint_dir, obj.object_name)) if not os.path.exists(filename): self.mc.fget_object(obj.bucket_name, obj.object_name, filename) except S3Error as e: screen.print("Got exception: %s\n while loading from S3", e)
def __init__(self, params: KubernetesParameters): """ :param params: The Kubernetes parameters which are used for deploying the components in Coach. These parameters include namespace and kubeconfig. """ super().__init__(params) self.params = params if self.params.kubeconfig: k8sconfig.load_kube_config() else: k8sconfig.load_incluster_config() if not self.params.namespace: _, current_context = k8sconfig.list_kube_config_contexts() self.params.namespace = current_context['context']['namespace'] if os.environ.get('http_proxy'): k8sclient.Configuration._default.proxy = os.environ.get( 'http_proxy') self.params.memory_backend_parameters.orchestrator_params = { 'namespace': self.params.namespace } self.memory_backend = get_memory_backend( self.params.memory_backend_parameters) self.params.data_store_params.orchestrator_params = { 'namespace': self.params.namespace } self.params.data_store_params.namespace = self.params.namespace self.data_store = get_data_store(self.params.data_store_params) if self.params.data_store_params.store_type == "s3": self.s3_access_key = None self.s3_secret_key = None if self.params.data_store_params.creds_file: s3config = ConfigParser() s3config.read(self.params.data_store_params.creds_file) try: self.s3_access_key = s3config.get('default', 'aws_access_key_id') self.s3_secret_key = s3config.get('default', 'aws_secret_access_key') except Error as e: screen.print("Error when reading S3 credentials file: %s", e) else: self.s3_access_key = os.environ.get('ACCESS_KEY_ID') self.s3_secret_key = os.environ.get('SECRET_ACCESS_KEY')
def create_k8s_nfs_resources(self) -> bool: """ Create NFS resources such as PV and PVC in Kubernetes. """ from kubernetes import client as k8sclient pv_name = "nfs-ckpt-pv-{}".format(uuid.uuid4()) persistent_volume = k8sclient.V1PersistentVolume( api_version="v1", kind="PersistentVolume", metadata=k8sclient.V1ObjectMeta(name=pv_name, labels={'app': pv_name}), spec=k8sclient.V1PersistentVolumeSpec( access_modes=["ReadWriteMany"], nfs=k8sclient.V1NFSVolumeSource(path=self.params.path, server=self.params.server), capacity={'storage': '10Gi'}, storage_class_name="")) k8s_api_client = k8sclient.CoreV1Api() try: k8s_api_client.create_persistent_volume(persistent_volume) self.params.pv_name = pv_name except k8sclient.rest.ApiException as e: screen.print("Got exception: %s\n while creating the NFS PV", e) return False pvc_name = "nfs-ckpt-pvc-{}".format(uuid.uuid4()) persistent_volume_claim = k8sclient.V1PersistentVolumeClaim( api_version="v1", kind="PersistentVolumeClaim", metadata=k8sclient.V1ObjectMeta(name=pvc_name), spec=k8sclient.V1PersistentVolumeClaimSpec( access_modes=["ReadWriteMany"], resources=k8sclient.V1ResourceRequirements( requests={'storage': '10Gi'}), selector=k8sclient.V1LabelSelector( match_labels={'app': self.params.pv_name}), storage_class_name="")) try: k8s_api_client.create_namespaced_persistent_volume_claim( self.params.namespace, persistent_volume_claim) self.params.pvc_name = pvc_name except k8sclient.rest.ApiException as e: screen.print("Got exception: %s\n while creating the NFS PVC", e) return False return True
def __init__(self, params: S3DataStoreParameters): """ :param params: The parameters required to use the S3 data store. """ super(S3DataStore, self).__init__(params) self.params = params access_key = None secret_key = None if params.creds_file: config = ConfigParser() config.read(params.creds_file) try: access_key = config.get('default', 'aws_access_key_id') secret_key = config.get('default', 'aws_secret_access_key') except Error as e: screen.print("Error when reading S3 credentials file: %s", e) else: access_key = os.environ.get('ACCESS_KEY_ID') secret_key = os.environ.get('SECRET_ACCESS_KEY') self.mc = Minio(self.params.end_point, access_key=access_key, secret_key=secret_key)
def tail_log(self, pod_name, corev1_api): while True: time.sleep(10) # Try to tail the pod logs try: for line in corev1_api.read_namespaced_pod_log( pod_name, self.params.namespace, follow=True, _preload_content=False): screen.print(line.decode('utf-8'), flush=True, end='') except k8sclient.rest.ApiException as e: pass # This part will get executed if the pod is one of the following phases: not ready, failed or terminated. # Check if the pod has errored out, else just try again. # Get the pod try: pod = corev1_api.read_namespaced_pod(pod_name, self.params.namespace) except k8sclient.rest.ApiException as e: continue if not hasattr(pod, 'status') or not pod.status: continue if not hasattr( pod.status, 'container_statuses') or not pod.status.container_statuses: continue for container_status in pod.status.container_statuses: if container_status.state.waiting is not None: if container_status.state.waiting.reason == 'Error' or \ container_status.state.waiting.reason == 'CrashLoopBackOff' or \ container_status.state.waiting.reason == 'ImagePullBackOff' or \ container_status.state.waiting.reason == 'ErrImagePull': return 1 if container_status.state.terminated is not None: return container_status.state.terminated.exit_code
def delete_k8s_nfs_resources(self) -> bool: """ Delete NFS resources such as PV and PVC from the Kubernetes orchestrator. """ from kubernetes import client as k8sclient del_options = k8sclient.V1DeleteOptions() k8s_api_client = k8sclient.CoreV1Api() try: k8s_api_client.delete_persistent_volume(self.params.pv_name, del_options) except k8sclient.rest.ApiException as e: screen.print("Got exception: %s\n while deleting NFS PV", e) return False try: k8s_api_client.delete_namespaced_persistent_volume_claim( self.params.pvc_name, self.params.namespace, del_options) except k8sclient.rest.ApiException as e: screen.print("Got exception: %s\n while deleting NFS PVC", e) return False return True
def undeploy_k8s_nfs(self) -> bool: from kubernetes import client as k8sclient del_options = k8sclient.V1DeleteOptions() k8s_apps_v1_api_client = k8sclient.AppsV1Api() try: k8s_apps_v1_api_client.delete_namespaced_deployment( self.params.name, self.params.namespace, del_options) except k8sclient.rest.ApiException as e: screen.print("Got exception: %s\n while deleting nfs-server", e) return False k8s_core_v1_api_client = k8sclient.CoreV1Api() try: k8s_core_v1_api_client.delete_namespaced_service( self.params.svc_name, self.params.namespace, del_options) except k8sclient.rest.ApiException as e: screen.print( "Got exception: %s\n while deleting the service for nfs-server", e) return False return True
def close(self): screen.print("")
def __exit__(self, type, value, traceback): screen.print(self.prefix, time.time() - self.start)
def deploy_kubernetes(self): """ Deploy the Redis Pub/Sub service in Kubernetes orchestrator. """ if 'namespace' not in self.params.orchestrator_params: self.params.orchestrator_params['namespace'] = "default" from kubernetes import client, config container = client.V1Container( name=self.redis_server_name, image='redis:4-alpine', resources=client.V1ResourceRequirements(limits={ "cpu": "8", "memory": "4Gi" # "nvidia.com/gpu": "0", }), ) template = client.V1PodTemplateSpec( metadata=client.V1ObjectMeta( labels={'app': self.redis_server_name}), spec=client.V1PodSpec(containers=[container])) deployment_spec = client.V1DeploymentSpec( replicas=1, template=template, selector=client.V1LabelSelector( match_labels={'app': self.redis_server_name})) deployment = client.V1Deployment( api_version='apps/v1', kind='Deployment', metadata=client.V1ObjectMeta( name=self.redis_server_name, labels={'app': self.redis_server_name}), spec=deployment_spec) config.load_kube_config() api_client = client.AppsV1Api() try: screen.print(self.params.orchestrator_params) api_client.create_namespaced_deployment( self.params.orchestrator_params['namespace'], deployment) except client.rest.ApiException as e: screen.print("Got exception: %s\n while creating redis-server", e) return False core_v1_api = client.CoreV1Api() service = client.V1Service( api_version='v1', kind='Service', metadata=client.V1ObjectMeta(name=self.redis_service_name), spec=client.V1ServiceSpec(selector={'app': self.redis_server_name}, ports=[ client.V1ServicePort( protocol='TCP', port=6379, target_port=6379) ])) try: core_v1_api.create_namespaced_service( self.params.orchestrator_params['namespace'], service) self.params.redis_address = '{}.{}.svc'.format( self.redis_service_name, self.params.orchestrator_params['namespace']) self.params.redis_port = 6379 return True except client.rest.ApiException as e: screen.print( "Got exception: %s\n while creating a service for redis-server", e) return False
def deploy_trainer(self) -> bool: """ Deploys the training worker in Kubernetes. """ trainer_params = self.params.run_type_params.get( str(RunType.TRAINER), None) if not trainer_params: return False trainer_params.command += [ '--memory_backend_params', json.dumps(self.params.memory_backend_parameters.__dict__) ] trainer_params.command += [ '--data_store_params', json.dumps(self.params.data_store_params.__dict__) ] name = "{}-{}".format(trainer_params.run_type, uuid.uuid4()) # TODO: instead of defining each container and template spec from scratch, loaded default # configuration and modify them as necessary depending on the store type if self.params.data_store_params.store_type == "nfs": container = k8sclient.V1Container( name=name, image=trainer_params.image, command=trainer_params.command, args=trainer_params.arguments, image_pull_policy='Always', volume_mounts=[ k8sclient.V1VolumeMount( name='nfs-pvc', mount_path=trainer_params.checkpoint_dir) ], stdin=True, tty=True) template = k8sclient.V1PodTemplateSpec( metadata=k8sclient.V1ObjectMeta(labels={'app': name}), spec=k8sclient.V1PodSpec( containers=[container], volumes=[ k8sclient.V1Volume( name="nfs-pvc", persistent_volume_claim=self.nfs_pvc) ], restart_policy='Never'), ) elif self.params.data_store_params.store_type == "s3": container = k8sclient.V1Container( name=name, image=trainer_params.image, command=trainer_params.command, args=trainer_params.arguments, image_pull_policy='Always', env=[ k8sclient.V1EnvVar("ACCESS_KEY_ID", self.s3_access_key), k8sclient.V1EnvVar("SECRET_ACCESS_KEY", self.s3_secret_key) ], stdin=True, tty=True) template = k8sclient.V1PodTemplateSpec( metadata=k8sclient.V1ObjectMeta(labels={'app': name}), spec=k8sclient.V1PodSpec(containers=[container], restart_policy='Never'), ) elif self.params.data_store_params.store_type == "redis": container = k8sclient.V1Container( name=name, image=trainer_params.image, command=trainer_params.command, args=trainer_params.arguments, image_pull_policy='Always', stdin=True, tty=True, resources=k8sclient.V1ResourceRequirements( limits={ "cpu": "24", "memory": "4Gi", "nvidia.com/gpu": "1", }), ) template = k8sclient.V1PodTemplateSpec( metadata=k8sclient.V1ObjectMeta(labels={'app': name}), spec=k8sclient.V1PodSpec(containers=[container], restart_policy='Never'), ) else: raise ValueError( "unexpected store_type {}. expected 's3', 'nfs', 'redis'". format(self.params.data_store_params.store_type)) job_spec = k8sclient.V1JobSpec(completions=1, template=template) job = k8sclient.V1Job(api_version="batch/v1", kind="Job", metadata=k8sclient.V1ObjectMeta(name=name), spec=job_spec) api_client = k8sclient.BatchV1Api() try: api_client.create_namespaced_job(self.params.namespace, job) trainer_params.orchestration_params['job_name'] = name return True except k8sclient.rest.ApiException as e: screen.print("Got exception: %s\n while creating job", e) return False
def handle_distributed_coach_orchestrator(args): from rl_coach.orchestrators.kubernetes_orchestrator import KubernetesParameters, Kubernetes, \ RunTypeParameters ckpt_inside_container = "/checkpoint" arg_list = sys.argv[1:] try: i = arg_list.index('--distributed_coach_run_type') arg_list.pop(i) arg_list.pop(i) except ValueError: pass trainer_command = [ 'python3', 'rl_coach/coach.py', '--distributed_coach_run_type', str(RunType.TRAINER) ] + arg_list rollout_command = [ 'python3', 'rl_coach/coach.py', '--distributed_coach_run_type', str(RunType.ROLLOUT_WORKER) ] + arg_list if '--experiment_name' not in rollout_command: rollout_command = rollout_command + [ '--experiment_name', args.experiment_name ] if '--experiment_name' not in trainer_command: trainer_command = trainer_command + [ '--experiment_name', args.experiment_name ] memory_backend_params = None if args.memory_backend == "redispubsub": memory_backend_params = RedisPubSubMemoryBackendParameters() ds_params_instance = None if args.data_store == "s3": ds_params = DataStoreParameters("s3", "", "") ds_params_instance = S3DataStoreParameters( ds_params=ds_params, end_point=args.s3_end_point, bucket_name=args.s3_bucket_name, creds_file=args.s3_creds_file, checkpoint_dir=ckpt_inside_container, expt_dir=args.experiment_path) elif args.data_store == "nfs": ds_params = DataStoreParameters("nfs", "kubernetes", "") ds_params_instance = NFSDataStoreParameters(ds_params) elif args.data_store == "redis": ds_params = DataStoreParameters("redis", "kubernetes", "") ds_params_instance = RedisDataStoreParameters(ds_params) else: raise ValueError("data_store {} found. Expected 's3' or 'nfs'".format( args.data_store)) worker_run_type_params = RunTypeParameters(args.image, rollout_command, run_type=str( RunType.ROLLOUT_WORKER), num_replicas=args.num_workers) trainer_run_type_params = RunTypeParameters(args.image, trainer_command, run_type=str(RunType.TRAINER)) orchestration_params = KubernetesParameters( [worker_run_type_params, trainer_run_type_params], kubeconfig='~/.kube/config', memory_backend_parameters=memory_backend_params, data_store_params=ds_params_instance) orchestrator = Kubernetes(orchestration_params) if not orchestrator.setup(args.checkpoint_restore_dir): screen.print("Could not setup.") return 1 if orchestrator.deploy_trainer(): screen.print("Successfully deployed trainer.") else: screen.print("Could not deploy trainer.") return 1 if orchestrator.deploy_worker(): screen.print("Successfully deployed rollout worker(s).") else: screen.print("Could not deploy rollout worker(s).") return 1 if args.dump_worker_logs: screen.log_title("Dumping rollout worker logs in: {}".format( args.experiment_path)) orchestrator.worker_logs(path=args.experiment_path) exit_code = 1 try: exit_code = orchestrator.trainer_logs() except KeyboardInterrupt: pass orchestrator.undeploy() return exit_code
def deploy_k8s_nfs(self) -> bool: """ Deploy the NFS server in the Kubernetes orchestrator. """ from kubernetes import client as k8sclient name = "nfs-server-{}".format(uuid.uuid4()) container = k8sclient.V1Container( name=name, image="k8s.gcr.io/volume-nfs:0.8", ports=[ k8sclient.V1ContainerPort(name="nfs", container_port=2049, protocol="TCP"), k8sclient.V1ContainerPort(name="rpcbind", container_port=111), k8sclient.V1ContainerPort(name="mountd", container_port=20048), ], volume_mounts=[ k8sclient.V1VolumeMount(name='nfs-host-path', mount_path='/exports') ], security_context=k8sclient.V1SecurityContext(privileged=True)) template = k8sclient.V1PodTemplateSpec( metadata=k8sclient.V1ObjectMeta(labels={'app': name}), spec=k8sclient.V1PodSpec( containers=[container], volumes=[ k8sclient.V1Volume( name="nfs-host-path", host_path=k8sclient.V1HostPathVolumeSource( path='/tmp/nfsexports-{}'.format(uuid.uuid4()))) ])) deployment_spec = k8sclient.V1DeploymentSpec( replicas=1, template=template, selector=k8sclient.V1LabelSelector(match_labels={'app': name})) deployment = k8sclient.V1Deployment(api_version='apps/v1', kind='Deployment', metadata=k8sclient.V1ObjectMeta( name=name, labels={'app': name}), spec=deployment_spec) k8s_apps_v1_api_client = k8sclient.AppsV1Api() try: k8s_apps_v1_api_client.create_namespaced_deployment( self.params.namespace, deployment) self.params.name = name except k8sclient.rest.ApiException as e: screen.print("Got exception: %s\n while creating nfs-server", e) return False k8s_core_v1_api_client = k8sclient.CoreV1Api() svc_name = "nfs-service-{}".format(uuid.uuid4()) service = k8sclient.V1Service( api_version='v1', kind='Service', metadata=k8sclient.V1ObjectMeta(name=svc_name), spec=k8sclient.V1ServiceSpec(selector={'app': self.params.name}, ports=[ k8sclient.V1ServicePort( protocol='TCP', port=2049, target_port=2049) ])) try: svc_response = k8s_core_v1_api_client.create_namespaced_service( self.params.namespace, service) self.params.svc_name = svc_name self.params.server = svc_response.spec.cluster_ip except k8sclient.rest.ApiException as e: screen.print( "Got exception: %s\n while creating a service for nfs-server", e) return False return True
def deploy_worker(self): """ Deploys the rollout worker(s) in Kubernetes. """ worker_params = self.params.run_type_params.get( str(RunType.ROLLOUT_WORKER), None) if not worker_params: return False # At this point, the memory backend and data store have been deployed and in the process, # these parameters have been updated to include things like the hostname and port the # service can be found at. worker_params.command += [ '--memory_backend_params', json.dumps(self.params.memory_backend_parameters.__dict__) ] worker_params.command += [ '--data_store_params', json.dumps(self.params.data_store_params.__dict__) ] worker_params.command += [ '--num_workers', '{}'.format(worker_params.num_replicas) ] name = "{}-{}".format(worker_params.run_type, uuid.uuid4()) # TODO: instead of defining each container and template spec from scratch, loaded default # configuration and modify them as necessary depending on the store type if self.params.data_store_params.store_type == "nfs": container = k8sclient.V1Container( name=name, image=worker_params.image, command=worker_params.command, args=worker_params.arguments, image_pull_policy='Always', volume_mounts=[ k8sclient.V1VolumeMount( name='nfs-pvc', mount_path=worker_params.checkpoint_dir) ], stdin=True, tty=True) template = k8sclient.V1PodTemplateSpec( metadata=k8sclient.V1ObjectMeta(labels={'app': name}), spec=k8sclient.V1PodSpec( containers=[container], volumes=[ k8sclient.V1Volume( name="nfs-pvc", persistent_volume_claim=self.nfs_pvc) ], restart_policy='Never'), ) elif self.params.data_store_params.store_type == "s3": container = k8sclient.V1Container( name=name, image=worker_params.image, command=worker_params.command, args=worker_params.arguments, image_pull_policy='Always', env=[ k8sclient.V1EnvVar("ACCESS_KEY_ID", self.s3_access_key), k8sclient.V1EnvVar("SECRET_ACCESS_KEY", self.s3_secret_key) ], stdin=True, tty=True) template = k8sclient.V1PodTemplateSpec( metadata=k8sclient.V1ObjectMeta(labels={'app': name}), spec=k8sclient.V1PodSpec(containers=[container], restart_policy='Never')) elif self.params.data_store_params.store_type == "redis": container = k8sclient.V1Container( name=name, image=worker_params.image, command=worker_params.command, args=worker_params.arguments, image_pull_policy='Always', stdin=True, tty=True, resources=k8sclient.V1ResourceRequirements(limits={ "cpu": "4", "memory": "4Gi", # "nvidia.com/gpu": "0", }), ) template = k8sclient.V1PodTemplateSpec( metadata=k8sclient.V1ObjectMeta(labels={'app': name}), spec=k8sclient.V1PodSpec(containers=[container], restart_policy='Never')) else: raise ValueError('unexpected store type {}'.format( self.params.data_store_params.store_type)) job_spec = k8sclient.V1JobSpec(completions=worker_params.num_replicas, parallelism=worker_params.num_replicas, template=template) job = k8sclient.V1Job(api_version="batch/v1", kind="Job", metadata=k8sclient.V1ObjectMeta(name=name), spec=job_spec) api_client = k8sclient.BatchV1Api() try: api_client.create_namespaced_job(self.params.namespace, job) worker_params.orchestration_params['job_name'] = name return True except k8sclient.rest.ApiException as e: screen.print("Got exception: %s\n while creating Job", e) return False