コード例 #1
0
ファイル: result_viewer.py プロジェクト: zaelot11/badger-2019
    def update_experiment(self):
        print('Update experiment')
        try:
            self.widget_config_div.text = 'loading'
            self.state.experiment_id = int(self.widget_text_experiment_id.value)
            sr = SacredReader(self.state.experiment_id, self._sacred_config)
            self.state.sacred_reader = sr
            # print(f'Experiment: {self.state}')
            epochs = sr.get_epochs()
            self.state.epochs = epochs
            # print(f'Epochs: {epochs}')
            self.widget_button.label = f'Experiment Id: {self.state.experiment_id}'
            self.widget_epoch_select.options = list(map(str, epochs))
            self.widget_epoch_select.value = str(epochs[-1])

            formatted_config = '<br/>'.join([f'{k}: {v}' for k, v in sr.config.items()])
            self.widget_config_div.text = f'<pre>{formatted_config}</pre>'
            # update epochs figure
            self.widget_loss_pane.children.clear()
            self.widget_loss_pane.children.append(
                self._create_loss_figure(self._sacred_utils.load_metrics([self.state.experiment_id])))

            self._update()
        except ValueError as e:
            print(f'Error: {e}')
コード例 #2
0
    def update_experiment(self):
        print('Update experiment')
        try:
            # parse experiment id and load experiment from sacred
            self.widget_config_div.text = 'loading'
            self.state.experiment_id = int(
                self.widget_text_experiment_id.value)
            sr = SacredReader(self.state.experiment_id, self._sacred_config)
            self.state.sacred_reader = sr

            # update epochs select
            epochs = sr.get_epochs()
            self.state.epochs = epochs
            self.widget_button.label = f'Experiment Id: {self.state.experiment_id}'
            self.widget_epoch_select.options = list(map(str, epochs))
            self.widget_epoch_select.value = str(epochs[-1])

            # update config
            formatted_config = '<br/>'.join(
                [f'{k}: {v}' for k, v in sr.config.items()])
            self.widget_config_div.text = f'<pre>{formatted_config}</pre>'

            # update inference params
            params = Params(**sr.config)
            self.widget_text_n_rollouts.value = str(params.rollout_size)
            self.widget_text_n_experts.value = str(params.n_experts)
            self.widget_text_n_inputs.value = str(params.task_size)

            self.update_loss_plot()
            self._update()
        except Exception as e:
            print(f'Error: {e}')
コード例 #3
0
def render(exp_id: int, gen: int, sleep: float, num_episodes: Optional[int],
           max_ep_length: Optional[int]):
    """Download a given config and policy from the sacred, run the inference"""

    # parse arguments, init the reader
    reader = SacredReader(exp_id, get_sacred_storage(), data_dir=Path.cwd())

    # obtain the config
    config = Namespace(**reader.config)
    num_episodes = num_episodes if num_episodes is not None else config.num_episodes
    max_ep_length = max_ep_length if max_ep_length is not None else config.max_ep_length
    env_seed = config.env_seed if config.env_seed is not None else -1

    policy, env = build_policy_env(config, env_seed)

    # deserialize the model parameters
    if gen is None:
        gen = reader.find_last_epoch()

    print(f'Deserialization from the epoch: {gen}')
    time.sleep(2)
    policy.load(reader=reader, epoch=gen)

    fitness, num_steps_used = simulate(env=env,
                                       policy=policy,
                                       num_episodes=num_episodes,
                                       max_ep_length=max_ep_length,
                                       render=True,
                                       sleep_render=sleep)

    print(f'\n\n Done, fitness is: {fitness}, num_steps: {num_steps_used}\n\n')
コード例 #4
0
def try_deserialize(config: Namespace, policy: Policy):
    """Deserialize the policy params if configured to do so"""
    if config.load_from is not None:
        print(
            f'DESERIALIZATION: will load policy from the exp-id {config.load_from}'
        )
        from badger_utils.sacred import SacredReader
        reader = SacredReader(config.load_from,
                              get_sacred_storage(),
                              data_dir=Path.cwd())
        policy.load(reader=reader, epoch=reader.find_last_epoch())
コード例 #5
0
 def get_reader(self, experiment_id: int, data_dir: Optional[Path] = None) -> SacredReader:
     """
     Args:
         experiment_id: id of the experiment to be loaded
         data_dir: optional directory for caching the data from sacred (will append data/loaded_from_sacred)
     """
     return SacredReader(experiment_id, self._config, data_dir)
コード例 #6
0
 def load(self, reader: SacredReader, epoch: int):
     reader.load_model(model=self, name=self.name, epoch=epoch)
コード例 #7
0
import torch
from torch import Tensor
import numpy as np
from badger_utils.sacred import SacredReader, SacredConfigFactory
from bokeh.plotting import figure, show, output_notebook

from badger_utils.view.observer_utils import Observer, MultiObserver
from attention.learning_loop import LearningLoop
from attention.search_experiment import Params, load_agent, create_agent, create_task

experiment_id, epoch = 1923, 100000

sr = SacredReader(experiment_id, SacredConfigFactory.local())
p = Params(**sr.config)

p.batch_size = 1
p.task_size = 5
p.n_experts = 1

with torch.no_grad():
    agent = create_agent(p)
    task = create_task(p)
    sr.load_model(agent, 'agent', epoch)
    inner_loop = LearningLoop()

    observer = MultiObserver()
    agent.init_rollout(p.batch_size, p.n_experts)
    err = inner_loop.train_fixed_steps(agent, task, p.rollout_size, observer)
    # task.reset(True, True)
    err = torch.mean(err)
コード例 #8
0
def load_agent(experiment_id, epoch, agent):
    reader = SacredReader(experiment_id, SacredConfigFactory.local())
    reader.load_model(agent, 'agent', epoch)
コード例 #9
0
            {'filename': f'{self._artifact_prefix(run_id)}{filename}'})[0]
        return file.read()


reader = GridFSReader(observer.fs)
run_id = 83

files = reader.list_artifacts(run_id)
print(f'Files: \n{files}')
item = torch.load(
    io.BytesIO(reader.read_artifact(run_id, 'agent_ep_1000.model')))
print(type(item))

# %%
experiment_id = 1
sr = SacredReader(experiment_id, SacredConfigFactory.local())
print(sr.config)

sr.run_data


# %%
def analyze_runs(observer, run_ids: List[int]):
    runs = observer.runs.find({'_id': {
        '$in': run_ids
    }}, {
        '_id': 1,
        'config': 1
    })
    items = list(runs)
    common_keys = find_common_keys(items, lambda x: x['config'])