Exemplo n.º 1
0
    def __init__(self,
                 monitor='val_loss',
                 verbose=0,
                 mode='auto',
                 save_weights_only=False,
                 log_weights=False,
                 log_gradients=False,
                 save_model=True,
                 training_data=None,
                 validation_data=None,
                 labels=[],
                 data_type=None,
                 predictions=36,
                 generator=None):
        """Constructor.

        # Arguments
            monitor: quantity to monitor.
            mode: one of {auto, min, max}.
                'min' - save model when monitor is minimized
                'max' - save model when monitor is maximized
                'auto' - try to guess when to save the model
            save_weights_only: if True, then only the model's weights will be
                saved (`model.save_weights(filepath)`), else the full model
                is saved (`model.save(filepath)`).
            save_model:
                True - save a model when monitor beats all previous epochs
                False - don't save models
            log_weights: if True save the weights in wandb.history
            log_gradients: if True log the training gradients in wandb.history
            training_data: tuple (X,y) needed for calculating gradients
            data_type: the type of data we're saving, set to "image" for saving images
            labels: list of labels to convert numeric output to if you are building a 
                multiclass classifier.  If you are making a binary classifier you can pass in
                a list of two labels ["label for false", "label for true"]
            predictions: the number of predictions to make each epic if data_type is set, max is 100.
            generator: a generator to use for making predictions
        """
        if wandb.run is None:
            raise wandb.Error(
                'You must call wandb.init() before WandbCallback()')
        if validation_data is not None:
            wandb.termlog(
                "DEPRECATED: validation_data is pulled from the model definition, set data_type."
            )
            # For backwards compatability
            self.data_type = data_type or "image"

        self.labels = labels
        self.data_type = data_type
        self.predictions = min(predictions, 100)

        self.monitor = monitor
        self.verbose = verbose
        self.save_weights_only = save_weights_only

        wandb.save('model-best.h5')
        self.filepath = os.path.join(wandb.run.dir, 'model-best.h5')
        self.save_model = save_model
        self.log_weights = log_weights
        self.log_gradients = log_gradients
        self.training_data = training_data
        self.generator = generator

        if self.training_data:
            if len(self.training_data) != 2:
                raise ValueError("training data must be a tuple of length two")

        # From Keras
        if mode not in ['auto', 'min', 'max']:
            print('WandbCallback mode %s is unknown, '
                  'fallback to auto mode.' % (mode))
            mode = 'auto'

        if mode == 'min':
            self.monitor_op = operator.lt
            self.best = float('inf')
        elif mode == 'max':
            self.monitor_op = operator.gt
            self.best = float('-inf')
        else:
            if 'acc' in self.monitor or self.monitor.startswith('fmeasure'):
                self.monitor_op = operator.gt
                self.best = float('-inf')
            else:
                self.monitor_op = operator.lt
                self.best = float('inf')
Exemplo n.º 2
0
def docker(ctx, docker_run_args, docker_image, nvidia, digest, jupyter, dir,
           no_dir, shell, port, cmd, no_tty):
    """W&B docker lets you run your code in a docker image ensuring wandb is configured. It adds the WANDB_DOCKER and WANDB_API_KEY
    environment variables to your container and mounts the current directory in /app by default.  You can pass additional
    args which will be added to `docker run` before the image name is declared, we'll choose a default image for you if
    one isn't passed:

    wandb docker -v /mnt/dataset:/app/data
    wandb docker gcr.io/kubeflow-images-public/tensorflow-1.12.0-notebook-cpu:v0.4.0 --jupyter
    wandb docker wandb/deepo:keras-gpu --no-tty --cmd "python train.py --epochs=5"

    By default we override the entrypoint to check for the existance of wandb and install it if not present.  If you pass the --jupyter
    flag we will ensure jupyter is installed and start jupyter lab on port 8888.  If we detect nvidia-docker on your system we will use
    the nvidia runtime.  If you just want wandb to set environment variable to an existing docker run command, see the wandb docker-run 
    command.
    """
    if not find_executable('docker'):
        raise ClickException(
            "Docker not installed, install it from https://docker.com")
    args = list(docker_run_args)
    image = docker_image or ""
    # remove run for users used to nvidia-docker
    if len(args) > 0 and args[0] == "run":
        args.pop(0)
    if image == "" and len(args) > 0:
        image = args.pop(0)
    # If the user adds docker args without specifying an image (should be rare)
    if not util.docker_image_regex(image.split("@")[0]):
        if image:
            args = args + [image]
        image = wandb.docker.default_image(gpu=nvidia)
        subprocess.call(["docker", "pull", image])
    _, repo_name, tag = wandb.docker.parse(image)

    resolved_image = wandb.docker.image_id(image)
    if resolved_image is None:
        raise ClickException(
            "Couldn't find image locally or in a registry, try running `docker pull %s`"
            % image)
    if digest:
        sys.stdout.write(resolved_image)
        exit(0)

    existing = wandb.docker.shell(
        ["ps", "-f", "ancestor=%s" % resolved_image, "-q"])
    if existing:
        question = {
            'type':
            'confirm',
            'name':
            'attach',
            'message':
            "Found running container with the same image, do you want to attach?",
        }
        result = whaaaaat.prompt([question])
        if result and result['attach']:
            subprocess.call(['docker', 'attach', existing.split("\n")[0]])
            exit(0)
    cwd = os.getcwd()
    command = [
        'docker', 'run', '-e', 'LANG=C.UTF-8', '-e',
        'WANDB_DOCKER=%s' % resolved_image, '--ipc=host', '-v',
        wandb.docker.entrypoint + ':/wandb-entrypoint.sh', '--entrypoint',
        '/wandb-entrypoint.sh'
    ]
    if nvidia:
        command.extend(['--runtime', 'nvidia'])
    if not no_dir:
        #TODO: We should default to the working directory if defined
        command.extend(['-v', cwd + ":" + dir, '-w', dir])
    if api.api_key:
        command.extend(['-e', 'WANDB_API_KEY=%s' % api.api_key])
    else:
        wandb.termlog(
            "Couldn't find WANDB_API_KEY, run `wandb login` to enable streaming metrics"
        )
    if jupyter:
        command.extend(['-e', 'WANDB_ENSURE_JUPYTER=1', '-p', port + ':8888'])
        no_tty = True
        cmd = "jupyter lab --no-browser --ip=0.0.0.0 --allow-root --NotebookApp.token= --notebook-dir %s" % dir
    command.extend(args)
    if no_tty:
        command.extend([image, shell, "-c", cmd])
    else:
        if cmd:
            command.extend(['-e', 'WANDB_COMMAND=%s' % cmd])
        command.extend(['-it', image, shell])
        wandb.termlog("Launching docker container \U0001F6A2")
    subprocess.call(command)
Exemplo n.º 3
0
    def from_directory(cls,
                       directory,
                       project=None,
                       entity=None,
                       run_id=None,
                       api=None,
                       ignore_globs=None):
        api = api or InternalApi()
        run_id = run_id or util.generate_id()
        run = Run(run_id=run_id, dir=directory)
        project = project or api.settings("project") or run.auto_project_name(
            api=api)
        if project is None:
            raise ValueError("You must specify project")
        api.set_current_run_id(run_id)
        api.set_setting("project", project)
        if entity:
            api.set_setting("entity", entity)
        res = api.upsert_run(name=run_id, project=project, entity=entity)
        entity = res["project"]["entity"]["name"]
        wandb.termlog("Syncing {} to:".format(directory))
        wandb.termlog(res["displayName"] + " " + run.get_url(api))

        file_api = api.get_file_stream_api()
        file_api.start()
        snap = DirectorySnapshot(directory)
        paths = [
            os.path.relpath(abs_path, directory) for abs_path in snap.paths
            if os.path.isfile(abs_path)
        ]
        if ignore_globs:
            paths = set(paths)
            for g in ignore_globs:
                paths = paths - set(fnmatch.filter(paths, g))
            paths = list(paths)
        run_update = {"id": res["id"]}
        tfevents = sorted([p for p in snap.paths if ".tfevents." in p])
        history = next((p for p in snap.paths if HISTORY_FNAME in p), None)
        event = next((p for p in snap.paths if EVENTS_FNAME in p), None)
        config = next((p for p in snap.paths if CONFIG_FNAME in p), None)
        user_config = next((p for p in snap.paths if USER_CONFIG_FNAME in p),
                           None)
        summary = next((p for p in snap.paths if SUMMARY_FNAME in p), None)
        meta = next((p for p in snap.paths if METADATA_FNAME in p), None)
        if history:
            wandb.termlog("Uploading history metrics")
            file_api.stream_file(history)
            snap.paths.remove(history)
        elif len(tfevents) > 0:
            from wandb import tensorflow as wbtf
            wandb.termlog("Found tfevents file, converting...")
            summary = {}
            for path in tfevents:
                filename = os.path.basename(path)
                namespace = path.replace(filename,
                                         "").replace(directory,
                                                     "").strip(os.sep)
                summary.update(
                    wbtf.stream_tfevents(path,
                                         file_api,
                                         run,
                                         namespace=namespace))
            for path in glob.glob(os.path.join(directory, "media/**/*"),
                                  recursive=True):
                if os.path.isfile(path):
                    paths.append(path)
        else:
            wandb.termerror(
                "No history or tfevents files found, only syncing files")
        if event:
            file_api.stream_file(event)
            snap.paths.remove(event)
        if config:
            run_update["config"] = util.load_yaml(open(config))
        elif user_config:
            # TODO: half backed support for config.json
            run_update["config"] = {
                k: {
                    "value": v
                }
                for k, v in six.iteritems(user_config)
            }
        if isinstance(summary, dict):
            #TODO: summary should already have data_types converted here...
            run_update["summary_metrics"] = util.json_dumps_safer(summary)
        elif summary:
            run_update["summary_metrics"] = open(summary).read()
        if meta:
            meta = json.load(open(meta))
            if meta.get("git"):
                run_update["commit"] = meta["git"].get("commit")
                run_update["repo"] = meta["git"].get("remote")
            run_update["host"] = meta["host"]
            run_update["program_path"] = meta["program"]
            run_update["job_type"] = meta.get("jobType")
        else:
            run_update["host"] = socket.gethostname()

        wandb.termlog("Updating run and uploading files")
        api.upsert_run(**run_update)
        pusher = FilePusher(api)
        for k in paths:
            path = os.path.abspath(os.path.join(directory, k))
            pusher.update_file(k, path)
            pusher.file_changed(k, path)
        pusher.finish()
        pusher.print_status()
        file_api.finish(0)
        # Remove temporary media images generated from tfevents
        if history is None and os.path.exists(os.path.join(directory,
                                                           "media")):
            shutil.rmtree(os.path.join(directory, "media"))
        wandb.termlog("Finished!")
        return run
Exemplo n.º 4
0
def heatmap(x_labels, y_labels, matrix_values, show_text=False):
    """
        Generates a heatmap.

        Arguments:
         matrix_values (arr): 2D dataset of shape x_labels * y_labels, containing
                            heatmap values that can be coerced into an ndarray.
         x_labels  (list): Named labels for rows (x_axis).
         y_labels  (list): Named labels for columns (y_axis).
         show_text (bool): Show text values in heatmap cells.

        Returns:
         Nothing. To see plots, go to your W&B run page then expand the 'media' tab
               under 'auto visualizations'.

        Example:
         wandb.log({'heatmap': wandb.plots.HeatMap(x_labels, y_labels,
                    matrix_values)})
        """
    deprecation_notice()

    np = util.get_module(
        "numpy",
        required=
        "roc requires the numpy library, install with `pip install numpy`",
    )
    scikit = util.get_module(
        "sklearn",
        required=
        "roc requires the scikit library, install with `pip install scikit-learn`",
    )

    if test_missing(x_labels=x_labels,
                    y_labels=y_labels,
                    matrix_values=matrix_values) and test_types(
                        x_labels=x_labels,
                        y_labels=y_labels,
                        matrix_values=matrix_values):
        matrix_values = np.array(matrix_values)
        wandb.termlog("Visualizing heatmap.")

        def heatmap_table(x_labels, y_labels, matrix_values, show_text):
            x_axis = []
            y_axis = []
            values = []
            count = 0
            for i, x in enumerate(x_labels):
                for j, y in enumerate(y_labels):
                    x_axis.append(x)
                    y_axis.append(y)
                    values.append(matrix_values[j][i])
                    count += 1
                    if count >= chart_limit:
                        wandb.termwarn(
                            "wandb uses only the first %d datapoints to create the plots."
                            % wandb.Table.MAX_ROWS)
                        break
            if show_text:
                heatmap_key = "wandb/heatmap/v1"
            else:
                heatmap_key = "wandb/heatmap_no_text/v1"
            return wandb.visualize(
                heatmap_key,
                wandb.Table(
                    columns=["x_axis", "y_axis", "values"],
                    data=[[x_axis[i], y_axis[i],
                           round(values[i], 2)] for i in range(len(x_axis))],
                ),
            )

        return heatmap_table(x_labels, y_labels, matrix_values, show_text)
Exemplo n.º 5
0
    def run(self, launch_project: LaunchProject) -> Optional[AbstractRun]:
        _logger.info("using AWSSagemakerRunner")

        boto3 = get_module(
            "boto3", "AWSSagemakerRunner requires boto3 to be installed")

        validate_docker_installation()
        given_sagemaker_args = launch_project.resource_args.get("sagemaker")
        if given_sagemaker_args is None:
            raise LaunchError(
                "No sagemaker args specified. Specify sagemaker args in resource_args"
            )
        if (given_sagemaker_args.get("EcrRepoName",
                                     given_sagemaker_args.get("ecr_repo_name"))
                is None):
            raise LaunchError(
                "AWS sagemaker requires an ECR Repo to push the container to "
                "set this by adding a `EcrRepoName` key to the sagemaker"
                "field of resource_args")

        region = get_region(given_sagemaker_args)
        access_key, secret_key = get_aws_credentials(given_sagemaker_args)
        client = boto3.client("sts",
                              aws_access_key_id=access_key,
                              aws_secret_access_key=secret_key)
        account_id = client.get_caller_identity()["Account"]

        # if the user provided the image they want to use, use that, but warn it won't have swappable artifacts
        if (given_sagemaker_args.get("AlgorithmSpecification",
                                     {}).get("TrainingImage") is not None):
            wandb.termwarn(
                "Launching sagemaker job with user provided ECR image, this image will not be able to swap artifacts"
            )
            sagemaker_client = boto3.client(
                "sagemaker",
                region_name=region,
                aws_access_key_id=access_key,
                aws_secret_access_key=secret_key,
            )
            sagemaker_args = build_sagemaker_args(launch_project, account_id)
            _logger.info(
                f"Launching sagemaker job on user supplied image with args: {sagemaker_args}"
            )
            run = launch_sagemaker_job(launch_project, sagemaker_args,
                                       sagemaker_client)
            if self.backend_config[PROJECT_SYNCHRONOUS]:
                run.wait()
            return run

        _logger.info("Connecting to AWS ECR Client")
        ecr_client = boto3.client(
            "ecr",
            region_name=region,
            aws_access_key_id=access_key,
            aws_secret_access_key=secret_key,
        )
        token = ecr_client.get_authorization_token()

        ecr_repo_name = given_sagemaker_args.get(
            "EcrRepoName", given_sagemaker_args.get("ecr_repo_name"))
        aws_registry = (token["authorizationData"][0]["proxyEndpoint"].replace(
            "https://", "") + f"/{ecr_repo_name}")

        if self.backend_config[PROJECT_DOCKER_ARGS]:
            wandb.termwarn(
                "Docker args are not supported for Sagemaker Resource. Not using docker args"
            )

        entry_point = launch_project.get_single_entry_point()

        if launch_project.docker_image:
            _logger.info("Pulling user provided docker image")
            pull_docker_image(launch_project.docker_image)
        else:
            # build our own image
            image_uri = construct_local_image_uri(launch_project)
            _logger.info("Building docker image")
            image = generate_docker_image(self._api, launch_project, image_uri,
                                          entry_point, {}, "sagemaker")

        _logger.info("Logging in to AWS ECR")
        login_resp = aws_ecr_login(region, aws_registry)
        if login_resp is None or "Login Succeeded" not in login_resp:
            raise LaunchError(
                f"Unable to login to ECR, response: {login_resp}")

        aws_tag = f"{aws_registry}:{launch_project.run_id}"
        docker.tag(image, aws_tag)

        wandb.termlog(f"Pushing image {image} to registry {aws_registry}")
        push_resp = docker.push(aws_registry, launch_project.run_id)
        if push_resp is None:
            raise LaunchError("Failed to push image to repository")
        if f"The push refers to repository [{aws_registry}]" not in push_resp:
            raise LaunchError(
                f"Unable to push image to ECR, response: {push_resp}")

        if self.backend_config.get("runQueueItemId"):
            try:
                self._api.ack_run_queue_item(
                    self.backend_config["runQueueItemId"],
                    launch_project.run_id)
            except CommError:
                wandb.termerror(
                    "Error acking run queue item. Item lease may have ended or another process may have acked it."
                )
                return None
        _logger.info("Connecting to sagemaker client")

        sagemaker_client = boto3.client(
            "sagemaker",
            region_name=region,
            aws_access_key_id=access_key,
            aws_secret_access_key=secret_key,
        )

        command_args = get_entry_point_command(entry_point,
                                               launch_project.override_args)
        command_args = list(
            itertools.chain(*[ca.split(" ") for ca in command_args]))
        wandb.termlog("Launching run on sagemaker with entrypoint: {}".format(
            " ".join(command_args)))

        sagemaker_args = build_sagemaker_args(launch_project, account_id,
                                              aws_tag)
        _logger.info(f"Launching sagemaker job with args: {sagemaker_args}")
        run = launch_sagemaker_job(launch_project, sagemaker_args,
                                   sagemaker_client)
        if self.backend_config[PROJECT_SYNCHRONOUS]:
            run.wait()
        return run
Exemplo n.º 6
0
def prompt_api_key(api,
                   input_callback=None,
                   browser_callback=None,
                   no_offline=False):
    input_callback = input_callback or getpass.getpass

    choices = [choice for choice in LOGIN_CHOICES]
    if os.environ.get(env.ANONYMOUS, "never") == "never":
        # Omit LOGIN_CHOICE_ANON as a choice if the env var is set to never
        choices.remove(LOGIN_CHOICE_ANON)
    if os.environ.get(env.JUPYTER, "false") == "true" or no_offline:
        choices.remove(LOGIN_CHOICE_DRYRUN)

    if os.environ.get(env.ANONYMOUS) == "must":
        result = LOGIN_CHOICE_ANON
    # If we're not in an interactive environment, default to dry-run.
    elif not isatty(sys.stdout) or not isatty(sys.stdin):
        result = LOGIN_CHOICE_DRYRUN
    else:
        for i, choice in enumerate(choices):
            wandb.termlog("(%i) %s" % (i + 1, choice))

        def prompt_choice():
            try:
                return int(
                    six.moves.input(
                        "%s: Enter your choice: " % wandb.core.LOG_STRING)) - 1
            except ValueError:
                return -1

        idx = -1
        while idx < 0 or idx > len(choices) - 1:
            idx = prompt_choice()
            if idx < 0 or idx > len(choices) - 1:
                wandb.termwarn("Invalid choice")
        result = choices[idx]
        wandb.termlog("You chose '%s'" % result)

    if result == LOGIN_CHOICE_ANON:
        key = api.create_anonymous_api_key()

        set_api_key(api, key, anonymous=True)
        return key
    elif result == LOGIN_CHOICE_NEW:
        key = browser_callback(signup=True) if browser_callback else None

        if not key:
            wandb.termlog(
                'Create an account here: {}/authorize?signup=true'.format(
                    api.app_url))
            key = input_callback(
                '%s: Paste an API key from your profile and hit enter' %
                wandb.core.LOG_STRING).strip()

        set_api_key(api, key)
        return key
    elif result == LOGIN_CHOICE_EXISTS:
        key = browser_callback() if browser_callback else None

        if not key:
            wandb.termlog(
                'You can find your API key in your browser here: {}/authorize'.
                format(api.app_url))
            key = input_callback(
                '%s: Paste an API key from your profile and hit enter' %
                wandb.core.LOG_STRING).strip()
        set_api_key(api, key)
        return key
    else:
        # Jupyter environments don't have a tty, but we can still try logging in using the browser callback if one
        # is supplied.
        key, anonymous = browser_callback() if os.environ.get(
            env.JUPYTER, "false") == "true" and browser_callback else (None,
                                                                       False)

        set_api_key(api, key, anonymous=anonymous)
        return key
Exemplo n.º 7
0
def patch(save=True,
          tensorboardX=TENSORBOARDX_LOADED,
          pytorch=PYTORCH_TENSORBOARD):
    """Monkeypatches tensorboard or tensorboardX so that all events are logged to tfevents files and wandb.
    We save the tfevents files and graphs to wandb by default.

    Arguments:
        save, default: True - Passing False will skip storing tfevent files.
        tensorboardX, default: True if module can be imported - You can override this when calling patch
    """

    if len(wandb.patched["tensorboard"]) > 0:
        raise ValueError(
            "Tensorboard already patched, remove sync_tensorboard=True from wandb.init or only call wandb.tensorboard.patch once."
        )
    elif Summary is None:
        raise ValueError(
            "Couldn't import tensorboard or tensorflow, ensure you have have tensorboard installed."
        )

    if tensorboardX:
        tensorboard_py_module = "tensorboardX.writer"
        if TENSORFLOW_LOADED:
            wandb.termlog(
                "Found tensorboardX and tensorflow, pass tensorboardX=False to patch regular tensorboard."
            )
    else:
        if wandb.util.get_module(
                "tensorboard.summary.writer.event_file_writer") and pytorch:
            # If we haven't imported tensorflow, let's patch the python tensorboard writer
            tensorboard_py_module = TENSORBOARD_PYTORCH_MODULE
        else:
            # If we're using tensorflow >= 2.0 this patch won't be used, but we'll do it anyway
            tensorboard_py_module = TENSORBOARD_LEGACY_MODULE
    writers = set()
    writer = wandb.util.get_module(tensorboard_py_module)

    def add_event(orig_event):
        """TensorboardX, TensorFlow <= 1.14 patch, and Tensorboard Patch"""
        def _add_event(self, event):
            """Add event monkeypatch for python event writers"""
            orig_event(self, event)
            try:
                if hasattr(self, "_file_name"):
                    # Current Tensorboard
                    name = self._file_name
                elif hasattr(self, "_ev_writer"):
                    if hasattr(self._ev_writer, "FileName"):
                        # Legacy Tensorflow
                        try:
                            name = self._ev_writer.FileName().decode("utf-8")
                        except AttributeError:
                            name = self._ev_writer.FileName()
                    elif hasattr(self._ev_writer, "_file_name"):
                        # Current TensorboardX
                        name = self._ev_writer._file_name
                    else:
                        # Legacy TensorboardX
                        name = self._ev_writer._file_prefix
                else:
                    wandb.termerror(
                        "Couldn't patch tensorboard, email [email protected] with the tensorboard version you're using."
                    )
                    writer.EventFileWriter.add_event = orig_event
                    return None
                writers.add(name)
                # This is a little hacky, there is a case where the log_dir changes.
                # Because the events files will have the same names in sub directories
                # we simply overwrite the previous symlink in wandb.save if the log_dir
                # changes.
                log_dir = os.path.dirname(os.path.commonprefix(list(writers)))
                filename = os.path.basename(name)
                # Tensorboard loads all tfevents files in a directory and prepends
                # their values with the path.  Passing namespace to log allows us
                # to nest the values in wandb
                namespace = name.replace(filename,
                                         "").replace(log_dir, "").strip(os.sep)
                if save:
                    wandb.save(name, base_path=log_dir)
                    for path in glob.glob(os.path.join(log_dir, "*.pbtxt")):
                        if os.stat(path).st_mtime >= wandb.START_TIME:
                            wandb.save(path, base_path=log_dir)
                log(event, namespace=namespace, step=event.step)
            except Exception as e:
                wandb.termerror("Unable to log event %s" % e)

        return _add_event

    if writer:
        # This is for TensorboardX and PyTorch 1.1 python tensorboard logging
        writer.EventFileWriter.orig_add_event = writer.EventFileWriter.add_event
        writer.EventFileWriter.add_event = add_event(
            writer.EventFileWriter.add_event)
        wandb.patched["tensorboard"].append(
            [tensorboard_py_module, "EventFileWriter.add_event"])

    # This configures TensorFlow 2 style Tensorboard logging
    c_writer = wandb.util.get_module(TENSORBOARD_C_MODULE)
    if c_writer:
        old_csfw_func = c_writer.create_summary_file_writer

        def new_csfw_func(*args, **kwargs):
            """Tensorboard 2+ monkeypatch for streaming events from the filesystem"""
            logdir = kwargs['logdir'].numpy().decode("utf8") if hasattr(
                kwargs['logdir'], 'numpy') else kwargs['logdir']
            wandb.run.send_message(
                {"tensorboard": {
                    "logdir": logdir,
                    "save": save
                }})
            return old_csfw_func(*args, **kwargs)

        c_writer.orig_create_summary_file_writer = old_csfw_func
        c_writer.create_summary_file_writer = new_csfw_func
        wandb.patched["tensorboard"].append(
            [TENSORBOARD_C_MODULE, "create_summary_file_writer"])
Exemplo n.º 8
0
def local(ctx, port, env, daemon, upgrade, edge):
    api = InternalApi()
    if not find_executable("docker"):
        raise ClickException("Docker not installed, install it from https://docker.com")
    if wandb.docker.image_id("wandb/local") != wandb.docker.image_id_from_registry(
        "wandb/local"
    ):
        if upgrade:
            subprocess.call(["docker", "pull", "wandb/local"])
        else:
            wandb.termlog(
                "A new version of W&B local is available, upgrade by calling `wandb local --upgrade`"
            )
    running = subprocess.check_output(
        ["docker", "ps", "--filter", "name=wandb-local", "--format", "{{.ID}}"]
    )
    if running != b"":
        if upgrade:
            subprocess.call(["docker", "stop", "wandb-local"])
        else:
            wandb.termerror(
                "A container named wandb-local is already running, run `docker stop wandb-local` if you want to start a new instance"
            )
            exit(1)
    image = "docker.pkg.github.com/wandb/core/local" if edge else "wandb/local"
    username = getpass.getuser()
    env_vars = ["-e", "LOCAL_USERNAME=%s" % username]
    for e in env:
        env_vars.append("-e")
        env_vars.append(e)
    command = [
        "docker",
        "run",
        "--rm",
        "-v",
        "wandb:/vol",
        "-p",
        port + ":8080",
        "--name",
        "wandb-local",
    ] + env_vars
    host = "http://localhost:%s" % port
    api.set_setting("base_url", host, globally=True, persist=True)
    if daemon:
        command += ["-d"]
    command += [image]

    # DEVNULL is only in py3
    try:
        from subprocess import DEVNULL
    except ImportError:
        DEVNULL = open(os.devnull, "wb")  # noqa: N806
    code = subprocess.call(command, stdout=DEVNULL)
    if daemon:
        if code != 0:
            wandb.termerror(
                "Failed to launch the W&B local container, see the above error."
            )
            exit(1)
        else:
            wandb.termlog("W&B local started at http://localhost:%s \U0001F680" % port)
            wandb.termlog(
                "You can stop the server by running `docker stop wandb-local`"
            )
            if not api.api_key:
                # Let the server start before potentially launching a browser
                time.sleep(2)
                ctx.invoke(login, host=host)
Exemplo n.º 9
0
    def __init__(self,
                 api,
                 run,
                 project=None,
                 tags=[],
                 cloud=True,
                 job_type="train",
                 port=None):
        self._api = api
        self._run = run
        self._cloud = cloud
        self._port = port

        self._project = project if project else api.settings("project")
        self._tags = tags
        self._watch_dir = self._run.dir

        logger.debug("Initialized sync for %s/%s", self._project, self._run.id)

        self._handler = PatternMatchingEventHandler()
        self._handler.on_created = self.on_file_created
        self._handler.on_modified = self.on_file_modified
        self.url = self._run.get_url(api)
        self._observer = Observer()

        self._observer.schedule(self._handler, self._watch_dir, recursive=True)

        self._config = run.config

        self._stats = stats.Stats()
        # This starts a thread to write system stats every 30 seconds
        self._system_stats = stats.SystemStats(run)
        self._meta = meta.Meta(api, self._run.dir)
        self._meta.data["jobType"] = job_type
        if self._run.program:
            self._meta.data["program"] = self._run.program

        def push_function(save_name, path):
            with open(path, 'rb') as f:
                self._api.push(self._project, {save_name: f},
                               run=self._run.id,
                               progress=lambda _, total: self._stats.
                               update_progress(path, total))

        self._file_pusher = file_pusher.FilePusher(push_function)

        self._event_handlers = {}

        self._handler._patterns = [
            os.path.join(self._watch_dir, os.path.normpath('*'))
        ]
        # Ignore hidden files/folders and output.log because we stream it specially
        self._handler._ignore_patterns = [
            '*/.*', '*.tmp',
            os.path.join(self._run.dir, OUTPUT_FNAME)
        ]

        self._socket = wandb_socket.Client(self._port)

        if self._cloud:
            self._observer.start()

            self._api.save_patches(self._watch_dir)

            wandb.termlog("Syncing %s" % self.url)
            wandb.termlog('Run directory: %s' % os.path.relpath(run.dir))
            wandb.termlog()

            self._api.get_file_stream_api().set_file_policy(
                OUTPUT_FNAME, CRDedupeFilePolicy())
Exemplo n.º 10
0
 def _summary():
     sync_items = get_runs()
     synced = []
     unsynced = []
     for item in sync_items:
         (synced if item.synced else unsynced).append(item)
     if synced:
         wandb.termlog("Number of synced runs: {}".format(len(synced)))
     if unsynced:
         wandb.termlog("Number of runs to be synced: {}".format(len(unsynced)))
         if show and show < len(unsynced):
             wandb.termlog("Showing {} unsynced runs:".format(show))
         for item in unsynced[: (show or len(unsynced))]:
             wandb.termlog("  {}".format(item))
     if synced:
         if not clean:
             wandb.termlog(
                 "NOTE: use sync --clean to cleanup synced runs from local directory."
             )
     if unsynced:
         if not path and not sync_all:
             wandb.termlog("NOTE: use sync --sync-all to sync all unsynced runs.")
Exemplo n.º 11
0
def sweep(
    ctx,
    project,
    entity,
    controller,
    verbose,
    name,
    program,
    settings,
    update,
    config_yaml,
):  # noqa: C901
    def _parse_settings(settings):
        """settings could be json or comma seperated assignments."""
        ret = {}
        # TODO(jhr): merge with magic:_parse_magic
        if settings.find("=") > 0:
            for item in settings.split(","):
                kv = item.split("=")
                if len(kv) != 2:
                    wandb.termwarn(
                        "Unable to parse sweep settings key value pair", repeat=False
                    )
                ret.update(dict([kv]))
            return ret
        wandb.termwarn("Unable to parse settings parameter", repeat=False)
        return ret

    api = InternalApi()
    if api.api_key is None:
        wandb.termlog("Login to W&B to use the sweep feature")
        ctx.invoke(login, no_offline=True)

    sweep_obj_id = None
    if update:
        parts = dict(entity=entity, project=project, name=update)
        err = util.parse_sweep_id(parts)
        if err:
            wandb.termerror(err)
            return
        entity = parts.get("entity") or entity
        project = parts.get("project") or project
        sweep_id = parts.get("name") or update
        found = api.sweep(sweep_id, "{}", entity=entity, project=project)
        if not found:
            wandb.termerror(
                "Could not find sweep {}/{}/{}".format(entity, project, sweep_id)
            )
            return
        sweep_obj_id = found["id"]

    wandb.termlog(
        "{} sweep from: {}".format(
            "Updating" if sweep_obj_id else "Creating", config_yaml
        )
    )
    try:
        yaml_file = open(config_yaml)
    except OSError:
        wandb.termerror("Couldn't open sweep file: %s" % config_yaml)
        return
    try:
        config = util.load_yaml(yaml_file)
    except yaml.YAMLError as err:
        wandb.termerror("Error in configuration file: %s" % err)
        return
    if config is None:
        wandb.termerror("Configuration file is empty")
        return

    # Set or override parameters
    if name:
        config["name"] = name
    if program:
        config["program"] = program
    if settings:
        settings = _parse_settings(settings)
        if settings:
            config.setdefault("settings", {})
            config["settings"].update(settings)
    if controller:
        config.setdefault("controller", {})
        config["controller"]["type"] = "local"

    is_local = config.get("controller", {}).get("type") == "local"
    if is_local:
        tuner = wandb_controller.controller()
        err = tuner._validate(config)
        if err:
            wandb.termerror("Error in sweep file: %s" % err)
            return

    env = os.environ
    entity = entity or env.get("WANDB_ENTITY") or config.get("entity")
    project = (
        project
        or env.get("WANDB_PROJECT")
        or config.get("project")
        or util.auto_project_name(config.get("program"))
    )
    sweep_id = api.upsert_sweep(
        config, project=project, entity=entity, obj_id=sweep_obj_id
    )
    wandb.termlog(
        "{} sweep with ID: {}".format(
            "Updated" if sweep_obj_id else "Created", click.style(sweep_id, fg="yellow")
        )
    )
    sweep_url = wandb_controller._get_sweep_url(api, sweep_id)
    if sweep_url:
        wandb.termlog(
            "View sweep at: {}".format(
                click.style(sweep_url, underline=True, fg="blue")
            )
        )

    # reprobe entity and project if it was autodetected by upsert_sweep
    entity = entity or env.get("WANDB_ENTITY")
    project = project or env.get("WANDB_PROJECT")

    if entity and project:
        sweep_path = "{}/{}/{}".format(entity, project, sweep_id)
    elif project:
        sweep_path = "{}/{}".format(project, sweep_id)
    else:
        sweep_path = sweep_id

    if sweep_path.find(" ") >= 0:
        sweep_path = '"{}"'.format(sweep_path)

    wandb.termlog(
        "Run sweep agent with: {}".format(
            click.style("wandb agent %s" % sweep_path, fg="yellow")
        )
    )
    if controller:
        wandb.termlog("Starting wandb controller...")
        tuner = wandb_controller.controller(sweep_id)
        tuner.run(verbose=verbose)
Exemplo n.º 12
0
def sync(
    ctx,
    path=None,
    view=None,
    verbose=None,
    run_id=None,
    project=None,
    entity=None,
    include_globs=None,
    exclude_globs=None,
    include_online=None,
    include_offline=None,
    include_synced=None,
    mark_synced=None,
    sync_all=None,
    ignore=None,
    show=None,
    clean=None,
    clean_old_hours=24,
    clean_force=None,
):
    api = InternalApi()
    if api.api_key is None:
        wandb.termlog("Login to W&B to sync offline runs")
        ctx.invoke(login, no_offline=True)

    if ignore:
        exclude_globs = ignore
    if include_globs:
        include_globs = include_globs.split(",")
    if exclude_globs:
        exclude_globs = exclude_globs.split(",")

    def _summary():
        sync_items = get_runs()
        synced = []
        unsynced = []
        for item in sync_items:
            (synced if item.synced else unsynced).append(item)
        if synced:
            wandb.termlog("Number of synced runs: {}".format(len(synced)))
        if unsynced:
            wandb.termlog("Number of runs to be synced: {}".format(len(unsynced)))
            if show and show < len(unsynced):
                wandb.termlog("Showing {} unsynced runs:".format(show))
            for item in unsynced[: (show or len(unsynced))]:
                wandb.termlog("  {}".format(item))
        if synced:
            if not clean:
                wandb.termlog(
                    "NOTE: use sync --clean to cleanup synced runs from local directory."
                )
        if unsynced:
            if not path and not sync_all:
                wandb.termlog("NOTE: use sync --sync-all to sync all unsynced runs.")

    def _sync_path(path):
        if run_id and len(path) > 1:
            wandb.termerror("id can only be set for a single run.")
            sys.exit(1)
        sm = SyncManager(
            project=project,
            entity=entity,
            run_id=run_id,
            mark_synced=mark_synced,
            app_url=api.app_url,
            view=view,
            verbose=verbose,
        )
        for p in path:
            sm.add(p)
        sm.start()
        while not sm.is_done():
            _ = sm.poll()
            # print(status)

    def _sync_all():
        sync_items = get_runs(
            include_online=include_online,
            include_offline=include_offline,
            include_synced=include_synced,
            exclude_globs=exclude_globs,
            include_globs=include_globs,
        )
        if not sync_items:
            wandb.termerror("Nothing to sync.")
        else:
            _sync_path(sync_items)

    def _clean():
        if path:
            runs = list(map(get_run_from_path, path))
            if not clean_force:
                click.confirm(
                    click.style(
                        "Are you sure you want to remove %i runs?" % len(runs),
                        bold=True,
                    ),
                    abort=True,
                )
            for run in runs:
                shutil.rmtree(run.path)
            click.echo(click.style("Success!", fg="green"))
            return
        runs = get_runs(
            include_online=True,
            include_offline=True,
            include_synced=True,
            include_unsynced=False,
            exclude_globs=exclude_globs,
            include_globs=include_globs,
        )
        since = datetime.datetime.now() - datetime.timedelta(hours=clean_old_hours)
        old_runs = [run for run in runs if run.datetime < since]
        old_runs.sort(key=lambda run: run.datetime)
        if old_runs:
            click.echo(
                "Found {} runs, {} are older than {} hours".format(
                    len(runs), len(old_runs), clean_old_hours
                )
            )
            for run in old_runs:
                click.echo(run.path)
            if not clean_force:
                click.confirm(
                    click.style(
                        "Are you sure you want to remove %i runs?" % len(old_runs),
                        bold=True,
                    ),
                    abort=True,
                )
            for run in old_runs:
                shutil.rmtree(run.path)
            click.echo(click.style("Success!", fg="green"))
        else:
            click.echo(
                click.style(
                    "No runs older than %i hours found" % clean_old_hours, fg="red"
                )
            )

    if sync_all:
        _sync_all()
    elif clean:
        _clean()
    elif path:
        _sync_path(path)
    else:
        _summary()
Exemplo n.º 13
0
def restore(ctx, run, no_git, branch, project, entity):
    from wandb.old.core import wandb_dir

    api = _get_cling_api()
    if ":" in run:
        if "/" in run:
            entity, rest = run.split("/", 1)
        else:
            rest = run
        project, run = rest.split(":", 1)
    elif run.count("/") > 1:
        entity, run = run.split("/", 1)

    project, run = api.parse_slug(run, project=project)
    commit, json_config, patch_content, metadata = api.run_config(
        project, run=run, entity=entity
    )
    print(metadata)
    repo = metadata.get("git", {}).get("repo")
    image = metadata.get("docker")
    restore_message = (
        """`wandb restore` needs to be run from the same git repository as the original run.
Run `git clone %s` and restore from there or pass the --no-git flag."""
        % repo
    )
    if no_git:
        commit = None
    elif not api.git.enabled:
        if repo:
            raise ClickException(restore_message)
        elif image:
            wandb.termlog(
                "Original run has no git history.  Just restoring config and docker"
            )

    if commit and api.git.enabled:
        subprocess.check_call(["git", "fetch", "--all"])
        try:
            api.git.repo.commit(commit)
        except ValueError:
            wandb.termlog("Couldn't find original commit: {}".format(commit))
            commit = None
            files = api.download_urls(project, run=run, entity=entity)
            for filename in files:
                if filename.startswith("upstream_diff_") and filename.endswith(
                    ".patch"
                ):
                    commit = filename[len("upstream_diff_") : -len(".patch")]
                    try:
                        api.git.repo.commit(commit)
                    except ValueError:
                        commit = None
                    else:
                        break

            if commit:
                wandb.termlog("Falling back to upstream commit: {}".format(commit))
                patch_path, _ = api.download_write_file(files[filename])
            else:
                raise ClickException(restore_message)
        else:
            if patch_content:
                patch_path = os.path.join(wandb_dir(), "diff.patch")
                with open(patch_path, "w") as f:
                    f.write(patch_content)
            else:
                patch_path = None

        branch_name = "wandb/%s" % run
        if branch and branch_name not in api.git.repo.branches:
            api.git.repo.git.checkout(commit, b=branch_name)
            wandb.termlog("Created branch %s" % click.style(branch_name, bold=True))
        elif branch:
            wandb.termlog(
                "Using existing branch, run `git branch -D %s` from master for a clean checkout"
                % branch_name
            )
            api.git.repo.git.checkout(branch_name)
        else:
            wandb.termlog("Checking out %s in detached mode" % commit)
            api.git.repo.git.checkout(commit)

        if patch_path:
            # we apply the patch from the repository root so git doesn't exclude
            # things outside the current directory
            root = api.git.root
            patch_rel_path = os.path.relpath(patch_path, start=root)
            # --reject is necessary or else this fails any time a binary file
            # occurs in the diff
            # we use .call() instead of .check_call() for the same reason
            # TODO(adrian): this means there is no error checking here
            subprocess.call(["git", "apply", "--reject", patch_rel_path], cwd=root)
            wandb.termlog("Applied patch")

    util.mkdir_exists_ok(wandb_dir())
    config_path = os.path.join(wandb_dir(), "config.yaml")
    config = Config()
    for k, v in json_config.items():
        if k not in ("_wandb", "wandb_version"):
            config[k] = v
    s = b"wandb_version: 1"
    s += b"\n\n" + yaml.dump(
        config._as_dict(),
        Dumper=yaml.SafeDumper,
        default_flow_style=False,
        allow_unicode=True,
        encoding="utf-8",
    )
    s = s.decode("utf-8")
    with open(config_path, "w") as f:
        f.write(s)

    wandb.termlog("Restored config variables to %s" % config_path)
    if image:
        if not metadata["program"].startswith("<") and metadata.get("args") is not None:
            # TODO: we may not want to default to python here.
            runner = util.find_runner(metadata["program"]) or ["python"]
            command = runner + [metadata["program"]] + metadata["args"]
            cmd = " ".join(command)
        else:
            wandb.termlog("Couldn't find original command, just restoring environment")
            cmd = None
        wandb.termlog("Docker image found, attempting to start")
        ctx.invoke(docker, docker_run_args=[image], cmd=cmd)

    return commit, json_config, patch_content, repo, metadata
Exemplo n.º 14
0
def put(path, name, description, type, alias):
    if name is None:
        name = os.path.basename(path)
    public_api = PublicApi()
    entity, project, artifact_name = public_api._parse_artifact_path(name)
    if project is None:
        project = click.prompt("Enter the name of the project you want to use")
    # TODO: settings nightmare...
    api = InternalApi()
    api.set_setting("entity", entity)
    api.set_setting("project", project)
    artifact = wandb.Artifact(name=artifact_name, type=type, description=description)
    artifact_path = "{entity}/{project}/{name}:{alias}".format(
        entity=entity, project=project, name=artifact_name, alias=alias[0]
    )
    if os.path.isdir(path):
        wandb.termlog(
            'Uploading directory {path} to: "{artifact_path}" ({type})'.format(
                path=path, type=type, artifact_path=artifact_path
            )
        )
        artifact.add_dir(path)
    elif os.path.isfile(path):
        wandb.termlog(
            'Uploading file {path} to: "{artifact_path}" ({type})'.format(
                path=path, type=type, artifact_path=artifact_path
            )
        )
        artifact.add_file(path)
    elif "://" in path:
        wandb.termlog(
            'Logging reference artifact from {path} to: "{artifact_path}" ({type})'.format(
                path=path, type=type, artifact_path=artifact_path
            )
        )
        artifact.add_reference(path)
    else:
        raise ClickException("Path argument must be a file or directory")

    run = wandb.init(
        entity=entity, project=project, config={"path": path}, job_type="cli_put"
    )
    # We create the artifact manually to get the current version
    res, _ = api.create_artifact(
        type,
        artifact_name,
        artifact.digest,
        entity_name=entity,
        project_name=project,
        run_name=run.id,
        description=description,
        aliases=[{"artifactCollectionName": artifact_name, "alias": a} for a in alias],
    )
    artifact_path = artifact_path.split(":")[0] + ":" + res.get("version", "latest")
    # Re-create the artifact and actually upload any files needed
    run.log_artifact(artifact, aliases=alias)
    wandb.termlog(
        "Artifact uploaded, use this artifact in a run by adding:\n", prefix=False
    )

    wandb.termlog(
        '    artifact = run.use_artifact("{path}")\n'.format(path=artifact_path,),
        prefix=False,
    )
Exemplo n.º 15
0
    def save(
        self,
        glob_str=None,
        base_path=None,
        policy="live",
    ):
        """ Ensure all files matching *glob_str* are synced to wandb with the policy specified.

        Args:
            glob_str (string): a relative or absolute path to a unix glob or regular
                path.  If this isn't specified the method is a noop.
            base_path (string): the base path to run the glob relative to
            policy (string): on of "live", "now", or "end"
                live: upload the file as it changes, overwriting the previous version
                now: upload the file once now
                end: only upload file when the run ends
        """
        if glob_str is None:
            # noop for historical reasons, run.save() may be called in legacy code
            wandb.termwarn(
                ("Calling run.save without any arguments is deprecated."
                 "Changes to attributes are automatically persisted."))
            return True
        if policy not in ("live", "end", "now"):
            raise ValueError(
                'Only "live" "end" and "now" policies are currently supported.'
            )
        if isinstance(glob_str, bytes):
            glob_str = glob_str.decode("utf-8")
        if not isinstance(glob_str, string_types):
            raise ValueError(
                "Must call wandb.save(glob_str) with glob_str a str")

        if base_path is None:
            if os.path.isabs(glob_str):
                base_path = os.path.dirname(glob_str)
                wandb.termwarn(
                    ("Saving files without folders. If you want to preserve "
                     "sub directories pass base_path to wandb.save, i.e. "
                     'wandb.save("/mnt/folder/file.h5", base_path="/mnt")'))
            else:
                base_path = "."
        wandb_glob_str = os.path.relpath(glob_str, base_path)
        if ".." + os.sep in wandb_glob_str:
            raise ValueError("globs can't walk above base_path")
        if glob_str.startswith("gs://") or glob_str.startswith("s3://"):
            wandb.termlog(
                "%s is a cloud storage url, can't save file to wandb." %
                glob_str)
            return []
        files = glob.glob(os.path.join(self.dir, wandb_glob_str))
        warn = False
        if len(files) == 0 and "*" in wandb_glob_str:
            warn = True
        for path in glob.glob(glob_str):
            file_name = os.path.relpath(path, base_path)
            abs_path = os.path.abspath(path)
            wandb_path = os.path.join(self.dir, file_name)
            wandb.util.mkdir_exists_ok(os.path.dirname(wandb_path))
            # We overwrite symlinks because namespaces can change in Tensorboard
            if os.path.islink(
                    wandb_path) and abs_path != os.readlink(wandb_path):
                os.remove(wandb_path)
                os.symlink(abs_path, wandb_path)
            elif not os.path.exists(wandb_path):
                os.symlink(abs_path, wandb_path)
            files.append(wandb_path)
        if warn:
            file_str = "%i file" % len(files)
            if len(files) > 1:
                file_str += "s"
            wandb.termwarn(
                ("Symlinked %s into the W&B run directory, "
                 "call wandb.save again to sync new files.") % file_str)
        files_dict = dict(files=[(wandb_glob_str, policy)])
        self._backend.interface.send_files(files_dict)
        return files
Exemplo n.º 16
0
    def _sync_etc(self, headless=False):
        # Ignore SIGQUIT (ctrl-\). The child process will # handle it, and we'll
        # exit when the child process does.
        #
        # We disable these signals after running the process so the child doesn't
        # inherit this behaviour.
        try:
            signal.signal(signal.SIGQUIT, signal.SIG_IGN)
        except AttributeError:  # SIGQUIT doesn't exist on windows
            pass

        exitcode = None
        try:
            while True:
                res = bytearray()
                try:
                    res = self._socket.recv(2)
                except socket.timeout:
                    pass
                if len(res) == 2 and res[0] == 2:
                    exitcode = res[1]
                    break
                elif len(res) > 0:
                    wandb.termerror(
                        "Invalid message received from child process: %s" %
                        str(res))
                    break
                else:
                    exitcode = self.proc.poll()
                    if exitcode is not None:
                        break
                    time.sleep(1)
        except KeyboardInterrupt:
            exitcode = 255
            wandb.termlog('Ctrl-c pressed; waiting for program to end.')
            keyboard_interrupt_time = time.time()
            if not headless:
                # give the process a couple of seconds to die, then kill it
                while self.proc.poll() is None and (
                        time.time() - keyboard_interrupt_time) < 2:
                    time.sleep(0.1)
                if self.proc.poll() is None:
                    wandb.termlog('Program still alive. Killing it.')
                    try:
                        self.proc.kill()
                    except OSError:
                        pass
        """TODO(adrian): garbage that appears in the logs sometimes

        Exception ignored in: <bound method Popen.__del__ of <subprocess.Popen object at 0x111adce48>>
        Traceback (most recent call last):
          File "/Users/adrian/.pyenv/versions/3.6.0/Python.framework/Versions/3.6/lib/python3.6/subprocess.py", line 760, in __del__
        AttributeError: 'NoneType' object has no attribute 'warn'
        """
        wandb.termlog()

        if exitcode is None:
            exitcode = 254
            wandb.termlog(
                'Killing program failed; syncing files anyway. Press ctrl-c to abort syncing.'
            )
        else:
            if exitcode == 0:
                wandb.termlog('Program ended.')
            else:
                wandb.termlog(
                    'Program failed with code %d. Press ctrl-c to abort syncing.'
                    % exitcode)
        #termlog('job (%s) Process exited with code: %s' % (program, exitcode))

        self._meta.data["exitcode"] = exitcode
        if exitcode == 0:
            self._meta.data["state"] = "finished"
        elif exitcode == 255:
            self._meta.data["state"] = "killed"
        else:
            self._meta.data["state"] = "failed"
        self._meta.shutdown()
        self._system_stats.shutdown()
        self._close_stdout_stderr_streams(exitcode or 254)

        # If we're not syncing to the cloud, we're done
        if not self._cloud:
            self._socket.done()
            return None

        # Show run summary/history
        self._run.summary.load()
        summary = self._run.summary._summary
        if len(summary):
            wandb.termlog('Run summary:')
            max_len = max([len(k) for k in summary.keys()])
            format_str = '  {:>%s} {}' % max_len
            for k, v in summary.items():
                wandb.termlog(format_str.format(k, v))
            self._run.history.load()

        history_keys = self._run.history.keys()
        if len(history_keys):
            wandb.termlog('Run history:')
            max_len = max([len(k) for k in history_keys])
            for key in history_keys:
                vals = util.downsample(self._run.history.column(key), 40)
                line = sparkline.sparkify(vals)
                format_str = u'  {:>%s} {}' % max_len
                wandb.termlog(format_str.format(key, line))

        if self._run.has_examples:
            wandb.termlog('Saved %s examples' % self._run.examples.count())

        wandb.termlog('Waiting for final file modifications.')
        # This is a a heuristic delay to catch files that were written just before
        # the end of the script.
        # TODO: ensure we catch all saved files.
        # TODO(adrian): do we need this?
        time.sleep(2)
        try:
            # avoid hanging if we crashed before the observer was started
            if self._observer.is_alive():
                self._observer.stop()
                self._observer.join()
        # TODO: py2 TypeError: PyCObject_AsVoidPtr called with null pointer
        except TypeError:
            pass
        # TODO: py3 SystemError: <built-in function stop> returned a result with an error set
        except SystemError:
            pass

        for handler in self._event_handlers.values():
            handler.finish()
        self._file_pusher.finish()

        wandb.termlog('Syncing files in %s:' %
                      os.path.relpath(self._watch_dir))
        for file_path in self._stats.files():
            wandb.termlog('  %s' % os.path.relpath(file_path, self._watch_dir))
        step = 0
        spinner_states = ['-', '\\', '|', '/']
        stop = False
        self._stats.update_all_files()
        while True:
            if not self._file_pusher.is_alive():
                stop = True
            summary = self._stats.summary()
            line = (
                ' %(completed_files)s of %(total_files)s files,'
                ' %(uploaded_bytes).03f of %(total_bytes).03f bytes uploaded\r'
                % summary)
            line = spinner_states[step % 4] + line
            step += 1
            wandb.termlog(line, newline=False)
            if stop:
                break
            time.sleep(0.25)
            #print('FP: ', self._file_pusher._pending, self._file_pusher._jobs)
        # clear progress line.
        wandb.termlog(' ' * 79)

        # Check md5s of uploaded files against what's on the file system.
        # TODO: We're currently using the list of uploaded files as our source
        #     of truth, but really we should use the files on the filesystem
        #     (ie if we missed a file this wouldn't catch it).
        # This polls the server, because there a delay between when the file
        # is done uploading, and when the datastore gets updated with new
        # metadata via pubsub.
        wandb.termlog('Verifying uploaded files... ', newline=False)
        error = False
        mismatched = None
        for delay_base in range(4):
            mismatched = []
            download_urls = self._api.download_urls(self._project,
                                                    run=self._run.id)
            for fname, info in download_urls.items():
                if fname == 'wandb-history.h5' or OUTPUT_FNAME:
                    continue
                local_path = os.path.join(self._watch_dir, fname)
                local_md5 = util.md5_file(local_path)
                if local_md5 != info['md5']:
                    mismatched.append((local_path, local_md5, info['md5']))
            if not mismatched:
                break
            wandb.termlog('  Retrying after %ss' % (delay_base**2))
            time.sleep(delay_base**2)

        if mismatched:
            print('')
            error = True
            for local_path, local_md5, remote_md5 in mismatched:
                wandb.termerror(
                    '%s (%s) did not match uploaded file (%s) md5' %
                    (local_path, local_md5, remote_md5))
        else:
            print('verified!')

        if error:
            wandb.termerror('Sync failed %s' % self.url)
        else:
            wandb.termlog('Synced %s' % self.url)

        if headless:
            self._socket.done()
Exemplo n.º 17
0
    def _on_final(self):
        # check for warnings and errors, show log file locations
        # if self._run_obj:
        #    self._display_run()
        # print("DEBUG on finish")
        if self._reporter:
            warning_lines = self._reporter.warning_lines
            if warning_lines:
                wandb.termlog("Warnings:")
                for line in warning_lines:
                    wandb.termlog(line)
                if len(warning_lines) < self._reporter.warning_count:
                    wandb.termlog("More warnings")

            error_lines = self._reporter.error_lines
            if error_lines:
                wandb.termlog("Errors:")
                for line in error_lines:
                    wandb.termlog(line)
                if len(error_lines) < self._reporter.error_count:
                    wandb.termlog("More errors")
        if self._settings.log_user:
            wandb.termlog("Find user logs for this run at: {}".format(
                self._settings.log_user))
        if self._settings.log_internal:
            wandb.termlog("Find internal logs for this run at: {}".format(
                self._settings.log_internal))

        self._print_summary()

        if self._exit_result.files:
            logger.info("logging synced files")
            wandb.termlog(
                "Synced {} W&B file(s), {} media file(s), {} artifact file(s) and {} other file(s)"
                .format(  # noqa:E501
                    self._exit_result.files.wandb_count,
                    self._exit_result.files.media_count,
                    self._exit_result.files.artifact_count,
                    self._exit_result.files.other_count,
                ))

        if self._run_obj:
            run_url = self._get_run_url()
            run_name = self._get_run_name()
            wandb.termlog("\nSynced {}: {}".format(
                click.style(run_name, fg="yellow"),
                click.style(run_url, fg="blue")))
Exemplo n.º 18
0
    def run(self):  # noqa: C901

        # TODO: catch exceptions, handle errors, show validation warnings, and make more generic
        sweep_obj = self._api.sweep(self._sweep_id, "{}")
        if sweep_obj:
            sweep_yaml = sweep_obj.get("config")
            if sweep_yaml:
                sweep_config = yaml.safe_load(sweep_yaml)
                if sweep_config:
                    sweep_command = sweep_config.get("command")
                    if sweep_command and isinstance(sweep_command, list):
                        self._sweep_command = sweep_command

        # TODO: include sweep ID
        agent = self._api.register_agent(socket.gethostname(),
                                         sweep_id=self._sweep_id)
        agent_id = agent["id"]

        try:
            while self._running:
                commands = util.read_many_from_queue(self._queue, 100,
                                                     self.POLL_INTERVAL)
                for command in commands:
                    command["resp_queue"].put(self._process_command(command))

                now = util.stopwatch_now()
                if self._last_report_time is None or (
                        self._report_interval != 0 and
                        now > self._last_report_time + self._report_interval):
                    logger.info("Running runs: %s",
                                list(self._run_processes.keys()))
                    self._last_report_time = now
                run_status = {}
                for run_id, run_process in list(
                        six.iteritems(self._run_processes)):
                    poll_result = run_process.poll()
                    if poll_result is None:
                        run_status[run_id] = True
                        continue
                    elif (not isinstance(poll_result, bool)
                          and isinstance(poll_result, int)
                          and poll_result > 0):
                        self._failed += 1
                        if self.is_flapping():
                            logger.error(
                                "Detected %i failed runs in the first %i seconds, shutting down.",
                                self.FLAPPING_MAX_FAILURES,
                                self.FLAPPING_MAX_SECONDS,
                            )
                            logger.info(
                                "To disable this check set WANDB_AGENT_DISABLE_FLAPPING=true"
                            )
                            self._running = False
                            break
                    logger.info("Cleaning up finished run: %s", run_id)
                    del self._run_processes[run_id]
                    self._last_report_time = None
                    self._finished += 1

                if self._count and self._finished >= self._count or not self._running:
                    self._running = False
                    continue

                commands = self._api.agent_heartbeat(agent_id, {}, run_status)

                # TODO: send _server_responses
                self._server_responses = []
                for command in commands:
                    self._server_responses.append(
                        self._process_command(command))

        except KeyboardInterrupt:
            try:
                wandb.termlog(
                    "Ctrl-c pressed. Waiting for runs to end. Press ctrl-c again to terminate them."
                )
                for _, run_process in six.iteritems(self._run_processes):
                    run_process.wait()
            except KeyboardInterrupt:
                pass
        finally:
            try:
                if not self._in_jupyter:
                    wandb.termlog(
                        "Terminating and syncing runs. Press ctrl-c to kill.")
                for _, run_process in six.iteritems(self._run_processes):
                    try:
                        run_process.terminate()
                    except OSError:
                        pass  # if process is already dead
                for _, run_process in six.iteritems(self._run_processes):
                    run_process.wait()
            except KeyboardInterrupt:
                wandb.termlog("Killing runs and quitting.")
                for _, run_process in six.iteritems(self._run_processes):
                    try:
                        run_process.kill()
                    except OSError:
                        pass  # if process is already dead
Exemplo n.º 19
0
    def run(self):
        # TODO: include sweep ID
        agent = self._api.register_agent(
            socket.gethostname(), sweep_id=self._sweep_id)
        agent_id = agent['id']

        try:
            while self._running:
                commands = util.read_many_from_queue(
                    self._queue, 100, self.POLL_INTERVAL)
                for command in commands:
                    command['resp_queue'].put(self._process_command(command))

                now = util.stopwatch_now()
                if self._last_report_time is None or (self._report_interval != 0 and
                                                      now > self._last_report_time + self._report_interval):
                    logger.info('Running runs: %s', list(
                        self._run_processes.keys()))
                    self._last_report_time = now
                run_status = {}
                for run_id, run_process in list(six.iteritems(self._run_processes)):
                    if run_process.poll() is None:
                        run_status[run_id] = True
                    else:
                        logger.info('Cleaning up dead run: %s', run_id)
                        del self._run_processes[run_id]
                        self._last_report_time = None

                commands = self._api.agent_heartbeat(agent_id, {}, run_status)

                # TODO: send _server_responses
                self._server_responses = []
                for command in commands:
                    self._server_responses.append(
                        self._process_command(command))
        except KeyboardInterrupt:
            try:
                wandb.termlog(
                    'Ctrl-c pressed. Waiting for runs to end. Press ctrl-c again to terminate them.')
                for run_id, run_process in six.iteritems(self._run_processes):
                    run_process.wait()
            except KeyboardInterrupt:
                pass
        finally:
            try:
                if not self._in_jupyter:
                    wandb.termlog(
                        'Terminating and syncing runs. Press ctrl-c to kill.')
                for run_id, run_process in six.iteritems(self._run_processes):
                    try:
                        run_process.terminate()
                    except OSError:
                        pass  # if process is already dead
                for run_id, run_process in six.iteritems(self._run_processes):
                    run_process.wait()
            except KeyboardInterrupt:
                wandb.termlog('Killing runs and quitting.')
                try:
                    run_process.kill()
                except OSError:
                    pass  # if process is already dead
Exemplo n.º 20
0
def prompt_api_key(  # noqa: C901
    settings,
    api=None,
    input_callback=None,
    browser_callback=None,
    no_offline=False,
    no_create=False,
    local=False,
):
    """Prompt for api key.

    Returns:
        str - if key is configured
        None - if dryrun is selected
        False - if unconfigured (notty)
    """
    input_callback = input_callback or getpass.getpass
    log_string = term.LOG_STRING
    api = api or InternalApi(settings)
    anon_mode = _fixup_anon_mode(settings.anonymous)
    jupyter = settings._jupyter or False
    app_url = api.app_url

    choices = [choice for choice in LOGIN_CHOICES]
    if anon_mode == "never":
        # Omit LOGIN_CHOICE_ANON as a choice if the env var is set to never
        choices.remove(LOGIN_CHOICE_ANON)
    if jupyter or no_offline:
        choices.remove(LOGIN_CHOICE_DRYRUN)
    if jupyter or no_create:
        choices.remove(LOGIN_CHOICE_NEW)

    if jupyter and "google.colab" in sys.modules:
        log_string = term.LOG_STRING_NOCOLOR
        key = wandb.jupyter.attempt_colab_login(app_url)
        if key is not None:
            write_key(settings, key, api=api)
            return key

    if anon_mode == "must":
        result = LOGIN_CHOICE_ANON
    # If we're not in an interactive environment, default to dry-run.
    elif (not jupyter and (not isatty(sys.stdout)
                           or not isatty(sys.stdin))) or _is_databricks():
        result = LOGIN_CHOICE_NOTTY
    elif local:
        result = LOGIN_CHOICE_EXISTS
    elif len(choices) == 1:
        result = choices[0]
    else:
        for i, choice in enumerate(choices):
            wandb.termlog("(%i) %s" % (i + 1, choice))

        idx = -1
        while idx < 0 or idx > len(choices) - 1:
            idx = _prompt_choice()
            if idx < 0 or idx > len(choices) - 1:
                wandb.termwarn("Invalid choice")
        result = choices[idx]
        wandb.termlog("You chose '%s'" % result)

    api_ask = "%s: Paste an API key from your profile and hit enter: " % log_string
    if result == LOGIN_CHOICE_ANON:
        key = api.create_anonymous_api_key()

        write_key(settings, key, api=api, anonymous=True)
        return key
    elif result == LOGIN_CHOICE_NEW:
        key = browser_callback(signup=True) if browser_callback else None

        if not key:
            wandb.termlog(
                "Create an account here: {}/authorize?signup=true".format(
                    app_url))
            key = input_callback(api_ask).strip()

        write_key(settings, key, api=api)
        return key
    elif result == LOGIN_CHOICE_EXISTS:
        key = browser_callback() if browser_callback else None

        if not key:
            wandb.termlog(
                "You can find your API key in your browser here: {}/authorize".
                format(app_url))
            key = input_callback(api_ask).strip()
        write_key(settings, key, api=api)
        return key
    elif result == LOGIN_CHOICE_NOTTY:
        # TODO: Needs refactor as this needs to be handled by caller
        return False
    else:
        # Jupyter environments don't have a tty, but we can still try logging in using
        # the browser callback if one is supplied.
        key, anonymous = (browser_callback()
                          if jupyter and browser_callback else (None, False))

        write_key(settings, key, api=api)
        return key
Exemplo n.º 21
0
    def _fetch_project_local(self, internal_api: Api) -> None:
        """Fetch a project (either wandb run or git repo) into a local directory, returning the path to the local project directory."""
        assert self.source != LaunchSource.LOCAL
        _logger.info("Fetching project locally...")
        if utils._is_wandb_uri(self.uri):
            source_entity, source_project, source_run_name = utils.parse_wandb_uri(
                self.uri)
            run_info = utils.fetch_wandb_project_run_info(
                source_entity, source_project, source_run_name, internal_api)
            entry_point = run_info.get("codePath", run_info["program"])

            if run_info.get("cudaVersion"):
                original_cuda_version = ".".join(
                    run_info["cudaVersion"].split(".")[:2])

                if self.cuda is None:
                    # only set cuda on by default if cuda is None (unspecified), not False (user specifically requested cpu image)
                    wandb.termlog(
                        "Original wandb run {} was run with cuda version {}. Enabling cuda builds by default; to build on a CPU-only image, run again with --cuda=False"
                        .format(source_run_name, original_cuda_version))
                    self.cuda_version = original_cuda_version
                    self.cuda = True
                if (self.cuda and self.cuda_version
                        and self.cuda_version != original_cuda_version):
                    wandb.termlog(
                        "Specified cuda version {} differs from original cuda version {}. Running with specified version {}"
                        .format(self.cuda_version, original_cuda_version,
                                self.cuda_version))

            downloaded_code_artifact = utils.check_and_download_code_artifacts(
                source_entity,
                source_project,
                source_run_name,
                internal_api,
                self.project_dir,
            )

            if downloaded_code_artifact:
                self.build_image = True
            elif not downloaded_code_artifact:
                if not run_info["git"]:
                    raise ExecutionError(
                        "Reproducing a run requires either an associated git repo or a code artifact logged with `run.log_code()`"
                    )
                utils._fetch_git_repo(
                    self.project_dir,
                    run_info["git"]["remote"],
                    run_info["git"]["commit"],
                )
                patch = utils.fetch_project_diff(source_entity, source_project,
                                                 source_run_name, internal_api)

                if patch:
                    utils.apply_patch(patch, self.project_dir)
                # For cases where the entry point wasn't checked into git
                if not os.path.exists(
                        os.path.join(self.project_dir, entry_point)):
                    downloaded_entrypoint = utils.download_entry_point(
                        source_entity,
                        source_project,
                        source_run_name,
                        internal_api,
                        entry_point,
                        self.project_dir,
                    )
                    if not downloaded_entrypoint:
                        raise LaunchError(
                            f"Entrypoint: {entry_point} does not exist, "
                            "and could not be downloaded. Please specify the entrypoint for this run."
                        )
                    # if the entrypoint is downloaded and inserted into the project dir
                    # need to rebuild image with new code
                    self.build_image = True

            if entry_point.endswith("ipynb"):
                entry_point = utils.convert_jupyter_notebook_to_script(
                    entry_point, self.project_dir)

            # Download any frozen requirements
            utils.download_wandb_python_deps(
                source_entity,
                source_project,
                source_run_name,
                internal_api,
                self.project_dir,
            )

            # Specify the python runtime for jupyter2docker
            self.python_version = run_info.get("python", "3")

            if not self._entry_points:
                self.add_entry_point(entry_point)
            self.override_args = utils.merge_parameters(
                self.override_args, run_info["args"])
        else:
            assert utils._GIT_URI_REGEX.match(
                self.uri), ("Non-wandb URI %s should be a Git URI" % self.uri)

            if not self._entry_points:
                wandb.termlog(
                    "Entry point for repo not specified, defaulting to main.py"
                )
                self.add_entry_point("main.py")
            utils._fetch_git_repo(self.project_dir, self.uri, self.git_version)
Exemplo n.º 22
0
    def on_finish(self):
        # check for warnings and errors, show log file locations
        if self._run_obj:
            self._display_run()
        if self._reporter:
            warning_lines = self._reporter.warning_lines
            if warning_lines:
                wandb.termlog("Warnings:")
                for line in warning_lines:
                    wandb.termlog(line)
                if len(warning_lines) < self._reporter.warning_count:
                    wandb.termlog("More warnings")

            error_lines = self._reporter.error_lines
            if error_lines:
                wandb.termlog("Errors:")
                for line in error_lines:
                    wandb.termlog(line)
                if len(error_lines) < self._reporter.error_count:
                    wandb.termlog("More errors")
        if self._settings.log_user:
            wandb.termlog("Find user logs for this run at: {}".format(
                self._settings.log_user))
        if self._settings.log_internal:
            wandb.termlog("Find internal logs for this run at: {}".format(
                self._settings.log_internal))
Exemplo n.º 23
0
def upload_h5(file, run, entity=None, project=None):
    api = Api()
    wandb.termlog("Uploading summary data...")
    api.push({os.path.basename(file): open(file, 'rb')}, run=run, project=project,
             entity=entity)
Exemplo n.º 24
0
    def wandb_save(
        glob_str: Optional[str] = None,
        base_path: Optional[str] = None,
        policy: str = "live",
    ) -> Union[bool, List[str]]:
        """
        NOTE: This reimplements wandb.save, but copies files instead of symlinking.
        The symlinks have caused many issues on Windows and google colab.

        ORIGINAL DOCS:
        Ensure all files matching `glob_str` are synced to wandb with the policy specified.

        Arguments:
            glob_str: (string) a relative or absolute path to a unix glob or regular
                path.  If this isn't specified the method is a noop.
            base_path: (string) the base path to run the glob relative to
            policy: (string) on of `live`, `now`, or `end`
                - live: upload the file as it changes, overwriting the previous version
                - now: upload the file once now
                - end: only upload file when the run ends
        """
        if glob_str is None:
            # noop for historical reasons, run.save() may be called in legacy code
            wandb.termwarn(
                ("Calling run.save without any arguments is deprecated."
                 "Changes to attributes are automatically persisted."))
            return True
        if policy not in ("live", "end", "now"):
            raise ValueError(
                'Only "live" "end" and "now" policies are currently supported.'
            )
        if isinstance(glob_str, bytes):
            glob_str = glob_str.decode("utf-8")
        if not isinstance(glob_str, string_types):
            raise ValueError(
                "Must call wandb.save(glob_str) with glob_str a str")

        if base_path is None:
            if os.path.isabs(glob_str):
                base_path = os.path.dirname(glob_str)
                wandb.termwarn(
                    ("Saving files without folders. If you want to preserve "
                     "sub directories pass base_path to wandb.save, i.e. "
                     'wandb.save("/mnt/folder/file.h5", base_path="/mnt")'))
            else:
                base_path = ""
        wandb_glob_str = os.path.relpath(glob_str, base_path)
        if ".." + os.sep in wandb_glob_str:
            raise ValueError("globs can't walk above base_path")

        with telemetry.context(run=wandb.run) as tel:
            tel.feature.save = True

        if glob_str.startswith("gs://") or glob_str.startswith("s3://"):
            wandb.termlog(
                "%s is a cloud storage url, can't save file to wandb." %
                glob_str)
            return []
        files = glob.glob(os.path.join(wandb.run.dir, wandb_glob_str))
        warn = False
        if len(files) == 0 and "*" in wandb_glob_str:
            warn = True
        for path in glob.glob(glob_str):
            file_name = os.path.relpath(path, base_path)
            abs_path = os.path.abspath(path)
            wandb_path = os.path.join(wandb.run.dir, file_name)
            wandb.util.mkdir_exists_ok(os.path.dirname(wandb_path))
            # We overwrite symlinks because namespaces can change in Tensorboard
            if os.path.islink(
                    wandb_path) and abs_path != os.readlink(wandb_path):
                os.remove(wandb_path)
                shutil.copy(abs_path,
                            wandb.run.dir)  # os.symlink(abs_path, wandb_path)
            elif not os.path.exists(wandb_path):
                shutil.copy(abs_path,
                            wandb.run.dir)  # os.symlink(abs_path, wandb_path)
            files.append(wandb_path)
        if warn:
            file_str = "%i file" % len(files)
            if len(files) > 1:
                file_str += "s"
            wandb.termwarn(
                ("Symlinked %s into the W&B run directory, "
                 "call wandb.save again to sync new files.") % file_str)
        files_dict = dict(files=[(wandb_glob_str, policy)])
        if wandb.run._backend:
            wandb.run._backend.interface.publish_files(files_dict)
        return files
Exemplo n.º 25
0
def restore(ctx, run, no_git, branch, project, entity):
    if ":" in run:
        if "/" in run:
            entity, rest = run.split("/", 1)
        else:
            rest = run
        project, run = rest.split(":", 1)
    elif run.count("/") > 1:
        entity, run = run.split("/", 1)

    project, run = api.parse_slug(run, project=project)
    commit, json_config, patch_content, metadata = api.run_config(
        project, run=run, entity=entity)
    repo = metadata.get("git", {}).get("repo")
    image = metadata.get("docker")
    RESTORE_MESSAGE = """`wandb restore` needs to be run from the same git repository as the original run.
Run `git clone %s` and restore from there or pass the --no-git flag.""" % repo
    if no_git:
        commit = None
    elif not api.git.enabled:
        if repo:
            raise ClickException(RESTORE_MESSAGE)
        elif image:
            wandb.termlog(
                "Original run has no git history.  Just restoring config and docker"
            )

    if commit and api.git.enabled:
        subprocess.check_call(['git', 'fetch', '--all'])
        try:
            api.git.repo.commit(commit)
        except ValueError:
            wandb.termlog("Couldn't find original commit: {}".format(commit))
            commit = None
            files = api.download_urls(project, run=run, entity=entity)
            for filename in files:
                if filename.startswith('upstream_diff_') and filename.endswith(
                        '.patch'):
                    commit = filename[len('upstream_diff_'):-len('.patch')]
                    try:
                        api.git.repo.commit(commit)
                    except ValueError:
                        commit = None
                    else:
                        break

            if commit:
                wandb.termlog(
                    "Falling back to upstream commit: {}".format(commit))
                patch_path, _ = api.download_write_file(files[filename])
            else:
                raise ClickException(RESTORE_MESSAGE)
        else:
            if patch_content:
                patch_path = os.path.join(wandb.wandb_dir(), 'diff.patch')
                with open(patch_path, "w") as f:
                    f.write(patch_content)
            else:
                patch_path = None

        branch_name = "wandb/%s" % run
        if branch and branch_name not in api.git.repo.branches:
            api.git.repo.git.checkout(commit, b=branch_name)
            wandb.termlog("Created branch %s" %
                          click.style(branch_name, bold=True))
        elif branch:
            wandb.termlog(
                "Using existing branch, run `git branch -D %s` from master for a clean checkout"
                % branch_name)
            api.git.repo.git.checkout(branch_name)
        else:
            wandb.termlog("Checking out %s in detached mode" % commit)
            api.git.repo.git.checkout(commit)

        if patch_path:
            # we apply the patch from the repository root so git doesn't exclude
            # things outside the current directory
            root = api.git.root
            patch_rel_path = os.path.relpath(patch_path, start=root)
            # --reject is necessary or else this fails any time a binary file
            # occurs in the diff
            # we use .call() instead of .check_call() for the same reason
            # TODO(adrian): this means there is no error checking here
            subprocess.call(['git', 'apply', '--reject', patch_rel_path],
                            cwd=root)
            wandb.termlog("Applied patch")

    # TODO: we should likely respect WANDB_DIR here.
    util.mkdir_exists_ok("wandb")
    config = Config(run_dir="wandb")
    config.load_json(json_config)
    config.persist()
    wandb.termlog("Restored config variables to %s" % config._config_path())
    if image:
        if not metadata["program"].startswith("<") and metadata.get(
                "args") is not None:
            # TODO: we may not want to default to python here.
            runner = util.find_runner(metadata["program"]) or ["python"]
            command = runner + [metadata["program"]] + metadata["args"]
            cmd = " ".join(command)
        else:
            wandb.termlog(
                "Couldn't find original command, just restoring environment")
            cmd = None
        wandb.termlog("Docker image found, attempting to start")
        ctx.invoke(docker, docker_run_args=[image], cmd=cmd)

    return commit, json_config, patch_content, repo, metadata
Exemplo n.º 26
0
def prompt_api_key(
    settings,
    api=None,
    input_callback=None,
    browser_callback=None,
    no_offline=False,
    local=False,
):
    input_callback = input_callback or getpass.getpass
    api = api or InternalApi()
    anon_mode = _fixup_anon_mode(settings.anonymous)
    jupyter = settings.jupyter or False
    app_url = settings.base_url.replace("//api.", "//app.")

    choices = [choice for choice in LOGIN_CHOICES]
    if anon_mode == "never":
        # Omit LOGIN_CHOICE_ANON as a choice if the env var is set to never
        choices.remove(LOGIN_CHOICE_ANON)
    if jupyter or no_offline:
        choices.remove(LOGIN_CHOICE_DRYRUN)

    if jupyter and 'google.colab' in sys.modules:
        key = wandb.jupyter.attempt_colab_login(api.app_url)
        write_key(settings, key)
        return key

    if anon_mode == "must":
        result = LOGIN_CHOICE_ANON
    # If we're not in an interactive environment, default to dry-run.
    elif not isatty(sys.stdout) or not isatty(sys.stdin):
        result = LOGIN_CHOICE_DRYRUN
    elif local:
        result = LOGIN_CHOICE_EXISTS
    else:
        for i, choice in enumerate(choices):
            wandb.termlog("(%i) %s" % (i + 1, choice))

        def prompt_choice():
            try:
                return (
                    int(
                        input(
                            "%s: Enter your choice: " % LOG_STRING
                        )
                    )
                    - 1  # noqa: W503
                )
            except ValueError:
                return -1

        idx = -1
        while idx < 0 or idx > len(choices) - 1:
            idx = prompt_choice()
            if idx < 0 or idx > len(choices) - 1:
                wandb.termwarn("Invalid choice")
        result = choices[idx]
        wandb.termlog("You chose '%s'" % result)

    if result == LOGIN_CHOICE_ANON:
        key = api.create_anonymous_api_key()

        write_key(settings, key)
        return key
    elif result == LOGIN_CHOICE_NEW:
        key = browser_callback(signup=True) if browser_callback else None

        if not key:
            wandb.termlog(
                "Create an account here: {}/authorize?signup=true".format(app_url)
            )
            key = input_callback(
                "%s: Paste an API key from your profile and hit enter"
                % LOG_STRING
            ).strip()

        write_key(settings, key)
        return key
    elif result == LOGIN_CHOICE_EXISTS:
        key = browser_callback() if browser_callback else None

        if not key:
            wandb.termlog(
                "You can find your API key in your browser here: {}/authorize".format(
                    app_url
                )
            )
            key = input_callback(
                "%s: Paste an API key from your profile and hit enter"
                % LOG_STRING
            ).strip()
        write_key(settings, key)
        return key
    else:
        # Jupyter environments don't have a tty, but we can still try logging in using
        # the browser callback if one is supplied.
        key, anonymous = (
            browser_callback()
            if jupyter and browser_callback
            else (None, False)
        )

        write_key(settings, key)
        return key
Exemplo n.º 27
0
 def ensure_configured(self):
     # The WANDB_DEBUG check ensures tests still work.
     if not os.getenv('WANDB_DEBUG') and not self.settings("project"):
         wandb.termlog('wandb.init() called but system not configured.\n'
                       'Run "wandb init" or set environment variables to get started')
         sys.exit(1)
Exemplo n.º 28
0
 def _run_jobs_from_queue(self):
     waiting = False
     count = 0
     while True:
         if self._exit_flag:
             return
         try:
             try:
                 job = self._queue.get(timeout=5)
                 if self._exit_flag:
                     logger.debug("Exiting main loop due to exit flag.")
                     wandb.termlog("Sweep Agent: Exiting.")
                     return
             except queue.Empty:
                 if not waiting:
                     logger.debug("Paused.")
                     wandb.termlog("Sweep Agent: Waiting for job.")
                     waiting = True
                 time.sleep(5)
                 if self._exit_flag:
                     logger.debug("Exiting main loop due to exit flag.")
                     wandb.termlog("Sweep Agent: Exiting.")
                     return
                 continue
             if waiting:
                 logger.debug("Resumed.")
                 wandb.termlog("Job received.")
                 waiting = False
             count += 1
             run_id = job.run_id
             logger.debug("Spawning new thread for run {}.".format(run_id))
             thread = threading.Thread(target=self._run_job, args=(job, ))
             self._run_threads[run_id] = thread
             thread.start()
             thread.join()
             logger.debug("Thread joined for run {}.".format(run_id))
             exc = self._errored_runs.get(run_id)
             if exc:
                 logger.error("Run {} errored: {}".format(
                     run_id, repr(exc)))
                 wandb.termerror("Run {} errored: {}".format(
                     run_id, repr(exc)))
                 if os.getenv(wandb.env.AGENT_DISABLE_FLAPPING) == "true":
                     self._exit_flag = True
                     return
                 elif (time.time() - self._start_time <
                       self.FLAPPING_MAX_SECONDS) and (
                           len(self._errored_runs) >=
                           self.FLAPPING_MAX_FAILURES):
                     msg = "Detected {} failed runs in the first {} seconds, killing sweep.".format(
                         self.FLAPPING_MAX_FAILURES,
                         self.FLAPPING_MAX_SECONDS)
                     logger.error(msg)
                     wandb.termerror(msg)
                     wandb.termlog(
                         "To disable this check set WANDB_AGENT_DISABLE_FLAPPING=true"
                     )
                     self._exit_flag = True
                     return
             del self._run_threads[job.run_id]
             if self._count and self._count == count:
                 logger.debug(
                     "Exiting main loop because max count reached.")
                 self._exit_flag = True
                 return
         except KeyboardInterrupt:
             logger.debug("Ctrl + C detected. Stopping sweep.")
             wandb.termlog("Ctrl + C detected. Stopping sweep.")
             self._exit()
             return
         except Exception as e:
             if self._exit_flag:
                 logger.debug("Exiting main loop due to exit flag.")
                 wandb.termlog("Sweep Agent: Killed.")
                 return
             else:
                 raise e
Exemplo n.º 29
0
    def from_environment_or_defaults(cls, environment=None):
        """Create a Run object taking values from the local environment where possible.

        The run ID comes from WANDB_RUN_ID or is randomly generated.
        The run mode ("dryrun", or "run") comes from WANDB_MODE or defaults to "dryrun".
        The run directory comes from WANDB_RUN_DIR or is generated from the run ID.

        The Run will have a .config attribute but its run directory won't be set by
        default.
        """
        if environment is None:
            environment = os.environ
        run_id = environment.get(env.RUN_ID)
        resume = environment.get(env.RESUME)
        storage_id = environment.get(env.RUN_STORAGE_ID)
        mode = environment.get(env.MODE)
        api = InternalApi(environ=environment)
        disabled = api.disabled()
        if not mode and disabled:
            mode = "dryrun"
        elif disabled and mode != "dryrun":
            wandb.termwarn(
                "WANDB_MODE is set to run, but W&B was disabled.  Run `wandb on` to remove this message"
            )
        elif disabled:
            wandb.termlog(
                'W&B is disabled in this directory.  Run `wandb on` to enable cloud syncing.'
            )

        group = environment.get(env.RUN_GROUP)
        job_type = environment.get(env.JOB_TYPE)
        run_dir = environment.get(env.RUN_DIR)
        sweep_id = environment.get(env.SWEEP_ID)
        program = environment.get(env.PROGRAM)
        description = environment.get(env.DESCRIPTION)
        name = environment.get(env.NAME)
        notes = environment.get(env.NOTES)
        args = env.get_args(env=environment)
        wandb_dir = env.get_dir(env=environment)
        tags = env.get_tags(env=environment)
        # TODO(adrian): should pass environment into here as well.
        config = Config.from_environment_or_defaults()
        run = cls(run_id,
                  mode,
                  run_dir,
                  group,
                  job_type,
                  config,
                  sweep_id,
                  storage_id,
                  program=program,
                  description=description,
                  args=args,
                  wandb_dir=wandb_dir,
                  tags=tags,
                  name=name,
                  notes=notes,
                  resume=resume,
                  api=api)

        return run
Exemplo n.º 30
0
def sweep(ctx, project, entity, controller, verbose, name, program, settings, update, config_yaml):
    def _parse_settings(settings):
        """settings could be json or comma seperated assignments."""
        ret = {}
        # TODO(jhr): merge with magic_impl:_parse_magic
        if settings.find('=') > 0:
            for item in settings.split(","):
                kv = item.split("=")
                if len(kv) != 2:
                    wandb.termwarn("Unable to parse sweep settings key value pair", repeat=False)
                ret.update(dict([kv]))
            return ret
        wandb.termwarn("Unable to parse settings parameter", repeat=False)
        return ret

    if api.api_key is None:
        termlog("Login to W&B to use the sweep feature")
        ctx.invoke(login, no_offline=True)

    sweep_obj_id = None
    if update:
        parts = dict(entity=entity, project=project, name=update)
        err = util.parse_sweep_id(parts)
        if err:
            wandb.termerror(err)
            return
        entity = parts.get("entity") or entity
        project = parts.get("project") or project
        sweep_id = parts.get("name") or update
        found = api.sweep(sweep_id, '{}', entity=entity, project=project)
        if not found:
            wandb.termerror('Could not find sweep {}/{}/{}'.format(entity, project, sweep_id))
            return
        sweep_obj_id = found['id']

    wandb.termlog('{} sweep from: {}'.format(
            'Updating' if sweep_obj_id else 'Creating',
            config_yaml))
    try:
        yaml_file = open(config_yaml)
    except (OSError, IOError):
        wandb.termerror('Couldn\'t open sweep file: %s' % config_yaml)
        return
    try:
        config = util.load_yaml(yaml_file)
    except yaml.YAMLError as err:
        wandb.termerror('Error in configuration file: %s' % err)
        return
    if config is None:
        wandb.termerror('Configuration file is empty')
        return

    # Set or override parameters
    if name:
        config["name"] = name
    if program:
        config["program"] = program
    if settings:
        settings = _parse_settings(settings)
        if settings:
            config.setdefault("settings", {})
            config["settings"].update(settings)
    if controller:
        config.setdefault("controller", {})
        config["controller"]["type"] = "local"

    is_local = config.get('controller', {}).get('type') == 'local'
    if is_local:
        tuner = wandb_controller.controller()
        err = tuner._validate(config)
        if err:
            wandb.termerror('Error in sweep file: %s' % err)
            return

    entity = entity or env.get_entity() or config.get('entity')
    project = project or env.get_project() or config.get('project') or util.auto_project_name(
            config.get("program"), api)
    sweep_id = api.upsert_sweep(config, project=project, entity=entity, obj_id=sweep_obj_id)
    wandb.termlog('{} sweep with ID: {}'.format(
            'Updated' if sweep_obj_id else 'Created',
            click.style(sweep_id, fg="yellow")))
    sweep_url = wandb_controller._get_sweep_url(api, sweep_id)
    if sweep_url:
        wandb.termlog("View sweep at: {}".format(
            click.style(sweep_url, underline=True, fg='blue')))

    # reprobe entity and project if it was autodetected by upsert_sweep
    entity = entity or env.get_entity()
    project = project or env.get_project()

    if entity and project:
        sweep_path = "{}/{}/{}".format(entity, project, sweep_id)
    elif project:
        sweep_path = "{}/{}".format(project, sweep_id)
    else:
        sweep_path = sweep_id

    wandb.termlog("Run sweep agent with: {}".format(
            click.style("wandb agent %s" % sweep_path, fg="yellow")))
    if controller:
        wandb.termlog('Starting wandb controller...')
        tuner = wandb_controller.controller(sweep_id)
        tuner.run(verbose=verbose)