Beispiel #1
0
    def generate_deployment_spec(self, pod_template_spec):
        """Returns a TFJob template

        :param pod_template_spec: template spec for pod

        """
        self.set_container_name(pod_template_spec)

        tf_replica_specs = {}
        worker = V1ReplicaSpec(replicas=self.distribution['Worker'],
                               template=pod_template_spec)
        tf_replica_specs = {"Worker": worker}

        if self.distribution.get('Chief', 0) > 0:
            chief = V1ReplicaSpec(replicas=self.distribution.get('Chief', 0),
                                  template=pod_template_spec)
            tf_replica_specs.update(Chief=chief)

        if self.distribution.get('PS', 0) > 0:
            ps = V1ReplicaSpec(replicas=self.distribution.get('PS', 0),
                               template=pod_template_spec)
            tf_replica_specs.update(PS=ps)

        tfjob = V1TFJob(api_version=constants.TF_JOB_GROUP + "/" +
                        constants.TF_JOB_VERSION,
                        kind=constants.TF_JOB_KIND,
                        metadata=k8s_client.V1ObjectMeta(
                            generate_name=self.job_name, labels=self.labels),
                        spec=V1TFJobSpec(tf_replica_specs=tf_replica_specs))

        return tfjob
Beispiel #2
0
def test_sdk_e2e():

    container = V1Container(
        name="tensorflow",
        image="gcr.io/kubeflow-ci/tf-mnist-with-summaries:1.0",
        command=[
            "python", "/var/tf_mnist/mnist_with_summaries.py",
            "--log_dir=/train/logs", "--learning_rate=0.01", "--batch_size=150"
        ])

    worker = V1ReplicaSpec(
        replicas=1,
        restart_policy="Never",
        template=V1PodTemplateSpec(spec=V1PodSpec(containers=[container])))

    tfjob = V1TFJob(api_version="kubeflow.org/v1",
                    kind="TFJob",
                    metadata=V1ObjectMeta(name="mnist-ci-test",
                                          namespace=SDK_TEST_NAMESPACE),
                    spec=V1TFJobSpec(clean_pod_policy="None",
                                     tf_replica_specs={"Worker": worker}))

    TFJOB_CLIENT.create(tfjob, namespace=SDK_TEST_NAMESPACE)

    TFJOB_CLIENT.wait_for_job("mnist-ci-test", namespace=SDK_TEST_NAMESPACE)
    if not TFJOB_CLIENT.if_job_succeeded("mnist-ci-test",
                                         namespace=SDK_TEST_NAMESPACE):
        raise RuntimeError("The TFJob is not succeeded.")

    TFJOB_CLIENT.delete("mnist-ci-test", namespace=SDK_TEST_NAMESPACE)