def test_insert(): tree = IntervalTree() tree[0:1] = "data" assert len(tree) == 1 assert tree.items() == set([Interval(0, 1, "data")]) tree.add(Interval(10, 20)) assert len(tree) == 2 assert tree.items() == set([Interval(0, 1, "data"), Interval(10, 20)]) tree.addi(19.9, 20) assert len(tree) == 3 assert tree.items() == set([ Interval(0, 1, "data"), Interval(19.9, 20), Interval(10, 20), ]) tree.update([Interval(19.9, 20.1), Interval(20.1, 30)]) assert len(tree) == 5 assert tree.items() == set([ Interval(0, 1, "data"), Interval(19.9, 20), Interval(10, 20), Interval(19.9, 20.1), Interval(20.1, 30), ])
def test_insert(): tree = IntervalTree() tree[0:1] = "data" assert len(tree) == 1 assert tree.items() == set([Interval(0, 1, "data")]) tree.add(Interval(10, 20)) assert len(tree) == 2 assert tree.items() == set([Interval(0, 1, "data"), Interval(10, 20)]) tree.addi(19.9, 20) assert len(tree) == 3 assert tree.items() == set([ Interval(0, 1, "data"), Interval(19.9, 20), Interval(10, 20), ]) tree.update([Interval(19.9, 20.1), Interval(20.1, 30)]) assert len(tree) == 5 assert tree.items() == set([ Interval(0, 1, "data"), Interval(19.9, 20), Interval(10, 20), Interval(19.9, 20.1), Interval(20.1, 30), ])
def concatDifferences(diffs): if(len(diffs) > 1): points = list() tree = IntervalTree() for diff in diffs: if(diff[0] == diff[1]): points.append(diff) else: tree[diff[0]:diff[1]] = (diff[2], diff[3]) tree.merge_overlaps(tupleReducer) items = tree.items() for point in points: if(len(tree[point[0]]) == 0): items.add((point[0], point[1], (point[2], point[3]))) points = list() tree = IntervalTree() for item in items: if(item[2][0] == item[2][1]): points.append([item[2][0], item[2][1], item[0], item[1]]) else: tree[item[2][0]:item[2][1]] = (item[0], item[1]) tree.merge_overlaps(tupleReducer) items = tree.items() for point in points: if(len(tree[point[0]]) == 0): items.add((point[0], point[1], (point[2], point[3]))) diffs = list() for item in items: diffs.append([item[2][0], item[2][1], item[0], item[1]]) return diffs
def pprint(self, depth=0): tree = IntervalTree() tree.add(Interval(self.begin, self.end)) for child in self.children: tree.chop(child.begin, child.end) tree.add(Interval(child.begin, child.end, child)) intervals = sorted(tree.items()) # if a child exists right where we start, emit a comment for this # enveloping structure, otherwise the first gap gets our comment comment = ' ' * depth + self.comment if type(intervals[0].data) == OhaNode: oha_comment(self.begin, comment) else: intervals[0] = Interval(intervals[0].begin, intervals[0].end, comment) for interval in intervals: if type(interval.data) == OhaNode: node = interval.data node.pprint(depth + 1) else: self.fp.seek(self.begin) data = self.fp.read(interval.length()) oha(data, interval.begin, interval.data)
class DumpTarget(Target): def __init__(self, avatar, name='dumper', file_prefix=None): super(DumpTarget, self).__init__(avatar, name=name) self.file_prefix = file_prefix self.memory = IntervalTree() self.registers = {} def init(self): self.update_state(TargetStates.STOPPED) def write_memory(self, address, size, value, num_words=1, raw=False): if raw == True: self.memory[address:address+len(value)] = value def write_register(self, register, value): self.registers[register] = value def dump(self): for mem in self.memory.items(): file_name = '%s/0x%x-0x%x.bin' % (avatar.output_directory, mem.begin, mem.end) with open(file_name, 'wb') as f: f.write(mem.data) with open(avatar.output_directory+'/regs.json', 'wb') as f: f.write(json.dumps(self.registers))
def test_list_init(): tree = IntervalTree([Interval(-10, 10), Interval(-20.0, -10.0)]) tree.verify() assert tree assert len(tree) == 2 assert tree.items() == set([Interval(-10, 10), Interval(-20.0, -10.0)]) assert tree.begin() == -20 assert tree.end() == 10
def test_list_init(): tree = IntervalTree([Interval(-10, 10), Interval(-20.0, -10.0)]) tree.verify() assert tree assert len(tree) == 2 assert tree.items() == set([Interval(-10, 10), Interval(-20.0, -10.0)]) assert tree.begin() == -20 assert tree.end() == 10
def test_chop_method_w_limit(): ### scenario 3 ### we want to only chop one period and not both v1 = [Interval(0, 10, data=0)] v2 = [Interval(0, 10, data=1)] v3 = [Interval(0, 4, data=2), Interval(6, 10, data=2)] intervals = sum((v1, v2, v3), []) avail_tree = IntervalTree(intervals) ### example booking for v1 avail_tree.chop_intervals_that_envelope_range(3, 7, limit=1) assert len(avail_tree.items()) == 5
def test_generator_init(): tree = IntervalTree( Interval(begin, end) for begin, end in [(-10, 10), (-20, -10), (10, 20)]) tree.verify() assert tree assert len(tree) == 3 assert tree.items() == set([ Interval(-20, -10), Interval(-10, 10), Interval(10, 20), ]) assert tree.begin() == -20 assert tree.end() == 20
def find_diff(list_a, list_b): interval_tree = IntervalTree() for interval in list_a: interval_tree.add(Interval(interval[0], interval[1])) for interval in list_b: interval_tree.chop(interval[0], interval[1]) result = [] for item in interval_tree.items(): result.append((item.begin, item.end)) return result
def test_generator_init(): tree = IntervalTree( Interval(begin, end) for begin, end in [(-10, 10), (-20, -10), (10, 20)] ) tree.verify() assert tree assert len(tree) == 3 assert tree.items() == set([ Interval(-20, -10), Interval(-10, 10), Interval(10, 20), ]) assert tree.begin() == -20 assert tree.end() == 20
def test_empty_queries(): t = IntervalTree() e = set() assert len(t) == 0 assert t.is_empty() assert t[3] == e assert t[4:6] == e assert t.begin() == 0 assert t.end() == 0 assert t[t.begin():t.end()] == e assert t.items() == e assert set(t) == e assert set(t.copy()) == e assert t.find_nested() == {} t.verify()
def find_all_common_intervals(interval_list): overlap_dict = dict() interval_tree = IntervalTree(interval_list) for interval in interval_tree.items( ): # create a copy of the tree to iterate over # check interval against overlap_dict for overlap in overlap_dict.keys(): add_overlap_to_dict(overlap, interval, overlap_dict) # check interval against other intervals in the tree to find overlapping regions other_intervals = find_other_intervals_which_overlap( interval_tree, interval) for other_interval in other_intervals: add_new_overlap_to_dict(interval, other_interval, overlap_dict) return overlap_dict
def test_duplicate_insert(): tree = IntervalTree() # string data tree[-10:20] = "arbitrary data" contents = frozenset([Interval(-10, 20, "arbitrary data")]) assert len(tree) == 1 assert tree.items() == contents tree.addi(-10, 20, "arbitrary data") assert len(tree) == 1 assert tree.items() == contents tree.add(Interval(-10, 20, "arbitrary data")) assert len(tree) == 1 assert tree.items() == contents tree.update([Interval(-10, 20, "arbitrary data")]) assert len(tree) == 1 assert tree.items() == contents # None data tree[-10:20] = None contents = frozenset([ Interval(-10, 20), Interval(-10, 20, "arbitrary data"), ]) assert len(tree) == 2 assert tree.items() == contents tree.addi(-10, 20) assert len(tree) == 2 assert tree.items() == contents tree.add(Interval(-10, 20)) assert len(tree) == 2 assert tree.items() == contents tree.update([Interval(-10, 20), Interval(-10, 20, "arbitrary data")]) assert len(tree) == 2 assert tree.items() == contents
def test_duplicate_insert(): tree = IntervalTree() # string data tree[-10:20] = "arbitrary data" contents = frozenset([Interval(-10, 20, "arbitrary data")]) assert len(tree) == 1 assert tree.items() == contents tree.addi(-10, 20, "arbitrary data") assert len(tree) == 1 assert tree.items() == contents tree.add(Interval(-10, 20, "arbitrary data")) assert len(tree) == 1 assert tree.items() == contents tree.update([Interval(-10, 20, "arbitrary data")]) assert len(tree) == 1 assert tree.items() == contents # None data tree[-10:20] = None contents = frozenset([ Interval(-10, 20), Interval(-10, 20, "arbitrary data"), ]) assert len(tree) == 2 assert tree.items() == contents tree.addi(-10, 20) assert len(tree) == 2 assert tree.items() == contents tree.add(Interval(-10, 20)) assert len(tree) == 2 assert tree.items() == contents tree.update([Interval(-10, 20), Interval(-10, 20, "arbitrary data")]) assert len(tree) == 2 assert tree.items() == contents
def test_empty_queries(): t = IntervalTree() e = set() assert len(t) == 0 assert t.is_empty() assert t[3] == e assert t[4:6] == e assert t.begin() == 0 assert t.end() == 0 assert t[t.begin():t.end()] == e assert t.items() == e assert set(t) == e assert set(t.copy()) == e assert t.find_nested() == {} assert t.range().is_null() assert t.range().length() == 0 t.verify()
def test_empty_queries(): t = IntervalTree() e = set() assert len(t) == 0 assert t.is_empty() assert t[3] == e assert t[4:6] == e assert t.begin() == 0 assert t.end() == 0 assert t[t.begin():t.end()] == e assert t.overlap(t.begin(), t.end()) == e assert t.envelop(t.begin(), t.end()) == e assert t.items() == e assert set(t) == e assert set(t.copy()) == e assert t.find_nested() == {} assert t.range().is_null() assert t.range().length() == 0 t.verify()
def find_all_common_intervals(interval_list): """ Finds common intervals (overlapping regions) and labels common intervals with intersecting intervals' data attribute (e.g. the user id) :param interval_list: list of Interval objects :return: a dict (key: Interval, value: set of user_ids which share that interval) """ overlap_dict = dict() interval_tree = IntervalTree(interval_list) for interval in interval_tree.items(): for overlap in list(overlap_dict): # compare interval against existing overlaps. add_overlap_to_dict(interval, overlap, overlap_dict) other_intervals = find_other_intervals_which_overlap( interval_tree, interval) for other_interval in other_intervals: # compare interval against other original intervals in the tree. add_new_overlap_to_dict(interval, other_interval, overlap_dict) return overlap_dict
class Node(object): def __init__(self): self.accepting = False self._outgoing_edges = IntervalTree() self._outgoing_epsilon = set() def eclose(self): seen = set() search = list([self]) while search: cur = search.pop() if cur in seen: continue seen.add(cur) yield cur search.extend(cur._outgoing_epsilon) def add_edge(self, k, v): if k is EPSILON: self._outgoing_epsilon.add(v) return if isinstance(k, tuple): if len(k) != 2: raise TypeError("add_edge only accepts 2-tuples") start, end = k elif isinstance(k, slice): if k.start > k.end: raise ValueError("add_edge only accepts ordered slices") if k.step: raise ValueError("add_edge only accepts slices without steps") start, end = k.start, k.end else: start = end = k self._outgoing_edges[start:end + 1] = v def has_epsilon(self): for node in self.iter_nodes(): if node._outgoing_epsilon: return True return False def iter_nodes(self): seen = set() search = list([self]) while search: cur = search.pop() if cur in seen: continue seen.add(cur) yield cur search.extend(cur._outgoing_epsilon) for v in self._outgoing_edges.items(): search.append(v.data) def remove_epsilon(self): if any(node.accepting for node in self.eclose()): self.accepting = True for node in self.iter_nodes(): for iv in node._outgoing_edges.items(): for edge_to in iv.data.eclose(): if edge_to == iv.data: continue node._outgoing_edges[iv.begin:iv.end] = edge_to for node in self.iter_nodes(): node._outgoing_epsilon = set() def to_dfa(self): pass def to_utf16_code_units(self): for node in self.iter_nodes(): # Get the edges edges = node._outgoing_edges # Split the intervals at the end of the BMP edges.slice(0xFFFF) edges.chop(0xFFFF, 0x10000) # Rewrite all these intervals using UTF-16 code-units for iv in edges[0x10000:0x110000]: begin, end, out_node = iv high_beg, low_beg = _to_utf16_code_units(begin) high_end, low_end = _to_utf16_code_units(end) if high_beg == high_end: high_node = Node() node.add_edge(high_beg, high_node) node.add_edge((low_beg, low_end), out_node) else: assert high_end > high_beg high_start_node = Node() node.add_edge(high_beg, high_start_node) high_start_node.add_edge((low_beg, 0xDFFF), out_node) if high_beg + 1 != high_end: high_mid_node = Node() node.add_edge((high_beg + 1, high_end - 1), high_mid_node) high_mid_node.add_edge((0xDC00, 0xDFFF), out_node) high_end_node = Node() node.add_edge(high_end, high_end_node) high_end_node.add_edge((0xDC00, low_end), out_node) edges.remove_overlap(0x10000, 0x110000) def match(self, s): states = set(self.eclose()) for c in s: c = ord(c) old_states = states states = set() for state in old_states: for iv in state._outgoing_edges[c]: states |= set(iv.data.eclose()) for state in states: if state.accepting: return True return False
class DateIntervalTree: """A slight adaption of the intervaltree library to support python dates The intervaltree data structure stores integer ranges, fundamentally. Therefore, if we want to store dates, we must fist convert them to integers, in a way that preserves inequalities. Luckily, the toordinal() function on datetime.date satisfies this requirement. It's important to note that this interval tree structure is, unless otherwise noted inclusive of lower bounds and exclusive of upper bounds. That is to say, an interval from A to B includes the value A and excludes the value B. """ def __init__(self): self.tree = IntervalTree() @staticmethod def to_date_interval(begin: date, end: date, data: Any) -> Interval: """Convert a date interval (and associated date, if any) into an ordinal interval""" return Interval(begin.toordinal(), end.toordinal(), data) @staticmethod def from_date_interval(ival: Interval) -> Interval: """Convert an ordinal interval to a date interval""" return Interval(date.fromordinal(ival.begin), date.fromordinal(ival.end), ival.data) def add(self, begin: date, end: date, data: Any = None): """Add a date interval to the interval tree, along with any associated date""" self.tree.add(DateIntervalTree.to_date_interval(begin, end, data)) def merge_overlaps(self, reducer: Callable = None, strict: bool = True): """Merge overlapping date intervals in the tree. A reduce function can be specified to determine how data elements are combined for overlapping intervals. The strict argument determines whether "kissing" intervals are merged. If true (the default), only "strictly" overlapping intervals are merged, otherwise adjacent intervals will also be merged. See the intervaltree library documentation for the merge_overlaps function for a more complete description. """ self.tree.merge_overlaps(data_reducer=reducer, strict=strict) def intervals(self) -> List[Interval]: """Return all date intervals in this tree""" # Note we convert from ordinal values to actual date objects return [ DateIntervalTree.from_date_interval(ival) for ival in self.tree.items() ] def overlaps(self, begin: date, end: date, strict: bool = True) -> bool: """Determine whether the given date interval overlaps with any interval in the tree. According to intervaltree, intervals include the lower bound but not the upper bound: 2015-07-23 -2015-08-21 does not overlap 2015-08-21-2015-09-21 If strict is false, add a day to the end date to return True for single day overlaps. """ if strict: ival = DateIntervalTree.to_date_interval(begin, end, None) else: ival = DateIntervalTree.to_date_interval(begin, end + timedelta(days=1), None) return self.tree.overlaps(ival.begin, ival.end) def range_query(self, begin: date, end: date) -> List[Interval]: """Return all intervals in the tree that strictly overlap with the given interval""" ival = DateIntervalTree.to_date_interval(begin, end, None) return [ DateIntervalTree.from_date_interval(ival) for ival in self.tree.overlap(ival.begin, ival.end) ] def point_query(self, point: date) -> List[Interval]: return [ DateIntervalTree.from_date_interval(ival) for ival in self.tree.at(point.toordinal()) ] @staticmethod def shift_endpoints(date_tree: "DateIntervalTree") -> "DateIntervalTree": """Produce a new tree where adjacent intervals are guaranteed to not match at a boundary by shifting the end dates of touching intervals E.g., the intervals (1/1/2000, 1/10/2000), (1/10/2000, 1/20/2000) become (1/1/2000, 1/9/2000), (1/10/2000, 1/20/2000) ^--A day was subtracted here to avoid matching exactly with the next interval Loop earliest -> latest, adjusting end date. """ adjusted = DateIntervalTree() work_list = deque(sorted(date_tree.intervals())) while work_list: cur_ival = work_list.popleft() if work_list: next_ival = work_list[0] if cur_ival.end == next_ival.begin: cur_ival = Interval(cur_ival.begin, cur_ival.end - timedelta(days=1), cur_ival.data) adjusted.add(cur_ival.begin, cur_ival.end, cur_ival.data) return adjusted @staticmethod def shift_endpoints_start( date_tree: "DateIntervalTree") -> "DateIntervalTree": """Produce a new tree where adjacent intervals are guaranteed to not match at a boundary by shifting the start dates of touching intervals E.g., the intervals (1/1/2000, 1/10/2000), (1/10/2000, 1/20/2000) become (1/1/2000, 1/10/2000), (1/11/2000, 1/20/2000) ^--A day was added here to avoid matching exactly with the next interval Loop latest -> earliest, adjusting start date. """ adjusted = DateIntervalTree() work_list = deque(sorted(date_tree.intervals(), reverse=True)) while work_list: cur_ival = work_list.popleft() if work_list: next_ival = work_list[0] if cur_ival.begin == next_ival.end: log.debug( "adjusting start of billing period: %s-%s", cur_ival.begin, cur_ival.end, ) cur_ival = Interval(cur_ival.begin + timedelta(days=1), cur_ival.end, cur_ival.data) adjusted.add(cur_ival.begin, cur_ival.end, cur_ival.data) return adjusted @staticmethod def shift_endpoints_end( date_tree: "DateIntervalTree") -> "DateIntervalTree": """Produce a new tree where adjacent intervals are guaranteed to not match at a boundary by shifting the end dates of touching intervals E.g., the intervals (1/1/2000, 1/10/2000), (1/10/2000, 1/20/2000) become (1/1/2000, 1/9/2000), (1/10/2000, 1/20/2000) ^--A day was subtracted here to avoid matching exactly with the next interval Loop latest -> earliest, adjusting end date. """ adjusted = DateIntervalTree() work_list = deque(sorted(date_tree.intervals(), reverse=True)) prev_ival = None while work_list: cur_ival = work_list.popleft() if prev_ival: while cur_ival.end >= prev_ival.begin: new_start, new_end = ( cur_ival.begin, cur_ival.end - timedelta(days=1), ) if new_start == new_end: # If new interval is one day long, shift start date back one day too. new_start = new_start - timedelta(days=1) cur_ival = Interval(new_start, new_end, cur_ival.data) prev_ival = cur_ival adjusted.add(cur_ival.begin, cur_ival.end, cur_ival.data) return adjusted
def test_all(): from intervaltree import Interval, IntervalTree from pprint import pprint from operator import attrgetter def makeinterval(lst): return Interval( lst[0], lst[1], "{}-{}".format(*lst) ) ivs = list(map(makeinterval, [ [1,2], [4,7], [5,9], [6,10], [8,10], [8,15], [10,12], [12,14], [14,15], ])) t = IntervalTree(ivs) t.verify() def data(s): return set(map(attrgetter('data'), s)) # Query tests print('Query tests...') assert data(t[4]) == set(['4-7']) assert data(t[4:5]) == set(['4-7']) assert data(t[4:6]) == set(['4-7', '5-9']) assert data(t[9]) == set(['6-10', '8-10', '8-15']) assert data(t[15]) == set() assert data(t.search(5)) == set(['4-7', '5-9']) assert data(t.search(6, 11, strict = True)) == set(['6-10', '8-10']) print(' passed') # Membership tests print('Membership tests...') assert ivs[1] in t assert Interval(1,3, '1-3') not in t assert t.overlaps(4) assert t.overlaps(9) assert not t.overlaps(15) assert t.overlaps(0,4) assert t.overlaps(1,2) assert t.overlaps(1,3) assert t.overlaps(8,15) assert not t.overlaps(15, 16) assert not t.overlaps(-1, 0) assert not t.overlaps(2,4) print(' passed') # Insertion tests print('Insertion tests...') t.add( makeinterval([1,2]) ) # adding duplicate should do nothing assert data(t[1]) == set(['1-2']) t[1:2] = '1-2' # adding duplicate should do nothing assert data(t[1]) == set(['1-2']) t.add(makeinterval([2,4])) assert data(t[2]) == set(['2-4']) t.verify() t[13:15] = '13-15' assert data(t[14]) == set(['8-15', '13-15', '14-15']) t.verify() print(' passed') # Duplication tests print('Interval duplication tests...') t.add(Interval(14,15,'14-15####')) assert data(t[14]) == set(['8-15', '13-15', '14-15', '14-15####']) t.verify() print(' passed') # Copying and casting print('Tree copying and casting...') tcopy = IntervalTree(t) tcopy.verify() assert t == tcopy tlist = list(t) for iv in tlist: assert iv in t for iv in t: assert iv in tlist tset = set(t) assert tset == t.items() print(' passed') # Deletion tests print('Deletion tests...') try: t.remove( Interval(1,3, "Doesn't exist") ) except ValueError: pass else: raise AssertionError("Expected ValueError") try: t.remove( Interval(500, 1000, "Doesn't exist") ) except ValueError: pass else: raise AssertionError("Expected ValueError") orig = t.print_structure(True) t.discard( Interval(1,3, "Doesn't exist") ) t.discard( Interval(500, 1000, "Doesn't exist") ) assert data(t[14]) == set(['8-15', '13-15', '14-15', '14-15####']) t.remove( Interval(14,15,'14-15####') ) assert data(t[14]) == set(['8-15', '13-15', '14-15']) t.verify() assert data(t[2]) == set(['2-4']) t.discard( makeinterval([2,4]) ) assert data(t[2]) == set() t.verify() assert t[14] t.remove_overlap(14) t.verify() assert not t[14] # Emptying the tree #t.print_structure() for iv in sorted(iter(t)): #print('### Removing '+str(iv)+'... ###') t.remove(iv) #t.print_structure() t.verify() #print('') assert len(t) == 0 assert t.is_empty() assert not t t = IntervalTree(ivs) #t.print_structure() t.remove_overlap(1) #t.print_structure() t.verify() t.remove_overlap(8) #t.print_structure() print(' passed') t = IntervalTree(ivs) pprint(t) t.split_overlaps() pprint(t) #import cPickle as pickle #p = pickle.dumps(t) #print(p)
class SegmentProducer(object): save_interval = SAVE_INTERVAL def __init__(self, download, n_procs): assert download.size is not None,\ 'Segment producer passed uninitizalied Download!' self.download = download self.n_procs = n_procs # Initialize producer self.load_state() self._setup_pbar() self._setup_queues() self._setup_work() self.schedule() def _setup_pbar(self): self.pbar = None self.pbar = get_pbar(self.download.ID, self.download.size) def _setup_work(self): if self.is_complete(): log.info('File already complete.') return work_size = self.integrate(self.work_pool) self.block_size = work_size / self.n_procs def _setup_queues(self): if WINDOWS: self.q_work = Queue() self.q_complete = Queue() else: manager = Manager() self.q_work = manager.Queue() self.q_complete = manager.Queue() def integrate(self, itree): return sum([i.end-i.begin for i in itree.items()]) def validate_segment_md5sums(self): if not self.download.check_segment_md5sums: return True corrupt_segments = 0 intervals = sorted(self.completed.items()) pbar = ProgressBar(widgets=[ 'Checksumming {}: '.format(self.download.ID), Percentage(), ' ', Bar(marker='#', left='[', right=']'), ' ', ETA()]) with mmap_open(self.download.path) as data: for interval in pbar(intervals): log.debug('Checking segment md5: {}'.format(interval)) if not interval.data or 'md5sum' not in interval.data: log.error(STRIP( """User opted to check segment md5sums on restart. Previous download did not record segment md5sums (--no-segment-md5sums).""")) return chunk = data[interval.begin:interval.end] checksum = md5sum(chunk) if checksum != interval.data.get('md5sum'): log.debug('Redownloading corrupt segment {}, {}.'.format( interval, checksum)) corrupt_segments += 1 self.completed.remove(interval) if corrupt_segments: log.warn('Redownloading {} currupt segments.'.format( corrupt_segments)) def load_state(self): # Establish default intervals self.work_pool = IntervalTree([Interval(0, self.download.size)]) self.completed = IntervalTree() self.size_complete = 0 if not os.path.isfile(self.download.state_path)\ and os.path.isfile(self.download.path): log.warn(STRIP( """A file named '{} was found but no state file was found at at '{}'. Either this file was downloaded to a different location, the state file was moved, or the state file was deleted. Parcel refuses to claim the file has been successfully downloaded and will restart the download.\n""").format( self.download.path, self.download.state_path)) return if not os.path.isfile(self.download.state_path): self.download.setup_file() return # If there is a file at load_path, attempt to remove # downloaded sections from work_pool log.info('Found state file {}, attempting to resume download'.format( self.download.state_path)) if not os.path.isfile(self.download.path): log.warn(STRIP( """State file found at '{}' but no file for {}. Restarting entire download.""".format( self.download.state_path, self.download.ID))) return try: with open(self.download.state_path, "rb") as f: self.completed = pickle.load(f) assert isinstance(self.completed, IntervalTree), \ "Bad save state: {}".format(self.download.state_path) except Exception as e: self.completed = IntervalTree() log.error('Unable to resume file state: {}'.format(str(e))) else: self.validate_segment_md5sums() self.size_complete = self.integrate(self.completed) for interval in self.completed: self.work_pool.chop(interval.begin, interval.end) def save_state(self): try: # Grab a temp file in the same directory (hopefully avoud # cross device links) in order to atomically write our save file temp = tempfile.NamedTemporaryFile( prefix='.parcel_', dir=os.path.abspath(self.download.state_directory), delete=False) # Write completed state pickle.dump(self.completed, temp) # Make sure all data is written to disk temp.flush() os.fsync(temp.fileno()) temp.close() # Rename temp file as our save file, this could fail if # the state file and the temp directory are on different devices if OS_WINDOWS and os.path.exists(self.download.state_path): # If we're on windows, there's not much we can do here # except stash the old state file, rename the new one, # and back up if there is a problem. old_path = os.path.join(tempfile.gettempdir(), ''.join( random.choice(string.ascii_lowercase + string.digits) for _ in range(10))) try: # stash the old state file os.rename(self.download.state_path, old_path) # move the new state file into place os.rename(temp.name, self.download.state_path) # if no exception, then delete the old stash os.remove(old_path) except Exception as msg: log.error('Unable to write state file: {}'.format(msg)) try: os.rename(old_path, self.download.state_path) except: pass raise else: # If we're not on windows, then we'll just try to # atomically rename the file os.rename(temp.name, self.download.state_path) except KeyboardInterrupt: log.warn('Keyboard interrupt. removing temp save file'.format( temp.name)) temp.close() os.remove(temp.name) except Exception as e: log.error('Unable to save state: {}'.format(str(e))) raise def schedule(self): while True: interval = self._get_next_interval() log.debug('Returning interval: {}'.format(interval)) if not interval: return self.q_work.put(interval) def _get_next_interval(self): intervals = sorted(self.work_pool.items()) if not intervals: return None interval = intervals[0] start = interval.begin end = min(interval.end, start + self.block_size) self.work_pool.chop(start, end) return Interval(start, end) def print_progress(self): if not self.pbar: return try: self.pbar.update(self.size_complete) except Exception as e: log.debug('Unable to update pbar: {}'.format(str(e))) def check_file_exists_and_size(self): if self.download.is_regular_file: return (os.path.isfile(self.download.path) and os.path.getsize( self.download.path) == self.download.size) else: log.debug('File is not a regular file, refusing to check size.') return (os.path.exists(self.download.path)) def is_complete(self): return (self.integrate(self.completed) == self.download.size and self.check_file_exists_and_size()) def finish_download(self): # Tell the children there is no more work, each child should # pull one NoneType from the queue and exit for i in range(self.n_procs): self.q_work.put(None) # Wait for all the children to exit by checking to make sure # that everyone has taken their NoneType from the queue. # Otherwise, the segment producer will exit before the # children return, causing them to read from a closed queue log.debug('Waiting for children to report') while not self.q_work.empty(): time.sleep(0.1) # Finish the progressbar if self.pbar: self.pbar.finish() def wait_for_completion(self): try: since_save = 0 while not self.is_complete(): while since_save < self.save_interval: interval = self.q_complete.get() self.completed.add(interval) if self.is_complete(): break this_size = interval.end - interval.begin self.size_complete += this_size since_save += this_size self.print_progress() since_save = 0 self.save_state() finally: self.finish_download()