def testSplitGcsUri(self): bucket, path = util.split_gcs_uri("gs://some-bucket/some/path") self.assertEquals("some-bucket", bucket) self.assertEquals("some/path", path) bucket, path = util.split_gcs_uri("gs://some-bucket") self.assertEquals("some-bucket", bucket) self.assertEquals("", path)
def build_new_release(args): # pylint: disable=too-many-locals """Find the latest release and build the artifacts if they are newer then the current release. """ if not args.src_dir: raise ValueError("src_dir must be provided when building last green.") gcs_client = storage.Client() sha = get_latest_green_presubmit(gcs_client) bucket_name, _ = util.split_gcs_uri(args.releases_path) bucket = gcs_client.get_bucket(bucket_name) logging.info("Latest passing postsubmit is %s", sha) last_release_sha = get_last_release(bucket) logging.info("Most recent release was for %s", last_release_sha) sha = build_and_push_image.GetGitHash(args.src_dir) if sha == last_release_sha: logging.info("Already cut release for %s", sha) return build(args)
def build_lastgreen(args): # pylint: disable=too-many-locals """Find the latest green postsubmit and build the artifacts. """ gcs_client = storage.Client() sha = get_latest_green_presubmit(gcs_client) bucket_name, _ = util.split_gcs_uri(args.releases_path) bucket = gcs_client.get_bucket(bucket_name) logging.info("Latest passing postsubmit is %s", sha) last_release_sha = get_last_release(bucket) logging.info("Most recent release was for %s", last_release_sha) if sha == last_release_sha: logging.info("Already cut release for %s", sha) return go_dir = tempfile.mkdtemp(prefix="tmpTfJobSrc") logging.info("Temporary go_dir: %s", go_dir) src_dir = os.path.join(go_dir, "src", "github.com", REPO_ORG, REPO_NAME) _, sha = util.clone_repo(src_dir, util.MASTER_REPO_OWNER, util.MASTER_REPO_NAME, sha) build_and_push(go_dir, src_dir, args)
def create_junit_xml_file(test_cases, output_path, gcs_client=None): """Create a JUnit XML file. The junit schema is specified here: https://www.ibm.com/support/knowledgecenter/en/SSQ2R2_9.5.0/com.ibm.rsar.analysis.codereview.cobol.doc/topics/cac_useresults_junit.html Args: test_cases: TestSuite or List of test case objects. output_path: Path to write the XML gcs_client: GCS client to use if output is GCS. """ total_time = 0 failures = 0 for c in test_cases: if c.time: total_time += c.time if c.failure: failures += 1 attrib = { "failures": "{0}".format(failures), "tests": "{0}".format(len(test_cases)), "time": "{0}".format(total_time) } root = ElementTree.Element("testsuite", attrib) for c in test_cases: attrib = { "classname": c.class_name, "name": c.name, } if c.time: attrib["time"] = "{0}".format(c.time) # If the time isn't set and no message is set we interpret that as # the test not being run. if not c.time and not c.failure: attrib["failure"] = "Test was not run." if c.failure: attrib["failure"] = c.failure e = ElementTree.Element("testcase", attrib) root.append(e) t = ElementTree.ElementTree(root) logging.info("Creating %s", output_path) if output_path.startswith("gs://"): b = six.StringIO() t.write(b) bucket_name, path = util.split_gcs_uri(output_path) bucket = gcs_client.get_bucket(bucket_name) blob = bucket.blob(path) blob.upload_from_string(b.getvalue()) else: t.write(output_path)
def upload_file_to_gcs(source, target): gcs_client = storage.Client() bucket_name, path = util.split_gcs_uri(target) bucket = gcs_client.get_bucket(bucket_name) logging.info("Uploading file %s to %s.", source, target) blob = bucket.blob(path) blob.upload_from_filename(source)
def upload_to_gcs(contents, target): gcs_client = storage.Client() bucket_name, path = util.split_gcs_uri(target) bucket = gcs_client.get_bucket(bucket_name) logging.info("Writing %s", target) blob = bucket.blob(path) blob.upload_from_string(contents)
def upload_outputs(gcs_client, output_dir, build_log): bucket_name, path = util.split_gcs_uri(output_dir) bucket = gcs_client.get_bucket(bucket_name) if not os.path.exists(build_log): logging.error("File %s doesn't exist.", build_log) else: logging.info("Uploading file %s.", build_log) blob = bucket.blob(os.path.join(path, "build-log.txt")) blob.upload_from_filename(build_log)
def create_junit_xml_file(test_cases, output_path, gcs_client=None): """Create a JUnit XML file. Args: test_cases: List of test case objects. output_path: Path to write the XML gcs_client: GCS client to use if output is GCS. """ total_time = 0 failures = 0 for c in test_cases: total_time += c.time if c.failure: failures += 1 attrib = { "failures": "{0}".format(failures), "tests": "{0}".format(len(test_cases)), "time": "{0}".format(total_time) } root = ElementTree.Element("testsuite", attrib) for c in test_cases: attrib = { "classname": c.class_name, "name": c.name, "time": "{0}".format(c.time), } if c.failure: attrib["failure"] = c.failure e = ElementTree.Element("testcase", attrib) root.append(e) t = ElementTree.ElementTree(root) logging.info("Creationg %s", output_path) if output_path.startswith("gs://"): b = six.StringIO() t.write(b) bucket_name, path = util.split_gcs_uri(output_path) bucket = gcs_client.get_bucket(bucket_name) blob = bucket.blob(path) blob.upload_from_string(b.getvalue()) else: t.write(output_path)
def write_build_info(build_info, paths, project=None): """Write the build info files. """ gcs_client = None contents = yaml.dump(build_info) for p in paths: logging.info("Writing build information to %s", p) if p.startswith("gs://"): if not gcs_client: gcs_client = storage.Client(project=project) bucket_name, path = util.split_gcs_uri(p) bucket = gcs_client.get_bucket(bucket_name) blob = bucket.blob(path) blob.upload_from_string(contents) else: with open(p, mode='w') as hf: hf.write(contents)
def check_no_errors(gcs_client, artifacts_dir, junit_files): """Check that all the XML files exist and there were no errors. Args: gcs_client: The GCS client. artifacts_dir: The directory where artifacts should be stored. junit_files: List of the names of the junit files. Returns: True if there were no errors and false otherwise. """ bucket_name, prefix = util.split_gcs_uri(artifacts_dir) bucket = gcs_client.get_bucket(bucket_name) no_errors = True # Get a list of actual junit files. actual_junit = _get_actual_junit_files(bucket, prefix) for f in junit_files: full_path = os.path.join(artifacts_dir, f) logging.info("Checking %s", full_path) b = bucket.blob(os.path.join(prefix, f)) if not b.exists(): logging.error("Missing %s", full_path) no_errors = False continue xml_contents = b.download_as_string() if test_util.get_num_failures(xml_contents) > 0: logging.info("Test failures in %s", full_path) no_errors = False # Check if there were any extra tests that ran and treat # that as a failure. extra = set(actual_junit) - set(junit_files) if extra: logging.error("Extra junit files found: %s", ",".join(extra)) no_errors = False return no_errors
def create_junit_xml_file(test_cases, output_path, gcs_client=None): """Create a JUnit XML file. The junit schema is specified here: https://www.ibm.com/support/knowledgecenter/en/SSQ2R2_9.5.0/com.ibm.rsar.analysis.codereview.cobol.doc/topics/cac_useresults_junit.html Args: test_cases: TestSuite or List of test case objects. output_path: Path to write the XML gcs_client: GCS client to use if output is GCS. """ t = create_xml(test_cases) logging.info("Creating %s", output_path) if output_path.startswith("gs://"): b = six.StringIO() t.write(b) bucket_name, path = util.split_gcs_uri(output_path) bucket = gcs_client.get_bucket(bucket_name) blob = bucket.blob(path) blob.upload_from_string(b.getvalue()) else: t.write(output_path)
def create_junit_xml_file(test_cases, output_path, gcs_client=None): """Create a JUnit XML file. The junit schema is specified here: https://www.ibm.com/support/knowledgecenter/en/SSQ2R2_9.5.0/com.ibm.rsar.analysis.codereview.cobol.doc/topics/cac_useresults_junit.html Args: test_cases: TestSuite or List of test case objects. output_path: Path to write the XML gcs_client: GCS client to use if output is GCS. """ t = create_xml(test_cases) logging.info("Creating %s", output_path) if output_path.startswith("gs://"): b = six.StringIO() t.write(b) bucket_name, path = util.split_gcs_uri(output_path) bucket = gcs_client.get_bucket(bucket_name) blob = bucket.blob(path) blob.upload_from_string(b.getvalue()) else: dir_name = os.path.dirname(output_path) if not os.path.exists(dir_name): logging.info("Creating directory %s", dir_name) try: os.makedirs(dir_name) except OSError as e: if e.errno == errno.EEXIST: # The path already exists. This is probably a race condition # with some other test creating the directory. # We should just be able to continue pass else: raise t.write(output_path)
def build_images(dag_run=None, ti=None, **_kwargs): # pylint: disable=too-many-statements """ Args: dag_run: A DagRun object. This is passed in as a result of setting provide_context to true for the operator. """ # Create a temporary directory suitable for checking out and building the # code. if not dag_run: # When running via airflow test dag_run isn't set logging.warn("Using fake dag_run") dag_run = FakeDagrun() logging.info("dag_id: %s", dag_run.dag_id) logging.info("run_id: %s", dag_run.run_id) run_dir = ti.xcom_pull(None, key="run_dir") logging.info("Using run_dir=%s", run_dir) src_dir = ti.xcom_pull(None, key="src_dir") logging.info("Using src_dir=%s", src_dir) gcs_path = run_path(dag_run.dag_id, dag_run.run_id) logging.info("gcs_path %s", gcs_path) conf = dag_run.conf if not conf: conf = {} logging.info("conf=%s", conf) artifacts_path = conf.get("ARTIFACTS_PATH", gcs_path) logging.info("artifacts_path %s", artifacts_path) # We use a GOPATH that is specific to this run because we don't want # interference from different runs. newenv = os.environ.copy() newenv["GOPATH"] = os.path.join(run_dir, "go") # Make sure pull_number is a string pull_number = "{0}".format(conf.get("PULL_NUMBER", "")) args = ["python", "-m", "py.release", "build", "--src_dir=" + src_dir] dryrun = bool(conf.get("dryrun", False)) build_info_file = os.path.join(gcs_path, "build_info.yaml") args.append("--build_info_path=" + build_info_file) args.append("--releases_path=" + gcs_path) args.append("--project=" + GCB_PROJECT) # We want subprocess output to bypass logging module otherwise multiline # output is squashed together. util.run(args, use_print=True, dryrun=dryrun, env=newenv) # Read the output yaml and publish relevant values to xcom. if not dryrun: gcs_client = storage.Client(project=GCB_PROJECT) logging.info("Reading %s", build_info_file) bucket_name, build_path = util.split_gcs_uri(build_info_file) bucket = gcs_client.get_bucket(bucket_name) blob = bucket.blob(build_path) contents = blob.download_as_string() build_info = yaml.load(contents) else: build_info = { "image": "gcr.io/dryrun/dryrun:latest", "commit": "1234abcd", "helm_chart": "gs://dryrun/dryrun.latest.", } for k, v in six.iteritems(build_info): logging.info("xcom push: %s=%s", k, v) ti.xcom_push(key=k, value=v)
def build_and_push_artifacts(go_dir, src_dir, registry, publish_path=None, gcb_project=None, build_info_path=None): """Build and push the artifacts. Args: go_dir: The GOPATH directory src_dir: The root directory where we checked out the repo. registry: Docker registry to use. publish_path: (Optional) The GCS path where artifacts should be published. Set to none to only build locally. gcb_project: The project to use with GCB to build docker images. If set to none uses docker to build. build_info_path: (Optional): GCS location to write YAML file containing information about the build. """ # Update the GOPATH to the temporary directory. env = os.environ.copy() if go_dir: env["GOPATH"] = go_dir bin_dir = os.path.join(src_dir, "bin") if not os.path.exists(bin_dir): os.makedirs(bin_dir) build_info = build_operator_image(src_dir, registry, project=gcb_project) # Copy the chart to a temporary directory because we will modify some # of its YAML files. chart_build_dir = tempfile.mkdtemp(prefix="tmpTFJobChartBuild") shutil.copytree(os.path.join(src_dir, "tf-job-operator-chart"), os.path.join(chart_build_dir, "tf-job-operator-chart")) version = build_info["image"].split(":")[-1] values_file = os.path.join(chart_build_dir, "tf-job-operator-chart", "values.yaml") update_values(values_file, build_info["image"]) chart_file = os.path.join(chart_build_dir, "tf-job-operator-chart", "Chart.yaml") update_chart(chart_file, version) # Delete any existing matches because we assume there is only 1 below. matches = glob.glob(os.path.join(bin_dir, "tf-job-operator-chart*.tgz")) for m in matches: logging.info("Delete previous build: %s", m) os.unlink(m) util.run([ "helm", "package", "--save=false", "--destination=" + bin_dir, "./tf-job-operator-chart" ], cwd=chart_build_dir) matches = glob.glob(os.path.join(bin_dir, "tf-job-operator-chart*.tgz")) if len(matches) != 1: raise ValueError( "Expected 1 chart archive to match but found {0}".format(matches)) chart_archive = matches[0] release_path = version targets = [ os.path.join(release_path, os.path.basename(chart_archive)), "latest/tf-job-operator-chart-latest.tgz", ] if publish_path: gcs_client = storage.Client(project=gcb_project) bucket_name, base_path = util.split_gcs_uri(publish_path) bucket = gcs_client.get_bucket(bucket_name) for t in targets: blob = bucket.blob(os.path.join(base_path, t)) gcs_path = util.to_gcs_uri(bucket_name, blob.name) if not t.startswith("latest"): build_info["helm_chart"] = gcs_path if blob.exists() and not t.startswith("latest"): logging.warn("%s already exists", gcs_path) continue logging.info("Uploading %s to %s.", chart_archive, gcs_path) blob.upload_from_filename(chart_archive) create_latest(bucket, build_info["commit"], util.to_gcs_uri(bucket_name, targets[0])) # Always write to the bin dir. paths = [os.path.join(bin_dir, "build_info.yaml")] if build_info_path: paths.append(build_info_path) write_build_info(build_info, paths, project=gcb_project)
def setup(args): """Setup a GKE cluster for TensorFlow jobs. Args: args: Command line arguments that control the setup process. """ gke = discovery.build("container", "v1") project = args.project cluster_name = args.cluster zone = args.zone chart = args.chart machine_type = "n1-standard-8" cluster_request = { "cluster": { "name": cluster_name, "description": "A GKE cluster for TF.", "initialNodeCount": 1, "nodeConfig": { "machineType": machine_type, "oauthScopes": [ "https://www.googleapis.com/auth/cloud-platform", ], }, # TODO(jlewi): Stop pinning GKE version once 1.8 becomes the default. "initialClusterVersion": "1.8.1-gke.1", } } if args.accelerators: # TODO(jlewi): Stop enabling Alpha once GPUs make it out of Alpha cluster_request["cluster"]["enableKubernetesAlpha"] = True cluster_request["cluster"]["nodeConfig"]["accelerators"] = [] for accelerator_spec in args.accelerators: accelerator_type, accelerator_count = accelerator_spec.split( "=", 1) cluster_request["cluster"]["nodeConfig"]["accelerators"].append({ "acceleratorCount": accelerator_count, "acceleratorType": accelerator_type, }) util.create_cluster(gke, project, zone, cluster_request) util.configure_kubectl(project, zone, cluster_name) util.load_kube_config() # Create an API client object to talk to the K8s master. api_client = k8s_client.ApiClient() util.setup_cluster(api_client) if chart.startswith("gs://"): remote = chart chart = os.path.join(tempfile.gettempdir(), os.path.basename(chart)) gcs_client = storage.Client(project=project) bucket_name, path = util.split_gcs_uri(remote) bucket = gcs_client.get_bucket(bucket_name) blob = bucket.blob(path) logging.info("Downloading %s to %s", remote, chart) blob.download_to_filename(chart) t = test_util.TestCase() try: start = time.time() util.run([ "helm", "install", chart, "-n", "tf-job", "--wait", "--replace", "--set", "rbac.install=true,cloud=gke" ]) except subprocess.CalledProcessError as e: t.failure = "helm install failed;\n" + e.output finally: t.time = time.time() - start t.name = "helm-tfjob-install" t.class_name = "GKE" test_util.create_junit_xml_file([t], args.junit_path, gcs_client)
def build_images(dag_run=None, ti=None, **_kwargs): # pylint: disable=too-many-statements """ Args: dag_run: A DagRun object. This is passed in as a result of setting provide_context to true for the operator. """ # Create a temporary directory suitable for checking out and building the # code. if not dag_run: # When running via airflow test dag_run isn't set logging.warn("Using fake dag_run") dag_run = FakeDagrun() logging.info("dag_id: %s", dag_run.dag_id) logging.info("run_id: %s", dag_run.run_id) gcs_path = run_path(dag_run.dag_id, dag_run.run_id) logging.info("gcs_path %s", gcs_path) conf = dag_run.conf if not conf: conf = {} logging.info("conf=%s", conf) artifacts_path = conf.get("ARTIFACTS_PATH", gcs_path) logging.info("artifacts_path %s", artifacts_path) # Make sure pull_number is a string pull_number = "{0}".format(conf.get("PULL_NUMBER", "")) args = ["python", "-m", "py.release"] if pull_number: commit = conf.get("PULL_PULL_SHA", "") args.append("pr") args.append("--pr=" + pull_number) if commit: args.append("--commit=" + commit) else: commit = conf.get("PULL_BASE_SHA", "") args.append("postsubmit") if commit: args.append("--commit=" + commit) dryrun = bool(conf.get("dryrun", False)) # Pick the directory where the source will be checked out. # This should be a persistent location that is accessible from subsequent # tasks; e.g. an NFS share or PD. src_dir = os.path.join(os.getenv("SRC_DIR", tempfile.gettempdir()), dag_run.dag_id.replace(":", "_"), dag_run.run_id.replace(":", "_")) logging.info("Using src_dir %s", src_dir) os.makedirs(src_dir) logging.info("xcom push: src_dir=%s", src_dir) ti.xcom_push(key="src_dir", value=src_dir) build_info_file = os.path.join(gcs_path, "build_info.yaml") args.append("--build_info_path=" + build_info_file) args.append("--releases_path=" + gcs_path) args.append("--project=" + GCB_PROJECT) args.append("--src_dir=" + src_dir) # We want subprocess output to bypass logging module otherwise multiline # output is squashed together. util.run(args, use_print=True, dryrun=dryrun) # Read the output yaml and publish relevant values to xcom. if not dryrun: gcs_client = storage.Client(project=GCB_PROJECT) logging.info("Reading %s", build_info_file) bucket_name, build_path = util.split_gcs_uri(build_info_file) bucket = gcs_client.get_bucket(bucket_name) blob = bucket.blob(build_path) contents = blob.download_as_string() build_info = yaml.load(contents) else: build_info = { "image": "gcr.io/dryrun/dryrun:latest", "commit": "1234abcd", "helm_chart": "gs://dryrun/dryrun.latest.", } for k, v in six.iteritems(build_info): logging.info("xcom push: %s=%s", k, v) ti.xcom_push(key=k, value=v)