示例#1
0
    def test_user_verify(
        self,
        mocked_input,
        invalid_strings,
    ):
        """tests user verify method"""

        # verify different defaults
        for default in [True, False]:

            # default input
            mocked_input.side_effect = ['']
            self.assertEqual(
                utils.user_verify('test default', default=default), default)

            # upper/lower true input
            for x in ['y', 'Y']:
                mocked_input.side_effect = invalid_strings + [x]
                self.assertTrue(utils.user_verify('y input', default=default))

            # upper/lower false input
            for x in ['n', 'N']:
                mocked_input.side_effect = invalid_strings + [x]
                self.assertFalse(utils.user_verify('n input', default=default))

        return
示例#2
0
文件: cli.py 项目: sagravat/caliban
def _check_for_existing_cluster(cluster_name: str, project_id: str,
                                creds: Credentials):
    '''checks for an existing cluster and confirms new cluster creation with user

  Args:
  cluster_name: name of cluster to create
  project_id: project id
  creds: credentials

  Returns:
  True if cluster creation should proceed, False otherwise
  '''

    clusters = Cluster.list(project_id=project_id, creds=creds)

    if len(clusters) == 0:
        return True

    if cluster_name in clusters:
        logging.error('cluster {} already exists'.format(cluster_name))
        return False

    logging.info('{} clusters already exist for this project:'.format(
        len(clusters)))
    for c in clusters:
        logging.info(c)

    return utils.user_verify('Do you really want to create a new cluster?',
                             default=False)
示例#3
0
文件: cli.py 项目: sagravat/caliban
def stop(args: Dict[str, Any]) -> None:
    '''executes the `caliban stop` cli command'''

    user = current_user()
    xgroup = args.get('xgroup')
    dry_run = args.get('dry_run', False)

    with session_scope(get_sql_engine()) as session:
        running_jobs = session.query(Job).join(Experiment).join(
            ExperimentGroup).filter(
                or_(Job.status == JobStatus.SUBMITTED,
                    Job.status == JobStatus.RUNNING))

        if xgroup is not None:
            running_jobs = running_jobs.filter(ExperimentGroup.name == xgroup)

        running_jobs = running_jobs.all()

        if len(running_jobs) == 0:
            logging.info(f'no running jobs found')
            return

        # this is necessary to filter out jobs that have finished but whose status
        # has not yet been updated in the backing store
        running_jobs = list(
            filter(
                lambda x: update_job_status(x) in
                [JobStatus.SUBMITTED, JobStatus.RUNNING], running_jobs))

        logging.info(f'the following jobs would be stopped:')
        for j in running_jobs:
            logging.info(_experiment_command_str(j.experiment))
            logging.info(f'    job {_job_str(j)}')

        if dry_run:
            logging.info(
                f'to actually stop these jobs, re-run the command without '
                f'the --dry_run flag')
            return

        # make sure
        if not user_verify(
                f'do you wish to stop these {len(running_jobs)} jobs?', False):
            return

        for j in running_jobs:
            logging.info(f'stopping job: {_job_str(j)}')
            stop_job(j)

        logging.info(
            f'requested job cancellation, please be patient as it may take '
            f'a short while for this status change to be reflected in the '
            f'gcp dashboard or from the `caliban status` command.')
示例#4
0
文件: cli.py 项目: sagravat/caliban
def _cluster_delete(args: dict, cluster: Cluster) -> None:
    """deletes given cluster

  Args:
  args: commandline args
  cluster: cluster to delete

  Returns:
  None
  """

    if utils.user_verify('Are you sure you want to delete {}?'.format(
            cluster.name),
                         default=False):
        cluster.delete()

    return
示例#5
0
文件: cli.py 项目: sagravat/caliban
def resubmit(args: Dict[str, Any]) -> None:
    '''executes the `caliban resubmit` command'''

    user = current_user()
    xgroup = args.get('xgroup')
    dry_run = args.get('dry_run', False)
    all_jobs = args.get('all_jobs', False)
    project_id = args.get('project_id')
    creds_file = args.get('cloud_key')
    rebuild = True

    if xgroup is None:
        logging.error(f'you must specify an experiment group for this command')
        return

    with session_scope(get_sql_engine()) as session:
        jobs = _get_resubmit_jobs(
            session=session,
            xgroup=xgroup,
            user=user,
            all_jobs=all_jobs,
        )

        if jobs is None:
            return

        # if we have CAIP or GKE jobs, then we need to have a project_id
        project_id = _get_resubmit_project_id(jobs, project_id, creds_file)

        # show what would be done
        logging.info(f'the following jobs would be resubmitted:')
        for j in jobs:
            logging.info(_experiment_command_str(j.experiment))
            logging.info(f'  job {_job_str(j)}')

        if dry_run:
            logging.info(
                f'to actually resubmit these jobs, run this command again '
                f'without the --dry_run flag')
            return

        # make sure
        if not user_verify(f'do you wish to resubmit these {len(jobs)} jobs?',
                           False):
            return

        # rebuild all containers first
        if rebuild:
            logging.info(f'rebuilding containers...')
            image_id_map = _rebuild_containers(jobs, project_id=project_id)
        else:
            image_id_map = {j: j.container for j in jobs}

        # create new job specs
        job_specs = [
            replace_job_spec_image(spec=j.spec, image_id=image_id_map[j])
            for j in jobs
        ]

        # submit jobs, grouped by platform
        for platform in [Platform.CAIP, Platform.GKE, Platform.LOCAL]:
            pspecs = list(filter(lambda x: x.platform == platform, job_specs))
            try:
                submit_job_specs(
                    specs=pspecs,
                    platform=platform,
                    project_id=project_id,
                    credentials_path=creds_file,
                )
            except Exception as e:
                session.commit()  # avoid rollback
                logging.error(f'there was an error submitting some jobs')

        return