def __init__(self, env_name, out_file): """Initialize. :param env_name: name of an Atari environment. :param out_file: output file. """ # Super Callback.__init__(self) # Video Writer frame_size = gym.make(env_name).observation_space.shape[:2] fourcc = cv.VideoWriter_fourcc(*"HFYU") fps = 20.0 self._writer = cv.VideoWriter( filename=out_file, fourcc=fourcc, fps=fps, frameSize=(frame_size[1], frame_size[0]), isColor=True, ) # Close properly tools.QuitWithResources.add("VideoWriter", lambda: self._writer.release())
def __init__(self, logdir, agent): """Initialize. :param logdir: directory of tensorboard logs :param agent: a QAgentDef instance. """ # Super Callback.__init__(self) # Store self._agent = agent # Both are dicts of {episode: data} # where data is a dict of metrics accumulated during an episode # {metric_name: episode_values} self._episode_scalars = {} self._episode_hists = {} # Metrics returned after step and episode to be visualized as # scalars or histograms self._kerasrl_step_scalars = ["reward", "metrics"] self._kerasrl_step_hists = ["action"] self._kerasrl_episode_scalars = ["episode_reward", "nb_episode_steps"] self._kerasrl_episode_hists = [] # Tf writer self.summary_writer = tf.summary.create_file_writer(logdir)
def __init__(self, policy): """Initialize.""" # Super Callback.__init__(self) # Store self._min_eps = policy._min_eps self._max_eps = policy._max_eps self._policy = policy
def __init__(self, env_name, skip_frames=None, port=None): """Initialize. :param env_name: name of an Atari environment. :param skip_frames: skip a random number of frames in [0, skip_frames]. :param port: if given, overrides the default port. """ # Super Callback.__init__(self) # Check if skip_frames is not None and skip_frames <= 0: raise ValueError("skip_frames must be positive") # Init self.sender = AtariFramesSender(env_name, port=port) self.skip_frames = skip_frames self._skips_left = 0 self._last_frame = None
def __init__(self, agent, path, interval): """Initialize. :param agent: a keras-rl agent :param path: directory of checkpoints :param interval: save frequency in number of steps """ # Super Callback.__init__(self) # Store self.agent = agent self.interval = interval self.init_step = 0 self.step = 0 self.episode = 0 self.counters_file = os.path.join(path, os.path.pardir, "counters.json") self.step_checkpoints = os.path.join( path, "weights_{step}." + self.save_format)
def __init__(self): Callback.__init__(self) self.action_dict_list = dict() self.ep_ctr = 0