def batch_test(num_threads, delay): """Run AWSNodeProvider.set_node_tags in several threads, with a specified delay between thread launches. Return the number of batches of tag updates and the number of tags updated. """ with mock.patch( "ray.autoscaler._private.aws.node_provider.make_ec2_resource" ), mock.patch.object(AWSNodeProvider, "_create_tags", mock_create_tags): provider = AWSNodeProvider(provider_config={"region": "nowhere"}, cluster_name="default") provider.batch_counter = 0 provider.tag_update_counter = 0 provider.tag_cache = {str(x): {} for x in range(num_threads)} threads = [] for x in range(num_threads): thread = threading.Thread(target=provider.set_node_tags, args=(str(x), { "foo": "bar" })) threads.append(thread) for thread in threads: thread.start() time.sleep(delay) for thread in threads: thread.join() return provider.batch_counter, provider.tag_update_counter
def apply_node_provider_config_updates(config, node_cfg, node_type_name, max_count): """ Applies default updates made by AWSNodeProvider to node_cfg during node creation. This should only be used for testing purposes. Args: config: autoscaler config node_cfg: node config node_type_name: node type name max_count: max nodes of the given type to launch """ tags = node_provider_tags(config, node_type_name) tags[TAG_RAY_CLUSTER_NAME] = DEFAULT_CLUSTER_NAME user_tag_specs = node_cfg.get("TagSpecifications", []) tag_specs = [{ "ResourceType": "instance", "Tags": [{ "Key": k, "Value": v } for k, v in sorted(tags.items())] }] node_provider_cfg_updates = { "MinCount": 1, "MaxCount": max_count, "TagSpecifications": tag_specs, } tags.pop(TAG_RAY_CLUSTER_NAME) node_cfg.update(node_provider_cfg_updates) # merge node provider tag specs with user overrides AWSNodeProvider._merge_tag_specs(tag_specs, user_tag_specs)
def test_terminate_nodes(num_on_demand_nodes, num_spot_nodes, stop): # This node makes sure that we stop or terminate all the nodes we're # supposed to stop or terminate when we call "terminate_nodes". This test # alse makes sure that we don't try to stop or terminate too many nodes in # a single EC2 request. By default, only 1000 nodes can be # stopped/terminated in one request. To terminate more nodes, we must break # them up into multiple smaller requests. # # "num_on_demand_nodes" is the number of on-demand nodes to stop or # terminate. # "num_spot_nodes" is the number of on-demand nodes to terminate. # "stop" is True if we want to stop nodes, and False to terminate nodes. # Note that spot instances are always terminated, even if "stop" is True. # Generate a list of unique instance ids to terminate on_demand_nodes = { "i-{:017d}".format(i) for i in range(num_on_demand_nodes) } spot_nodes = { "i-{:017d}".format(i + num_on_demand_nodes) for i in range(num_spot_nodes) } node_ids = list(on_demand_nodes.union(spot_nodes)) with patch("ray.autoscaler._private.aws.node_provider.make_ec2_client"): provider = AWSNodeProvider( provider_config={ "region": "nowhere", "cache_stopped_nodes": stop }, cluster_name="default", ) # "_get_cached_node" is used by the AWSNodeProvider to determine whether a # node is a spot instance or an on-demand instance. def mock_get_cached_node(node_id): result = Mock() result.spot_instance_request_id = ("sir-08b93456" if node_id in spot_nodes else "") return result provider._get_cached_node = mock_get_cached_node provider.terminate_nodes(node_ids) stop_calls = provider.ec2.meta.client.stop_instances.call_args_list terminate_calls = provider.ec2.meta.client.terminate_instances.call_args_list nodes_to_stop = set() nodes_to_terminate = spot_nodes if stop: nodes_to_stop.update(on_demand_nodes) else: nodes_to_terminate.update(on_demand_nodes) for calls, nodes_to_include_in_call in (stop_calls, nodes_to_stop), ( terminate_calls, nodes_to_terminate, ): nodes_included_in_call = set() for call in calls: assert len(call[1]["InstanceIds"]) <= provider.max_terminate_nodes nodes_included_in_call.update(call[1]["InstanceIds"]) assert nodes_to_include_in_call == nodes_included_in_call