import functools import os import queue import numpy as np import libhpref from tartist import random, image from tartist.app import rl from tartist.core import get_env, get_logger from tartist.core.utils.cache import cached_result from tartist.core.utils.naming import get_dump_directory from tartist.data import flow from tartist.nn import opr as O, optimizer, summary logger = get_logger(__file__) __envs__ = { 'dir': { 'root': get_dump_directory(__file__), }, 'a3c': { 'env_name': 'Breakout-v0', 'input_shape': (84, 84), 'nr_history_frames': 4, 'max_nr_steps': 40000, 'gamma': 0.99, 'nr_td_steps': 5, 'nr_players': 2, 'nr_predictors': 2, 'predictor': {
# -*- coding:utf8 -*- # File : snapshot.py # Author : Jiayuan Mao # Email : [email protected] # Date : 2/26/17 # # This file is part of TensorArtist. from tartist.core import get_env, register_event, get_logger from tartist.core import io import os.path as osp import numpy as np logger = get_logger() __snapshot_dir__ = 'snapshots' __snapshot_ext__ = '.snapshot.pkl' __weights_ext__ = '.weights.pkl' def get_snapshot_dir(): return get_env('dir.snapshot', osp.join(get_env('dir.root'), __snapshot_dir__)) def enable_snapshot_saver(trainer, save_interval=1): def dump_snapshot_on_epoch_after(trainer): if trainer.epoch % save_interval != 0: return snapshot_dir = get_snapshot_dir() snapshot = trainer.dump_snapshot()