def shell_run_with_retry(cmd, retries=1, **kwargs): for i in range(retries + 1): try: if i > 0: tf.logging.info("Retry %d for %s", i, cmd) cloud.shell_run(cmd, **kwargs) return except sp.CalledProcessError as e: if i == retries: raise e
def wait_for_ssh(ip): """Wait for SSH to be available at given IP address.""" i = 0 while True: try: cloud.shell_run(SSH_CHECK, ip=ip) break except sp.CalledProcessError: if i > 12: # ~2m return False time.sleep(10) i += 1 return True
def upload_trainer_package_to_gcs(train_dir): """Upload trainer package to GCS. Args: train_dir: The GCS directory in which to stage the trainer package. Returns: The path to the trainer package staged in GCS.""" tf.logging.info('Uploading trainer package to %s.', train_dir) src_base = '{}-{}.tar.gz'.format(PACKAGE_NAME, VERSION) package_path = os.path.join(os.getcwd(), 'dist', src_base) final_destination = os.path.join(train_dir, src_base) cloud.shell_run( ('gsutil cp {package_path} ' '{final_destination}'), package_path=package_path, final_destination=final_destination) return final_destination
def _tar_and_copy(src_dir, target_dir): """Tar and gzip src_dir and copy to GCS target_dir.""" src_dir = src_dir.rstrip("/") target_dir = target_dir.rstrip("/") tmp_dir = tempfile.gettempdir().rstrip("/") src_base = os.path.basename(src_dir) cloud.shell_run("tar -zcf {tmp_dir}/{src_base}.tar.gz -C {src_dir} .", src_dir=src_dir, src_base=src_base, tmp_dir=tmp_dir) final_destination = "%s/%s.tar.gz" % (target_dir, src_base) cloud.shell_run(("gsutil cp {tmp_dir}/{src_base}.tar.gz " "{final_destination}"), tmp_dir=tmp_dir, src_base=src_base, final_destination=final_destination) return final_destination
def delete_instance(instance_name): cloud.shell_run(DELETE, name=instance_name)