예제 #1
0
    def _get_project(self, name):
        if self.project is None:
            self.project = self.backend.get_project(Project(name=name))

            if self.project is None:
                self.project = self.backend.new_project(Project(name=name))

        assert self.project, "Project should have been found"
예제 #2
0
def test_project():
    p = Project(name='1',
                description='2',
                metadata=dict(a=2, b=3),
                groups=set([TrialGroup(name='TG', project_id='1')]),
                trials=set())

    ps = from_json(to_json(p))
    assert p == ps
예제 #3
0
 def from_json(self, obj):
     return Project(
         _uid=obj['uid'],
         name=obj['name'],
         description=obj['description'],
         metadata=obj['metadata'],
         groups=set([from_json(g) for g in obj['groups']]),
         trials=set([from_json(t) for t in obj['trials']]),
     )
예제 #4
0
def from_json(obj: Dict[str, any]) -> any:
    if not isinstance(obj, dict):
        return obj

    dtype = obj.get('dtype')

    if dtype == 'project':
        # from track.utils.debug import print_stack
        # if obj.get('metadata') is None:
        #     print_stack()

        return Project(
            _uid=obj['uid'],
            name=obj['name'],
            description=obj['description'],
            metadata=obj['metadata'],
            groups=set([from_json(g) for g in obj['groups']]),
            trials=set([from_json(t) for t in obj['trials']]),
        )

    elif dtype == 'trial_group':
        return TrialGroup(
            _uid=obj['uid'],
            name=obj['name'],
            description=obj['description'],
            metadata=obj['metadata'],
            trials=set(obj['trials']),
            project_id=obj['project_id']
        )

    elif dtype == 'trial':
        return Trial(
            _hash=obj['hash'],
            revision=obj['revision'],
            name=obj['name'],
            description=obj['description'],
            tags=obj['tags'],
            version=obj['version'],
            group_id=obj['group_id'],
            project_id=obj['project_id'],
            parameters=obj['parameters'],
            metadata=to_json(obj['metadata']),
            metrics=obj['metrics'],
            chronos=obj['chronos'],
            errors=obj['errors'],
            status=status(
                name=obj['status']['name'],
                value=obj['status']['value'])
        )

    return obj
예제 #5
0
파일: client.py 프로젝트: Delaunay/track
    def set_project(self,
                    project: Optional[Project] = None,
                    force: bool = False,
                    get_only: bool = False,
                    **kwargs):
        """Set or create a new project

        Parameters
        ----------
        project: Optional[Project]
            project definition you can use to create or set the project

        force: bool
            by default once the project is set it cannot be changed.
            use force to override this behaviour.

        get_only: bool
            if true does not insert the project if missing.
            default to false

        kwargs
            arguments used to create a `Project` object if no project object were provided
            See :func:`~track.structure.Project` for possible arguments

        Returns
        -------
        returns created project
        """
        if self.project is not None and not force:
            info('Project is already set, to override use force=True')
            return self.project

        if project is None:
            project = Project(**kwargs)

        assert project.name is not None, 'Project name cannot be none'

        # does the project exist ?
        self.project = self.protocol.get_project(project)

        if self.project is not None:
            return self.project

        if get_only:
            raise RuntimeError(
                f'Project (name: {project.name}) was not found!')

        self.project = self.protocol.new_project(project)

        debug(f'set project to (project: {self.project.name})')
        return self.project
예제 #6
0
파일: cockroach.py 프로젝트: Delaunay/track
    def get_project(self, project: Project):
        self.cursor.execute(
            """
            SELECT
                uid, name, description, metadata, trial_groups, trials
            FROM
                track.projects
            WHERE
                uid = %s
            """, (project.uid, ))

        r = self.cursor.fetchone()
        if r is None:
            return r

        return Project(_uid=self.decode_uid(r[0]),
                       name=r[1],
                       description=r[2],
                       metadata=self.deserialize(r[3]),
                       groups=set(self.process_uuid_array(r[4])),
                       trials=set(self.process_uuid_array(r[5])))
예제 #7
0
def test_local_parallel(woker_count=20):
    """Here we check that _update_count is atomic and cannot run out of sync.
    `count` and the other can because it does not happen inside the lock (first fetch then increment
    """

    global trial_hash, trial_rev

    # -- Create the object that are going to be accessed in parallel
    remove('test_parallel.json')
    backend = make_local('file://test_parallel.json')

    project_def = Project(name='test')
    project = backend.new_project(project_def)

    group_def = TrialGroup(name='test_group', project_id=project.uid)
    group = backend.new_trial_group(group_def)

    trial = backend.new_trial(
        Trial(
            parameters={'batch': 256},
            project_id=project.uid,
            group_id=group.uid)
    )

    count = trial.metadata.get('count', 0)
    backend.log_trial_metadata(trial, count=count)

    trial_hash, trial_rev = trial.uid.split('_')
    # -- Setup done

    processes = [Process(target=increment) for _ in range(0, woker_count)]
    print('-- Start')
    [p.start() for p in processes]
    [p.join() for p in processes]

    trial = backend.get_trial(trial)[0]

    # remove('test_parallel.json')
    print(trial.metadata)
    assert trial.metadata.get('_update_count', 0) == woker_count + 1, 'Parallel write should wait for each other'
예제 #8
0
파일: socketed.py 프로젝트: bouthilx/track
    def process_args(self, args, cache=None):
        """ replace ids by their object reference so the backend modifies the objects and not a copy"""

        new_args = dict()

        for k, v in args.items():
            if k == 'trial':
                if isinstance(v, str):
                    hashid, rev = v.split('_')
                    rev = int(rev)

                    v = self.backend.get_trial(
                        Trial(_hash=hashid, revision=rev))
                    for i in v:
                        if i.revision == rev:
                            v = i
                            break
                    else:
                        warning(
                            'Was not able to find the correct trial revision')

                v = from_json(v)

            elif k == 'project':
                if isinstance(v, str):
                    v = self.backend.get_project(Project(name=v))

                v = from_json(v)

            elif k == 'group':
                if isinstance(v, str):
                    v = self.backend.get_trial_group(TrialGroup(_uid=v))

                v = from_json(v)

            new_args[k] = v

        return new_args
예제 #9
0
import os

from track.persistence import get_protocol
from track.structure import Trial, TrialGroup, Project, Status, CustomStatus

new = CustomStatus('new', Status.CreatedGroup.value + 1)
reserved = CustomStatus('reserved', Status.CreatedGroup.value + 2)

project = Project(name='test')

group = TrialGroup(name='MyGroup', project_id=project.uid)

statuses = [
    Status.Interrupted, Status.Broken, Status.Completed, reserved, new, new,
    new, new, new
]
trials = []
for idx, status in enumerate(statuses):
    trial = Trial(project_id=project.uid,
                  group_id=group.uid,
                  status=Status.Interrupted,
                  parameters={
                      'batch_size': 256,
                      'id': idx
                  })
    trials.append(trial)

TRIAL_COUNT = len(trials)


def remove(filename):
예제 #10
0
def test_cockroach_inserts():
    from track.distributed.cockroachdb import CockRoachDB

    db = CockRoachDB(location='/tmp/cockroach', addrs='localhost:8125')
    db.start(wait=True)

    for k, v in db.properties.items():
        print(k, v)

    try:
        proto = Cockroach('cockroach://localhost:8125')

        p1 = Project(name='test')
        proto.new_project(p1)
        p2 = proto.get_project(p1)
        print(p1)
        assert p1.name == p2.name

        g1 = TrialGroup(name='test', project_id=p1.uid)
        proto.new_trial_group(g1)
        g2 = proto.get_trial_group(TrialGroup(_uid=g1.uid))
        print(g1)
        assert g1.name == g2.name

        trial1 = Trial(parameters={'a': 1}, project_id=p1.uid, group_id=g1.uid)
        proto.new_trial(trial1)
        trial2 = proto.get_trial(Trial(_hash=trial1.hash))

        print(trial1)
        assert len(trial2) == 1
        assert trial1.parameters == trial2[0].parameters

        # fetch by project_id
        # trials = proto.fetch_trials({'status': Status.CreatedGroup, 'group_id': g1.uid})
        # print(trials)

        # fetch by group_id
        trials = proto.fetch_trials({'group_id': g1.uid})
        assert len(trials) == 1
        assert trials[0].uid == trial1.uid

        # fetch by status
        trials = proto.fetch_trials({'status': Status.CreatedGroup, 'group_id': g1.uid})

        assert len(trials) == 1
        assert trials[0].uid == trial1.uid

        # fetch by status
        trials = proto.fetch_trials({'status': {'$in': ['CreatedGroup']}, 'group_id': g1.uid})

        assert len(trials) == 1
        assert trials[0].uid == trial1.uid

        proto.log_trial_metrics(trial1, step=2, epoch_loss=1)
        proto.log_trial_metrics(trial1, step=3, epoch_loss=2)
        proto.log_trial_metrics(trial1, step=4, epoch_loss=3)

        proto.log_trial_metrics(trial1, loss=3)
        proto.log_trial_metrics(trial1, loss=2)
        proto.log_trial_metrics(trial1, loss=1)

        trials = proto.fetch_trials({'group_id': g1.uid})
        assert len(trials) == 1

        print(trials[0].metrics)

        assert trials[0].metrics == {
            'loss': [3, 2, 1],
            'epoch_loss': {
                2: 1,
                3: 2,
                4: 3
            }
        }

        trial3 = Trial(parameters={'a': 2}, project_id=p1.uid, group_id=g1.uid)
        proto.new_trial(trial3)
        trial4 = proto.get_trial(Trial(_hash=trial3.hash))

        trials = proto.fetch_trials({
            'group_id': g1.uid,
            'uid': trial3.uid
        })
        assert trials == [trial3]

    except Exception as e:
        raise e

    finally:
        db.stop()