Ejemplo n.º 1
0
    def new_trial(self, trial: Trial, auto_increment=False):
        if trial.uid in self.storage.objects:
            if not auto_increment:
                return None

            trials = self.get_trial(trial)

            max_rev = 0
            for t in trials:
                max_rev = max(max_rev, t.revision)

            warning(
                f'Trial was already completed. Increasing revision number (rev={max_rev + 1})'
            )
            trial.revision = max_rev + 1
            trial._hash = None

        self.storage.objects[trial.uid] = trial
        self.storage.trials.add(trial.uid)

        if trial.project_id is not None:
            project = self.storage.objects.get(trial.project_id)

            if project is not None or self.strict:
                project.trials.add(trial)
        else:
            warning('Orphan trial')

        if trial.group_id is not None:
            group = self.storage.objects.get(trial.group_id)
            if group is not None or self.strict:
                group.trials.add(trial.uid)

        trial.metadata['_update_count'] = 0
        return trial
Ejemplo n.º 2
0
def test_execute_query_custom_status():
    data = [
        Trial(name=str(i), status=CustomStatus('new', 1))
        for i in range(0, 20)
    ] + [
        Trial(name=str(i), status=CustomStatus('interrupted', 1))
        for i in range(0, 20)
    ]

    query = dict(status={'$in': ['new', 'interrupted']})

    arr = check_query(data, query)
    assert len(arr) == 40
Ejemplo n.º 3
0
def increment():
    backend = make_local('file://test_parallel.json')
    trial = backend.get_trial(Trial(_hash=trial_hash, revision=trial_rev))[0]

    # there is no lock here so the number could have changed already
    count = trial.metadata.get('count', 0)
    backend.log_trial_metadata(trial, count=count + 1)
Ejemplo n.º 4
0
def test_execute_query_status():
    data = [Trial(name=str(i)) for i in range(0, 20)]

    query = dict(status={'$in': [Status.CreatedGroup]})

    arr = check_query(data, query)
    assert len(arr) == 20
Ejemplo n.º 5
0
    def set_trial(self,
                  trial: Optional[Trial] = None,
                  force: bool = False,
                  **kwargs):
        """Set a new trial

        Parameters
        ----------
        trial: Optional[Trial]
            project definition you can use to create or set the project

        force: bool
            by default once the trial is set it cannot be changed.
            use force to override this behaviour.

        kwargs: {uid, hash, revision}
            arguments used to create a `Trial` object if no Trial object were provided.
            You should specify `uid` or the pair `(hash, revision)`.
            See :func:`~track.structure.Trial` for possible arguments

        Returns
        -------
        returns a trial logger
        """
        if self.trial is not None and not force:
            info('Trial is already set, to override use force=True')
            return self.logger

        if trial is None:
            uhash = kwargs.pop('hash', None)
            uid = kwargs.pop('uid', None)
            version = kwargs.pop('version', self.version())

            trial = Trial(version=version, **kwargs)
            if uhash is not None:
                trial.hash = uhash

            if uid is not None:
                trial.uid = uid

        try:
            if trial.version is None:
                trial.version = self.version()

            trials = self.protocol.get_trial(trial)

            if trials is None:
                raise TrialDoesNotExist(
                    f'Trial (hash: {trial.hash}, v:{trial.version} rev: {trial.revision}) does not exist!'
                )

            self.trial = trials[0]
            self.logger = TrialLogger(self.trial, self.protocol)
            return self.logger

        except IndexError:
            raise TrialDoesNotExist(
                f'cannot set trial (id: {trial.uid}, hash:{hash}) it does not exist'
            )
Ejemplo n.º 6
0
def from_json(obj: Dict[str, any]) -> any:
    if not isinstance(obj, dict):
        return obj

    dtype = obj.get('dtype')

    if dtype == 'project':
        # from track.utils.debug import print_stack
        # if obj.get('metadata') is None:
        #     print_stack()

        return Project(
            _uid=obj['uid'],
            name=obj['name'],
            description=obj['description'],
            metadata=obj['metadata'],
            groups=set([from_json(g) for g in obj['groups']]),
            trials=set([from_json(t) for t in obj['trials']]),
        )

    elif dtype == 'trial_group':
        return TrialGroup(
            _uid=obj['uid'],
            name=obj['name'],
            description=obj['description'],
            metadata=obj['metadata'],
            trials=set(obj['trials']),
            project_id=obj['project_id']
        )

    elif dtype == 'trial':
        return Trial(
            _hash=obj['hash'],
            revision=obj['revision'],
            name=obj['name'],
            description=obj['description'],
            tags=obj['tags'],
            version=obj['version'],
            group_id=obj['group_id'],
            project_id=obj['project_id'],
            parameters=obj['parameters'],
            metadata=to_json(obj['metadata']),
            metrics=obj['metrics'],
            chronos=obj['chronos'],
            errors=obj['errors'],
            status=status(
                name=obj['status']['name'],
                value=obj['status']['value'])
        )

    return obj
Ejemplo n.º 7
0
def test_parallel_fetch_update(backend='file://test.json', workers=12):
    from multiprocessing import Queue

    proto, reservable = get_reservable_trials(backend)

    queue = Queue()

    reserve_trials(workers, backend, queue)

    reserved_trial = retrieve_trials(queue)
    assert len(reserved_trial) == len(
        set(reserved_trial)), 'All reserved trials should be different'

    reserved_trial = set(reserved_trial)
    intersect = reservable.intersection(reserved_trial)
    assert len(reserved_trial) == len(
        intersect), 'All reserved trials should be reservable'

    for uid in reserved_trial:
        t = Trial()
        t.uid = uid

        trial = proto.get_trial(t)[0]
        assert trial.status.name == 'reserved'
Ejemplo n.º 8
0
    def _make_trial(self, arguments, name=None, **kwargs):
        project_id = None
        group_id = None
        if self.project is not None:
            project_id = self.project.uid

        if self.group is not None:
            group_id = self.group.uid

        trial = Trial(name=name,
                      version=self.version(),
                      project_id=project_id,
                      group_id=group_id,
                      parameters=arguments,
                      **kwargs)

        trial = self.protocol.new_trial(trial)
        return trial
Ejemplo n.º 9
0
    def log_trial_metrics(self, trial: Trial, step: any = None, aggregator: Callable[[], Aggregator] = None, **kwargs):
        ntrial = self.storage.objects.get(trial.uid)
        for k, v in kwargs.items():
            container = trial.metrics.get(k)

            if container is None:
                container = _make_container(step, aggregator)
                trial.metrics[k] = container

            if step is not None and isinstance(container, dict):
                container[step] = v
            elif step:
                container.append((step, v))
            else:
                container.append(v)

        ntrial.metrics.update(trial.metrics)
        self._inc_trial(ntrial)
Ejemplo n.º 10
0
 def from_json(self, obj):
     return Trial(
         _hash=obj['hash'],
         revision=obj['revision'],
         name=obj['name'],
         description=obj['description'],
         tags=obj['tags'],
         version=obj['version'],
         group_id=obj['group_id'],
         project_id=obj['project_id'],
         parameters=obj['parameters'],
         metadata=to_json(obj['metadata']),
         metrics=obj['metrics'],
         chronos={k: from_json(v) for k, v in obj['chronos'].items()},
         errors=obj['errors'],
         status=status(
             name=obj['status']['name'],
             value=obj['status']['value'])
     )
Ejemplo n.º 11
0
def test_trial():
    p = Trial(name='name',
              description='2',
              tags=['0', '1'],
              version='version',
              group_id='0',
              project_id='1',
              parameters=dict(a=1, b=2),
              metadata=dict(a=2, b=3),
              metrics=dict(a=3, b=4),
              chronos=dict(a=4, b=5),
              status=Status.FinishedGroup,
              errors=[])

    ps = from_json(to_json(p))

    # status is not serializable perfectly because we allow for custom status
    ps.status = p.status
    assert p == ps
Ejemplo n.º 12
0
    def _make_trial(self, parameters, name=None, **kwargs):
        project_id = None
        group_id = None
        if self.project is not None:
            project_id = self.project.uid

        if self.group is not None:
            group_id = self.group.uid

        trial = Trial(name=name,
                      version=self.version(),
                      project_id=project_id,
                      group_id=group_id,
                      parameters=parameters,
                      **kwargs)

        trial = self.protocol.new_trial(trial)
        assert trial is not None, 'Trial already exist!'
        return trial
Ejemplo n.º 13
0
def test_local_parallel(woker_count=20):
    """Here we check that _update_count is atomic and cannot run out of sync.
    `count` and the other can because it does not happen inside the lock (first fetch then increment
    """

    global trial_hash, trial_rev

    # -- Create the object that are going to be accessed in parallel
    remove('test_parallel.json')
    backend = make_local('file://test_parallel.json')

    project_def = Project(name='test')
    project = backend.new_project(project_def)

    group_def = TrialGroup(name='test_group', project_id=project.uid)
    group = backend.new_trial_group(group_def)

    trial = backend.new_trial(
        Trial(
            parameters={'batch': 256},
            project_id=project.uid,
            group_id=group.uid)
    )

    count = trial.metadata.get('count', 0)
    backend.log_trial_metadata(trial, count=count)

    trial_hash, trial_rev = trial.uid.split('_')
    # -- Setup done

    processes = [Process(target=increment) for _ in range(0, woker_count)]
    print('-- Start')
    [p.start() for p in processes]
    [p.join() for p in processes]

    trial = backend.get_trial(trial)[0]

    # remove('test_parallel.json')
    print(trial.metadata)
    assert trial.metadata.get('_update_count', 0) == woker_count + 1, 'Parallel write should wait for each other'
Ejemplo n.º 14
0
    def process_args(self, args, cache=None):
        """ replace ids by their object reference so the backend modifies the objects and not a copy"""

        new_args = dict()

        for k, v in args.items():
            if k == 'trial':
                if isinstance(v, str):
                    hashid, rev = v.split('_')
                    rev = int(rev)

                    v = self.backend.get_trial(
                        Trial(_hash=hashid, revision=rev))
                    for i in v:
                        if i.revision == rev:
                            v = i
                            break
                    else:
                        warning(
                            'Was not able to find the correct trial revision')

                v = from_json(v)

            elif k == 'project':
                if isinstance(v, str):
                    v = self.backend.get_project(Project(name=v))

                v = from_json(v)

            elif k == 'group':
                if isinstance(v, str):
                    v = self.backend.get_trial_group(TrialGroup(_uid=v))

                v = from_json(v)

            new_args[k] = v

        return new_args
Ejemplo n.º 15
0
    def get_trial(self, trial: Trial):
        self.cursor.execute(
            """
            SELECT
                hash, revision, name, description,
                tags, metadata, metrics, version,
                group_id, project_id, parameters,
                status, errors
            FROM
                track.trials
            WHERE
                hash = %s AND
                revision = %s
            """, (self.encode_uid(trial.hash), trial.revision))

        results = self.cursor.fetchall()
        if results is None:
            return []

        trials = []
        for r in results:
            t = Trial(_hash=self.decode_uid(r[0]),
                      revision=r[1],
                      name=r[2],
                      description=r[3],
                      tags=self.deserialize(r[4]),
                      metadata=self.deserialize(r[5]),
                      metrics=self.decode_metrics(self.deserialize(r[6])),
                      version=r[7],
                      group_id=self.decode_uid(r[8]),
                      project_id=self.decode_uid(r[9]),
                      parameters=r[10],
                      status=make_status(r[11]),
                      errors=r[12])
            trials.append(t)
        return trials
Ejemplo n.º 16
0
def test_cockroach_inserts():
    from track.distributed.cockroachdb import CockRoachDB

    db = CockRoachDB(location='/tmp/cockroach', addrs='localhost:8125')
    db.start(wait=True)

    for k, v in db.properties.items():
        print(k, v)

    try:
        proto = Cockroach('cockroach://localhost:8125')

        p1 = Project(name='test')
        proto.new_project(p1)
        p2 = proto.get_project(p1)
        print(p1)
        assert p1.name == p2.name

        g1 = TrialGroup(name='test', project_id=p1.uid)
        proto.new_trial_group(g1)
        g2 = proto.get_trial_group(TrialGroup(_uid=g1.uid))
        print(g1)
        assert g1.name == g2.name

        trial1 = Trial(parameters={'a': 1}, project_id=p1.uid, group_id=g1.uid)
        proto.new_trial(trial1)
        trial2 = proto.get_trial(Trial(_hash=trial1.hash))

        print(trial1)
        assert len(trial2) == 1
        assert trial1.parameters == trial2[0].parameters

        # fetch by project_id
        # trials = proto.fetch_trials({'status': Status.CreatedGroup, 'group_id': g1.uid})
        # print(trials)

        # fetch by group_id
        trials = proto.fetch_trials({'group_id': g1.uid})
        assert len(trials) == 1
        assert trials[0].uid == trial1.uid

        # fetch by status
        trials = proto.fetch_trials({'status': Status.CreatedGroup, 'group_id': g1.uid})

        assert len(trials) == 1
        assert trials[0].uid == trial1.uid

        # fetch by status
        trials = proto.fetch_trials({'status': {'$in': ['CreatedGroup']}, 'group_id': g1.uid})

        assert len(trials) == 1
        assert trials[0].uid == trial1.uid

        proto.log_trial_metrics(trial1, step=2, epoch_loss=1)
        proto.log_trial_metrics(trial1, step=3, epoch_loss=2)
        proto.log_trial_metrics(trial1, step=4, epoch_loss=3)

        proto.log_trial_metrics(trial1, loss=3)
        proto.log_trial_metrics(trial1, loss=2)
        proto.log_trial_metrics(trial1, loss=1)

        trials = proto.fetch_trials({'group_id': g1.uid})
        assert len(trials) == 1

        print(trials[0].metrics)

        assert trials[0].metrics == {
            'loss': [3, 2, 1],
            'epoch_loss': {
                2: 1,
                3: 2,
                4: 3
            }
        }

        trial3 = Trial(parameters={'a': 2}, project_id=p1.uid, group_id=g1.uid)
        proto.new_trial(trial3)
        trial4 = proto.get_trial(Trial(_hash=trial3.hash))

        trials = proto.fetch_trials({
            'group_id': g1.uid,
            'uid': trial3.uid
        })
        assert trials == [trial3]

    except Exception as e:
        raise e

    finally:
        db.stop()
Ejemplo n.º 17
0
def test_execute_query_in():
    data = [Trial(name=str(i)) for i in range(0, 20)]

    query = dict(name={'$in': ['0', '10', '19', '30']})

    assert len(check_query(data, query)) == 3
Ejemplo n.º 18
0
reserved = CustomStatus('reserved', Status.CreatedGroup.value + 2)

project = Project(name='test')

group = TrialGroup(name='MyGroup', project_id=project.uid)

statuses = [
    Status.Interrupted, Status.Broken, Status.Completed, reserved, new, new,
    new, new, new
]
trials = []
for idx, status in enumerate(statuses):
    trial = Trial(project_id=project.uid,
                  group_id=group.uid,
                  status=Status.Interrupted,
                  parameters={
                      'batch_size': 256,
                      'id': idx
                  })
    trials.append(trial)

TRIAL_COUNT = len(trials)


def remove(filename):
    try:
        os.remove(filename)
    except:
        pass

Ejemplo n.º 19
0
    def fetch_trials(self, query):
        status = query.get('status')
        heartbeat = query.get('metadata.heartbeat')
        uid = query.get('uid')

        if heartbeat is not None:
            self.cursor.execute(
                """
                SELECT
                    hash, revision, name, description, tags,
                    metadata, metrics, version, group_id,
                    project_id, parameters, status, errors
                FROM
                    track.trials
                WHERE
                    group_id = %s AND
                    status->>'name' = %s AND
                    CAST (metadata->>'heartbeat' AS DECIMAL) <= %s
                """, (self.encode_uid(
                    query['group_id']), status.name, heartbeat['$lte']))

        elif isinstance(status, dict):
            self.cursor.execute(
                """
                SELECT
                    hash, revision, name, description, tags,
                    metadata, metrics, version, group_id,
                    project_id, parameters, status, errors
                FROM
                    track.trials
                WHERE
                    group_id = %s AND
                    status->>'name' IN %s
                """,
                (self.encode_uid(query['group_id']), tuple(status['$in'])))

        elif status is not None:
            self.cursor.execute(
                """
                SELECT
                    hash, revision, name, description, tags,
                     metadata, metrics, version, group_id,
                     project_id, parameters, status, errors
                FROM
                    track.trials
                WHERE
                    group_id = %s AND
                    status->>'name' = %s AND
                    CAST (status->>'value' AS INTEGER) = %s
                """, (self.encode_uid(
                    query['group_id']), status.name, status.value))
        elif uid is not None:
            self.cursor.execute(
                """
                SELECT
                    hash, revision, name, description, tags,
                    metadata, metrics, version, group_id,
                    project_id, parameters, status, errors
                FROM
                    track.trials
                WHERE
                    uid = %s AND
                    group_id = %s
                """, (
                    self.encode_uid(uid),
                    self.encode_uid(query['group_id']),
                ))
        else:
            self.cursor.execute(
                """
                SELECT
                    hash, revision, name, description, tags,
                    metadata, metrics, version, group_id,
                    project_id, parameters, status, errors
                FROM
                    track.trials
                WHERE
                    group_id = %s
                """, (self.encode_uid(query['group_id']), ))

        results = self.cursor.fetchall()
        if results is None:
            return []

        trials = []
        for r in results:
            t = Trial(_hash=self.decode_uid(r[0]),
                      revision=r[1],
                      name=r[2],
                      description=r[3],
                      tags=self.deserialize(r[4]),
                      metadata=self.deserialize(r[5]),
                      metrics=self.decode_metrics(self.deserialize(r[6])),
                      version=r[7],
                      group_id=self.decode_uid(r[8]),
                      project_id=self.decode_uid(r[9]),
                      parameters=r[10],
                      status=make_status(r[11]),
                      errors=r[12])
            trials.append(t)

        return trials