Exemplo n.º 1
0
def test_log_hyperparams_verbose(mock_api_post_block, capsys):
    data = {'arch_name': 'cnn_1', 'lr': .001}
    log_hyperparams(data)
    captured = capsys.readouterr()

    expected_result = '[jovian] Hyperparams logged.'
    assert expected_result == captured.out.strip()
Exemplo n.º 2
0
    def test_log_hyperparams(self, mock_api_post_block):
        data = {'arch_name': 'cnn_1', 'lr': .001}
        expected_result = [('fake_slug_metrics_1', 'metrics', {}),
                           ('fake_slug_metrics_2', 'metrics', {}),
                           ('fake_slug_hyperparams_1', 'hyperparams', {}),
                           ('fake_slug_hyperparams_2', 'hyperparams', {}),
                           ('fake_slug_3', 'hyperparams', data)]

        log_hyperparams(data)
        self.assertEqual(jovian.utils.records._data_blocks, expected_result)
Exemplo n.º 3
0
    def on_train_begin(self, n_epochs: int, metrics_names: list, **ka):
        if self.reset_tracking:
            reset('hyperparams', 'metrics')
        hyp_dict = {
            'epochs': n_epochs,
            'batch_size': self.learn.data.batch_size,
            'loss_func': str(self.learn.loss_func.func),
            'opt_func': str(self.learn.opt_func.func).split("'")[1],
            'weight_decay': self.learn.wd,
            'learning_rate': str(self.learn.opt.lr)
        }
        if self.arch_name:
            hyp_dict['arch_name'] = self.arch_name
        log_hyperparams(hyp_dict)

        if self.valid_set:
            self.met_names.extend(metrics_names)
Exemplo n.º 4
0
    def on_train_begin(self, logs=None):
        # Reset state if required
        if self.reset_tracking:
            reset('hyperparams', 'metrics')

        hyp_dict = {
            'epochs': self.params['epochs'],
            'batch_size': self.params['batch_size'],
            'loss_func': self.model.loss,
            'opt':
            str(self.model.optimizer.__class__).split("'")[1].split('.')[-1],
            'wt_decay': self.model.optimizer.initial_decay,
            'lr': str(get_value(self.model.optimizer.lr))
        }
        if self.arch_name:
            hyp_dict['arch'] = self.arch_name
        log_hyperparams(hyp_dict, verbose=False)
        self.hyperparams = hyp_dict