Ejemplo n.º 1
0
    def testSplitGcsUri(self):
        bucket, path = util.split_gcs_uri("gs://some-bucket/some/path")
        self.assertEquals("some-bucket", bucket)
        self.assertEquals("some/path", path)

        bucket, path = util.split_gcs_uri("gs://some-bucket")
        self.assertEquals("some-bucket", bucket)
        self.assertEquals("", path)
Ejemplo n.º 2
0
def build_new_release(args):  # pylint: disable=too-many-locals
    """Find the latest release and build the artifacts if they are newer then
  the current release.
  """
    if not args.src_dir:
        raise ValueError("src_dir must be provided when building last green.")

    gcs_client = storage.Client()
    sha = get_latest_green_presubmit(gcs_client)

    bucket_name, _ = util.split_gcs_uri(args.releases_path)
    bucket = gcs_client.get_bucket(bucket_name)

    logging.info("Latest passing postsubmit is %s", sha)

    last_release_sha = get_last_release(bucket)
    logging.info("Most recent release was for %s", last_release_sha)

    sha = build_and_push_image.GetGitHash(args.src_dir)

    if sha == last_release_sha:
        logging.info("Already cut release for %s", sha)
        return

    build(args)
Ejemplo n.º 3
0
def build_lastgreen(args):  # pylint: disable=too-many-locals
    """Find the latest green postsubmit and build the artifacts.
  """
    gcs_client = storage.Client()
    sha = get_latest_green_presubmit(gcs_client)

    bucket_name, _ = util.split_gcs_uri(args.releases_path)
    bucket = gcs_client.get_bucket(bucket_name)

    logging.info("Latest passing postsubmit is %s", sha)

    last_release_sha = get_last_release(bucket)
    logging.info("Most recent release was for %s", last_release_sha)

    if sha == last_release_sha:
        logging.info("Already cut release for %s", sha)
        return

    go_dir = tempfile.mkdtemp(prefix="tmpTfJobSrc")
    logging.info("Temporary go_dir: %s", go_dir)

    src_dir = os.path.join(go_dir, "src", "github.com", REPO_ORG, REPO_NAME)

    _, sha = util.clone_repo(src_dir, util.MASTER_REPO_OWNER,
                             util.MASTER_REPO_NAME, sha)
    build_and_push(go_dir, src_dir, args)
Ejemplo n.º 4
0
def create_junit_xml_file(test_cases, output_path, gcs_client=None):
    """Create a JUnit XML file.

  The junit schema is specified here:
  https://www.ibm.com/support/knowledgecenter/en/SSQ2R2_9.5.0/com.ibm.rsar.analysis.codereview.cobol.doc/topics/cac_useresults_junit.html

  Args:
    test_cases: TestSuite or List of test case objects.
    output_path: Path to write the XML
    gcs_client: GCS client to use if output is GCS.
  """
    total_time = 0
    failures = 0
    for c in test_cases:
        if c.time:
            total_time += c.time

        if c.failure:
            failures += 1
    attrib = {
        "failures": "{0}".format(failures),
        "tests": "{0}".format(len(test_cases)),
        "time": "{0}".format(total_time)
    }
    root = ElementTree.Element("testsuite", attrib)

    for c in test_cases:
        attrib = {
            "classname": c.class_name,
            "name": c.name,
        }
        if c.time:
            attrib["time"] = "{0}".format(c.time)

        # If the time isn't set and no message is set we interpret that as
        # the test not being run.
        if not c.time and not c.failure:
            attrib["failure"] = "Test was not run."

        if c.failure:
            attrib["failure"] = c.failure
        e = ElementTree.Element("testcase", attrib)

        root.append(e)

    t = ElementTree.ElementTree(root)
    logging.info("Creating %s", output_path)
    if output_path.startswith("gs://"):
        b = six.StringIO()
        t.write(b)

        bucket_name, path = util.split_gcs_uri(output_path)
        bucket = gcs_client.get_bucket(bucket_name)
        blob = bucket.blob(path)
        blob.upload_from_string(b.getvalue())
    else:
        t.write(output_path)
Ejemplo n.º 5
0
def upload_file_to_gcs(source, target):
  gcs_client = storage.Client()
  bucket_name, path = util.split_gcs_uri(target)

  bucket = gcs_client.get_bucket(bucket_name)

  logging.info("Uploading file %s to %s.", source, target)
  blob = bucket.blob(path)
  blob.upload_from_filename(source)
Ejemplo n.º 6
0
def upload_to_gcs(contents, target):
  gcs_client = storage.Client()

  bucket_name, path = util.split_gcs_uri(target)

  bucket = gcs_client.get_bucket(bucket_name)
  logging.info("Writing %s", target)
  blob = bucket.blob(path)
  blob.upload_from_string(contents)
Ejemplo n.º 7
0
def upload_outputs(gcs_client, output_dir, build_log):
    bucket_name, path = util.split_gcs_uri(output_dir)

    bucket = gcs_client.get_bucket(bucket_name)

    if not os.path.exists(build_log):
        logging.error("File %s doesn't exist.", build_log)
    else:
        logging.info("Uploading file %s.", build_log)
        blob = bucket.blob(os.path.join(path, "build-log.txt"))
        blob.upload_from_filename(build_log)
Ejemplo n.º 8
0
def create_junit_xml_file(test_cases, output_path, gcs_client=None):
    """Create a JUnit XML file.

  Args:
    test_cases: List of test case objects.
    output_path: Path to write the XML
    gcs_client: GCS client to use if output is GCS.
  """
    total_time = 0
    failures = 0
    for c in test_cases:
        total_time += c.time

        if c.failure:
            failures += 1
    attrib = {
        "failures": "{0}".format(failures),
        "tests": "{0}".format(len(test_cases)),
        "time": "{0}".format(total_time)
    }
    root = ElementTree.Element("testsuite", attrib)

    for c in test_cases:
        attrib = {
            "classname": c.class_name,
            "name": c.name,
            "time": "{0}".format(c.time),
        }
        if c.failure:
            attrib["failure"] = c.failure
        e = ElementTree.Element("testcase", attrib)

        root.append(e)

    t = ElementTree.ElementTree(root)
    logging.info("Creationg %s", output_path)
    if output_path.startswith("gs://"):
        b = six.StringIO()
        t.write(b)

        bucket_name, path = util.split_gcs_uri(output_path)
        bucket = gcs_client.get_bucket(bucket_name)
        blob = bucket.blob(path)
        blob.upload_from_string(b.getvalue())
    else:
        t.write(output_path)
Ejemplo n.º 9
0
def write_build_info(build_info, paths, project=None):
    """Write the build info files.
  """
    gcs_client = None

    contents = yaml.dump(build_info)

    for p in paths:
        logging.info("Writing build information to %s", p)
        if p.startswith("gs://"):
            if not gcs_client:
                gcs_client = storage.Client(project=project)
            bucket_name, path = util.split_gcs_uri(p)
            bucket = gcs_client.get_bucket(bucket_name)
            blob = bucket.blob(path)
            blob.upload_from_string(contents)

        else:
            with open(p, mode='w') as hf:
                hf.write(contents)
Ejemplo n.º 10
0
def check_no_errors(gcs_client, artifacts_dir, junit_files):
    """Check that all the XML files exist and there were no errors.

  Args:
    gcs_client: The GCS client.
    artifacts_dir: The directory where artifacts should be stored.
    junit_files: List of the names of the junit files.

  Returns:
    True if there were no errors and false otherwise.
  """
    bucket_name, prefix = util.split_gcs_uri(artifacts_dir)
    bucket = gcs_client.get_bucket(bucket_name)
    no_errors = True

    # Get a list of actual junit files.
    actual_junit = _get_actual_junit_files(bucket, prefix)

    for f in junit_files:
        full_path = os.path.join(artifacts_dir, f)
        logging.info("Checking %s", full_path)
        b = bucket.blob(os.path.join(prefix, f))
        if not b.exists():
            logging.error("Missing %s", full_path)
            no_errors = False
            continue

        xml_contents = b.download_as_string()

        if test_util.get_num_failures(xml_contents) > 0:
            logging.info("Test failures in %s", full_path)
            no_errors = False

    # Check if there were any extra tests that ran and treat
    # that as a failure.
    extra = set(actual_junit) - set(junit_files)
    if extra:
        logging.error("Extra junit files found: %s", ",".join(extra))
        no_errors = False
    return no_errors
Ejemplo n.º 11
0
def create_junit_xml_file(test_cases, output_path, gcs_client=None):
    """Create a JUnit XML file.

  The junit schema is specified here:
  https://www.ibm.com/support/knowledgecenter/en/SSQ2R2_9.5.0/com.ibm.rsar.analysis.codereview.cobol.doc/topics/cac_useresults_junit.html

  Args:
    test_cases: TestSuite or List of test case objects.
    output_path: Path to write the XML
    gcs_client: GCS client to use if output is GCS.
  """
    t = create_xml(test_cases)
    logging.info("Creating %s", output_path)
    if output_path.startswith("gs://"):
        b = six.StringIO()
        t.write(b)

        bucket_name, path = util.split_gcs_uri(output_path)
        bucket = gcs_client.get_bucket(bucket_name)
        blob = bucket.blob(path)
        blob.upload_from_string(b.getvalue())
    else:
        t.write(output_path)
Ejemplo n.º 12
0
def create_junit_xml_file(test_cases, output_path, gcs_client=None):
    """Create a JUnit XML file.

  The junit schema is specified here:
  https://www.ibm.com/support/knowledgecenter/en/SSQ2R2_9.5.0/com.ibm.rsar.analysis.codereview.cobol.doc/topics/cac_useresults_junit.html

  Args:
    test_cases: TestSuite or List of test case objects.
    output_path: Path to write the XML
    gcs_client: GCS client to use if output is GCS.
  """
    t = create_xml(test_cases)
    logging.info("Creating %s", output_path)
    if output_path.startswith("gs://"):
        b = six.StringIO()
        t.write(b)

        bucket_name, path = util.split_gcs_uri(output_path)
        bucket = gcs_client.get_bucket(bucket_name)
        blob = bucket.blob(path)
        blob.upload_from_string(b.getvalue())
    else:
        dir_name = os.path.dirname(output_path)
        if not os.path.exists(dir_name):
            logging.info("Creating directory %s", dir_name)
            try:
                os.makedirs(dir_name)
            except OSError as e:
                if e.errno == errno.EEXIST:
                    # The path already exists. This is probably a race condition
                    # with some other test creating the directory.
                    # We should just be able to continue
                    pass
                else:
                    raise
        t.write(output_path)
Ejemplo n.º 13
0
def build_images(dag_run=None, ti=None, **_kwargs):  # pylint: disable=too-many-statements
    """
  Args:
    dag_run: A DagRun object. This is passed in as a result of setting
      provide_context to true for the operator.
  """
    # Create a temporary directory suitable for checking out and building the
    # code.
    if not dag_run:
        # When running via airflow test dag_run isn't set
        logging.warn("Using fake dag_run")
        dag_run = FakeDagrun()

    logging.info("dag_id: %s", dag_run.dag_id)
    logging.info("run_id: %s", dag_run.run_id)

    run_dir = ti.xcom_pull(None, key="run_dir")
    logging.info("Using run_dir=%s", run_dir)

    src_dir = ti.xcom_pull(None, key="src_dir")
    logging.info("Using src_dir=%s", src_dir)

    gcs_path = run_path(dag_run.dag_id, dag_run.run_id)
    logging.info("gcs_path %s", gcs_path)

    conf = dag_run.conf
    if not conf:
        conf = {}
    logging.info("conf=%s", conf)
    artifacts_path = conf.get("ARTIFACTS_PATH", gcs_path)
    logging.info("artifacts_path %s", artifacts_path)

    # We use a GOPATH that is specific to this run because we don't want
    # interference from different runs.
    newenv = os.environ.copy()
    newenv["GOPATH"] = os.path.join(run_dir, "go")

    # Make sure pull_number is a string
    pull_number = "{0}".format(conf.get("PULL_NUMBER", ""))
    args = ["python", "-m", "py.release", "build", "--src_dir=" + src_dir]

    dryrun = bool(conf.get("dryrun", False))

    build_info_file = os.path.join(gcs_path, "build_info.yaml")
    args.append("--build_info_path=" + build_info_file)
    args.append("--releases_path=" + gcs_path)
    args.append("--project=" + GCB_PROJECT)
    # We want subprocess output to bypass logging module otherwise multiline
    # output is squashed together.
    util.run(args, use_print=True, dryrun=dryrun, env=newenv)

    # Read the output yaml and publish relevant values to xcom.
    if not dryrun:
        gcs_client = storage.Client(project=GCB_PROJECT)
        logging.info("Reading %s", build_info_file)
        bucket_name, build_path = util.split_gcs_uri(build_info_file)
        bucket = gcs_client.get_bucket(bucket_name)
        blob = bucket.blob(build_path)
        contents = blob.download_as_string()
        build_info = yaml.load(contents)
    else:
        build_info = {
            "image": "gcr.io/dryrun/dryrun:latest",
            "commit": "1234abcd",
            "helm_chart": "gs://dryrun/dryrun.latest.",
        }
    for k, v in six.iteritems(build_info):
        logging.info("xcom push: %s=%s", k, v)
        ti.xcom_push(key=k, value=v)
Ejemplo n.º 14
0
def build_and_push_artifacts(go_dir,
                             src_dir,
                             registry,
                             publish_path=None,
                             gcb_project=None,
                             build_info_path=None):
    """Build and push the artifacts.

  Args:
    go_dir: The GOPATH directory
    src_dir: The root directory where we checked out the repo.
    registry: Docker registry to use.
    publish_path: (Optional) The GCS path where artifacts should be published.
       Set to none to only build locally.
    gcb_project: The project to use with GCB to build docker images.
      If set to none uses docker to build.
    build_info_path: (Optional): GCS location to write YAML file containing
      information about the build.
  """
    # Update the GOPATH to the temporary directory.
    env = os.environ.copy()
    if go_dir:
        env["GOPATH"] = go_dir

    bin_dir = os.path.join(src_dir, "bin")
    if not os.path.exists(bin_dir):
        os.makedirs(bin_dir)

    build_info = build_operator_image(src_dir, registry, project=gcb_project)

    # Copy the chart to a temporary directory because we will modify some
    # of its YAML files.
    chart_build_dir = tempfile.mkdtemp(prefix="tmpTFJobChartBuild")
    shutil.copytree(os.path.join(src_dir, "tf-job-operator-chart"),
                    os.path.join(chart_build_dir, "tf-job-operator-chart"))
    version = build_info["image"].split(":")[-1]
    values_file = os.path.join(chart_build_dir, "tf-job-operator-chart",
                               "values.yaml")
    update_values(values_file, build_info["image"])

    chart_file = os.path.join(chart_build_dir, "tf-job-operator-chart",
                              "Chart.yaml")
    update_chart(chart_file, version)

    # Delete any existing matches because we assume there is only 1 below.
    matches = glob.glob(os.path.join(bin_dir, "tf-job-operator-chart*.tgz"))
    for m in matches:
        logging.info("Delete previous build: %s", m)
        os.unlink(m)

    util.run([
        "helm", "package", "--save=false", "--destination=" + bin_dir,
        "./tf-job-operator-chart"
    ],
             cwd=chart_build_dir)

    matches = glob.glob(os.path.join(bin_dir, "tf-job-operator-chart*.tgz"))

    if len(matches) != 1:
        raise ValueError(
            "Expected 1 chart archive to match but found {0}".format(matches))

    chart_archive = matches[0]

    release_path = version

    targets = [
        os.path.join(release_path, os.path.basename(chart_archive)),
        "latest/tf-job-operator-chart-latest.tgz",
    ]

    if publish_path:
        gcs_client = storage.Client(project=gcb_project)
        bucket_name, base_path = util.split_gcs_uri(publish_path)
        bucket = gcs_client.get_bucket(bucket_name)
        for t in targets:
            blob = bucket.blob(os.path.join(base_path, t))
            gcs_path = util.to_gcs_uri(bucket_name, blob.name)
            if not t.startswith("latest"):
                build_info["helm_chart"] = gcs_path
            if blob.exists() and not t.startswith("latest"):
                logging.warn("%s already exists", gcs_path)
                continue
            logging.info("Uploading %s to %s.", chart_archive, gcs_path)
            blob.upload_from_filename(chart_archive)

        create_latest(bucket, build_info["commit"],
                      util.to_gcs_uri(bucket_name, targets[0]))

    # Always write to the bin dir.
    paths = [os.path.join(bin_dir, "build_info.yaml")]

    if build_info_path:
        paths.append(build_info_path)

    write_build_info(build_info, paths, project=gcb_project)
Ejemplo n.º 15
0
def setup(args):
    """Setup a GKE cluster for TensorFlow jobs.

  Args:
    args: Command line arguments that control the setup process.
  """
    gke = discovery.build("container", "v1")

    project = args.project
    cluster_name = args.cluster
    zone = args.zone
    chart = args.chart
    machine_type = "n1-standard-8"

    cluster_request = {
        "cluster": {
            "name": cluster_name,
            "description": "A GKE cluster for TF.",
            "initialNodeCount": 1,
            "nodeConfig": {
                "machineType":
                machine_type,
                "oauthScopes": [
                    "https://www.googleapis.com/auth/cloud-platform",
                ],
            },
            # TODO(jlewi): Stop pinning GKE version once 1.8 becomes the default.
            "initialClusterVersion": "1.8.1-gke.1",
        }
    }

    if args.accelerators:
        # TODO(jlewi): Stop enabling Alpha once GPUs make it out of Alpha
        cluster_request["cluster"]["enableKubernetesAlpha"] = True

        cluster_request["cluster"]["nodeConfig"]["accelerators"] = []
        for accelerator_spec in args.accelerators:
            accelerator_type, accelerator_count = accelerator_spec.split(
                "=", 1)
            cluster_request["cluster"]["nodeConfig"]["accelerators"].append({
                "acceleratorCount":
                accelerator_count,
                "acceleratorType":
                accelerator_type,
            })

    util.create_cluster(gke, project, zone, cluster_request)

    util.configure_kubectl(project, zone, cluster_name)

    util.load_kube_config()
    # Create an API client object to talk to the K8s master.
    api_client = k8s_client.ApiClient()

    util.setup_cluster(api_client)

    if chart.startswith("gs://"):
        remote = chart
        chart = os.path.join(tempfile.gettempdir(), os.path.basename(chart))
        gcs_client = storage.Client(project=project)
        bucket_name, path = util.split_gcs_uri(remote)

        bucket = gcs_client.get_bucket(bucket_name)
        blob = bucket.blob(path)
        logging.info("Downloading %s to %s", remote, chart)
        blob.download_to_filename(chart)

    t = test_util.TestCase()
    try:
        start = time.time()
        util.run([
            "helm", "install", chart, "-n", "tf-job", "--wait", "--replace",
            "--set", "rbac.install=true,cloud=gke"
        ])
    except subprocess.CalledProcessError as e:
        t.failure = "helm install failed;\n" + e.output
    finally:
        t.time = time.time() - start
        t.name = "helm-tfjob-install"
        t.class_name = "GKE"
        test_util.create_junit_xml_file([t], args.junit_path, gcs_client)
Ejemplo n.º 16
0
def build_images(dag_run=None, ti=None, **_kwargs):  # pylint: disable=too-many-statements
    """
  Args:
    dag_run: A DagRun object. This is passed in as a result of setting
      provide_context to true for the operator.
  """
    # Create a temporary directory suitable for checking out and building the
    # code.
    if not dag_run:
        # When running via airflow test dag_run isn't set
        logging.warn("Using fake dag_run")
        dag_run = FakeDagrun()

    logging.info("dag_id: %s", dag_run.dag_id)
    logging.info("run_id: %s", dag_run.run_id)

    gcs_path = run_path(dag_run.dag_id, dag_run.run_id)
    logging.info("gcs_path %s", gcs_path)

    conf = dag_run.conf
    if not conf:
        conf = {}
    logging.info("conf=%s", conf)
    artifacts_path = conf.get("ARTIFACTS_PATH", gcs_path)
    logging.info("artifacts_path %s", artifacts_path)

    # Make sure pull_number is a string
    pull_number = "{0}".format(conf.get("PULL_NUMBER", ""))
    args = ["python", "-m", "py.release"]
    if pull_number:
        commit = conf.get("PULL_PULL_SHA", "")
        args.append("pr")
        args.append("--pr=" + pull_number)
        if commit:
            args.append("--commit=" + commit)
    else:
        commit = conf.get("PULL_BASE_SHA", "")
        args.append("postsubmit")
        if commit:
            args.append("--commit=" + commit)

    dryrun = bool(conf.get("dryrun", False))

    # Pick the directory where the source will be checked out.
    # This should be a persistent location that is accessible from subsequent
    # tasks; e.g. an NFS share or PD.
    src_dir = os.path.join(os.getenv("SRC_DIR", tempfile.gettempdir()),
                           dag_run.dag_id.replace(":", "_"),
                           dag_run.run_id.replace(":", "_"))
    logging.info("Using src_dir %s", src_dir)
    os.makedirs(src_dir)
    logging.info("xcom push: src_dir=%s", src_dir)
    ti.xcom_push(key="src_dir", value=src_dir)

    build_info_file = os.path.join(gcs_path, "build_info.yaml")
    args.append("--build_info_path=" + build_info_file)
    args.append("--releases_path=" + gcs_path)
    args.append("--project=" + GCB_PROJECT)
    args.append("--src_dir=" + src_dir)
    # We want subprocess output to bypass logging module otherwise multiline
    # output is squashed together.
    util.run(args, use_print=True, dryrun=dryrun)

    # Read the output yaml and publish relevant values to xcom.
    if not dryrun:
        gcs_client = storage.Client(project=GCB_PROJECT)
        logging.info("Reading %s", build_info_file)
        bucket_name, build_path = util.split_gcs_uri(build_info_file)
        bucket = gcs_client.get_bucket(bucket_name)
        blob = bucket.blob(build_path)
        contents = blob.download_as_string()
        build_info = yaml.load(contents)
    else:
        build_info = {
            "image": "gcr.io/dryrun/dryrun:latest",
            "commit": "1234abcd",
            "helm_chart": "gs://dryrun/dryrun.latest.",
        }
    for k, v in six.iteritems(build_info):
        logging.info("xcom push: %s=%s", k, v)
        ti.xcom_push(key=k, value=v)