def generate_deployment_spec(self, pod_template_spec): """Returns a TFJob template :param pod_template_spec: template spec for pod """ self.set_container_name(pod_template_spec) tf_replica_specs = {} worker = V1ReplicaSpec(replicas=self.distribution['Worker'], template=pod_template_spec) tf_replica_specs = {"Worker": worker} if self.distribution.get('Chief', 0) > 0: chief = V1ReplicaSpec(replicas=self.distribution.get('Chief', 0), template=pod_template_spec) tf_replica_specs.update(Chief=chief) if self.distribution.get('PS', 0) > 0: ps = V1ReplicaSpec(replicas=self.distribution.get('PS', 0), template=pod_template_spec) tf_replica_specs.update(PS=ps) tfjob = V1TFJob(api_version=constants.TF_JOB_GROUP + "/" + constants.TF_JOB_VERSION, kind=constants.TF_JOB_KIND, metadata=k8s_client.V1ObjectMeta( generate_name=self.job_name, labels=self.labels), spec=V1TFJobSpec(tf_replica_specs=tf_replica_specs)) return tfjob
def test_sdk_e2e(): container = V1Container( name="tensorflow", image="gcr.io/kubeflow-ci/tf-mnist-with-summaries:1.0", command=[ "python", "/var/tf_mnist/mnist_with_summaries.py", "--log_dir=/train/logs", "--learning_rate=0.01", "--batch_size=150" ]) worker = V1ReplicaSpec( replicas=1, restart_policy="Never", template=V1PodTemplateSpec(spec=V1PodSpec(containers=[container]))) tfjob = V1TFJob(api_version="kubeflow.org/v1", kind="TFJob", metadata=V1ObjectMeta(name="mnist-ci-test", namespace=SDK_TEST_NAMESPACE), spec=V1TFJobSpec(clean_pod_policy="None", tf_replica_specs={"Worker": worker})) TFJOB_CLIENT.create(tfjob, namespace=SDK_TEST_NAMESPACE) TFJOB_CLIENT.wait_for_job("mnist-ci-test", namespace=SDK_TEST_NAMESPACE) if not TFJOB_CLIENT.if_job_succeeded("mnist-ci-test", namespace=SDK_TEST_NAMESPACE): raise RuntimeError("The TFJob is not succeeded.") TFJOB_CLIENT.delete("mnist-ci-test", namespace=SDK_TEST_NAMESPACE)