Esempio n. 1
0
def test_project():
    p = Project(name='1',
                description='2',
                metadata=dict(a=2, b=3),
                groups=set([TrialGroup(name='TG', project_id='1')]),
                trials=set())

    ps = from_json(to_json(p))
    assert p == ps
Esempio n. 2
0
 def from_json(self, obj):
     return TrialGroup(
         _uid=obj['uid'],
         name=obj['name'],
         description=obj['description'],
         metadata=obj['metadata'],
         trials=set(obj['trials']),
         project_id=obj['project_id']
     )
Esempio n. 3
0
def test_trial_group():
    p = TrialGroup(name='1',
                   description='2',
                   metadata=dict(a=2, b=3),
                   trials=set(),
                   project_id=1)

    ps = from_json(to_json(p))
    assert p == ps
Esempio n. 4
0
def from_json(obj: Dict[str, any]) -> any:
    if not isinstance(obj, dict):
        return obj

    dtype = obj.get('dtype')

    if dtype == 'project':
        # from track.utils.debug import print_stack
        # if obj.get('metadata') is None:
        #     print_stack()

        return Project(
            _uid=obj['uid'],
            name=obj['name'],
            description=obj['description'],
            metadata=obj['metadata'],
            groups=set([from_json(g) for g in obj['groups']]),
            trials=set([from_json(t) for t in obj['trials']]),
        )

    elif dtype == 'trial_group':
        return TrialGroup(
            _uid=obj['uid'],
            name=obj['name'],
            description=obj['description'],
            metadata=obj['metadata'],
            trials=set(obj['trials']),
            project_id=obj['project_id']
        )

    elif dtype == 'trial':
        return Trial(
            _hash=obj['hash'],
            revision=obj['revision'],
            name=obj['name'],
            description=obj['description'],
            tags=obj['tags'],
            version=obj['version'],
            group_id=obj['group_id'],
            project_id=obj['project_id'],
            parameters=obj['parameters'],
            metadata=to_json(obj['metadata']),
            metrics=obj['metrics'],
            chronos=obj['chronos'],
            errors=obj['errors'],
            status=status(
                name=obj['status']['name'],
                value=obj['status']['value'])
        )

    return obj
Esempio n. 5
0
    def set_group(self,
                  group: Optional[TrialGroup] = None,
                  force: bool = False,
                  get_only: bool = False,
                  **kwargs):
        """Set or create a new group

        Parameters
        ----------
        group: Optional[TrialGroup]
            project definition you can use to create or set the project

        force: bool
            by default once the trial group is set it cannot be changed.
            use force to override this behaviour.

        get_only: bool
            if true does not insert the group if missing.
            default to false

        kwargs
            arguments used to create a `TrialGroup` object if no TrialGroup object were provided.
            See :func:`~track.structure.TrialGroup` for possible arguments

        Returns
        -------
        returns created trial group
        """

        if self.group is not None and not force:
            info('Group is already set, to override use force=True')
            return self.group

        if group is None:
            group = TrialGroup(**kwargs)

        if group.project_id is None:
            group.project_id = self.project.uid

        self.group = self.protocol.get_trial_group(group)

        if self.group is not None:
            return self.group

        if get_only:
            raise RuntimeError(f'Group (name: {group.name}) was not found!')

        self.group = self.protocol.new_trial_group(group)
        return self.group
Esempio n. 6
0
    def create_experiment(self, config):
        """Insert a new experiment inside the database"""
        self._get_project(config['name'])

        self.group = self.backend.new_trial_group(
            TrialGroup(name=experiment_uid(name=config['name'],
                                           version=config['version']),
                       project_id=self.project.uid,
                       metadata=to_json(config)))

        if self.group is None:
            raise DuplicateKeyError('Experiment was already created')

        config['_id'] = self.group.uid
        return config
Esempio n. 7
0
    def get_trial_group(self, group: TrialGroup):
        self.cursor.execute(
            """
            SELECT
                uid, name, description, metadata, trials, project_id
            FROM
                track.trial_groups
            WHERE
                uid = %s
            """, (group.uid, ))

        r = self.cursor.fetchone()
        if r is None:
            return r

        return TrialGroup(_uid=self.decode_uid(r[0]),
                          name=r[1],
                          description=r[2],
                          metadata=self.deserialize(r[3]),
                          trials=set(self.process_uuid_array(r[4])),
                          project_id=self.decode_uid(r[5]))
Esempio n. 8
0
def test_local_parallel(woker_count=20):
    """Here we check that _update_count is atomic and cannot run out of sync.
    `count` and the other can because it does not happen inside the lock (first fetch then increment
    """

    global trial_hash, trial_rev

    # -- Create the object that are going to be accessed in parallel
    remove('test_parallel.json')
    backend = make_local('file://test_parallel.json')

    project_def = Project(name='test')
    project = backend.new_project(project_def)

    group_def = TrialGroup(name='test_group', project_id=project.uid)
    group = backend.new_trial_group(group_def)

    trial = backend.new_trial(
        Trial(
            parameters={'batch': 256},
            project_id=project.uid,
            group_id=group.uid)
    )

    count = trial.metadata.get('count', 0)
    backend.log_trial_metadata(trial, count=count)

    trial_hash, trial_rev = trial.uid.split('_')
    # -- Setup done

    processes = [Process(target=increment) for _ in range(0, woker_count)]
    print('-- Start')
    [p.start() for p in processes]
    [p.join() for p in processes]

    trial = backend.get_trial(trial)[0]

    # remove('test_parallel.json')
    print(trial.metadata)
    assert trial.metadata.get('_update_count', 0) == woker_count + 1, 'Parallel write should wait for each other'
Esempio n. 9
0
    def process_args(self, args, cache=None):
        """ replace ids by their object reference so the backend modifies the objects and not a copy"""

        new_args = dict()

        for k, v in args.items():
            if k == 'trial':
                if isinstance(v, str):
                    hashid, rev = v.split('_')
                    rev = int(rev)

                    v = self.backend.get_trial(
                        Trial(_hash=hashid, revision=rev))
                    for i in v:
                        if i.revision == rev:
                            v = i
                            break
                    else:
                        warning(
                            'Was not able to find the correct trial revision')

                v = from_json(v)

            elif k == 'project':
                if isinstance(v, str):
                    v = self.backend.get_project(Project(name=v))

                v = from_json(v)

            elif k == 'group':
                if isinstance(v, str):
                    v = self.backend.get_trial_group(TrialGroup(_uid=v))

                v = from_json(v)

            new_args[k] = v

        return new_args
Esempio n. 10
0
    def fetch_groups(self, query):
        self.cursor.execute(
            """
            SELECT
                uid, name, description, metadata, trials, project_id
            FROM
                track.trial_groups
            WHERE
                name = %s AND
                metadata->>'user' = %s
            """, (query['name'], query['metadata.user']))

        r = self.cursor.fetchone()
        if r is None:
            return None

        group = TrialGroup(_uid=r[0],
                           name=r[1],
                           description=r[2],
                           metadata=self.deserialize(r[3]),
                           trials=set(self.process_uuid_array(r[4])),
                           project_id=r[5])
        return group
Esempio n. 11
0
import os

from track.persistence import get_protocol
from track.structure import Trial, TrialGroup, Project, Status, CustomStatus

new = CustomStatus('new', Status.CreatedGroup.value + 1)
reserved = CustomStatus('reserved', Status.CreatedGroup.value + 2)

project = Project(name='test')

group = TrialGroup(name='MyGroup', project_id=project.uid)

statuses = [
    Status.Interrupted, Status.Broken, Status.Completed, reserved, new, new,
    new, new, new
]
trials = []
for idx, status in enumerate(statuses):
    trial = Trial(project_id=project.uid,
                  group_id=group.uid,
                  status=Status.Interrupted,
                  parameters={
                      'batch_size': 256,
                      'id': idx
                  })
    trials.append(trial)

TRIAL_COUNT = len(trials)


def remove(filename):
Esempio n. 12
0
def test_cockroach_inserts():
    from track.distributed.cockroachdb import CockRoachDB

    db = CockRoachDB(location='/tmp/cockroach', addrs='localhost:8125')
    db.start(wait=True)

    for k, v in db.properties.items():
        print(k, v)

    try:
        proto = Cockroach('cockroach://localhost:8125')

        p1 = Project(name='test')
        proto.new_project(p1)
        p2 = proto.get_project(p1)
        print(p1)
        assert p1.name == p2.name

        g1 = TrialGroup(name='test', project_id=p1.uid)
        proto.new_trial_group(g1)
        g2 = proto.get_trial_group(TrialGroup(_uid=g1.uid))
        print(g1)
        assert g1.name == g2.name

        trial1 = Trial(parameters={'a': 1}, project_id=p1.uid, group_id=g1.uid)
        proto.new_trial(trial1)
        trial2 = proto.get_trial(Trial(_hash=trial1.hash))

        print(trial1)
        assert len(trial2) == 1
        assert trial1.parameters == trial2[0].parameters

        # fetch by project_id
        # trials = proto.fetch_trials({'status': Status.CreatedGroup, 'group_id': g1.uid})
        # print(trials)

        # fetch by group_id
        trials = proto.fetch_trials({'group_id': g1.uid})
        assert len(trials) == 1
        assert trials[0].uid == trial1.uid

        # fetch by status
        trials = proto.fetch_trials({'status': Status.CreatedGroup, 'group_id': g1.uid})

        assert len(trials) == 1
        assert trials[0].uid == trial1.uid

        # fetch by status
        trials = proto.fetch_trials({'status': {'$in': ['CreatedGroup']}, 'group_id': g1.uid})

        assert len(trials) == 1
        assert trials[0].uid == trial1.uid

        proto.log_trial_metrics(trial1, step=2, epoch_loss=1)
        proto.log_trial_metrics(trial1, step=3, epoch_loss=2)
        proto.log_trial_metrics(trial1, step=4, epoch_loss=3)

        proto.log_trial_metrics(trial1, loss=3)
        proto.log_trial_metrics(trial1, loss=2)
        proto.log_trial_metrics(trial1, loss=1)

        trials = proto.fetch_trials({'group_id': g1.uid})
        assert len(trials) == 1

        print(trials[0].metrics)

        assert trials[0].metrics == {
            'loss': [3, 2, 1],
            'epoch_loss': {
                2: 1,
                3: 2,
                4: 3
            }
        }

        trial3 = Trial(parameters={'a': 2}, project_id=p1.uid, group_id=g1.uid)
        proto.new_trial(trial3)
        trial4 = proto.get_trial(Trial(_hash=trial3.hash))

        trials = proto.fetch_trials({
            'group_id': g1.uid,
            'uid': trial3.uid
        })
        assert trials == [trial3]

    except Exception as e:
        raise e

    finally:
        db.stop()