def batch_test(num_threads, delay):
    """Run AWSNodeProvider.set_node_tags in several threads, with a
    specified delay between thread launches.

    Return the number of batches of tag updates and the number of tags
    updated.
    """
    with mock.patch(
            "ray.autoscaler._private.aws.node_provider.make_ec2_resource"
    ), mock.patch.object(AWSNodeProvider, "_create_tags", mock_create_tags):
        provider = AWSNodeProvider(provider_config={"region": "nowhere"},
                                   cluster_name="default")
        provider.batch_counter = 0
        provider.tag_update_counter = 0
        provider.tag_cache = {str(x): {} for x in range(num_threads)}

        threads = []
        for x in range(num_threads):
            thread = threading.Thread(target=provider.set_node_tags,
                                      args=(str(x), {
                                          "foo": "bar"
                                      }))
            threads.append(thread)

        for thread in threads:
            thread.start()
            time.sleep(delay)
        for thread in threads:
            thread.join()

        return provider.batch_counter, provider.tag_update_counter
예제 #2
0
파일: helpers.py 프로젝트: stjordanis/ray
def apply_node_provider_config_updates(config, node_cfg, node_type_name,
                                       max_count):
    """
    Applies default updates made by AWSNodeProvider to node_cfg during node
    creation. This should only be used for testing purposes.

    Args:
        config: autoscaler config
        node_cfg: node config
        node_type_name: node type name
        max_count: max nodes of the given type to launch
    """
    tags = node_provider_tags(config, node_type_name)
    tags[TAG_RAY_CLUSTER_NAME] = DEFAULT_CLUSTER_NAME
    user_tag_specs = node_cfg.get("TagSpecifications", [])
    tag_specs = [{
        "ResourceType": "instance",
        "Tags": [{
            "Key": k,
            "Value": v
        } for k, v in sorted(tags.items())]
    }]
    node_provider_cfg_updates = {
        "MinCount": 1,
        "MaxCount": max_count,
        "TagSpecifications": tag_specs,
    }
    tags.pop(TAG_RAY_CLUSTER_NAME)
    node_cfg.update(node_provider_cfg_updates)
    # merge node provider tag specs with user overrides
    AWSNodeProvider._merge_tag_specs(tag_specs, user_tag_specs)
예제 #3
0
def test_terminate_nodes(num_on_demand_nodes, num_spot_nodes, stop):
    # This node makes sure that we stop or terminate all the nodes we're
    # supposed to stop or terminate when we call "terminate_nodes". This test
    # alse makes sure that we don't try to stop or terminate too many nodes in
    # a single EC2 request. By default, only 1000 nodes can be
    # stopped/terminated in one request. To terminate more nodes, we must break
    # them up into multiple smaller requests.
    #
    # "num_on_demand_nodes" is the number of on-demand nodes to stop or
    #   terminate.
    # "num_spot_nodes" is the number of on-demand nodes to terminate.
    # "stop" is True if we want to stop nodes, and False to terminate nodes.
    #   Note that spot instances are always terminated, even if "stop" is True.

    # Generate a list of unique instance ids to terminate
    on_demand_nodes = {
        "i-{:017d}".format(i)
        for i in range(num_on_demand_nodes)
    }
    spot_nodes = {
        "i-{:017d}".format(i + num_on_demand_nodes)
        for i in range(num_spot_nodes)
    }
    node_ids = list(on_demand_nodes.union(spot_nodes))

    with patch("ray.autoscaler._private.aws.node_provider.make_ec2_client"):
        provider = AWSNodeProvider(
            provider_config={
                "region": "nowhere",
                "cache_stopped_nodes": stop
            },
            cluster_name="default",
        )

    # "_get_cached_node" is used by the AWSNodeProvider to determine whether a
    # node is a spot instance or an on-demand instance.
    def mock_get_cached_node(node_id):
        result = Mock()
        result.spot_instance_request_id = ("sir-08b93456"
                                           if node_id in spot_nodes else "")
        return result

    provider._get_cached_node = mock_get_cached_node

    provider.terminate_nodes(node_ids)

    stop_calls = provider.ec2.meta.client.stop_instances.call_args_list
    terminate_calls = provider.ec2.meta.client.terminate_instances.call_args_list

    nodes_to_stop = set()
    nodes_to_terminate = spot_nodes

    if stop:
        nodes_to_stop.update(on_demand_nodes)
    else:
        nodes_to_terminate.update(on_demand_nodes)

    for calls, nodes_to_include_in_call in (stop_calls, nodes_to_stop), (
            terminate_calls,
            nodes_to_terminate,
    ):
        nodes_included_in_call = set()
        for call in calls:
            assert len(call[1]["InstanceIds"]) <= provider.max_terminate_nodes
            nodes_included_in_call.update(call[1]["InstanceIds"])

        assert nodes_to_include_in_call == nodes_included_in_call