示例#1
0
    def _flush(self, force=False):
        """Flush all relevant monitor information to disk."""
        if not self.write_upon_reset and not force:
            return

        self.stats_recorder.flush()

        # Give it a very distiguished name, since we need to pick it
        # up from the filesystem later.
        path = os.path.join(
            self.directory,
            '{}.manifest.{}.manifest.json'.format(self.file_prefix,
                                                  self.file_infix))
        logger.debug('Writing training manifest file to %s', path)
        with atomic_write.atomic_write(path) as f:
            # We need to write relative paths here since people may
            # move the training_dir around. It would be cleaner to
            # already have the basenames rather than basename'ing
            # manually, but this works for now.
            json.dump(
                {
                    'stats':
                    os.path.basename(self.stats_recorder.path),
                    'videos': [(os.path.basename(v), os.path.basename(m))
                               for v, m in self.videos],
                    'env_info':
                    self._env_info(),
                },
                f,
                default=json_encode_np)
示例#2
0
    def close(self):
        """Make sure to manually close, or else you'll leak the encoder process"""
        if not self.enabled:
            return

        if self.encoder:
            logger.debug('Closing video encoder: path=%s', self.path)
            self.encoder.close()
            self.encoder = None
        else:
            # No frames captured. Set metadata, and remove the empty output file.
            os.remove(self.path)

            if self.metadata is None:
                self.metadata = {}
            self.metadata['empty'] = True

        # If broken, get rid of the output file, otherwise we'd leak it.
        if self.broken:
            logger.info(
                'Cleaning up paths for broken video recorder: path=%s metadata_path=%s',
                self.path, self.metadata_path)

            # Might have crashed before even starting the output file, don't try to remove in that case.
            if os.path.exists(self.path):
                os.remove(self.path)

            if self.metadata is None:
                self.metadata = {}
            self.metadata['broken'] = True

        self.write_metadata()
示例#3
0
    def _past_limit(self):
        """Return true if we are past our limit"""
        if self._max_episode_steps is not None and self._max_episode_steps <= self._elapsed_steps:
            logger.debug("Env has passed the step limit defined by TimeLimit.")
            return True

        if self._max_episode_seconds is not None and self._max_episode_seconds <= self._elapsed_seconds:
            logger.debug(
                "Env has passed the seconds limit defined by TimeLimit.")
            return True

        return False
示例#4
0
def update_rollout_dict(spec, rollout_dict):
    """
    Takes as input the environment spec for which the rollout is to be generated,
    and the existing dictionary of rollouts. Returns True iff the dictionary was
    modified.
    """
    # Skip platform-dependent
    if should_skip_env_spec_for_tests(spec):
        logger.info("Skipping tests for {}".format(spec.id))
        return False

    # Skip environments that are nondeterministic
    if spec.nondeterministic:
        logger.info("Skipping tests for nondeterministic env {}".format(
            spec.id))
        return False

    logger.info("Generating rollout for {}".format(spec.id))

    try:
        observations_hash, actions_hash, rewards_hash, dones_hash = generate_rollout_hash(
            spec)
    except:
        # If running the env generates an exception, don't write to the rollout file
        logger.warn(
            "Exception {} thrown while generating rollout for {}. Rollout not added."
            .format(sys.exc_info()[0], spec.id))
        return False

    rollout = {}
    rollout['observations'] = observations_hash
    rollout['actions'] = actions_hash
    rollout['rewards'] = rewards_hash
    rollout['dones'] = dones_hash

    existing = rollout_dict.get(spec.id)
    if existing:
        differs = False
        for key, new_hash in rollout.items():
            differs = differs or existing[key] != new_hash
        if not differs:
            logger.debug("Hashes match with existing for {}".format(spec.id))
            return False
        else:
            logger.warn("Got new hash for {}. Overwriting.".format(spec.id))

    rollout_dict[spec.id] = rollout
    return True
示例#5
0
def load_results(training_dir):
    if not os.path.exists(training_dir):
        logger.error('Training directory %s not found', training_dir)
        return

    manifests = detect_training_manifests(training_dir)
    if not manifests:
        logger.error('No manifests found in training directory %s',
                     training_dir)
        return

    logger.debug('Uploading data from manifest %s', ', '.join(manifests))

    # Load up stats + video files
    stats_files = []
    videos = []
    env_infos = []

    for manifest in manifests:
        with open(manifest) as f:
            contents = json.load(f)
            # Make these paths absolute again
            stats_files.append(os.path.join(training_dir, contents['stats']))
            videos += [(os.path.join(training_dir,
                                     v), os.path.join(training_dir, m))
                       for v, m in contents['videos']]
            env_infos.append(contents['env_info'])

    env_info = collapse_env_infos(env_infos, training_dir)
    data_sources, initial_reset_timestamps, timestamps, episode_lengths, episode_rewards, episode_types, initial_reset_timestamp = merge_stats_files(
        stats_files)

    return {
        'manifests': manifests,
        'env_info': env_info,
        'data_sources': data_sources,
        'timestamps': timestamps,
        'episode_lengths': episode_lengths,
        'episode_rewards': episode_rewards,
        'episode_types': episode_types,
        'initial_reset_timestamps': initial_reset_timestamps,
        'initial_reset_timestamp': initial_reset_timestamp,
        'videos': videos,
    }
示例#6
0
def add_new_rollouts(spec_ids, overwrite):
    environments = [
        spec for spec in envs.registry.all() if spec._entry_point is not None
    ]
    if spec_ids:
        environments = [spec for spec in environments if spec.id in spec_ids]
        assert len(environments) == len(spec_ids), "Some specs not found"
    with open(ROLLOUT_FILE) as data_file:
        rollout_dict = json.load(data_file)
    modified = False
    for spec in environments:
        if not overwrite and spec.id in rollout_dict:
            logger.debug("Rollout already exists for {}. Skipping.".format(
                spec.id))
        else:
            modified = update_rollout_dict(spec, rollout_dict) or modified

    if modified:
        logger.info("Writing new rollout file to {}".format(ROLLOUT_FILE))
        with open(ROLLOUT_FILE, "w") as outfile:
            json.dump(rollout_dict, outfile, indent=2, sort_keys=True)
    else:
        logger.info("No modifications needed.")
示例#7
0
    def start(self):
        self.cmdline = (
            self.backend,
            '-nostats',
            '-loglevel',
            'error',  # suppress warnings
            '-y',
            '-r',
            '%d' % self.frames_per_sec,

            # input
            '-f',
            'rawvideo',
            '-s:v',
            '{}x{}'.format(*self.wh),
            '-pix_fmt',
            ('rgb32' if self.includes_alpha else 'rgb24'),
            '-i',
            '-',  # this used to be /dev/stdin, which is not Windows-friendly

            # output
            '-vf',
            'scale=trunc(iw/2)*2:trunc(ih/2)*2',
            '-vcodec',
            'libx264',
            '-pix_fmt',
            'yuv420p',
            self.output_path)

        logger.debug('Starting ffmpeg with "%s"', ' '.join(self.cmdline))
        if hasattr(os, 'setsid'):  #setsid not present on Windows
            self.proc = subprocess.Popen(self.cmdline,
                                         stdin=subprocess.PIPE,
                                         preexec_fn=os.setsid)
        else:
            self.proc = subprocess.Popen(self.cmdline, stdin=subprocess.PIPE)
示例#8
0
    def capture_frame(self):
        """Render the given `env` and add the resulting frame to the video."""
        if not self.functional: return
        logger.debug('Capturing video frame: path=%s', self.path)

        render_mode = 'ansi' if self.ansi_mode else 'rgb_array'
        frame = self.env.render(mode=render_mode)

        if frame is None:
            if self._async:
                return
            else:
                # Indicates a bug in the environment: don't want to raise
                # an error here.
                logger.warn(
                    'Env returned None on render(). Disabling further rendering for video recorder by marking as disabled: path=%s metadata_path=%s',
                    self.path, self.metadata_path)
                self.broken = True
        else:
            self.last_frame = frame
            if self.ansi_mode:
                self._encode_ansi_frame(frame)
            else:
                self._encode_image_frame(frame)