def new_trial(self, trial: Trial, auto_increment=False): if trial.uid in self.storage.objects: if not auto_increment: return None trials = self.get_trial(trial) max_rev = 0 for t in trials: max_rev = max(max_rev, t.revision) warning( f'Trial was already completed. Increasing revision number (rev={max_rev + 1})' ) trial.revision = max_rev + 1 trial._hash = None self.storage.objects[trial.uid] = trial self.storage.trials.add(trial.uid) if trial.project_id is not None: project = self.storage.objects.get(trial.project_id) if project is not None or self.strict: project.trials.add(trial) else: warning('Orphan trial') if trial.group_id is not None: group = self.storage.objects.get(trial.group_id) if group is not None or self.strict: group.trials.add(trial.uid) trial.metadata['_update_count'] = 0 return trial
def test_execute_query_custom_status(): data = [ Trial(name=str(i), status=CustomStatus('new', 1)) for i in range(0, 20) ] + [ Trial(name=str(i), status=CustomStatus('interrupted', 1)) for i in range(0, 20) ] query = dict(status={'$in': ['new', 'interrupted']}) arr = check_query(data, query) assert len(arr) == 40
def increment(): backend = make_local('file://test_parallel.json') trial = backend.get_trial(Trial(_hash=trial_hash, revision=trial_rev))[0] # there is no lock here so the number could have changed already count = trial.metadata.get('count', 0) backend.log_trial_metadata(trial, count=count + 1)
def test_execute_query_status(): data = [Trial(name=str(i)) for i in range(0, 20)] query = dict(status={'$in': [Status.CreatedGroup]}) arr = check_query(data, query) assert len(arr) == 20
def set_trial(self, trial: Optional[Trial] = None, force: bool = False, **kwargs): """Set a new trial Parameters ---------- trial: Optional[Trial] project definition you can use to create or set the project force: bool by default once the trial is set it cannot be changed. use force to override this behaviour. kwargs: {uid, hash, revision} arguments used to create a `Trial` object if no Trial object were provided. You should specify `uid` or the pair `(hash, revision)`. See :func:`~track.structure.Trial` for possible arguments Returns ------- returns a trial logger """ if self.trial is not None and not force: info('Trial is already set, to override use force=True') return self.logger if trial is None: uhash = kwargs.pop('hash', None) uid = kwargs.pop('uid', None) version = kwargs.pop('version', self.version()) trial = Trial(version=version, **kwargs) if uhash is not None: trial.hash = uhash if uid is not None: trial.uid = uid try: if trial.version is None: trial.version = self.version() trials = self.protocol.get_trial(trial) if trials is None: raise TrialDoesNotExist( f'Trial (hash: {trial.hash}, v:{trial.version} rev: {trial.revision}) does not exist!' ) self.trial = trials[0] self.logger = TrialLogger(self.trial, self.protocol) return self.logger except IndexError: raise TrialDoesNotExist( f'cannot set trial (id: {trial.uid}, hash:{hash}) it does not exist' )
def from_json(obj: Dict[str, any]) -> any: if not isinstance(obj, dict): return obj dtype = obj.get('dtype') if dtype == 'project': # from track.utils.debug import print_stack # if obj.get('metadata') is None: # print_stack() return Project( _uid=obj['uid'], name=obj['name'], description=obj['description'], metadata=obj['metadata'], groups=set([from_json(g) for g in obj['groups']]), trials=set([from_json(t) for t in obj['trials']]), ) elif dtype == 'trial_group': return TrialGroup( _uid=obj['uid'], name=obj['name'], description=obj['description'], metadata=obj['metadata'], trials=set(obj['trials']), project_id=obj['project_id'] ) elif dtype == 'trial': return Trial( _hash=obj['hash'], revision=obj['revision'], name=obj['name'], description=obj['description'], tags=obj['tags'], version=obj['version'], group_id=obj['group_id'], project_id=obj['project_id'], parameters=obj['parameters'], metadata=to_json(obj['metadata']), metrics=obj['metrics'], chronos=obj['chronos'], errors=obj['errors'], status=status( name=obj['status']['name'], value=obj['status']['value']) ) return obj
def test_parallel_fetch_update(backend='file://test.json', workers=12): from multiprocessing import Queue proto, reservable = get_reservable_trials(backend) queue = Queue() reserve_trials(workers, backend, queue) reserved_trial = retrieve_trials(queue) assert len(reserved_trial) == len( set(reserved_trial)), 'All reserved trials should be different' reserved_trial = set(reserved_trial) intersect = reservable.intersection(reserved_trial) assert len(reserved_trial) == len( intersect), 'All reserved trials should be reservable' for uid in reserved_trial: t = Trial() t.uid = uid trial = proto.get_trial(t)[0] assert trial.status.name == 'reserved'
def _make_trial(self, arguments, name=None, **kwargs): project_id = None group_id = None if self.project is not None: project_id = self.project.uid if self.group is not None: group_id = self.group.uid trial = Trial(name=name, version=self.version(), project_id=project_id, group_id=group_id, parameters=arguments, **kwargs) trial = self.protocol.new_trial(trial) return trial
def log_trial_metrics(self, trial: Trial, step: any = None, aggregator: Callable[[], Aggregator] = None, **kwargs): ntrial = self.storage.objects.get(trial.uid) for k, v in kwargs.items(): container = trial.metrics.get(k) if container is None: container = _make_container(step, aggregator) trial.metrics[k] = container if step is not None and isinstance(container, dict): container[step] = v elif step: container.append((step, v)) else: container.append(v) ntrial.metrics.update(trial.metrics) self._inc_trial(ntrial)
def from_json(self, obj): return Trial( _hash=obj['hash'], revision=obj['revision'], name=obj['name'], description=obj['description'], tags=obj['tags'], version=obj['version'], group_id=obj['group_id'], project_id=obj['project_id'], parameters=obj['parameters'], metadata=to_json(obj['metadata']), metrics=obj['metrics'], chronos={k: from_json(v) for k, v in obj['chronos'].items()}, errors=obj['errors'], status=status( name=obj['status']['name'], value=obj['status']['value']) )
def test_trial(): p = Trial(name='name', description='2', tags=['0', '1'], version='version', group_id='0', project_id='1', parameters=dict(a=1, b=2), metadata=dict(a=2, b=3), metrics=dict(a=3, b=4), chronos=dict(a=4, b=5), status=Status.FinishedGroup, errors=[]) ps = from_json(to_json(p)) # status is not serializable perfectly because we allow for custom status ps.status = p.status assert p == ps
def _make_trial(self, parameters, name=None, **kwargs): project_id = None group_id = None if self.project is not None: project_id = self.project.uid if self.group is not None: group_id = self.group.uid trial = Trial(name=name, version=self.version(), project_id=project_id, group_id=group_id, parameters=parameters, **kwargs) trial = self.protocol.new_trial(trial) assert trial is not None, 'Trial already exist!' return trial
def test_local_parallel(woker_count=20): """Here we check that _update_count is atomic and cannot run out of sync. `count` and the other can because it does not happen inside the lock (first fetch then increment """ global trial_hash, trial_rev # -- Create the object that are going to be accessed in parallel remove('test_parallel.json') backend = make_local('file://test_parallel.json') project_def = Project(name='test') project = backend.new_project(project_def) group_def = TrialGroup(name='test_group', project_id=project.uid) group = backend.new_trial_group(group_def) trial = backend.new_trial( Trial( parameters={'batch': 256}, project_id=project.uid, group_id=group.uid) ) count = trial.metadata.get('count', 0) backend.log_trial_metadata(trial, count=count) trial_hash, trial_rev = trial.uid.split('_') # -- Setup done processes = [Process(target=increment) for _ in range(0, woker_count)] print('-- Start') [p.start() for p in processes] [p.join() for p in processes] trial = backend.get_trial(trial)[0] # remove('test_parallel.json') print(trial.metadata) assert trial.metadata.get('_update_count', 0) == woker_count + 1, 'Parallel write should wait for each other'
def process_args(self, args, cache=None): """ replace ids by their object reference so the backend modifies the objects and not a copy""" new_args = dict() for k, v in args.items(): if k == 'trial': if isinstance(v, str): hashid, rev = v.split('_') rev = int(rev) v = self.backend.get_trial( Trial(_hash=hashid, revision=rev)) for i in v: if i.revision == rev: v = i break else: warning( 'Was not able to find the correct trial revision') v = from_json(v) elif k == 'project': if isinstance(v, str): v = self.backend.get_project(Project(name=v)) v = from_json(v) elif k == 'group': if isinstance(v, str): v = self.backend.get_trial_group(TrialGroup(_uid=v)) v = from_json(v) new_args[k] = v return new_args
def get_trial(self, trial: Trial): self.cursor.execute( """ SELECT hash, revision, name, description, tags, metadata, metrics, version, group_id, project_id, parameters, status, errors FROM track.trials WHERE hash = %s AND revision = %s """, (self.encode_uid(trial.hash), trial.revision)) results = self.cursor.fetchall() if results is None: return [] trials = [] for r in results: t = Trial(_hash=self.decode_uid(r[0]), revision=r[1], name=r[2], description=r[3], tags=self.deserialize(r[4]), metadata=self.deserialize(r[5]), metrics=self.decode_metrics(self.deserialize(r[6])), version=r[7], group_id=self.decode_uid(r[8]), project_id=self.decode_uid(r[9]), parameters=r[10], status=make_status(r[11]), errors=r[12]) trials.append(t) return trials
def test_cockroach_inserts(): from track.distributed.cockroachdb import CockRoachDB db = CockRoachDB(location='/tmp/cockroach', addrs='localhost:8125') db.start(wait=True) for k, v in db.properties.items(): print(k, v) try: proto = Cockroach('cockroach://localhost:8125') p1 = Project(name='test') proto.new_project(p1) p2 = proto.get_project(p1) print(p1) assert p1.name == p2.name g1 = TrialGroup(name='test', project_id=p1.uid) proto.new_trial_group(g1) g2 = proto.get_trial_group(TrialGroup(_uid=g1.uid)) print(g1) assert g1.name == g2.name trial1 = Trial(parameters={'a': 1}, project_id=p1.uid, group_id=g1.uid) proto.new_trial(trial1) trial2 = proto.get_trial(Trial(_hash=trial1.hash)) print(trial1) assert len(trial2) == 1 assert trial1.parameters == trial2[0].parameters # fetch by project_id # trials = proto.fetch_trials({'status': Status.CreatedGroup, 'group_id': g1.uid}) # print(trials) # fetch by group_id trials = proto.fetch_trials({'group_id': g1.uid}) assert len(trials) == 1 assert trials[0].uid == trial1.uid # fetch by status trials = proto.fetch_trials({'status': Status.CreatedGroup, 'group_id': g1.uid}) assert len(trials) == 1 assert trials[0].uid == trial1.uid # fetch by status trials = proto.fetch_trials({'status': {'$in': ['CreatedGroup']}, 'group_id': g1.uid}) assert len(trials) == 1 assert trials[0].uid == trial1.uid proto.log_trial_metrics(trial1, step=2, epoch_loss=1) proto.log_trial_metrics(trial1, step=3, epoch_loss=2) proto.log_trial_metrics(trial1, step=4, epoch_loss=3) proto.log_trial_metrics(trial1, loss=3) proto.log_trial_metrics(trial1, loss=2) proto.log_trial_metrics(trial1, loss=1) trials = proto.fetch_trials({'group_id': g1.uid}) assert len(trials) == 1 print(trials[0].metrics) assert trials[0].metrics == { 'loss': [3, 2, 1], 'epoch_loss': { 2: 1, 3: 2, 4: 3 } } trial3 = Trial(parameters={'a': 2}, project_id=p1.uid, group_id=g1.uid) proto.new_trial(trial3) trial4 = proto.get_trial(Trial(_hash=trial3.hash)) trials = proto.fetch_trials({ 'group_id': g1.uid, 'uid': trial3.uid }) assert trials == [trial3] except Exception as e: raise e finally: db.stop()
def test_execute_query_in(): data = [Trial(name=str(i)) for i in range(0, 20)] query = dict(name={'$in': ['0', '10', '19', '30']}) assert len(check_query(data, query)) == 3
reserved = CustomStatus('reserved', Status.CreatedGroup.value + 2) project = Project(name='test') group = TrialGroup(name='MyGroup', project_id=project.uid) statuses = [ Status.Interrupted, Status.Broken, Status.Completed, reserved, new, new, new, new, new ] trials = [] for idx, status in enumerate(statuses): trial = Trial(project_id=project.uid, group_id=group.uid, status=Status.Interrupted, parameters={ 'batch_size': 256, 'id': idx }) trials.append(trial) TRIAL_COUNT = len(trials) def remove(filename): try: os.remove(filename) except: pass
def fetch_trials(self, query): status = query.get('status') heartbeat = query.get('metadata.heartbeat') uid = query.get('uid') if heartbeat is not None: self.cursor.execute( """ SELECT hash, revision, name, description, tags, metadata, metrics, version, group_id, project_id, parameters, status, errors FROM track.trials WHERE group_id = %s AND status->>'name' = %s AND CAST (metadata->>'heartbeat' AS DECIMAL) <= %s """, (self.encode_uid( query['group_id']), status.name, heartbeat['$lte'])) elif isinstance(status, dict): self.cursor.execute( """ SELECT hash, revision, name, description, tags, metadata, metrics, version, group_id, project_id, parameters, status, errors FROM track.trials WHERE group_id = %s AND status->>'name' IN %s """, (self.encode_uid(query['group_id']), tuple(status['$in']))) elif status is not None: self.cursor.execute( """ SELECT hash, revision, name, description, tags, metadata, metrics, version, group_id, project_id, parameters, status, errors FROM track.trials WHERE group_id = %s AND status->>'name' = %s AND CAST (status->>'value' AS INTEGER) = %s """, (self.encode_uid( query['group_id']), status.name, status.value)) elif uid is not None: self.cursor.execute( """ SELECT hash, revision, name, description, tags, metadata, metrics, version, group_id, project_id, parameters, status, errors FROM track.trials WHERE uid = %s AND group_id = %s """, ( self.encode_uid(uid), self.encode_uid(query['group_id']), )) else: self.cursor.execute( """ SELECT hash, revision, name, description, tags, metadata, metrics, version, group_id, project_id, parameters, status, errors FROM track.trials WHERE group_id = %s """, (self.encode_uid(query['group_id']), )) results = self.cursor.fetchall() if results is None: return [] trials = [] for r in results: t = Trial(_hash=self.decode_uid(r[0]), revision=r[1], name=r[2], description=r[3], tags=self.deserialize(r[4]), metadata=self.deserialize(r[5]), metrics=self.decode_metrics(self.deserialize(r[6])), version=r[7], group_id=self.decode_uid(r[8]), project_id=self.decode_uid(r[9]), parameters=r[10], status=make_status(r[11]), errors=r[12]) trials.append(t) return trials