Esempio n. 1
0
def delete_instances(compute, node_list, arg_job_id):

    batch_list = []
    curr_batch = 0
    req_cnt = 0
    batch_list.insert(
        curr_batch,
        compute.new_batch_http_request(callback=delete_instances_cb))

    for node_name in node_list:

        pid = util.get_pid(node_name)
        if (not arg_job_id and cfg.instance_defs[pid].exclusive):
            # Node was deleted by EpilogSlurmctld, skip for SuspendProgram
            continue

        if req_cnt >= TOT_REQ_CNT:
            req_cnt = 0
            curr_batch += 1
            batch_list.insert(
                curr_batch,
                compute.new_batch_http_request(callback=delete_instances_cb))

        zone = None
        if cfg.instance_defs[pid].regional_capacity:
            node_find = util.ensure_execute(
                compute.instances().aggregatedList(
                    project=cfg.project, filter=f'name={node_name}'))
            for key, zone_value in node_find['items'].items():
                if 'instances' in zone_value:
                    zone = zone_value['instances'][0]['zone'].split('/')[-1]
                    break
            if zone is None:
                log.error(f"failed to find regional node '{node_name}' to delete")
                continue
        else:
            zone = cfg.instance_defs[pid].zone

        batch_list[curr_batch].add(
            compute.instances().delete(project=cfg.project,
                                       zone=zone,
                                       instance=node_name),
            request_id=node_name)
        req_cnt += 1

    try:
        for i, batch in enumerate(batch_list):
            util.ensure_execute(batch)
            if i < (len(batch_list) - 1):
                time.sleep(30)
    except Exception:
        log.exception("error in batch:")
Esempio n. 2
0
def delete_instances(compute, node_list, arg_job_id):

    batch_list = []
    curr_batch = 0
    req_cnt = 0
    batch_list.insert(
        curr_batch,
        compute.new_batch_http_request(callback=delete_instances_cb))

    def_list = {
        pid: cfg.instance_defs[pid]
        for pid, nodes in groupby(node_list, util.get_pid)
    }
    regional_instances = util.get_regional_instances(compute, cfg.project,
                                                     def_list)

    for node_name in node_list:

        pid = util.get_pid(node_name)
        if (not arg_job_id and cfg.instance_defs[pid].exclusive):
            # Node was deleted by EpilogSlurmctld, skip for SuspendProgram
            continue

        zone = None
        if cfg.instance_defs[pid].regional_capacity:
            instance = regional_instances.get(node_name, None)
            if instance is None:
                log.debug("Regional node not found. Already deleted?")
                continue
            zone = instance['zone'].split('/')[-1]
        else:
            zone = cfg.instance_defs[pid].zone

        if req_cnt >= TOT_REQ_CNT:
            req_cnt = 0
            curr_batch += 1
            batch_list.insert(
                curr_batch,
                compute.new_batch_http_request(callback=delete_instances_cb))

        batch_list[curr_batch].add(compute.instances().delete(
            project=cfg.project, zone=zone, instance=node_name),
                                   request_id=node_name)
        req_cnt += 1

    try:
        for i, batch in enumerate(batch_list):
            util.ensure_execute(batch)
            if i < (len(batch_list) - 1):
                time.sleep(30)
    except Exception:
        log.exception("error in batch:")
Esempio n. 3
0
def start_instances(compute, node_list, gcp_nodes):

    req_cnt = 0
    curr_batch = 0
    batch_list = []
    batch_list.insert(
        curr_batch,
        compute.new_batch_http_request(callback=start_instances_cb))

    for node in node_list:

        pid = util.get_pid(node)
        zone = cfg.instance_defs[pid].zone

        if cfg.instance_defs[pid].regional_capacity:
            g_node = gcp_nodes.get(node, None)
            if not g_node:
                log.error(f"Didn't find regional GCP record for '{node}'")
                continue
            zone = g_node['zone'].split('/')[-1]

        if req_cnt >= TOT_REQ_CNT:
            req_cnt = 0
            curr_batch += 1
            batch_list.insert(
                curr_batch,
                compute.new_batch_http_request(callback=start_instances_cb))

        batch_list[curr_batch].add(
            compute.instances().start(project=cfg.project, zone=zone,
                                      instance=node),
            request_id=node)
        req_cnt += 1
    try:
        for i, batch in enumerate(batch_list):
            util.ensure_execute(batch)
            if i < (len(batch_list) - 1):
                time.sleep(30)
    except Exception:
        log.exception("error in start batch: ")
Esempio n. 4
0
def create_placement_groups(arg_job_id, vm_count, region):
    log.debug(f"Creating PG: {arg_job_id} vm_count:{vm_count} region:{region}")

    pg_names = []
    pg_ops = []
    pg_index = 0

    auth_http = None
    if not cfg.google_app_cred_path:
        http = set_user_agent(httplib2.Http(),
                              "Slurm_GCP_Scripts/1.2 (GPN:SchedMD)")
        creds = compute_engine.Credentials()
        auth_http = google_auth_httplib2.AuthorizedHttp(creds, http=http)
    compute = googleapiclient.discovery.build('compute',
                                              'v1',
                                              http=auth_http,
                                              cache_discovery=False)

    for i in range(vm_count):
        if i % PLACEMENT_MAX_CNT:
            continue
        pg_index += 1
        pg_name = f'{cfg.cluster_name}-{arg_job_id}-{pg_index}'
        pg_names.append(pg_name)

        config = {
            'name': pg_name,
            'region': region,
            'groupPlacementPolicy': {
                "collocation": "COLLOCATED",
                "vmCount": min(vm_count - i, PLACEMENT_MAX_CNT)
            }
        }

        pg_ops.append(
            util.ensure_execute(compute.resourcePolicies().insert(
                project=cfg.project, region=region, body=config)))

    for operation in pg_ops:
        result = util.wait_for_operation(compute, cfg.project, operation)
        if result and 'error' in result:
            err_msg = result['error']['errors'][0]['message']
            log.error(f" placement group operation failed: {err_msg}")
            os._exit(1)

    return pg_names
Esempio n. 5
0
def delete_placement_groups(job_id, region, partition_name):
    def delete_placement_request(pg_name):
        return compute.resourcePolicies().delete(project=cfg.project,
                                                 region=region,
                                                 resourcePolicy=pg_name)

    flt = f"name={cfg.slurm_cluster_name}-{partition_name}-{job_id}-*"
    req = compute.resourcePolicies().list(project=cfg.project,
                                          region=region,
                                          filter=flt)
    result = ensure_execute(req).get("items")
    if not result:
        log.debug(f"No placement groups found to delete for job id {job_id}")
        return
    requests = {
        pg["name"]: delete_placement_request(pg["name"])
        for pg in result
    }
    done, failed = batch_execute(requests)
    if failed:
        failed_pg = [f"{n}: {e}" for n, (_, e) in failed.items()]
        log.error(f"some nodes failed to delete: {failed_pg}")
Esempio n. 6
0
def main():
    compute = googleapiclient.discovery.build('compute', 'v1',
                                              cache_discovery=False)

    try:
        s_nodes = dict()
        cmd = (f"{SCONTROL} show nodes | "
               r"grep -oP '^NodeName=\K(\S+)|State=\K(\S+)' | "
               "paste -sd',\n'")
        nodes = util.run(cmd, shell=True, check=True, get_stdout=True).stdout
        if nodes:
            # result is a list of tuples like:
            # (nodename, (base='base_state', flags=<set of state flags>))
            # from 'nodename,base_state+flag1+flag2'
            # state flags include: CLOUD, COMPLETING, DRAIN, FAIL, POWERED_DOWN,
            #   POWERING_DOWN
            # Modifiers on base state still include: @ (reboot), $ (maint),
            #   * (nonresponsive), # (powering up)
            StateTuple = collections.namedtuple('StateTuple', 'base,flags')

            def make_state_tuple(state):
                return StateTuple(state[0], set(state[1:]))
            s_nodes = {node: make_state_tuple(args.split('+'))
                       for node, args in
                       map(lambda x: x.split(','), nodes.rstrip().splitlines())
                       if 'CLOUD' in args}

        g_nodes = util.get_regional_instances(compute, cfg.project,
                                              cfg.instance_defs)
        for pid, part in cfg.instance_defs.items():
            page_token = ""
            while True:
                if not part.regional_capacity:
                    resp = util.ensure_execute(
                        compute.instances().list(
                            project=cfg.project, zone=part.zone,
                            fields='items(name,zone,status),nextPageToken',
                            pageToken=page_token, filter=f"name={pid}-*"))

                    if "items" in resp:
                        g_nodes.update({instance['name']: instance
                                       for instance in resp['items']})
                    if "nextPageToken" in resp:
                        page_token = resp['nextPageToken']
                        continue

                break

        to_down = []
        to_idle = []
        to_start = []
        for s_node, s_state in s_nodes.items():
            g_node = g_nodes.get(s_node, None)
            pid = util.get_pid(s_node)

            if (('POWERED_DOWN' not in s_state.flags) and
                    ('POWERING_DOWN' not in s_state.flags)):
                # slurm nodes that aren't powered down and are stopped in GCP:
                #   mark down in slurm
                #   start them in gcp
                if g_node and (g_node['status'] == "TERMINATED"):
                    if not s_state.base.startswith('DOWN'):
                        to_down.append(s_node)
                    if cfg.instance_defs[pid].preemptible_bursting != 'false':
                        to_start.append(s_node)

                # can't check if the node doesn't exist in GCP while the node
                # is booting because it might not have been created yet by the
                # resume script.
                # This should catch the completing states as well.
                if (g_node is None and "POWERING_UP" not in s_state.flags and
                        not s_state.base.startswith('DOWN')):
                    to_down.append(s_node)

            elif g_node is None:
                # find nodes that are down~ in slurm and don't exist in gcp:
                #   mark idle~
                if s_state.base.startswith('DOWN') and 'POWERED_DOWN' in s_state.flags:
                    to_idle.append(s_node)
                elif 'POWERING_DOWN' in s_state.flags:
                    to_idle.append(s_node)
                elif s_state.base.startswith('COMPLETING'):
                    to_down.append(s_node)

        if len(to_down):
            log.info("{} stopped/deleted instances ({})".format(
                len(to_down), ",".join(to_down)))
            log.info("{} instances to start ({})".format(
                len(to_start), ",".join(to_start)))
            hostlist = to_hostlist(to_down)

            util.run(f"{SCONTROL} update nodename={hostlist} state=down "
                     "reason='Instance stopped/deleted'")

            while True:
                start_instances(compute, to_start, g_nodes)
                if not len(retry_list):
                    break

                log.debug("got {} nodes to retry ({})"
                          .format(len(retry_list), ','.join(retry_list)))
                to_start = list(retry_list)
                del retry_list[:]

        if len(to_idle):
            log.info("{} instances to resume ({})".format(
                len(to_idle), ','.join(to_idle)))

            hostlist = to_hostlist(to_idle)
            util.run(f"{SCONTROL} update nodename={hostlist} state=resume")

        orphans = [
            inst for inst, info in g_nodes.items()
            if info['status'] == 'RUNNING' and (
                inst not in s_nodes or 'POWERED_DOWN' in s_nodes[inst].flags
            )
        ]
        if orphans:
            if args.debug:
                for orphan in orphans:
                    info = g_nodes.get(orphan)
                    state = s_nodes.get(orphan, None)
                    log.debug(f"orphan {orphan}: status={info['status']} state={state}")
            hostlist = to_hostlist(orphans)
            log.info(f"{len(orphans)} orphan instances found to terminate: {hostlist}")
            util.run(f"{SCRIPTS_DIR}/suspend.py {hostlist}")

    except Exception:
        log.exception("failed to sync instances")
Esempio n. 7
0
def create_instance(compute, instance_def, node_list, placement_group_name):

    # Configure the machine

    meta_files = {
        'config': SCRIPTS_DIR/'config.yaml',
        'util-script': SCRIPTS_DIR/'util.py',
        'startup-script': SCRIPTS_DIR/'startup.sh',
        'setup-script': SCRIPTS_DIR/'setup.py',
    }
    custom_compute = SCRIPTS_DIR/'custom-compute-install'
    if custom_compute.exists():
        meta_files['custom-compute-install'] = str(custom_compute)

    metadata = {
        'enable-oslogin': '******',
        'VmDnsSetting': 'GlobalOnly',
        'instance_type': 'compute',
    }
    if not instance_def.image_hyperthreads:
        metadata['google_mpi_tuning'] = '--nosmt'

    config = {
        'name': 'notused',

        # Specify a network interface
        'networkInterfaces': [{
            'subnetwork': (
                "projects/{}/regions/{}/subnetworks/{}".format(
                    cfg.shared_vpc_host_project or cfg.project,
                    instance_def.region,
                    (instance_def.vpc_subnet
                     or f'{cfg.cluster_name}-{instance_def.region}'))
            ),
        }],

        'tags': {'items': ['compute']},

        'metadata': {
            'items': [
                *[{'key': k, 'value': v} for k, v in metadata.items()],
                *[{'key': k, 'value': Path(v).read_text()} for k, v in meta_files.items()]
            ]
        }
    }

    if instance_def.machine_type:
        config['machineType'] = instance_def.machine_type

    if (instance_def.image and
            instance_def.compute_disk_type and
            instance_def.compute_disk_size_gb):
        config['disks'] = [{
            'boot': True,
            'autoDelete': True,
            'initializeParams': {
                'sourceImage': instance_def.image,
                'diskType': instance_def.compute_disk_type,
                'diskSizeGb': instance_def.compute_disk_size_gb
            }
        }]

    if cfg.compute_node_service_account and cfg.compute_node_scopes:
        # Allow the instance to access cloud storage and logging.
        config['serviceAccounts'] = [{
            'email': cfg.compute_node_service_account,
            'scopes': cfg.compute_node_scopes
        }]

    if placement_group_name is not None:
        config['scheduling'] = {
            'onHostMaintenance': 'TERMINATE',
            'automaticRestart': False
        }
        config['resourcePolicies'] = [placement_group_name]

    if instance_def.gpu_count:
        config['guestAccelerators'] = [{
            'acceleratorCount': instance_def.gpu_count,
            'acceleratorType': instance_def.gpu_type
        }]
        config['scheduling'] = {'onHostMaintenance': 'TERMINATE'}

    if instance_def.preemptible_bursting:
        config['scheduling'] = {
            'preemptible': True,
            'onHostMaintenance': 'TERMINATE',
            'automaticRestart': False
        }

    if instance_def.compute_labels:
        config['labels'] = instance_def.compute_labels

    if instance_def.cpu_platform:
        config['minCpuPlatform'] = instance_def.cpu_platform

    if cfg.external_compute_ips:
        config['networkInterfaces'][0]['accessConfigs'] = [
            {'type': 'ONE_TO_ONE_NAT', 'name': 'External NAT'}
        ]

    perInstanceProperties = {k: {} for k in node_list}
    body = {
        'count': len(node_list),
        'instanceProperties': config,
        'perInstanceProperties': perInstanceProperties,
    }

    if instance_def.instance_template:
        body['sourceInstanceTemplate'] = (
            "projects/{}/global/instanceTemplates/{}".format(
                cfg.project, instance_def.instance_template)
        )

    # For non-exclusive requests, create as many instances as possible as the
    # nodelist isn't tied to a specific set of instances.
    if not instance_def.exclusive:
        body['minCount'] = 1

    if instance_def.regional_capacity:
        if instance_def.regional_policy:
            body['locationPolicy'] = instance_def.regional_policy
        op = compute.regionInstances().bulkInsert(
            project=cfg.project, region=instance_def.region,
            body=body)
        return op.execute()

    return util.ensure_execute(compute.instances().bulkInsert(
        project=cfg.project, zone=instance_def.zone, body=body))
Esempio n. 8
0
def main():
    compute = googleapiclient.discovery.build('compute',
                                              'v1',
                                              cache_discovery=False)

    try:
        s_nodes = dict()
        cmd = (f"{SCONTROL} show nodes | "
               r"grep -oP '^NodeName=\K(\S+)|State=\K(\S+)' | "
               "paste -sd',\n'")
        nodes = util.run(cmd, shell=True, check=True, get_stdout=True).stdout
        if nodes:
            # result is a list of tuples like:
            # (nodename, (base='base_state', flags=<set of state flags>))
            # from 'nodename,base_state+flag1+flag2'
            # state flags include: CLOUD, COMPLETING, DRAIN, FAIL, POWER,
            #   POWERING_DOWN
            # Modifiers on base state still include: @ (reboot), $ (maint),
            #   * (nonresponsive), # (powering up)
            StateTuple = collections.namedtuple('StateTuple', 'base,flags')

            def make_state_tuple(state):
                return StateTuple(state[0], set(state[1:]))

            s_nodes = [(node, make_state_tuple(args.split('+')))
                       for node, args in map(lambda x: x.split(','),
                                             nodes.rstrip().splitlines())
                       if 'CLOUD' in args]

        g_nodes = util.get_regional_instances(compute, cfg.project,
                                              cfg.instance_defs)
        for pid, part in cfg.instance_defs.items():
            page_token = ""
            while True:
                if not part.regional_capacity:
                    resp = util.ensure_execute(compute.instances().list(
                        project=cfg.project,
                        zone=part.zone,
                        fields='items(name,zone,status),nextPageToken',
                        pageToken=page_token,
                        filter=f"name={pid}-*"))

                    if "items" in resp:
                        g_nodes.update({
                            instance['name']: instance
                            for instance in resp['items']
                        })
                    if "nextPageToken" in resp:
                        page_token = resp['nextPageToken']
                        continue

                break

        to_down = []
        to_idle = []
        to_start = []
        for s_node, s_state in s_nodes:
            g_node = g_nodes.get(s_node, None)
            pid = util.get_pid(s_node)

            if (('POWER' not in s_state.flags)
                    and ('POWERING_DOWN' not in s_state.flags)):
                # slurm nodes that aren't in power_save and are stopped in GCP:
                #   mark down in slurm
                #   start them in gcp
                if g_node and (g_node['status'] == "TERMINATED"):
                    if not s_state.base.startswith('DOWN'):
                        to_down.append(s_node)
                    if (cfg.instance_defs[pid].preemptible_bursting):
                        to_start.append(s_node)

                # can't check if the node doesn't exist in GCP while the node
                # is booting because it might not have been created yet by the
                # resume script.
                # This should catch the completing states as well.
                if (g_node is None and "#" not in s_state.base
                        and not s_state.base.startswith('DOWN')):
                    to_down.append(s_node)

            elif g_node is None:
                # find nodes that are down~ in slurm and don't exist in gcp:
                #   mark idle~
                if s_state.base.startswith(
                        'DOWN') and 'POWER' in s_state.flags:
                    to_idle.append(s_node)
                elif 'POWERING_DOWN' in s_state.flags:
                    to_idle.append(s_node)
                elif s_state.base.startswith('COMPLETING'):
                    to_down.append(s_node)

        if len(to_down):
            log.info("{} stopped/deleted instances ({})".format(
                len(to_down), ",".join(to_down)))
            log.info("{} instances to start ({})".format(
                len(to_start), ",".join(to_start)))

            # write hosts to a file that can be given to get a slurm
            # hostlist. Since the number of hosts could be large.
            tmp_file = tempfile.NamedTemporaryFile(mode='w+t', delete=False)
            tmp_file.writelines("\n".join(to_down))
            tmp_file.close()
            log.debug("tmp_file = {}".format(tmp_file.name))

            hostlist = util.run(f"{SCONTROL} show hostlist {tmp_file.name}",
                                check=True,
                                get_stdout=True).stdout.rstrip()
            log.debug("hostlist = {}".format(hostlist))
            os.remove(tmp_file.name)

            util.run(f"{SCONTROL} update nodename={hostlist} state=down "
                     "reason='Instance stopped/deleted'")

            while True:
                start_instances(compute, to_start, g_nodes)
                if not len(retry_list):
                    break

                log.debug("got {} nodes to retry ({})".format(
                    len(retry_list), ','.join(retry_list)))
                to_start = list(retry_list)
                del retry_list[:]

        if len(to_idle):
            log.info("{} instances to resume ({})".format(
                len(to_idle), ','.join(to_idle)))

            # write hosts to a file that can be given to get a slurm
            # hostlist. Since the number of hosts could be large.
            tmp_file = tempfile.NamedTemporaryFile(mode='w+t', delete=False)
            tmp_file.writelines("\n".join(to_idle))
            tmp_file.close()
            log.debug("tmp_file = {}".format(tmp_file.name))

            hostlist = util.run(f"{SCONTROL} show hostlist {tmp_file.name}",
                                check=True,
                                get_stdout=True).stdout.rstrip()
            log.debug("hostlist = {}".format(hostlist))
            os.remove(tmp_file.name)

            util.run(f"{SCONTROL} update nodename={hostlist} state=resume")

    except Exception:
        log.exception("failed to sync instances")