Beispiel #1
0
def init(ctx):
    from wandb import _set_stage_dir, __stage_dir__, wandb_dir
    if __stage_dir__ is None:
        _set_stage_dir('wandb')
    if os.path.isdir(wandb_dir()) and os.path.exists(
            os.path.join(wandb_dir(), "settings")):
        click.confirm(click.style(
            "This directory has been configured previously, should we re-configure it?",
            bold=True),
                      abort=True)
    else:
        click.echo(
            click.style("Let's setup this directory for W&B!",
                        fg="green",
                        bold=True))

    if api.api_key is None:
        ctx.invoke(login)

    viewer = api.viewer()

    # Viewer can be `None` in case your API information became invalid, or
    # in testing if you switch hosts.
    if not viewer:
        click.echo(
            click.style(
                "Your login information seems to be invalid: can you log in again please?",
                fg="red",
                bold=True))
        ctx.invoke(login)

    # This shouldn't happen.
    viewer = api.viewer()
    if not viewer:
        click.echo(
            click.style(
                "We're sorry, there was a problem logging you in. Please send us a note at [email protected] and tell us how this happened.",
                fg="red",
                bold=True))
        sys.exit(1)

    # At this point we should be logged in successfully.
    if len(viewer["teams"]["edges"]) > 1:
        team_names = [e["node"]["name"] for e in viewer["teams"]["edges"]]
        question = {
            'type': 'list',
            'name': 'team_name',
            'message': "Which team should we use?",
            'choices': team_names + ["Manual Entry"]
        }
        result = whaaaaat.prompt([question])
        # result can be empty on click
        if result:
            entity = result['team_name']
        else:
            entity = "Manual Entry"
        if entity == "Manual Entry":
            entity = click.prompt("Enter the name of the team you want to use")
    else:
        entity = click.prompt("What username or team should we use?",
                              default=viewer.get('entity', 'models'))

    # TODO: this error handling sucks and the output isn't pretty
    try:
        project = prompt_for_project(ctx, entity)
    except wandb.cli.ClickWandbException:
        raise ClickException('Could not find team: %s' % entity)

    util.write_settings(entity, project, api.settings())

    with open(os.path.join(wandb_dir(), '.gitignore'), "w") as file:
        file.write("*\n!settings")

    click.echo(
        click.style("This directory is configured!  Next, track a run:\n",
                    fg="green") +
        textwrap.dedent("""\
        * In your training script:
            {code1}
            {code2}
        * then `{run}`.
        """).format(
            code1=click.style("import wandb", bold=True),
            code2=click.style("wandb.init()", bold=True),
            run=click.style("python <train.py>", bold=True),
            # saving this here so I can easily put it back when we re-enable
            # push/pull
            # """
            # * Run `{push}` to manually add a file.
            # * Pull popular models into your project with: `{pull}`.
            # """
            # push=click.style("wandb push run_id weights.h5", bold=True),
            # pull=click.style("wandb pull models/inception-v4", bold=True)
        ))
Beispiel #2
0
def start_experiment(output_dir,
                     cometml_project="",
                     wandb_project="",
                     run_id=None,
                     note_params="",
                     extra_kwargs=None,
                     overwrite=False):
    """Start a model training experiment. This will create a new output directory
    and setup the experiment management handles
    """
    sys.path.append(os.getcwd())
    if cometml_project:
        logger.info("Using comet.ml")
        if Experiment is None:
            raise ImportError("Comet.ml could not be imported")
        workspace, project_name = cometml_project.split("/")
        cometml_experiment = Experiment(project_name=project_name,
                                        workspace=workspace)
        # TODO - get the experiment id
        # specify output_dir to that directory
    else:
        cometml_experiment = None

    if wandb_project:
        assert "/" in wandb_project
        entity, project = wandb_project.split("/")
        if wandb is None:
            logger.warn("wandb not installed. Not using it")
            wandb_run = None
        else:
            logger.info("Using wandb. Running wandb.init()")
            wandb._set_stage_dir("./")  # Don't prepend wandb to output file
            if run_id is not None:
                wandb.init(project=project,
                           dir=output_dir,
                           entity=entity,
                           reinit=True,
                           resume=run_id)
            else:
                # automatically set the output
                wandb.init(project=project,
                           entity=entity,
                           reinit=True,
                           dir=output_dir)
            wandb_run = wandb.run
            if wandb_run is None:
                logger.warn("Wandb run is None")
            print(wandb_run)
    else:
        wandb_run = None

    # update the output directory
    if run_id is None:
        if wandb_run is not None:
            run_id = os.path.basename(wandb_run.dir)
        elif cometml_experiment is not None:
            run_id = cometml_experiment.id
        else:
            # random run_id
            run_id = str(uuid4())
    output_dir = os.path.join(output_dir, run_id)

    if wandb_run is not None:
        # make sure the output directory is the same
        # wandb_run._dir = os.path.normpath(output_dir)  # This doesn't work
        # assert os.path.normpath(wandb_run.dir) == os.path.normpath(output_dir)
        # TODO - fix this assertion-> the output directories should be the same
        # in order for snakemake to work correctly
        pass
    # -----------------------------

    if os.path.exists(os.path.join(output_dir, 'config.gin')):
        if overwrite:
            logger.info(
                f"config.gin already exists in the output "
                "directory {output_dir}. Removing the whole directory.")
            shutil.rmtree(output_dir)
        else:
            raise ValueError(f"Output directory {output_dir} shouldn't exist!")
    os.makedirs(output_dir,
                exist_ok=True)  # make the output directory. It shouldn't exist

    # add logging to the file
    add_file_logging(output_dir, logger)

    # write note_params.json
    if note_params:
        logger.info(f"note_params: {note_params}")
        note_params_dict = kv_string2dict(note_params)
    else:
        note_params_dict = dict()
    write_json(note_params_dict,
               os.path.join(output_dir, "note_params.json"),
               sort_keys=True,
               indent=2)

    if cometml_experiment is not None:
        cometml_experiment.log_parameters(note_params_dict)
        cometml_experiment.log_parameters(dict(output_dir=output_dir),
                                          prefix='cli/')

        exp_url = f"https://www.comet.ml/{cometml_experiment.workspace}/{cometml_experiment.project_name}/{cometml_experiment.id}"
        logger.info("Comet.ml url: " + exp_url)
        # write the information about comet.ml experiment
        write_json(
            {
                "url": exp_url,
                "key": cometml_experiment.id,
                "project": cometml_experiment.project_name,
                "workspace": cometml_experiment.workspace
            },
            os.path.join(output_dir, "cometml.json"),
            sort_keys=True,
            indent=2)

    if wandb_run is not None:
        wandb_run.config.update(note_params_dict)
        write_json(
            {
                "url": wandb_run.get_url(),
                "key": wandb_run.id,
                "project": wandb_run.project,
                "path": wandb_run.path,
                "group": wandb_run.group
            },
            os.path.join(output_dir, "wandb.json"),
            sort_keys=True,
            indent=2)
        wandb_run.config.update(
            dict_prefix_key(dict(output_dir=output_dir), prefix='cli/'))

    return cometml_experiment, wandb_run, output_dir
Beispiel #3
0
def init(ctx):
    from wandb import _set_stage_dir, __stage_dir__, wandb_dir
    if __stage_dir__ is None:
        _set_stage_dir('wandb')
    if os.path.isdir(wandb_dir()):
        click.confirm(click.style(
            "This directory has been configured previously, should we re-configure it?",
            bold=True),
                      abort=True)
    else:
        click.echo(
            click.style("Let's setup this directory for W&B!",
                        fg="green",
                        bold=True))

    global api, IS_INIT

    if api.api_key is None:
        ctx.invoke(login)
        api = Api()

    IS_INIT = True

    viewer = api.viewer()
    if len(viewer["teams"]["edges"]) > 1:
        team_names = [e["node"]["name"] for e in viewer["teams"]["edges"]]
        question = {
            'type': 'list',
            'name': 'team_name',
            'message': "Which team should we use?",
            'choices': team_names + ["Manual Entry"]
        }
        entity = whaaaaat.prompt([question])['team_name']
        if entity == "Manual Entry":
            entity = click.prompt("Enter the name of the team you want to use")
    else:
        entity = click.prompt("What username or team should we use?",
                              default=viewer.get('entity', 'models'))

    # TODO: this error handling sucks and the output isn't pretty
    try:
        project = prompt_for_project(ctx, entity)
    except wandb.cli.ClickWandbException:
        raise ClickException('Could not find team: %s' % entity)

    if not os.path.isdir(wandb_dir()):
        os.mkdir(wandb_dir())

    with open(os.path.join(wandb_dir(), 'settings'), "w") as file:
        print('[default]', file=file)
        print('entity: {}'.format(entity), file=file)
        print('project: {}'.format(project), file=file)
        print('base_url: {}'.format(api.settings()['base_url']), file=file)

    with open(os.path.join(wandb_dir(), '.gitignore'), "w") as file:
        file.write("*\n!settings")

    click.echo(
        click.style("This directory is configured!  Next, track a run:\n",
                    fg="green") +
        textwrap.dedent("""\
        * In your training script:
            {code1}
            {code2}
        * then `{run}`.
        """).format(
            code1=click.style("import wandb", bold=True),
            code2=click.style("wandb.init()", bold=True),
            run=click.style("wandb run <train.py>", bold=True),
            # saving this here so I can easily put it back when we re-enable
            # push/pull
            #"""
            #* Run `{push}` to manually add a file.
            #* Pull popular models into your project with: `{pull}`.
            #"""
            # push=click.style("wandb push run_id weights.h5", bold=True),
            # pull=click.style("wandb pull models/inception-v4", bold=True)
        ))
Beispiel #4
0
def gin_train(gin_files,
              output_dir,
              gin_bindings='',
              gpu=0,
              memfrac=0.45,
              framework='tf',
              cometml_project="",
              wandb_project="",
              remote_dir="",
              run_id=None,
              note_params="",
              force_overwrite=False):
    """Train a model using gin-config

    Args:
      gin_file: comma separated list of gin files
      output_dir: where to store the results. Note: a subdirectory `run_id`
        will be created in `output_dir`.
      gin_bindings: comma separated list of additional gin-bindings to use
      gpu: which gpu to use. Example: gpu=1
      memfrac: what fraction of the GPU's memory to use
      framework: which framework to use. Available: tf
      cometml_project: comet_ml project name. Example: Avsecz/basepair.
        If not specified, cometml will not get used
      wandb_project: wandb `<entity>/<project>` name. Example: Avsecz/test.
        If not specified, wandb will not be used
      remote_dir: additional path to the remote directory. Can be an s3 path.
        Example: `s3://mybucket/model1/exp1`
      run_id: manual run id. If not specified, it will be either randomly
        generated or re-used from wandb or comet.ml.
      note_params: take note of additional key=value pairs.
        Example: --note-params note='my custom note',feature_set=this
      force_overwrite: if True, the output directory will be overwritten
    """

    sys.path.append(os.getcwd())
    if cometml_project:
        logger.info("Using comet.ml")
        workspace, project_name = cometml_project.split("/")
        cometml_experiment = Experiment(project_name=project_name,
                                        workspace=workspace)
        # TODO - get the experiment id
        # specify output_dir to that directory
    else:
        cometml_experiment = None

    if wandb_project:
        assert "/" in wandb_project
        entity, project = wandb_project.split("/")
        if wandb is None:
            logger.warn("wandb not installed. Not using it")
            wandb_run = None
        else:
            wandb._set_stage_dir("./")  # Don't prepend wandb to output file
            if run_id is not None:
                wandb.init(project=project,
                           dir=output_dir,
                           entity=entity,
                           resume=run_id)
            else:
                # automatically set the output
                wandb.init(project=project, entity=entity, dir=output_dir)
            wandb_run = wandb.run
            logger.info("Using wandb")
            print(wandb_run)
    else:
        wandb_run = None

    # update the output directory
    if run_id is None:
        if wandb_run is not None:
            run_id = os.path.basename(wandb_run.dir)
        elif cometml_experiment is not None:
            run_id = cometml_experiment.id
        else:
            # random run_id
            run_id = str(uuid4())
    output_dir = os.path.join(output_dir, run_id)
    if remote_dir:
        remote_dir = os.path.join(remote_dir, run_id)
    if wandb_run is not None:
        # make sure the output directory is the same
        # wandb_run._dir = os.path.normpath(output_dir)  # This doesn't work
        # assert os.path.normpath(wandb_run.dir) == os.path.normpath(output_dir)
        # TODO - fix this assertion-> the output directories should be the same
        # in order for snakemake to work correctly
        pass
    # -----------------------------

    if os.path.exists(os.path.join(output_dir, 'config.gin')):
        if force_overwrite:
            logger.info(
                f"config.gin already exists in the output "
                "directory {output_dir}. Removing the whole directory.")
            import shutil
            shutil.rmtree(output_dir)
        else:
            raise ValueError(f"Output directory {output_dir} shouldn't exist!")
    os.makedirs(output_dir,
                exist_ok=True)  # make the output directory. It shouldn't exist

    # add logging to the file
    add_file_logging(output_dir, logger)

    if framework == 'tf':
        import gin.tf
        if gpu is not None:
            logger.info(f"Using gpu: {gpu}, memory fraction: {memfrac}")
            create_tf_session(gpu, per_process_gpu_memory_fraction=memfrac)

    gin.parse_config_files_and_bindings(gin_files.split(","),
                                        bindings=gin_bindings.split(";"),
                                        skip_unknown=False)

    # write note_params.json
    if note_params:
        logger.info(f"note_params: {note_params}")
        note_params_dict = kv_string2dict(note_params)
    else:
        note_params_dict = dict()
    write_json(note_params_dict,
               os.path.join(output_dir, "note_params.json"),
               sort_keys=True,
               indent=2)

    # comet - log environment
    if cometml_experiment is not None:
        # log other parameters
        cometml_experiment.log_multiple_params(dict(gin_files=gin_files,
                                                    gin_bindings=gin_bindings,
                                                    output_dir=output_dir,
                                                    gpu=gpu),
                                               prefix='cli/')
        cometml_experiment.log_multiple_params(note_params_dict)

        exp_url = f"https://www.comet.ml/{cometml_experiment.workspace}/{cometml_experiment.project_name}/{cometml_experiment.id}"
        logger.info("Comet.ml url: " + exp_url)
        # write the information about comet.ml experiment
        write_json(
            {
                "url": exp_url,
                "key": cometml_experiment.id,
                "project": cometml_experiment.project_name,
                "workspace": cometml_experiment.workspace
            },
            os.path.join(output_dir, "cometml.json"),
            sort_keys=True,
            indent=2)

    # wandb - log environment
    if wandb_run is not None:
        write_json(
            {
                "url": wandb_run.get_url(),
                "key": wandb_run.id,
                "project": wandb_run.project,
                "path": wandb_run.path,
                "group": wandb_run.group
            },
            os.path.join(output_dir, "wandb.json"),
            sort_keys=True,
            indent=2)
        # store general configs
        wandb_run.config.update(
            prefix_dict(dict(gin_files=gin_files,
                             gin_bindings=gin_bindings,
                             output_dir=output_dir,
                             gpu=gpu),
                        prefix='cli/'))
        wandb_run.config.update(note_params_dict)

    if remote_dir:
        import time
        logger.info("Test file upload to: {}".format(remote_dir))
        upload_dir(output_dir, remote_dir)
    return train(output_dir=output_dir,
                 remote_dir=remote_dir,
                 cometml_experiment=cometml_experiment,
                 wandb_run=wandb_run)