def test_range_set_equality(self) -> None: self.assertEqual( ImmutableRangeSet.builder() # type: ignore .add(Range.at_most(2)).add(Range.at_least(5)).build(), ImmutableRangeSet.builder() # type: ignore .add(Range.at_least(5)).add(Range.at_most(2)).build(), )
def test_range_clear(self) -> None: range_set: MutableRangeSet[int] = RangeSet.create_mutable() range_set.add_all( [Range.at_most(2), Range.open_closed(5, 8), Range.at_least(10)]) range_set.clear() self.assertEqual(0, len(range_set.as_ranges()))
def test_intersection_singleton(self): rng = Range.closed(3, 3) self.assertEqual(rng, rng.intersection(rng)) self.assertEqual(rng, rng.intersection(Range.at_most(4))) self.assertEqual(rng, rng.intersection(Range.at_most(3))) self.assertEqual(rng, rng.intersection(Range.at_least(3))) self.assertEqual(rng, rng.intersection(Range.at_least(2))) self.assertEqual(Range.closed_open(3, 3), rng.intersection(Range.less_than(3))) self.assertEqual(Range.open_closed(3, 3), rng.intersection(Range.greater_than(3))) with self.assertRaises(ValueError): rng.intersection(Range.at_least(4)) with self.assertRaises(ValueError): rng.intersection(Range.at_most(2))
def test_at_least(self): rng = Range.at_least(6) self.assertFalse(5 in rng) self.assertTrue(6 in rng) self.assertTrue(sys.maxsize in rng) self.assertTrue(rng.has_lower_bound()) self.assertEqual(6, rng.lower_endpoint) self.assertEqual(BoundType.closed(), rng.lower_bound_type) self.assert_unbounded_above(rng) self.assertFalse(rng.is_empty()) self.assertEqual("[6..+\u221e)", str(rng))
def main(cluster_params: Parameters, job_param_file: Path) -> None: runner = SlurmPythonRunner.from_parameters(cluster_params) job_params = YAMLParametersLoader().load(job_param_file) entry_point = job_params.string("entry_point") memory = MemoryAmount.parse(job_params.string("memory")) runner.run_entry_point( entry_point_name=entry_point, param_file=job_param_file, partition=cluster_params.string("partition"), working_directory=job_params.optional_creatable_directory( "working_directory") or Path(os.getcwd()), num_gpus=job_params.integer("num_gpus", default=0, valid_range=Range.at_least(0)), num_cpus=job_params.integer("num_cpus", default=1, valid_range=Range.at_least(1)), job_name=job_params.string("job_name", default=entry_point), memory_request=memory, echo_template=cluster_params.boolean("echo_template", default=False), slurm_script_path=job_params.optional_creatable_file( "slurm_script_path"), )
def test_encloses_closed(self): rng = Range.closed(2, 5) self.assertTrue(rng.encloses(rng)) self.assertTrue(rng.encloses(Range.open(2, 5))) self.assertTrue(rng.encloses(Range.open_closed(2, 5))) self.assertTrue(rng.encloses(Range.closed_open(2, 5))) self.assertTrue(rng.encloses(Range.closed(3, 5))) self.assertTrue(rng.encloses(Range.closed(2, 4))) self.assertFalse(rng.encloses(Range.open(1, 6))) self.assertFalse(rng.encloses(Range.greater_than(3))) self.assertFalse(rng.encloses(Range.less_than(3))) self.assertFalse(rng.encloses(Range.at_least(3))) self.assertFalse(rng.encloses(Range.at_most(3))) self.assertFalse(rng.encloses(Range.all()))
def test_range_enclosing_range(self) -> None: range_set: MutableRangeSet[int] = RangeSet.create_mutable() range_set.add_all( [Range.at_most(2), Range.open_closed(5, 8), Range.at_least(10)]) self.assertEqual(None, range_set.range_enclosing_range(Range.closed(2, 3))) self.assertEqual(Range.at_most(2), range_set.range_enclosing_range(Range.open(-1, 0))) self.assertEqual( Range.open_closed(5, 8), range_set.range_enclosing_range(Range.closed_open(6, 7)), ) self.assertEqual(None, range_set.range_enclosing_range(Range.closed(5, 8)))
def test_intersection_de_facto_empty(self): rng = Range.open(3, 4) self.assertEqual(rng, rng.intersection(rng)) self.assertEqual(Range.open_closed(3, 3), rng.intersection(Range.at_most(3))) self.assertEqual(Range.closed_open(4, 4), rng.intersection(Range.at_least(4))) with self.assertRaises(ValueError): rng.intersection(Range.less_than(3)) with self.assertRaises(ValueError): rng.intersection(Range.greater_than(4)) rng2 = Range.closed(3, 4) self.assertEqual(Range.open_closed(4, 4), rng2.intersection(Range.greater_than(4)))
class MemoryAmount: """ An amount of memory, consisting of an *amount* paired with its corresponding `MemoryUnit` *unit*. """ amount: int = attrib(validator=and_(instance_of(int), in_(Range.at_least(1)))) unit: MemoryUnit = attrib(validator=None) _PARSE_PATTERN = re.compile(r"(\d+) ?([TtGgMmKk])[bB]?") @staticmethod def parse(memory_string: str) -> "MemoryAmount": parts = MemoryAmount._PARSE_PATTERN.match(memory_string) if parts: return MemoryAmount( amount=int(parts.group(1)), unit=MemoryUnit.parse(parts.group(2)) ) else: raise RuntimeError( f"Cannot parse {memory_string} as an amount of memory. " f"Expected an integer followed by K, M, G, or T" )
class TestRangeSet(TestCase): """ Tests for RangeSet Derived from Guava's TreeRangeSet tests, which were written by Louis Wasserman and Chris Povrik """ MIN_BOUND = -1 MAX_BOUND = 1 BOUND_TYPES = [BoundType.open(), BoundType.closed()] QUERY_RANGES: List[Range[int]] = [Range.all()] for i in range(MIN_BOUND, MAX_BOUND + 1): QUERY_RANGES.extend([ Range.at_most(i), Range.at_least(i), Range.less_than(i), Range.greater_than(i), Range.closed(i, i), Range.open_closed(i, i), Range.closed_open(i, i), ]) for j in range(i + 1, MAX_BOUND + 1): QUERY_RANGES.extend([ Range.open(i, j), Range.open_closed(i, j), Range.closed_open(i, j), Range.closed(i, j), ]) def test_empty_enclosing(self): self._test_encloses(RangeSet.create_mutable()) def test_empty_intersects(self): self._test_intersects(RangeSet.create_mutable()) def test_all_single_ranges_enclosing(self): for query_range in TestRangeSet.QUERY_RANGES: self._test_encloses(RangeSet.create_mutable().add(query_range)) # also test for the complement of empty once complements are implemented def test_all_pair_ranges_enclosing(self): for query_range_1 in TestRangeSet.QUERY_RANGES: for query_range_2 in TestRangeSet.QUERY_RANGES: self._test_encloses(RangeSet.create_mutable().add( query_range_1).add(query_range_2)) def test_intersect_ranges(self): range_set = RangeSet.create_mutable() range_set.add_all([ Range.closed(2, 4), Range.closed(5, 7), Range.closed(10, 12), Range.closed(18, 20), ]) self.assertEqual( range_set.ranges_overlapping(Range.closed(19, 21)), immutableset([Range.closed(18, 20)]), ) self.assertEqual( range_set.ranges_overlapping(Range.closed(11, 19)), immutableset([Range.closed(10, 12), Range.closed(18, 20)]), ) self.assertEqual(range_set.ranges_overlapping(Range.closed(0, 1)), immutableset()) self.assertEqual(range_set.ranges_overlapping(Range.closed(21, 23)), immutableset()) self.assertEqual(range_set.ranges_overlapping(Range.closed(13, 15)), immutableset()) self.assertEqual( range_set.ranges_overlapping(Range.closed(0, 2)), immutableset([Range.closed(2, 4)]), ) self.assertEqual( range_set.ranges_overlapping(Range.closed(12, 15)), immutableset([Range.closed(10, 12)]), ) self.assertEqual( range_set.ranges_overlapping(Range.closed(5, 16)), immutableset([Range.closed(5, 7), Range.closed(10, 12)]), ) def test_merges_connected_with_overlap(self): range_set = RangeSet.create_mutable() range_set.add(Range.closed(1, 4)) range_set.add(Range.open(2, 6)) self._test_invariants(range_set) self.assertTrue(Range.closed_open(1, 6) in range_set.as_ranges()) def test_merges_connected_disjoint(self): range_set = RangeSet.create_mutable() range_set.add(Range.closed(1, 4)) range_set.add(Range.open(4, 6)) self._test_invariants(range_set) self.assertTrue(Range.closed_open(1, 6) in range_set.as_ranges()) def test_ignores_smaller_sharing_no_bound(self): range_set = RangeSet.create_mutable() range_set.add(Range.closed(1, 6)) range_set.add(Range.open(2, 4)) self._test_invariants(range_set) self.assertTrue(Range.closed(1, 6) in range_set.as_ranges()) def test_ignores_smaller_sharing_lower_bound(self): range_set = RangeSet.create_mutable() range_set.add(Range.closed(1, 6)) range_set.add(Range.closed(1, 4)) self._test_invariants(range_set) self.assertEqual(tuple([Range.closed(1, 6)]), tuple(range_set.as_ranges())) def test_ignores_smaller_sharing_upper_bound(self): range_set = RangeSet.create_mutable() range_set.add(Range.closed(1, 6)) range_set.add(Range.closed(3, 6)) self._test_invariants(range_set) self.assertTrue(Range.closed(1, 6) in range_set.as_ranges()) def test_ignores_equal(self): range_set = RangeSet.create_mutable() range_set.add(Range.closed(1, 6)) range_set.add(Range.closed(1, 6)) self._test_invariants(range_set) self.assertTrue(Range.closed(1, 6) in range_set.as_ranges()) def test_extend_same_lower_bound(self): range_set = RangeSet.create_mutable() range_set.add(Range.closed(1, 4)) range_set.add(Range.closed(1, 6)) self._test_invariants(range_set) self.assertTrue(Range.closed(1, 6) in range_set.as_ranges()) def test_extend_same_upper_bound(self): range_set = RangeSet.create_mutable() range_set.add(Range.closed(3, 6)) range_set.add(Range.closed(1, 6)) self._test_invariants(range_set) self.assertTrue(Range.closed(1, 6) in range_set.as_ranges()) def test_extend_both_directions(self): range_set = RangeSet.create_mutable() range_set.add(Range.closed(3, 4)) range_set.add(Range.closed(1, 6)) self._test_invariants(range_set) self.assertTrue(Range.closed(1, 6) in range_set.as_ranges()) def test_add_empty(self): range_set = RangeSet.create_mutable() range_set.add(Range.closed_open(3, 3)) self._test_invariants(range_set) self.assertTrue(len(range_set.as_ranges()) == 0) self.assertTrue(range_set.is_empty()) def test_fill_hole_exactly(self): range_set = RangeSet.create_mutable() range_set.add(Range.closed_open(1, 3)) range_set.add(Range.closed_open(4, 6)) range_set.add(Range.closed_open(3, 4)) self._test_invariants(range_set) self.assertTrue(Range.closed_open(1, 6) in range_set.as_ranges()) def test_fill_hole_with_overlap(self): range_set = RangeSet.create_mutable() range_set.add(Range.closed_open(1, 3)) range_set.add(Range.closed_open(4, 6)) range_set.add(Range.closed_open(2, 5)) self._test_invariants(range_set) self.assertEqual(tuple([Range.closed_open(1, 6)]), tuple(range_set.as_ranges())) def test_add_many_pairs(self): for a_low in range(0, 6): for a_high in range(0, 6): if a_low > a_high: continue a_ranges = [ Range.closed(a_low, a_high), Range.open_closed(a_low, a_high), Range.closed_open(a_low, a_high), ] if a_low != a_high: a_ranges.append(Range.open(a_low, a_high)) for b_low in range(0, 6): for b_high in range(0, 6): if b_low > b_high: continue b_ranges = [ Range.closed(b_low, b_high), Range.open_closed(b_low, b_high), Range.closed_open(b_low, b_high), ] if b_low != b_high: b_ranges.append(Range.open(b_low, b_high)) for a_range in a_ranges: for b_range in b_ranges: self._pair_test(a_range, b_range) def test_range_containing1(self): range_set = RangeSet.create_mutable() range_set.add(Range.closed(3, 10)) self.assertEqual(Range.closed(3, 10), range_set.range_containing(5)) self.assertTrue(5 in range_set) self.assertIsNone(range_set.range_containing(1)) self.assertFalse(1 in range_set) def test_add_all(self): range_set = RangeSet.create_mutable() range_set.add(Range.closed(3, 10)) range_set.add_all( [Range.open(1, 3), Range.closed(5, 8), Range.closed(9, 11)]) self.assertEqual(tuple(range_set.as_ranges()), tuple([Range.open_closed(1, 11)])) def test_all_single_ranges_intersecting(self): for query in TestRangeSet.QUERY_RANGES: self._test_intersects(RangeSet.create_mutable().add(query)) def test_all_two_ranges_intersecting(self): for query_1 in TestRangeSet.QUERY_RANGES: for query_2 in TestRangeSet.QUERY_RANGES: self._test_intersects( RangeSet.create_mutable().add(query_1).add(query_2)) # forms the basis for corresponding tests in test_range_map def test_rightmost_containing_or_below(self): range_set = RangeSet.create_mutable().add_all(( Range.closed(-2, -1), Range.closed_open(0, 2), # we don't do [0, 2), [2.1, 3] because they will coalesce # ditto for (4, 5] and (5.1, 7) Range.closed(2.1, 3), Range.open_closed(4, 5), Range.open(5.1, 7), )) # probe value is in the middle of a set # [2.1 ... *2.5* ... 3] self.assertEqual(Range.closed(2.1, 3.0), range_set.rightmost_containing_or_below(2.5)) # probe value is at a closed upper limit # [2.1 .... *3*] self.assertEqual(Range.closed(2.1, 3.0), range_set.rightmost_containing_or_below(3.0)) # probe value is at a closed lower limit # [*2.1* .... 3] self.assertEqual(Range.closed(2.1, 3.0), range_set.rightmost_containing_or_below(2.1)) # probe value is at an open lower limit # [2.1 ... 3], (*4* ... 5] self.assertEqual(Range.closed(2.1, 3.0), range_set.rightmost_containing_or_below(4.0)) # probe value is at an open upper limit # [0 ... *2.1*) self.assertEqual(Range.closed_open(0.0, 2.0), range_set.rightmost_containing_or_below(2.0)) # probe value falls into a gap # [-2, -1] ... *-0.5* ... [0, 2) self.assertEqual(Range.closed(-2.0, -1.0), range_set.rightmost_containing_or_below(-0.5)) # no range below # *-3* .... [-2,-1] self.assertIsNone(range_set.rightmost_containing_or_below(-3.0)) # empty rangeset self.assertIsNone(RangeSet.create_mutable().add(Range.closed( 1.0, 2.0)).rightmost_containing_or_below(0.0)) # lowest range has open lower bound # (*1*,2) self.assertIsNone(RangeSet.create_mutable().add(Range.open( 1.0, 2.0)).rightmost_containing_or_below(1.0)) # forms the basis for corresponding tests in test_range_set def test_leftmost_containing_or_above(self): range_set = RangeSet.create_mutable().add_all(( Range.closed(-2, -1), Range.closed_open(0, 2), # we don't do [0, 2), [2.1, 3] because they will coalesce # ditto for (4, 5] and (5.1, 7) Range.closed(2.1, 3), Range.open_closed(4, 5), Range.open(5.1, 7), )) # probe value is in the middle of a set # [2.1 ... *2.5* ... 3] self.assertEqual(Range.closed(2.1, 3.0), range_set.leftmost_containing_or_above(2.5)) # probe value is at a closed upper limit # [2.1 .... *3*] self.assertEqual(Range.closed(2.1, 3.0), range_set.leftmost_containing_or_above(3.0)) # probe value is at a closed lower limit # [*2.1* .... 3] self.assertEqual(Range.closed(2.1, 3.0), range_set.leftmost_containing_or_above(2.1)) # probe value is at an open lower limit # [2 ... 3], (*4* ... 5] self.assertEqual(Range.open_closed(4.0, 5.0), range_set.leftmost_containing_or_above(4.0)) # probe value is at an open upper limit # [0 ... *2*) [2.1, 3.0] self.assertEqual(Range.closed(2.1, 3.0), range_set.leftmost_containing_or_above(2.0)) # probe value falls into a gap # [-2, -1] ... *-0.5* ... [0, 2) self.assertEqual(Range.closed_open(0, 2), range_set.leftmost_containing_or_above(-0.5)) # no range above # (5.1 ... 7) ... *8* self.assertIsNone(range_set.leftmost_containing_or_above(8)) # empty rangeset self.assertIsNone(RangeSet.create_mutable().add(Range.closed( 1.0, 2.0)).leftmost_containing_or_above(3.0)) # higher range has open upper bound # (1,*2*) self.assertIsNone(RangeSet.create_mutable().add(Range.open( 1.0, 2.0)).leftmost_containing_or_above(2.0)) def test_len(self): self.assertEqual(0, len(RangeSet.create_mutable())) self.assertEqual( 1, len(RangeSet.create_mutable().add(Range.closed(1, 2)))) self.assertEqual( 2, len(RangeSet.create_mutable().add(Range.closed(1, 2)).add( Range.open(3, 4))), ) # support methods def _pair_test(self, a: Range[int], b: Range[int]) -> None: range_set: MutableRangeSet[int] = RangeSet.create_mutable() range_set.add(a) range_set.add(b) if a.is_empty() and b.is_empty(): self.assertTrue(range_set.is_empty()) self.assertFalse(range_set.as_ranges()) elif a.is_empty(): self.assertTrue(b in range_set.as_ranges()) elif b.is_empty(): self.assertTrue(a in range_set.as_ranges()) elif a.is_connected(b): self.assertEqual(tuple(range_set.as_ranges()), tuple([a.span(b)])) else: if a.lower_endpoint < b.lower_endpoint: self.assertEqual(tuple(range_set.as_ranges()), tuple([a, b])) else: self.assertEqual(ImmutableSet.of([a, b]), ImmutableSet.of(range_set.as_ranges())) def _test_encloses(self, range_set: RangeSet[int]): self.assertTrue(range_set.encloses_all(ImmutableSet.empty())) for query_range in TestRangeSet.QUERY_RANGES: expected_to_enclose = any( x.encloses(query_range) for x in range_set.as_ranges()) self.assertEqual(expected_to_enclose, range_set.encloses(query_range)) self.assertEqual(expected_to_enclose, range_set.encloses_all([query_range])) def _test_intersects(self, range_set: RangeSet[int]): for query in TestRangeSet.QUERY_RANGES: expect_intersects = any( r.is_connected(query) and not r.intersection(query).is_empty() for r in range_set.as_ranges()) self.assertEqual(expect_intersects, range_set.intersects(query)) def _test_invariants(self, range_set: RangeSet[int]): self.assertEqual(len(range_set.as_ranges()) == 0, range_set.is_empty()) as_ranges: Sequence[Range[int]] = tuple(range_set.as_ranges()) # test that connected ranges are coalesced for (range_1, range_2) in tile_with_pairs(as_ranges): self.assertFalse(range_1.is_connected(range_2)) for rng in as_ranges: self.assertFalse(rng.is_empty()) # test that the RangeSet's span is the span of all the ranges if as_ranges: self.assertEqual(Range.create_spanning(range_set.as_ranges()), range_set.span) else: with self.assertRaises(ValueError): # pylint: disable=pointless-statement # noinspection PyStatementEffect range_set.span # test internal utility functions def test_entry_above_below(self): sorted_dict = SortedDict({1: 1, 3: 3, 5: 5, 7: 7, 9: 9}) value_at_or_below_reference = ( (0, None), (1, 1), (2, 1), (3, 3), (4, 3), (5, 5), (6, 5), (7, 7), (8, 7), (9, 9), (10, 9), (200, 9), ) for (key, ref) in value_at_or_below_reference: self.assertEqual(_value_at_or_below(sorted_dict, key), ref) value_below_reference = ( (0, None), (1, None), (2, 1), (3, 1), (4, 3), (5, 3), (6, 5), (7, 5), (8, 7), (9, 7), (10, 9), (200, 9), ) for (key, ref) in value_below_reference: self.assertEqual(_value_below(sorted_dict, key), ref) value_at_or_above_reference = ( (0, 1), (1, 1), (2, 3), (3, 3), (4, 5), (5, 5), (6, 7), (7, 7), (8, 9), (9, 9), (10, None), (200, None), ) for (key, ref) in value_at_or_above_reference: self.assertEqual(_value_at_or_above(sorted_dict, key), ref) def test_pickling(self): empty_mutable_rangeset = MutableRangeSet.create_mutable() empty_immutable_rangeset = ImmutableRangeSet.builder().build() ranges = (Range.closed(0, 2), Range.closed(5, 29), Range.closed(35, 39)) mutable_rangeset = MutableRangeSet.create_mutable().add_all(ranges) immutable_rangeset = ImmutableRangeSet.builder().add_all( ranges).build() self.assertEqual(empty_mutable_rangeset, pickle.loads(pickle.dumps(empty_mutable_rangeset))) self.assertEqual(empty_immutable_rangeset, pickle.loads(pickle.dumps(empty_immutable_rangeset))) self.assertEqual(mutable_rangeset, pickle.loads(pickle.dumps(mutable_rangeset))) self.assertEqual(immutable_rangeset, pickle.loads(pickle.dumps(immutable_rangeset))) self.assertEqual(empty_mutable_rangeset.__getstate__(), ()) self.assertEqual(empty_immutable_rangeset.__getstate__(), ()) self.assertEqual(mutable_rangeset.__getstate__(), ranges) self.assertEqual(immutable_rangeset.__getstate__(), ranges)
_T = TypeVar("_T") # TODO: move this to vistautils def _in_range(_range: Range[_T]) -> Callable[[Any, Any, Any], None]: def validator(obj, attribute: Attribute, value) -> None: # type: ignore if value not in _range: raise ValueError( f"Attribute {attribute.name}'s value is not in required range {_range} for object" f" of type {type(obj)}") return validator _positive = _in_range(Range.greater_than(0.0)) # pylint:disable=invalid-name _non_negative = _in_range(Range.at_least(0.0)) # pylint:disable=invalid-name _degrees = _in_range(Range.closed_open(-360.0, 360.0)) # pylint:disable=invalid-name @attrs(frozen=True) class Cylinder: """ A cylinder, irrespective of orientation. Marr's representation builds objects up from generalized cylinders; right now we only represent cylinders with circular cross-sections. """ length_in_meters: float = attrib(validator=_positive, kw_only=True) diameter_in_meters: float = attrib(validator=_positive, kw_only=True)
class SlurmResourceRequest(ResourceRequest): """ A `ResourceRequest` for a job running on a SLURM cluster. """ partition: Optional[Partition] = attrib( converter=lambda x: Partition.from_str(x) if x else None, kw_only=True, default=None, ) memory: Optional[MemoryAmount] = attrib(validator=optional( instance_of(MemoryAmount)), kw_only=True, default=None) num_cpus: Optional[int] = attrib(validator=optional(in_( Range.at_least(1))), default=None, kw_only=True) num_gpus: Optional[int] = attrib(validator=optional(in_( Range.at_least(0))), default=None, kw_only=True) job_time_in_minutes: Optional[int] = attrib(validator=optional( instance_of(int)), default=None, kw_only=True) exclude_list: Optional[str] = attrib(validator=optional(instance_of(str)), kw_only=True, default=None) run_on_single_node: Optional[str] = attrib(validator=optional( instance_of(str)), kw_only=True, default=None) def __attrs_post_init__(self): if not self.job_time_in_minutes: partition_job_time = None if not self.partition: logging.warning( "Could not find selected partition. Setting job with no job time specified to max project partition walltime." ) partition_job_time = _PROJECT_PARTITION_JOB_TIME_IN_MINUTES else: logging.warning( "Defaulting job with no job time specified to max walltime of selected partition '%s'", self.partition.name, ) partition_job_time = self.partition.max_walltime # Workaround suggested by maintainers of attrs. # See https://www.attrs.org/en/stable/how-does-it-work.html#how-frozen object.__setattr__(self, "job_time_in_minutes", partition_job_time) @run_on_single_node.validator def check(self, _, value: str): if value and len(value.split(",")) != 1: raise ValueError( "run_on_single_node parameter must provide only node!") @staticmethod def from_parameters(params: Parameters) -> ResourceRequest: return SlurmResourceRequest( partition=params.string("partition"), num_cpus=params.optional_positive_integer("num_cpus"), num_gpus=params.optional_integer("num_gpus"), memory=MemoryAmount.parse(params.string("memory")) if "memory" in params else None, job_time_in_minutes=params.optional_integer("job_time_in_minutes"), exclude_list=params.optional_string("exclude_list"), run_on_single_node=params.optional_string("run_on_single_node"), ) def unify(self, other: "ResourceRequest") -> "SlurmResourceRequest": if not isinstance(other, SlurmResourceRequest): raise RuntimeError( f"Unable to unify a Non-SlurmResourceRequest with a Slurm Resource Request." f"Other: {other} & Self {self}") partition = other.partition or self.partition return SlurmResourceRequest( partition=partition.name, memory=other.memory or self.memory, num_cpus=other.num_cpus or self.num_cpus, num_gpus=other.num_gpus if other.num_gpus is not None else self.num_gpus, job_time_in_minutes=other.job_time_in_minutes or self.job_time_in_minutes, exclude_list=other.exclude_list or self.exclude_list, run_on_single_node=other.run_on_single_node or self.run_on_single_node, ) def apply_to_job(self, job: Job, *, job_name: str) -> None: if not self.partition: raise RuntimeError("A partition to run on must be specified.") if self.partition.max_walltime < self.job_time_in_minutes: raise ValueError( f"Partition '{self.partition.name}' has a max walltime of {self.partition.max_walltime} mins, which is less than the time given ({self.job_time_in_minutes} mins) for job: {job_name}." ) slurm_resource_content = SLURM_RESOURCE_STRING.format( num_cpus=self.num_cpus or 1, num_gpus=self.num_gpus if self.num_gpus is not None else 0, job_name=job_name, mem_str=to_slurm_memory_string(self.memory or _SLURM_DEFAULT_MEMORY), ) if (self.exclude_list and self.run_on_single_node and self.run_on_single_node in self.exclude_list): raise ValueError( "the 'exclude_list' and 'run_on_single_node' options are not consistent." ) if self.exclude_list: slurm_resource_content += f" --exclude={self.exclude_list}" if self.run_on_single_node: slurm_resource_content += f" --nodelist={self.run_on_single_node}" if self.partition.name in (SCAVENGE, EPHEMERAL): slurm_resource_content += f" --qos={self.partition.name}" job.add_pegasus_profile( runtime=str(self.job_time_in_minutes * 60), queue=str(self.partition.name), project=_BORROWED_KEY if self.partition.name in (EPHEMERAL, SCAVENGE) else self.partition.name, glite_arguments=slurm_resource_content, ) if ("dagman" not in job.profiles.keys() or "CATEGORY" not in job.profiles["dagman"].keys()): job.add_dagman_profile(category=str(self.partition))
def test_ranges_enclosed_by_out_of_bounds(self) -> None: self.assertEqual( ImmutableSet.empty(), RangeSet.create_mutable() # type: ignore .add(Range.closed(0, 10)).ranges_enclosed_by(Range.at_least(20)), )
class SlurmResourceRequest(ResourceRequest): """ A `ResourceRequest` for a job running on a SLURM cluster. """ memory: Optional[MemoryAmount] = attrib( validator=optional(instance_of(MemoryAmount)), kw_only=True, default=None ) partition: Optional[str] = attrib( validator=optional(instance_of(str)), kw_only=True, default=None ) num_cpus: Optional[int] = attrib( validator=optional(in_(Range.at_least(1))), default=None, kw_only=True ) num_gpus: Optional[int] = attrib( validator=optional(in_(Range.at_least(0))), default=None, kw_only=True ) job_time_in_minutes: Optional[int] = attrib( validator=optional(instance_of(int)), default=_DEFAULT_JOB_TIME_IN_MINUTES, kw_only=True, ) @staticmethod def from_parameters(params: Parameters) -> ResourceRequest: return SlurmResourceRequest( partition=params.string("partition"), num_cpus=params.optional_positive_integer("num_cpus"), num_gpus=params.optional_integer("num_gpus"), memory=MemoryAmount.parse(params.string("memory")) if "memory" in params else None, job_time_in_minutes=params.optional_integer("job_time_in_minutes"), ) def unify(self, other: ResourceRequest) -> ResourceRequest: if isinstance(other, SlurmResourceRequest): partition = other.partition if other.partition else self.partition else: partition = self.partition return SlurmResourceRequest( partition=partition, memory=other.memory if other.memory else self.memory, num_cpus=other.num_cpus if other.num_cpus else self.num_cpus, num_gpus=other.num_gpus if other.num_gpus is not None else self.num_gpus, job_time_in_minutes=other.job_time_in_minutes if other.job_time_in_minutes else self.job_time_in_minutes, ) def convert_time_to_slurm_format(self, job_time_in_minutes: int) -> str: hours, mins = divmod(job_time_in_minutes, 60) return f"{hours}:{str(mins)+'0' if mins < 10 else mins}:00" def apply_to_job(self, job: Job, *, job_name: str) -> None: if not self.partition: raise RuntimeError("A partition to run on must be specified.") qos_or_account = ( f"qos {self.partition}" if self.partition in (SCAVENGE, EPHEMERAL) else f"account {self.partition}" ) slurm_resource_content = SLURM_RESOURCE_STRING.format( qos_or_account=qos_or_account, partition=self.partition, num_cpus=self.num_cpus if self.num_cpus else 1, num_gpus=self.num_gpus if self.num_gpus is not None else 0, job_name=job_name, mem_str=to_slurm_memory_string( self.memory if self.memory else _SLURM_DEFAULT_MEMORY ), time=self.convert_time_to_slurm_format( self.job_time_in_minutes if self.job_time_in_minutes else _DEFAULT_JOB_TIME_IN_MINUTES ), ) logging.debug( "Slurm Resource Request for %s: %s", job_name, slurm_resource_content ) job.addProfile( Profile(Namespace.PEGASUS, "glite.arguments", slurm_resource_content) ) category_profile = Profile(Namespace.DAGMAN, "category", self.partition) if not job.hasProfile(category_profile): job.addProfile(category_profile)