예제 #1
0
 def test_explicit_run_finalize(self):
     run = Run(system_tracking_interval=None)
     for i in range(10):
         run.track(i, name='seq')
     self.assertIsNone(run.end_time)
     run.finalize()
     self.assertIsNotNone(run.end_time)
예제 #2
0
    def test_incompatible_type_after_tracking_restart(self):
        run = Run(system_tracking_interval=None)
        run_hash = run.hash
        run.track(1., name='numbers', context={})
        run.finalize()
        del run

        new_run = Run(run_hash=run_hash, system_tracking_interval=None)
        with self.assertRaises(ValueError) as cm:
            new_run.track(1, name='numbers', context={})
        exception = cm.exception
        self.assertEqual(
            'Cannot log value \'1\' on sequence \'numbers\'. Incompatible data types.',
            exception.args[0])
예제 #3
0
def fill_up_test_data():
    remove_test_data()

    # put dummy data into test repo with 10 runs, tracking 2 metrics over 3 contexts
    repo = Repo.default_repo()
    run_hashes = [hex(random.getrandbits(64))[-7:] for _ in range(10)]

    contexts = [{
        'is_training': True,
        'subset': 'train'
    }, {
        'is_training': True,
        'subset': 'val'
    }, {
        'is_training': False
    }]
    metrics = ['loss', 'accuracy']

    with repo.structured_db:
        runs = []
        for idx, run_hash in enumerate(run_hashes):
            run = Run(run_hash, repo=repo, system_tracking_interval=None)
            run['hparams'] = create_run_params()
            run['run_index'] = idx
            run['start_time'] = datetime.datetime.utcnow().isoformat()
            run['name'] = f'Run # {idx}'
            run.name = run['name']
            runs.append(run)
            metric_contexts = itertools.product(metrics, contexts)
            for metric_context in metric_contexts:
                metric = metric_context[0]
                context = metric_context[1]
                if metric == 'accuracy' and 'subset' in context:
                    continue
                else:
                    # track 100 values per run
                    for step in range(100):
                        val = 1.0 - 1.0 / (step + 1)
                        run.track(val,
                                  name=metric,
                                  step=step,
                                  epoch=1,
                                  context=context)
        for run in runs:
            run.finalize()
예제 #4
0
def finalize_stalled_runs(repo: 'Repo', runs: set):
    runs_in_progress = []
    for run_hash in tqdm.tqdm(runs,
                              desc='Finalizing stalled runs',
                              total=len(runs)):
        try:
            run = Run(run_hash=run_hash,
                      repo=repo,
                      system_tracking_interval=None)
        except filelock.Timeout:
            runs_in_progress.append(run_hash)
        else:
            # TODO: [AT] handle lock timeout on index db (retry logic).
            run.finalize()
    if runs_in_progress:
        click.echo('Skipped indexing for the following runs in progress:')
        for run_hash in runs_in_progress:
            click.secho(f'\t\'{run_hash}\'', fg='yellow')