Exemplo n.º 1
0
    def run_simple_tfjob(self, component):
        api_client = k8s_client.ApiClient()

        # Setup the ksonnet app
        ks_util.setup_ks_app(self.app_dir, self.env, self.namespace, component,
                             self.params)

        # Create the TF job
        util.run(["ks", "apply", self.env, "-c", component], cwd=self.app_dir)
        logging.info("Created job %s in namespaces %s", self.name,
                     self.namespace)

        # Wait for the job to either be in Running state or a terminal state
        logging.info("Wait for conditions Running, Succeeded, or Failed")
        results = tf_job_client.wait_for_condition(
            api_client,
            self.namespace,
            self.name, ["Running", "Succeeded", "Failed"],
            status_callback=tf_job_client.log_status)
        logging.info("Current TFJob:\n %s", json.dumps(results, indent=2))

        # Wait for the job to complete.
        logging.info("Waiting for job to finish.")
        results = tf_job_client.wait_for_job(
            api_client,
            self.namespace,
            self.name,
            self.tfjob_version,
            status_callback=tf_job_client.log_status)
        logging.info("Final TFJob:\n %s", json.dumps(results, indent=2))

        if not tf_job_client.job_succeeded(results):
            self.failure = "Job {0} in namespace {1} in status {2}".format(
                self.name, self.namespace, results.get("status", {}))
            logging.error(self.failure)
            return

        # Check for creation failures.
        creation_failures = tf_job_client.get_creation_failures_from_tfjob(
            api_client, self.namespace, results)
        if creation_failures:
            # TODO(jlewi): Starting with
            # https://github.com/kubeflow/tf-operator/pull/646 the number of events
            # no longer seems to match the expected; it looks like maybe events
            # are being combined? For now we just log a warning rather than an
            # error.
            logging.warning(creation_failures)

        # Delete the TFJob.
        tf_job_client.delete_tf_job(api_client,
                                    self.namespace,
                                    self.name,
                                    version=self.tfjob_version)
        logging.info("Waiting for job %s in namespaces %s to be deleted.",
                     self.name, self.namespace)
        tf_job_client.wait_for_delete(api_client,
                                      self.namespace,
                                      self.name,
                                      self.tfjob_version,
                                      status_callback=tf_job_client.log_status)
Exemplo n.º 2
0
    def run_tfjob_with_shutdown_policy(self, component, shutdown_policy):
        api_client = k8s_client.ApiClient()

        # Setup the ksonnet app
        ks_util.setup_ks_app(self.app_dir, self.env, self.namespace, component,
                             self.params)

        # Create the TF job
        util.run(["ks", "apply", self.env, "-c", component], cwd=self.app_dir)
        logging.info("Created job %s in namespaces %s", self.name,
                     self.namespace)

        # Wait for the job to either be in Running state or a terminal state
        logging.info("Wait for conditions Running, Succeeded, or Failed")
        results = tf_job_client.wait_for_condition(
            api_client,
            self.namespace,
            self.name, ["Running", "Succeeded", "Failed"],
            version=self.tfjob_version,
            status_callback=tf_job_client.log_status)
        logging.info("Current TFJob:\n %s", json.dumps(results, indent=2))

        if shutdown_policy == "worker":
            tf_job_client.terminate_replicas(api_client, self.namespace,
                                             self.name, "worker", 1)
        else:
            tf_job_client.terminate_replicas(api_client, self.namespace,
                                             self.name, "chief", 1)

        # Wait for the job to complete.
        logging.info("Waiting for job to finish.")
        results = tf_job_client.wait_for_job(
            api_client,
            self.namespace,
            self.name,
            self.tfjob_version,
            status_callback=tf_job_client.log_status)
        logging.info("Final TFJob:\n %s", json.dumps(results, indent=2))

        if not tf_job_client.job_succeeded(results):
            self.failure = "Job {0} in namespace {1} in status {2}".format(
                self.name, self.namespace, results.get("status", {}))
            logging.error(self.failure)
            return

        # Delete the TFJob.
        tf_job_client.delete_tf_job(api_client,
                                    self.namespace,
                                    self.name,
                                    version=self.tfjob_version)
        logging.info("Waiting for job %s in namespaces %s to be deleted.",
                     self.name, self.namespace)
        tf_job_client.wait_for_delete(api_client,
                                      self.namespace,
                                      self.name,
                                      self.tfjob_version,
                                      status_callback=tf_job_client.log_status)
Exemplo n.º 3
0
    def run_distributed_training_job(self, component):
        api_client = k8s_client.ApiClient()

        # Setup the ksonnet app
        ks_util.setup_ks_app(self.app_dir, self.env, self.namespace, component,
                             self.params)

        # Create the TF job
        util.run(["ks", "apply", self.env, "-c", component], cwd=self.app_dir)
        logging.info("Created job %s in namespaces %s", self.name,
                     self.namespace)

        # Wait for the job to either be in Running state or a terminal state
        logging.info("Wait for conditions Running, Succeeded, or Failed")
        results = tf_job_client.wait_for_condition(
            api_client,
            self.namespace,
            self.name, ["Running", "Succeeded", "Failed"],
            version=self.tfjob_version,
            status_callback=tf_job_client.log_status)
        logging.info("Current TFJob:\n %s", json.dumps(results, indent=2))

        # Wait for the job to complete.
        logging.info("Waiting for job to finish.")
        results = tf_job_client.wait_for_job(
            api_client,
            self.namespace,
            self.name,
            self.tfjob_version,
            status_callback=tf_job_client.log_status)
        logging.info("Final TFJob:\n %s", json.dumps(results, indent=2))

        if not tf_job_client.job_succeeded(results):
            self.failure = "Job {0} in namespace {1} in status {2}".format(
                self.name, self.namespace, results.get("status", {}))
            logging.error(self.failure)
            return

        # Check for creation failures.
        creation_failures = tf_job_client.get_creation_failures_from_tfjob(
            api_client, self.namespace, results)
        if creation_failures:
            logging.warning(creation_failures)

        # Delete the TFJob.
        tf_job_client.delete_tf_job(api_client,
                                    self.namespace,
                                    self.name,
                                    version=self.tfjob_version)
        logging.info("Waiting for job %s in namespaces %s to be deleted.",
                     self.name, self.namespace)
        tf_job_client.wait_for_delete(api_client,
                                      self.namespace,
                                      self.name,
                                      self.tfjob_version,
                                      status_callback=tf_job_client.log_status)
Exemplo n.º 4
0
    def test_invalid_tfjob_spec(self):
        api_client = k8s_client.ApiClient()
        component = INVALID_TFJOB_COMPONENT_NAME + "_" + self.tfjob_version

        # Setup the ksonnet app
        ks_util.setup_ks_app(self.app_dir, self.env, self.namespace, component,
                             self.params)

        # Create the TF job
        util.run(["ks", "apply", self.env, "-c", component], cwd=self.app_dir)
        logging.info("Created job %s in namespaces %s", self.name,
                     self.namespace)

        logging.info("Wait for conditions Failed")
        results = tf_job_client.wait_for_condition(
            api_client,
            self.namespace,
            self.name, ["Failed"],
            version=self.tfjob_version,
            status_callback=tf_job_client.log_status)

        logging.info("Final TFJob:\n %s", json.dumps(results, indent=2))

        # For v1alpha2 check for non-empty completionTime
        last_condition = results.get("status", {}).get("conditions", [])[-1]
        if last_condition.get("type", "").lower() != "failed":
            self.failure = "Job {0} in namespace {1} did not fail; status {2}".format(
                self.name, self.namespace, results.get("status", {}))
            logging.error(self.failure)
            return

        pattern = ".*the spec is invalid.*"
        condition_message = last_condition.get("message", "")
        if not re.match(pattern, condition_message):
            self.failure = "Condition message {0} did not match pattern {1}".format(
                condition_message, pattern)
            logging.error(self.failure)

        # Delete the TFJob.
        tf_job_client.delete_tf_job(api_client,
                                    self.namespace,
                                    self.name,
                                    version=self.tfjob_version)
        logging.info("Waiting for job %s in namespaces %s to be deleted.",
                     self.name, self.namespace)
        tf_job_client.wait_for_delete(api_client,
                                      self.namespace,
                                      self.name,
                                      self.tfjob_version,
                                      status_callback=tf_job_client.log_status)
Exemplo n.º 5
0
  def test_tfjob_and_verify_runconfig(self):
    api_client = k8s_client.ApiClient()
    masterHost = api_client.configuration.host

    # Setup the ksonnet app
    ks_util.setup_ks_app(self.app_dir, self.env, self.namespace, COMPONENT_NAME, self.params)

    # Create the TF job
    util.run(["ks", "apply", self.env, "-c", COMPONENT_NAME], cwd=self.app_dir)
    logging.info("Created job %s in namespaces %s", self.name, self.namespace)

    # Wait for the job to either be in Running state or a terminal state
    logging.info("Wait for conditions Running, Succeeded, or Failed")
    results = tf_job_client.wait_for_condition(
      api_client, self.namespace, self.name, ["Running", "Succeeded", "Failed"],
      status_callback=tf_job_client.log_status)
    logging.info("Current TFJob:\n %s", json.dumps(results, indent=2))

    num_ps = results.get("spec", {}).get("tfReplicaSpecs", {}).get(
      "PS", {}).get("replicas", 0)
    num_workers = results.get("spec", {}).get("tfReplicaSpecs", {}).get(
      "Worker", {}).get("replicas", 0)
    verify_runconfig(masterHost, self.namespace, self.name, "chief", num_ps, num_workers)
    verify_runconfig(masterHost, self.namespace, self.name, "worker", num_ps, num_workers)
    verify_runconfig(masterHost, self.namespace, self.name, "ps", num_ps, num_workers)

    tf_job_client.terminate_replicas(api_client, self.namespace, self.name, "chief", 1)

    # Wait for the job to complete.
    logging.info("Waiting for job to finish.")
    results = tf_job_client.wait_for_job(
      api_client, self.namespace, self.name, self.tfjob_version,
      status_callback=tf_job_client.log_status)
    logging.info("Final TFJob:\n %s", json.dumps(results, indent=2))

    if not tf_job_client.job_succeeded(results):
      self.failure = "Job {0} in namespace {1} in status {2}".format(
        self.name, self.namespace, results.get("status", {}))
      logging.error(self.failure)

    # Delete the TFJob.
    tf_job_client.delete_tf_job(api_client, self.namespace, self.name, version=self.tfjob_version)
    logging.info("Waiting for job %s in namespaces %s to be deleted.", self.name,
                 self.namespace)
    tf_job_client.wait_for_delete(
      api_client, self.namespace, self.name, self.tfjob_version,
      status_callback=tf_job_client.log_status)
Exemplo n.º 6
0
def run_test(args):  # pylint: disable=too-many-branches,too-many-statements
    """Run a test."""
    gcs_client = storage.Client(project=args.project)
    project = args.project
    cluster_name = args.cluster
    zone = args.zone
    # TODO(jlewi): When using GKE we should copy the .kube config and any other
    # files to the test directory. We should then set the environment variable
    # KUBECONFIG to point at that file. This should prevent us from having
    # to rerun util.configure_kubectl on each step. Instead we could run it once
    # as part of GKE cluster creation and store the config in the NFS directory.
    # This would make the handling of credentials
    # and KUBECONFIG more consistent between GKE and minikube and eventually
    # this could be extended to other K8s deployments.
    if cluster_name:
        util.configure_kubectl(project, zone, cluster_name)
    util.load_kube_config()

    api_client = k8s_client.ApiClient()
    masterHost = api_client.configuration.host

    t = test_util.TestCase()
    t.class_name = "tfjob_test"
    namespace, name, env = _setup_ks_app(args)
    t.name = os.path.basename(name)

    start = time.time()

    try:  # pylint: disable=too-many-nested-blocks
        # We repeat the test multiple times.
        # This ensures that if we delete the job we can create a new job with the
        # same name.

        # TODO(jlewi): We should make this an argument.
        num_trials = 2

        for trial in range(num_trials):
            logging.info("Trial %s", trial)
            util.run(["ks", "apply", env, "-c", args.component],
                     cwd=args.app_dir)

            logging.info("Created job %s in namespaces %s", name, namespace)
            logging.info("tfjob_version=%s", args.tfjob_version)
            # Wait for the job to either be in Running state or a terminal state
            if args.tfjob_version == "v1alpha1":
                logging.info("Wait for Phase Running, Done, or Failed")
                results = tf_job_client.wait_for_phase(
                    api_client,
                    namespace,
                    name, ["Running", "Done", "Failed"],
                    status_callback=tf_job_client.log_status)
            else:
                logging.info(
                    "Wait for conditions Running, Succeeded, or Failed")
                results = tf_job_client.wait_for_condition(
                    api_client,
                    namespace,
                    name, ["Running", "Succeeded", "Failed"],
                    status_callback=tf_job_client.log_status)

            logging.info("Current TFJob:\n %s", json.dumps(results, indent=2))

            # The job is now either running or done.
            if args.shutdown_policy:
                logging.info("Enforcing shutdownPolicy %s",
                             args.shutdown_policy)
                if args.shutdown_policy in ["master", "chief"]:
                    if args.tfjob_version == "v1alpha1":
                        replica = "master"
                    else:
                        replica = "chief"
                elif args.shutdown_policy in ["worker", "all_workers"]:
                    replica = "worker"
                else:
                    raise ValueError("Unrecognized shutdown_policy "
                                     "%s" % args.shutdown_policy)

                # Number of targets.
                num_targets = 1
                if args.shutdown_policy in ["all_workers"]:
                    # Assume v1alpha2
                    num_targets = results.get("spec", {}).get(
                        "tfReplicaSpecs", {}).get("Worker",
                                                  {}).get("replicas", 0)
                    logging.info("There are %s worker replicas", num_targets)

                if args.tfjob_version == "v1alpha1":
                    runtime_id = results.get("spec", {}).get("RuntimeId")
                    target = "{name}-{replica}-{runtime}".format(
                        name=name, replica=replica, runtime=runtime_id)
                    pod_labels = get_labels(name, runtime_id)
                    pod_selector = to_selector(pod_labels)
                else:
                    target = "{name}-{replica}".format(name=name,
                                                       replica=replica)
                    pod_labels = get_labels_v1alpha2(namespace, name)
                    pod_selector = to_selector(pod_labels)

                # Wait for the pods to be ready before we shutdown
                # TODO(jlewi): We are get pods using a label selector so there is
                # a risk that the pod we actual care about isn't present.
                logging.info(
                    "Waiting for pods to be running before shutting down.")
                wait_for_pods_to_be_in_phases(
                    api_client,
                    namespace,
                    pod_selector, ["Running"],
                    timeout=datetime.timedelta(minutes=4))
                logging.info("Pods are ready")
                logging.info("Issuing the terminate request")
                for num in range(num_targets):
                    full_target = target + "-{0}".format(num)
                    terminateReplica(masterHost, namespace, full_target)

            logging.info("Waiting for job to finish.")
            results = tf_job_client.wait_for_job(
                api_client,
                namespace,
                name,
                args.tfjob_version,
                status_callback=tf_job_client.log_status)

            if args.tfjob_version == "v1alpha1":
                if results.get("status", {}).get("state",
                                                 {}).lower() != "succeeded":
                    t.failure = "Trial {0} Job {1} in namespace {2} in state {3}".format(
                        trial, name, namespace,
                        results.get("status", {}).get("state", None))
                    logging.error(t.failure)
                    break
            else:
                # For v1alpha2 check for non-empty completionTime
                last_condition = results.get("status",
                                             {}).get("conditions", [])[-1]
                if last_condition.get("type", "").lower() != "succeeded":
                    t.failure = "Trial {0} Job {1} in namespace {2} in status {3}".format(
                        trial, name, namespace, results.get("status", {}))
                    logging.error(t.failure)
                    break

            runtime_id = results.get("spec", {}).get("RuntimeId")
            logging.info("Trial %s Job %s in namespace %s runtime ID %s",
                         trial, name, namespace, runtime_id)

            uid = results.get("metadata", {}).get("uid")
            events = get_events(api_client, namespace, uid)
            for e in events:
                logging.info("K8s event: %s", e.message)

            # Print out the K8s events because it can be useful for debugging.
            for e in events:
                logging.info("Recieved K8s Event:\n%s", e)
            created_pods, created_services = parse_events(events)

            num_expected = 0
            if args.tfjob_version == "v1alpha1":
                for replica in results.get("spec", {}).get("replicaSpecs", []):
                    num_expected += replica.get("replicas", 0)
            else:
                for replicakey in results.get("spec",
                                              {}).get("tfReplicaSpecs", {}):
                    replica_spec = results.get("spec",
                                               {}).get("tfReplicaSpecs",
                                                       {}).get(replicakey, {})
                    if replica_spec:
                        num_expected += replica_spec.get("replicas", 1)

            creation_failures = []
            if len(created_pods) != num_expected:
                message = ("Expected {0} pods to be created but only "
                           "got {1} create events.").format(
                               num_expected, len(created_pods))
                creation_failures.append(message)

            if len(created_services) != num_expected:
                message = ("Expected {0} services to be created but only "
                           "got {1} create events.").format(
                               num_expected, len(created_services))
                creation_failures.append(message)

            if creation_failures:
                # TODO(jlewi): Starting with
                # https://github.com/kubeflow/tf-operator/pull/646 the number of events
                # no longer seems to match the expected; it looks like maybe events
                # are being combined? For now we just log a warning rather than an
                # error.
                logging.warning(creation_failures)
            if args.tfjob_version == "v1alpha1":
                pod_labels = get_labels(name, runtime_id)
                pod_selector = to_selector(pod_labels)
            else:
                pod_labels = get_labels_v1alpha2(name)
                pod_selector = to_selector(pod_labels)

            # We don't wait for pods to be deleted in v1alpha2 because CleanPodPolicy
            # means completed pods won't be deleted.
            # TODO(jlewi): We should add a test to deal with deleted pods.
            if args.tfjob_version == "v1alpha1":
                wait_for_pods_to_be_deleted(api_client, namespace,
                                            pod_selector)

            tf_job_client.delete_tf_job(api_client,
                                        namespace,
                                        name,
                                        version=args.tfjob_version)

            logging.info("Waiting for job %s in namespaces %s to be deleted.",
                         name, namespace)
            wait_for_delete(api_client,
                            namespace,
                            name,
                            args.tfjob_version,
                            status_callback=tf_job_client.log_status)

        # TODO(jlewi):
        #  Here are some validation checks to run:
        #  1. Check that all resources are garbage collected.
        # TODO(jlewi): Add an option to add chaos and randomly kill various resources?
        # TODO(jlewi): Are there other generic validation checks we should
        # run.
    except util.TimeoutError:
        t.failure = "Timeout waiting for {0} in namespace {1} to finish.".format(
            name, namespace)
        logging.exception(t.failure)
    except Exception as e:  # pylint: disable-msg=broad-except
        # TODO(jlewi): I'm observing flakes where the exception has message "status"
        # in an effort to try to nail down this exception we print out more
        # information about the exception.
        logging.exception("There was a problem running the job; Exception %s",
                          e)
        # We want to catch all exceptions because we want the test as failed.
        t.failure = ("Exception occured; type {0} message {1}".format(
            e.__class__, e.message))
    finally:
        t.time = time.time() - start
        if args.junit_path:
            test_util.create_junit_xml_file([t], args.junit_path, gcs_client)
    def test_pod_names(self):
        api_client = k8s_client.ApiClient()
        component = COMPONENT_NAME + "_" + self.tfjob_version

        ks_util.setup_ks_app(self.app_dir, self.env, self.namespace, component,
                             self.params)
        util.run(["ks", "apply", self.env, "-c", component], cwd=self.app_dir)
        logging.info("Created job %s in namespaces %s", self.name,
                     self.namespace)
        logging.info("Wait for conditions Running, Succeeded, or Failed")
        results = tf_job_client.wait_for_condition(
            api_client,
            self.namespace,
            self.name, ["Running", "Succeeded", "Failed"],
            version=self.tfjob_version,
            status_callback=tf_job_client.log_status)
        logging.info("Current TFJob:\n %s", json.dumps(results, indent=2))

        job_specs = extract_job_specs(
            results.get("spec", {}).get("tfReplicaSpecs", {}))
        expected_pod_names = []
        for replica_type, replica_num in job_specs.items():
            logging.info("job_type = %s, replica = %s", replica_type,
                         replica_num)
            for i in range(replica_num):
                expected_pod_names.append(
                    POD_NAME_FORMAT.format(name=self.name,
                                           replica=replica_type,
                                           index=i))
        expected_pod_names = set(expected_pod_names)
        actual_pod_names = tf_job_client.get_pod_names(api_client,
                                                       self.namespace,
                                                       self.name)

        # We are not able to guarantee pods selected with default namespace and job
        # name are only for this test run only. Therefore we only do partial check,
        # e.g. make sure expected set of pod names are in the selected pod names.
        if not (expected_pod_names & actual_pod_names) == expected_pod_names:
            msg = "Actual pod names doesn't match. Expected: {0} Actual: {1}".format(
                str(expected_pod_names), str(actual_pod_names))
            logging.error(msg)
            raise RuntimeError(msg)

        tf_job_client.terminate_replicas(api_client, self.namespace, self.name,
                                         "chief", 1)
        # Wait for the job to complete.
        logging.info("Waiting for job to finish.")
        results = tf_job_client.wait_for_job(
            api_client,
            self.namespace,
            self.name,
            self.tfjob_version,
            status_callback=tf_job_client.log_status)
        logging.info("Final TFJob:\n %s", json.dumps(results, indent=2))

        if not tf_job_client.job_succeeded(results):
            self.failure = "Job {0} in namespace {1} in status {2}".format(
                self.name, self.namespace, results.get("status", {}))
            logging.error(self.failure)

        # Delete the TFJob.
        tf_job_client.delete_tf_job(api_client,
                                    self.namespace,
                                    self.name,
                                    version=self.tfjob_version)
        logging.info("Waiting for job %s in namespaces %s to be deleted.",
                     self.name, self.namespace)
        tf_job_client.wait_for_delete(api_client,
                                      self.namespace,
                                      self.name,
                                      self.tfjob_version,
                                      status_callback=tf_job_client.log_status)
Exemplo n.º 8
0
def main(argv=None):
    parser = argparse.ArgumentParser(description='Kubeflow TFJob launcher')
    parser.add_argument(
        '--container-image',
        type=str,
        help=
        '''Container image to run using KubeFlow TFJob. The command line should be added after --.'''
    )
    parser.add_argument('--workers', type=int, default=0)
    parser.add_argument('--pss', type=int, default=0)
    parser.add_argument(
        '--cluster',
        type=str,
        help='GKE cluster set up for kubeflow. If set, zone must be provided. '
        + 'If not set, assuming this runs in a GKE container and current ' +
        'cluster is used.')
    parser.add_argument('--zone',
                        type=str,
                        help='zone of the kubeflow cluster.')
    parser.add_argument('--kfversion',
                        type=str,
                        default='v1alpha2',
                        help='The version of the deployed kubeflow. ' +
                        'If not set, the default version is v1alpha2')
    parser.add_argument('--tfjob-ns',
                        type=str,
                        default='default',
                        help='The namespace where the tfjob is submitted' +
                        'If not set, the default namespace is default')
    parser.add_argument(
        '--tfjob-timeout-minutes',
        type=int,
        default=10,
        help='Time in minutes to wait for the TFJob to complete')
    parser.add_argument('--output-dir', type=str)
    parser.add_argument('--ui-metadata-type', type=str, default='tensorboard')
    import sys
    all_args = sys.argv[1:]
    separator_idx = all_args.index('--')
    launcher_args = all_args[:separator_idx]
    remaining_args = all_args[separator_idx + 1:]

    args = parser.parse_args(launcher_args)

    logging.getLogger().setLevel(logging.INFO)
    args_dict = vars(args)
    if args.cluster and args.zone:
        cluster = args_dict.pop('cluster')
        zone = args_dict.pop('zone')
    else:
        # Get culster name and zone from metadata
        metadata_server = "http://metadata/computeMetadata/v1/instance/"
        metadata_flavor = {'Metadata-Flavor': 'Google'}
        cluster = requests.get(metadata_server + "attributes/cluster-name",
                               headers=metadata_flavor).text
        zone = requests.get(metadata_server + "zone",
                            headers=metadata_flavor).text.split('/')[-1]

    logging.info('Getting credentials for GKE cluster %s.' % cluster)
    subprocess.call([
        'gcloud', 'container', 'clusters', 'get-credentials', cluster,
        '--zone', zone
    ])

    workers = args_dict.pop('workers')
    pss = args_dict.pop('pss')
    kf_version = args_dict.pop('kfversion')
    tfjob_ns = args_dict.pop('tfjob_ns')
    tfjob_timeout_minutes = args_dict.pop('tfjob_timeout_minutes')
    trainer_image = args.container_image or os.environ['TRAINER_IMAGE_NAME']
    command = remaining_args
    logging.info('Generating training template.')
    template_file = os.path.join(os.path.dirname(os.path.realpath(__file__)),
                                 'train.template.yaml')
    content_yaml = _generate_train_yaml(template_file, tfjob_ns, workers, pss,
                                        trainer_image, command)

    logging.info('Start training.')
    # Set up handler for k8s clients
    config.load_incluster_config()
    api_client = k8s_client.ApiClient()
    create_response = tf_job_client.create_tf_job(api_client,
                                                  content_yaml,
                                                  version=kf_version)
    job_name = create_response['metadata']['name']

    if args.output_dir:
        # Create metadata.json file for visualization.
        metadata = {
            'outputs': [{
                'type': args.ui_metadata_type,
                'source': args.output_dir,
            }]
        }
        with open('/mlpipeline-ui-metadata.json', 'w') as f:
            json.dump(metadata, f)

    wait_response = tf_job_client.wait_for_job(
        api_client,
        tfjob_ns,
        job_name,
        kf_version,
        timeout=datetime.timedelta(minutes=tfjob_timeout_minutes))
    succ = True
    #TODO: update this failure checking after tf-operator has the condition checking function.
    if 'Worker' in wait_response['status']['tfReplicaStatuses']:
        if 'Failed' in wait_response['status']['tfReplicaStatuses']['Worker']:
            logging.error('Training failed since workers failed.')
            succ = False
    if 'PS' in wait_response['status']['tfReplicaStatuses']:
        if 'Failed' in wait_response['status']['tfReplicaStatuses']['PS']:
            logging.error('Training failed since PSs failed.')
            succ = False
    if 'MASTER' in wait_response['status']['tfReplicaStatuses']:
        if 'Failed' in wait_response['status']['tfReplicaStatuses']['MASTER']:
            logging.error('Training failed since MASTER failed.')
            succ = False

    #TODO: remove this after kubeflow fixes the wait_for_job issue
    # because the wait_for_job returns when the worker finishes but the master might not be complete yet.
    if 'MASTER' in wait_response['status'][
            'tfReplicaStatuses'] and 'active' in wait_response['status'][
                'tfReplicaStatuses']['MASTER']:
        master_active = True
        while master_active:
            # Wait for master to finish
            time.sleep(2)
            wait_response = tf_job_client.wait_for_job(
                api_client,
                tfjob_ns,
                job_name,
                kf_version,
                timeout=datetime.timedelta(minutes=tfjob_timeout_minutes))
            if 'active' not in wait_response['status']['tfReplicaStatuses'][
                    'MASTER']:
                master_active = False

    if succ:
        logging.info('Training success.')

    tf_job_client.delete_tf_job(api_client,
                                tfjob_ns,
                                job_name,
                                version=kf_version)
    with open('/output.txt', 'w') as f:
        f.write(args.output_dir)
Exemplo n.º 9
0
def run_test(args):  # pylint: disable=too-many-branches,too-many-statements
    """Run a test."""
    gcs_client = storage.Client(project=args.project)
    project = args.project
    cluster_name = args.cluster
    zone = args.zone
    # TODO(jlewi): When using GKE we should copy the .kube config and any other
    # files to the test directory. We should then set the environment variable
    # KUBECONFIG to point at that file. This should prevent us from having
    # to rerun util.configure_kubectl on each step. Instead we could run it once
    # as part of GKE cluster creation and store the config in the NFS directory.
    # This would make the handling of credentials
    # and KUBECONFIG more consistent between GKE and minikube and eventually
    # this could be extended to other K8s deployments.
    if cluster_name:
        util.configure_kubectl(project, zone, cluster_name)
    util.load_kube_config()

    api_client = k8s_client.ApiClient()

    salt = uuid.uuid4().hex[0:4]

    # Create a new environment for this run
    env = "test-env-{0}".format(salt)

    util.run(["ks", "env", "add", env], cwd=args.app_dir)

    name = None
    namespace = None
    for pair in args.params.split(","):
        k, v = pair.split("=", 1)
        if k == "name":
            name = v

        if k == "namespace":
            namespace = v
        util.run(["ks", "param", "set", "--env=" + env, args.component, k, v],
                 cwd=args.app_dir)

    if not name:
        raise ValueError("name must be provided as a parameter.")

    t = test_util.TestCase()
    t.class_name = "tfjob_test"
    t.name = os.path.basename(name)

    if not namespace:
        raise ValueError("namespace must be provided as a parameter.")

    start = time.time()

    try:
        # We repeat the test multiple times.
        # This ensures that if we delete the job we can create a new job with the
        # same name.

        # TODO(jlewi): We should make this an argument.
        num_trials = 2

        for trial in range(num_trials):
            logging.info("Trial %s", trial)
            util.run(["ks", "apply", env, "-c", args.component],
                     cwd=args.app_dir)

            logging.info("Created job %s in namespaces %s", name, namespace)
            results = tf_job_client.wait_for_job(
                api_client,
                namespace,
                name,
                status_callback=tf_job_client.log_status)

            if results.get("status", {}).get("state",
                                             {}).lower() != "succeeded":
                t.failure = "Trial {0} Job {1} in namespace {2} in state {3}".format(
                    trial, name, namespace,
                    results.get("status", {}).get("state", None))
                logging.error(t.failure)
                break

            runtime_id = results.get("spec", {}).get("RuntimeId")
            logging.info("Trial %s Job %s in namespace %s runtime ID %s",
                         trial, name, namespace, runtime_id)

            uid = results.get("metadata", {}).get("uid")
            events = get_events(api_client, namespace, uid)
            created_pods, created_services = parse_events(events)

            num_expected = 0
            for replica in results.get("spec", {}).get("replicaSpecs", []):
                num_expected += replica.get("replicas", 0)

            creation_failures = []
            if len(created_pods) != num_expected:
                message = ("Expected {0} pods to be created but only "
                           "got {1} create events.").format(
                               num_expected, len(created_pods))
                creation_failures.append(message)

            if len(created_services) != num_expected:
                message = ("Expected {0} services to be created but only "
                           "got {1} create events.").format(
                               num_expected, len(created_services))
                creation_failures.append(message)

            if creation_failures:
                t.failure = "Trial {0} Job {1} in namespace {2}: {3}".format(
                    trial, name, namespace, ", ".join(creation_failures))
                logging.error(t.failure)
                break
            pod_labels = get_labels(name, runtime_id)
            pod_selector = to_selector(pod_labels)

            wait_for_pods_to_be_deleted(api_client, namespace, pod_selector)

            tf_job_client.delete_tf_job(api_client, namespace, name)

            logging.info("Waiting for job %s in namespaces %s to be deleted.",
                         name, namespace)
            wait_for_delete(api_client,
                            namespace,
                            name,
                            status_callback=tf_job_client.log_status)

        # TODO(jlewi):
        #  Here are some validation checks to run:
        #  1. Check that all resources are garbage collected.
        # TODO(jlewi): Add an option to add chaos and randomly kill various resources?
        # TODO(jlewi): Are there other generic validation checks we should
        # run.
    except util.TimeoutError:
        t.failure = "Timeout waiting for {0} in namespace {1} to finish.".format(
            name, namespace)
        logging.error(t.failure)
    except Exception as e:  # pylint: disable-msg=broad-except
        # TODO(jlewi): I'm observing flakes where the exception has message "status"
        # in an effort to try to nail down this exception we print out more
        # information about the exception.
        logging.error("There was a problem running the job; Exception %s", e)
        logging.error(
            "There was a problem running the job; Exception "
            "message: %s", e.message)
        logging.error("Exception type: %s", e.__class__)
        logging.error("Exception args: %s", e.args)
        # We want to catch all exceptions because we want the test as failed.
        t.failure = ("Exception occured; type {0} message {1}".format(
            e.__class__, e.message))
    finally:
        t.time = time.time() - start
        if args.junit_path:
            test_util.create_junit_xml_file([t], args.junit_path, gcs_client)
    def run_tfjob_with_replica_restart_policy(self, component,
                                              replica_restart_policy,
                                              exit_code):
        api_client = k8s_client.ApiClient()

        # Setup the ksonnet app
        ks_util.setup_ks_app(self.app_dir, self.env, self.namespace, component,
                             self.params)

        # Create the TF job
        util.run(["ks", "apply", self.env, "-c", component], cwd=self.app_dir)
        logging.info("Created job %s in namespaces %s", self.name,
                     self.namespace)

        # Wait for the job to either be in Running state or a terminal state
        logging.info("Wait for conditions Running, Succeeded, or Failed")
        results = tf_job_client.wait_for_condition(
            api_client,
            self.namespace,
            self.name, ["Running", "Succeeded", "Failed"],
            version=self.tfjob_version,
            status_callback=tf_job_client.log_status)
        logging.info("Current TFJob:\n %s", json.dumps(results, indent=2))

        if replica_restart_policy == "Always" and exit_code == 0:
            res = tf_job_client.terminate_and_verify_start_time(
                api_client, self.namespace, self.name, "ps", 0, exit_code,
                True)

        elif replica_restart_policy == "Always" and exit_code == 1:
            res = tf_job_client.terminate_and_verify_start_time(
                api_client, self.namespace, self.name, "ps", 0, exit_code,
                True)

        elif replica_restart_policy == "OnFailure" and exit_code == 1:
            res = tf_job_client.terminate_and_verify_start_time(
                api_client, self.namespace, self.name, "ps", 0, exit_code,
                True)

        elif replica_restart_policy == "OnFailure" and exit_code == 0:
            res = tf_job_client.terminate_and_verify_start_time(
                api_client, self.namespace, self.name, "ps", 0, exit_code,
                False)

        elif replica_restart_policy == "Never" and exit_code == 1:
            res = tf_job_client.terminate_and_verify_start_time(
                api_client, self.namespace, self.name, "ps", 0, exit_code,
                False)

        elif replica_restart_policy == "Never" and exit_code == 0:
            res = tf_job_client.terminate_and_verify_start_time(
                api_client, self.namespace, self.name, "ps", 0, exit_code,
                False)

        elif replica_restart_policy == "ExitCode" and exit_code == 1:
            res = tf_job_client.terminate_and_verify_start_time(
                api_client, self.namespace, self.name, "ps", 0, exit_code,
                False)

        else:
            res = tf_job_client.terminate_and_verify_start_time(
                api_client, self.namespace, self.name, "ps", 0, exit_code,
                True)

        if res is False:
            self.failure = "Job {0} in namespace {1} with restart policy {2} failed test \
        with exit_code {3}".format(self.name, self.namespace,
                                   replica_restart_policy, exit_code)
            logging.error(self.failure)
            return

        # Delete the TFJob.
        tf_job_client.delete_tf_job(api_client,
                                    self.namespace,
                                    self.name,
                                    version=self.tfjob_version)
        logging.info("Waiting for job %s in namespaces %s to be deleted.",
                     self.name, self.namespace)
        tf_job_client.wait_for_delete(api_client,
                                      self.namespace,
                                      self.name,
                                      self.tfjob_version,
                                      status_callback=tf_job_client.log_status)
Exemplo n.º 11
0
  def run_tfjob_with_cleanpod_policy(self, component, clean_pod_policy):
    api_client = k8s_client.ApiClient()

    # Setup the ksonnet app
    ks_util.setup_ks_app(self.app_dir, self.env, self.namespace, component,
                         self.params)

    # Create the TF job
    util.run(["ks", "apply", self.env, "-c", component], cwd=self.app_dir)
    logging.info("Created job %s in namespaces %s", self.name, self.namespace)

    # Wait for the job to either be in Running state or a terminal state
    logging.info("Wait for conditions Running, Succeeded, or Failed")
    results = tf_job_client.wait_for_condition(
      api_client,
      self.namespace,
      self.name, ["Running", "Succeeded", "Failed"],
      version=self.tfjob_version,
      status_callback=tf_job_client.log_status)
    logging.info("Current TFJob:\n %s", json.dumps(results, indent=2))

    # Wait for the job to complete.
    logging.info("Waiting for job to finish.")
    results = tf_job_client.wait_for_job(
      api_client,
      self.namespace,
      self.name,
      self.tfjob_version,
      status_callback=tf_job_client.log_status)
    logging.info("Final TFJob:\n %s", json.dumps(results, indent=2))

    if not tf_job_client.job_succeeded(results):
      self.failure = "Job {0} in namespace {1} in status {2}".format(
        self.name, self.namespace, results.get("status", {}))
      logging.error(self.failure)
      return

    # All pods are deleted.
    if clean_pod_policy == "All":
      pod_labels = tf_job_client.get_labels(self.name)
      pod_selector = tf_job_client.to_selector(pod_labels)
      k8s_util.wait_for_pods_to_be_deleted(api_client, self.namespace,
                                           pod_selector)
    # Only running pods (PS) are deleted, completed pods are not.
    elif clean_pod_policy == "Running":
      tf_job_client.wait_for_replica_type_in_phases(
        api_client, self.namespace, self.name, "Chief", ["Succeeded"])
      tf_job_client.wait_for_replica_type_in_phases(
        api_client, self.namespace, self.name, "Worker", ["Succeeded"])
      pod_labels = tf_job_client.get_labels(self.name, "PS")
      pod_selector = tf_job_client.to_selector(pod_labels)
      k8s_util.wait_for_pods_to_be_deleted(api_client, self.namespace,
                                           pod_selector)
    # No pods are deleted.
    elif clean_pod_policy == "None":
      tf_job_client.wait_for_replica_type_in_phases(
        api_client, self.namespace, self.name, "Chief", ["Succeeded"])
      tf_job_client.wait_for_replica_type_in_phases(
        api_client, self.namespace, self.name, "Worker", ["Succeeded"])
      tf_job_client.wait_for_replica_type_in_phases(
        api_client, self.namespace, self.name, "PS", ["Running"])

    # Delete the TFJob.
    tf_job_client.delete_tf_job(
      api_client, self.namespace, self.name, version=self.tfjob_version)
    logging.info("Waiting for job %s in namespaces %s to be deleted.",
                 self.name, self.namespace)
    tf_job_client.wait_for_delete(
      api_client,
      self.namespace,
      self.name,
      self.tfjob_version,
      status_callback=tf_job_client.log_status)
Exemplo n.º 12
0
def main(argv=None):
    parser = argparse.ArgumentParser(description='ML Trainer')
    parser.add_argument('--working-dir',
                        help='Training job working directory.',
                        required=True)
    parser.add_argument('--train-files-dir',
                        help='Path to training data',
                        required=True)
    parser.add_argument('--train-files-prefix',
                        help='The prefix of the training input files.',
                        required=True)

    parser.add_argument(
        '--tf-transform-dir',
        help='Tf-transform directory with model from preprocessing step',
        required=True)

    parser.add_argument('--output-dir',
                        help="""\
      Directory under which which the serving model (under /serving_model_dir)\
      and the tf-mode-analysis model (under /eval_model_dir) will be written\
      """,
                        required=True)

    parser.add_argument('--eval-files-dir',
                        help='Path to evaluation data',
                        required=True)
    parser.add_argument('--eval-files-prefix',
                        help='The prefix of the eval input files.',
                        required=True)

    # Training arguments
    parser.add_argument(
        '--job-dir',
        help='GCS location to write checkpoints and export models',
        required=True)

    # Argument to turn on all logging
    parser.add_argument(
        '--verbosity',
        choices=['DEBUG', 'ERROR', 'FATAL', 'INFO', 'WARN'],
        default='INFO',
    )
    # Experiment arguments
    parser.add_argument('--train-steps',
                        help='Count of steps to run the training job for',
                        required=True,
                        type=int)
    parser.add_argument(
        '--eval-steps',
        help='Number of steps to run evalution for at each checkpoint',
        default=100,
        type=int)
    parser.add_argument('--workers', type=int, default=0)
    parser.add_argument('--pss', type=int, default=0)
    parser.add_argument(
        '--cluster',
        type=str,
        help='GKE cluster set up for kubeflow. If set, zone must be provided. '
        + 'If not set, assuming this runs in a GKE container and current ' +
        'cluster is used.')
    parser.add_argument('--zone',
                        type=str,
                        help='zone of the kubeflow cluster.')
    parser.add_argument('--kfversion',
                        type=str,
                        default='v1alpha2',
                        help='The version of the deployed kubeflow. ' +
                        'If not set, the default version is v1alpha2')
    parser.add_argument('--tfjob-ns',
                        type=str,
                        default='kubeflow',
                        help='The namespace where the tfjob is submitted' +
                        'If not set, the namespace is kubeflow')
    parser.add_argument(
        '--tfjob-timeout-minutes',
        type=int,
        default=10,
        help='Time in minutes to wait for the TFJob to complete')
    args = parser.parse_args()

    # KUBEFLOW_NAMESPACE = 'default'

    logging.getLogger().setLevel(logging.INFO)
    args_dict = vars(args)
    if args.cluster and args.zone:
        cluster = args_dict.pop('cluster')
        zone = args_dict.pop('zone')
    else:
        # Get cluster name and zone from metadata
        metadata_server = "http://metadata/computeMetadata/v1/instance/"
        metadata_flavor = {'Metadata-Flavor': 'Google'}
        cluster = requests.get(metadata_server + "attributes/cluster-name",
                               headers=metadata_flavor).text
        zone = requests.get(metadata_server + "zone",
                            headers=metadata_flavor).text.split('/')[-1]

    logging.info('Getting credentials for GKE cluster %s.' % cluster)
    subprocess.call([
        'gcloud', 'container', 'clusters', 'get-credentials', cluster,
        '--zone', zone
    ])

    # Create metadata.json file for visualization.
    tb_dir = args_dict.pop(
        'working_dir')  # don't pass this arg to the training module
    metadata = {
        'outputs': [{
            'type': 'tensorboard',
            'source': tb_dir,
        }]
    }
    with file_io.FileIO('/mlpipeline-ui-metadata.json', 'w') as f:
        json.dump(metadata, f)

    workers = args_dict.pop('workers')
    pss = args_dict.pop('pss')
    kf_version = args_dict.pop('kfversion')
    tfjob_ns = args_dict.pop('tfjob_ns')
    tfjob_timeout_minutes = args_dict.pop('tfjob_timeout_minutes')
    args_list = [
        '--%s=%s' % (k.replace('_', '-'), v)
        for k, v in six.iteritems(args_dict) if v is not None
    ]
    logging.info('Generating training template.')
    template_file = os.path.join(os.path.dirname(os.path.realpath(__file__)),
                                 'train.template.yaml')
    content_yaml = _generate_train_yaml(template_file, tfjob_ns, workers, pss,
                                        args_list)

    logging.info('Start training.')
    # Set up handler for k8s clients
    config.load_incluster_config()
    api_client = k8s_client.ApiClient()
    create_response = tf_job_client.create_tf_job(api_client,
                                                  content_yaml,
                                                  version=kf_version)
    job_name = create_response['metadata']['name']

    wait_response = tf_job_client.wait_for_job(
        api_client,
        tfjob_ns,
        job_name,
        kf_version,
        timeout=datetime.timedelta(minutes=tfjob_timeout_minutes))
    succ = True
    #TODO: update this failure checking after tf-operator has the condition checking function.
    if 'Worker' in wait_response['status']['tfReplicaStatuses']:
        if 'Failed' in wait_response['status']['tfReplicaStatuses']['Worker']:
            logging.error('Training failed since workers failed.')
            succ = False
    if 'PS' in wait_response['status']['tfReplicaStatuses']:
        if 'Failed' in wait_response['status']['tfReplicaStatuses']['PS']:
            logging.error('Training failed since PSs failed.')
            succ = False
    if 'MASTER' in wait_response['status']['tfReplicaStatuses']:
        if 'Failed' in wait_response['status']['tfReplicaStatuses']['MASTER']:
            logging.error('Training failed since MASTER failed.')
            succ = False

    #TODO: remove this after kubeflow fixes the wait_for_job issue
    # because the wait_for_job returns when the worker finishes but the master might not be complete yet.
    if 'MASTER' in wait_response['status'][
            'tfReplicaStatuses'] and 'active' in wait_response['status'][
                'tfReplicaStatuses']['MASTER']:
        master_active = True
        while master_active:
            # Wait for master to finish
            time.sleep(2)
            wait_response = tf_job_client.wait_for_job(
                api_client,
                tfjob_ns,
                job_name,
                kf_version,
                timeout=datetime.timedelta(minutes=tfjob_timeout_minutes))
            if 'active' not in wait_response['status']['tfReplicaStatuses'][
                    'MASTER']:
                master_active = False

    if succ:
        logging.info('Training success.')

    tf_job_client.delete_tf_job(api_client,
                                tfjob_ns,
                                job_name,
                                version=kf_version)
    with open('/output.txt', 'w') as f:
        f.write(args.job_dir)
Exemplo n.º 13
0
def run_test(args):  # pylint: disable=too-many-branches,too-many-statements
    """Run a test."""
    gcs_client = storage.Client(project=args.project)
    project = args.project
    cluster_name = args.cluster
    zone = args.zone
    util.configure_kubectl(project, zone, cluster_name)
    util.load_kube_config()

    api_client = k8s_client.ApiClient()

    salt = uuid.uuid4().hex[0:4]

    # Create a new environment for this run
    env = "test-env-{0}".format(salt)

    util.run(["ks", "env", "add", env], cwd=args.app_dir)

    name = None
    namespace = None
    for pair in args.params.split(","):
        k, v = pair.split("=", 1)
        if k == "name":
            name = v

        if k == "namespace":
            namespace = v
        util.run(["ks", "param", "set", "--env=" + env, args.component, k, v],
                 cwd=args.app_dir)

    if not name:
        raise ValueError("name must be provided as a parameter.")

    t = test_util.TestCase()
    t.class_name = "tfjob_test"
    t.name = os.path.basename(name)

    if not namespace:
        raise ValueError("namespace must be provided as a parameter.")

    start = time.time()

    try:
        # We repeat the test multiple times.
        # This ensures that if we delete the job we can create a new job with the
        # same name.

        # TODO(jlewi): We should make this an argument.
        num_trials = 2

        for trial in range(num_trials):
            logging.info("Trial %s", trial)
            util.run(["ks", "apply", env, "-c", args.component],
                     cwd=args.app_dir)

            logging.info("Created job %s in namespaces %s", name, namespace)
            results = tf_job_client.wait_for_job(
                api_client,
                namespace,
                name,
                status_callback=tf_job_client.log_status)

            if results.get("status", {}).get("state",
                                             {}).lower() != "succeeded":
                t.failure = "Trial {0} Job {1} in namespace {2} in state {3}".format(
                    trial, name, namespace,
                    results.get("status", {}).get("state", None))
                logging.error(t.failure)
                break

            runtime_id = results.get("spec", {}).get("RuntimeId")
            logging.info("Trial %s Job %s in namespace %s runtime ID %s",
                         trial, name, namespace, runtime_id)

            # TODO(jlewi): We should check that pods were created for each replica
            pod_labels = get_labels(name, runtime_id)
            pod_selector = to_selector(pod_labels)
            pods = list_pods(api_client, namespace, pod_selector)

            logging.info("Trial %s selector: %s matched %s pods", trial,
                         pod_selector, len(pods.items))

            if not pods.items:
                t.failure = (
                    "Trial {0} Job {1} in namespace {2} no pods found for "
                    " selector {3}").format(trial, name, namespace,
                                            pod_selector)
                logging.error(t.failure)
                break

            tf_job_client.delete_tf_job(api_client, namespace, name)

            wait_for_delete(api_client,
                            namespace,
                            name,
                            status_callback=tf_job_client.log_status)

            # Verify the pods have been deleted. tf_job_client uses foreground
            # deletion so there shouldn't be any resources for the job left
            # once the job is gone.
            pods = list_pods(api_client, namespace, pod_selector)

            logging.info("Trial %s selector: %s matched %s pods", trial,
                         pod_selector, len(pods.items))

            if pods.items:
                t.failure = (
                    "Trial {0} Job {1} in namespace {2} pods found for "
                    " selector {3}; pods\n{4}").format(trial, name, namespace,
                                                       pod_selector, pods)
                logging.error(t.failure)
                break

            logging.info("Trial %s all pods deleted.", trial)

        # TODO(jlewi):
        #  Here are some validation checks to run:
        #  1. Check that all resources are garbage collected.
        # TODO(jlewi): Add an option to add chaos and randomly kill various resources?
        # TODO(jlewi): Are there other generic validation checks we should
        # run.
    except util.TimeoutError:
        t.failure = "Timeout waiting for {0} in namespace {1} to finish.".format(
            name, namespace)
        logging.error(t.failure)
    except Exception as e:  # pylint: disable-msg=broad-except
        # TODO(jlewi): I'm observing flakes where the exception has message "status"
        # in an effort to try to nail down this exception we print out more
        # information about the exception.
        logging.error("There was a problem running the job; Exception %s", e)
        logging.error(
            "There was a problem running the job; Exception "
            "message: %s", e.message)
        logging.error("Exception type: %s", e.__class__)
        logging.error("Exception args: %s", e.args)
        # We want to catch all exceptions because we want the test as failed.
        t.failure = ("Exception occured; type {0} message {1}".format(
            e.__class__, e.message))
    finally:
        t.time = time.time() - start
        if args.junit_path:
            test_util.create_junit_xml_file([t], args.junit_path, gcs_client)
Exemplo n.º 14
0
def main(argv=None):
  parser = argparse.ArgumentParser(description='ML Trainer')
  parser.add_argument(
      '--working-dir',
      help='Training job working directory.',
      required=True)
  parser.add_argument(
      '--train-files-dir',
      help='Path to training data',
      required=True)
  parser.add_argument(
      '--train-files-prefix',
      help='The prefix of the training input files.',
      required=True)

  parser.add_argument(
      '--tf-transform-dir',
      help='Tf-transform directory with model from preprocessing step',
      required=True)

  parser.add_argument(
      '--output-dir',
      help="""\
      Directory under which which the serving model (under /serving_model_dir)\
      and the tf-mode-analysis model (under /eval_model_dir) will be written\
      """,
      required=True)

  parser.add_argument(
      '--eval-files-dir',
      help='Path to evaluation data',
      required=True
  )
  parser.add_argument(
      '--eval-files-prefix',
      help='The prefix of the eval input files.',
      required=True)

  # Training arguments
  parser.add_argument(
      '--job-dir',
      help='GCS location to write checkpoints and export models',
      required=True)

  # Argument to turn on all logging
  parser.add_argument(
      '--verbosity',
      choices=['DEBUG', 'ERROR', 'FATAL', 'INFO', 'WARN'],
      default='INFO',
  )
  # Experiment arguments
  parser.add_argument(
      '--train-steps',
      help='Count of steps to run the training job for',
      required=True,
      type=int)
  parser.add_argument(
      '--eval-steps',
      help='Number of steps to run evalution for at each checkpoint',
      default=100,
      type=int)
  parser.add_argument('--workers', type=int, default=0)
  parser.add_argument('--pss', type=int, default=0)
  parser.add_argument('--cluster', type=str,
                      help='GKE cluster set up for kubeflow. If set, zone must be provided. ' +
                           'If not set, assuming this runs in a GKE container and current ' +
                           'cluster is used.')
  parser.add_argument('--zone', type=str, help='zone of the kubeflow cluster.')
  parser.add_argument('--kfversion', type=str,
                      default='v1beta1',
                      help='The version of the deployed kubeflow. ' +
                           'If not set, the default version is v1beta1')
  parser.add_argument('--tfjob-ns', type=str,
                      default='kubeflow',
                      help='The namespace where the tfjob is submitted' +
                           'If not set, the namespace is kubeflow')
  parser.add_argument('--tfjob-timeout-minutes', type=int,
                      default=20,
                      help='Time in minutes to wait for the TFJob to complete')
  args = parser.parse_args()

  logging.getLogger().setLevel(logging.INFO)
  args_dict = vars(args)
  if args.cluster and args.zone:
    cluster = args_dict.pop('cluster')
    zone = args_dict.pop('zone')
  else:
    # Get cluster name and zone from metadata
    metadata_server = "http://metadata/computeMetadata/v1/instance/"
    metadata_flavor = {'Metadata-Flavor' : 'Google'}
    cluster = requests.get(metadata_server + "attributes/cluster-name",
                           headers = metadata_flavor).text
    zone = requests.get(metadata_server + "zone",
                        headers = metadata_flavor).text.split('/')[-1]

  # logging.info('Getting credentials for GKE cluster %s.' % cluster)
  # subprocess.call(['gcloud', 'container', 'clusters', 'get-credentials', cluster,
                   # '--zone', zone])

  # Create metadata.json file for visualization.
  tb_dir = args_dict.pop('working_dir') # don't pass this arg to the training module
  metadata = {
    'outputs' : [{
      'type': 'tensorboard',
      'source': tb_dir,
    }]
  }
  with file_io.FileIO('/mlpipeline-ui-metadata.json', 'w') as f:
    json.dump(metadata, f)

  workers = args_dict.pop('workers')
  pss = args_dict.pop('pss')
  kf_version = args_dict.pop('kfversion')
  tfjob_ns = args_dict.pop('tfjob_ns')
  tfjob_timeout_minutes = args_dict.pop('tfjob_timeout_minutes')
  args_list = ['--%s=%s' % (k.replace('_', '-'),v)
               for k,v in six.iteritems(args_dict) if v is not None]
  logging.info('Generating training template.')
  template_file = os.path.join(os.path.dirname(os.path.realpath(__file__)), 'train.template.yaml')
  content_yaml = _generate_train_yaml(template_file, tfjob_ns, workers, pss, args_list)

  logging.info('Start training.')
  # Set up handler for k8s clients
  config.load_incluster_config()
  api_client = k8s_client.ApiClient()
  create_response = tf_job_client.create_tf_job(api_client, content_yaml, version=kf_version)
  job_name = create_response['metadata']['name']

  wait_response = tf_job_client.wait_for_job(
      api_client, tfjob_ns, job_name, kf_version,
      timeout=datetime.timedelta(minutes=tfjob_timeout_minutes))
  succ = True

  # TODO: update this failure checking after tf-operator has the condition checking function.
  if 'Worker' in wait_response['status']['replicaStatuses']:
    if 'Failed' in wait_response['status']['replicaStatuses']['Worker']:
      logging.error('Training failed since workers failed.')
      succ = False
  if 'PS' in wait_response['status']['replicaStatuses']:
    if 'Failed' in wait_response['status']['replicaStatuses']['PS']:
      logging.error('Training failed since PSs failed.')
      succ = False
  if 'Master' in wait_response['status']['replicaStatuses']:
    if 'Failed' in wait_response['status']['replicaStatuses']['Master']:
      logging.error('Training failed since Master failed.')
      succ = False

  # #TODO: remove this after kubeflow fixes the wait_for_job issue
  # # because the wait_for_job returns when the worker finishes but the master might not be complete yet.
  # if 'Master' in wait_response['status']['replicaStatuses'] and 'active' in wait_response['status']['replicaStatuses']['Master']:
  #   master_active = True
  #   while master_active:
  #     # Wait for master to finish
  #     time.sleep(2)
  #     wait_response = tf_job_client.wait_for_job(api_client, tfjob_ns, job_name, kf_version,
  #                                            timeout=datetime.timedelta(minutes=tfjob_timeout_minutes))
  #     if 'active' not in wait_response['status']['tfReplicaStatuses']['Master']:
  #       master_active = False

  if succ:
    logging.info('Training success.')

  tf_job_client.delete_tf_job(api_client, tfjob_ns, job_name, version=kf_version)
  with open('/output.txt', 'w') as f:
    f.write(args.job_dir)