예제 #1
0
    def gather(self):
        self.runner.run()

        if hasattr(self.runner, 'trunc_lst_mb_sil_valid'):
            diff = len(self.runner.trunc_lst_mb_sil_valid) - np.sum(
                self.runner.trunc_lst_mb_sil_valid)
            self.processed_frames = int(diff * self.frame_skip)
            sil_frames = self.runner.trunc_lst_mb_sil_valid
        else:
            self.processed_frames = self.runner.steps_taken * self.runner.nenv * self.frame_skip
            sil_frames = np.zeros_like(self.runner.trunc_lst_mb_dones)

        local_ep_infos = self.runner.epinfos
        if hvd.size() > 1:
            ep_infos = flatten_lists(mpi.COMM_WORLD.allgather(local_ep_infos))
        else:
            ep_infos = local_ep_infos

        self.ep_info_window.extend(ep_infos)
        if len(ep_infos) >= self.log_window_size:
            self.ep_infos_to_report = ep_infos
        else:
            self.ep_infos_to_report = self.ep_info_window
        self.nb_return_goals_reached = sum(
            [ei['reached'] for ei in self.ep_infos_to_report])
        self.nb_return_goals_chosen = len(self.ep_infos_to_report)
        self.nb_exploration_goals_reached = sum([
            ei['nb_exploration_goals_reached']
            for ei in self.ep_infos_to_report
        ])
        self.nb_exploration_goals_chosen = sum([
            ei['nb_exploration_goals_chosen'] for ei in self.ep_infos_to_report
        ])
        self.reward_mean = safemean(
            [ei['r'] for ei in self.ep_infos_to_report])
        self.length_mean = safemean(
            [ei['l'] for ei in self.ep_infos_to_report])
        self.nb_of_episodes += len(ep_infos)
        self.return_goals_chosen = [ei['goal_chosen'] for ei in local_ep_infos]
        self.return_goals_reached = [ei['reached'] for ei in local_ep_infos]
        self.sub_goals = [ei['sub_goal'] for ei in local_ep_infos]
        self.ent_incs = [ei['inc_ent'] for ei in local_ep_infos]

        # We do not update the network during the warm up period
        if not self.freeze_network and not self.warm_up:
            self._train()

        return (self.runner.ar_mb_cells, self.runner.ar_mb_game_reward,
                self.runner.trunc_lst_mb_trajectory_ids,
                self.runner.trunc_lst_mb_dones, self.runner.trunc_mb_obs,
                self.runner.trunc_mb_goals, self.runner.trunc_mb_actions,
                self.runner.trunc_lst_mb_rewards, sil_frames,
                self.runner.ar_mb_ret_strat, self.runner.ar_mb_traj_index,
                self.runner.ar_mb_traj_len)
예제 #2
0
    def gather(self):
        ep_infos = []
        for _ in range(self.num_steps):
            actions = [
                self.explorer.get_action(env) for env in self.env.get_envs()
            ]
            obs_and_goals, rewards, dones, infos = self.env.step(actions)
            for info in infos:
                maybe_ep_info = info.get('episode')
                if maybe_ep_info:
                    ep_infos.append(maybe_ep_info)

        if hvd.size() > 1:
            ep_infos = flatten_lists(mpi.COMM_WORLD.allgather(ep_infos))

        trajectories = [ei['trajectory'] for ei in ep_infos]

        self.ep_info_window.extend(ep_infos)
        if len(ep_infos) >= 100:
            self.ep_infos_to_report = ep_infos
        else:
            self.ep_infos_to_report = self.ep_info_window

        self.nb_return_goals_reached = sum(
            [ei['nb_return_goals_reached'] for ei in self.ep_infos_to_report])
        self.nb_return_goals_chosen = sum(
            [ei['nb_return_goals_chosen'] for ei in self.ep_infos_to_report])
        self.nb_exploration_goals_reached = sum([
            ei['nb_exploration_goals_reached']
            for ei in self.ep_infos_to_report
        ])
        self.nb_exploration_goals_chosen = sum([
            ei['nb_exploration_goals_chosen'] for ei in self.ep_infos_to_report
        ])
        self.return_goals_chosen = flatten_lists(
            [ei['return_goals_chosen'] for ei in ep_infos])
        self.return_goals_info_chosen = flatten_lists(
            [ei['return_goals_info_chosen'] for ei in ep_infos])
        self.exploration_goals_chosen = flatten_lists(
            [ei['exploration_goals_chosen'] for ei in ep_infos])
        self.return_goals_reached = flatten_lists(
            [ei['return_goals_reached'] for ei in ep_infos])
        self.exploration_goals_reached = flatten_lists(
            [ei['exploration_goals_reached'] for ei in ep_infos])
        self.restored = flatten_lists([ei['restored'] for ei in ep_infos])
        self.reward_mean = safemean(
            [ei['r'] for ei in self.ep_infos_to_report])
        self.length_mean = safemean(
            [ei['l'] for ei in self.ep_infos_to_report])
        self.nb_of_episodes += len(trajectories)

        return trajectories
예제 #3
0
    def sync_before_checkpoint(self):
        if self.sil == 'sil' or self.sil == 'replay':
            # Let everyone in the world know who has which full trajectory
            owned_by_world = self.get_traj_owners(self.archive.cell_trajectory_manager.cell_trajectories)

            requests = []
            if hvd.rank() == 0:
                # Rank 0: figure out which trajectories you are missing
                owned_by_others = []
                for traj_id in self.archive.cell_trajectory_manager.cell_trajectories:
                    if not self.archive.cell_trajectory_manager.has_full_trajectory(traj_id):
                        owned_by_others.append(traj_id)

                # Rank 0: figure out who owns those trajectories
                owners = [self.get_traj_owner(owned_by_world, other_traj_id) for other_traj_id in owned_by_others]

                # Rank 0: construct a set of requests
                requests = [(hvd.rank(), traj_id, owner) for traj_id, owner in zip(owned_by_others, owners)]

            # Exchange requests
            requests = mpi.COMM_WORLD.allgather(requests)
            requests = flatten_lists(requests)
            self.process_requests(requests)
예제 #4
0
def _run(**kwargs):
    # Make sure that, if one worker crashes, the entire MPI process is aborted
    def handle_exception(exc_type, exc_value, exc_traceback):
        sys.__excepthook__(exc_type, exc_value, exc_traceback)
        sys.stderr.flush()
        if hvd.size() > 1:
            mpi.COMM_WORLD.Abort(1)

    sys.excepthook = handle_exception

    track_memory = kwargs['trace_memory']
    disable_logging = bool(kwargs['disable_logging'])
    warm_up_cycles = kwargs['warm_up_cycles']
    log_after_warm_up = kwargs['log_after_warm_up']
    screenshot_merge = kwargs['screenshot_merge']
    clear_checkpoints = list(
        filter(None, kwargs['clear_checkpoints'].split(':')))

    if 'all' in clear_checkpoints:
        clear_checkpoints = CHECKPOINT_ABBREVIATIONS.keys()

    if track_memory:
        tracemalloc.start(25)

    kwargs = del_out_of_setup_args(kwargs)
    expl, log_par = setup(**kwargs)
    local_logger.info('setup done')

    # We only need one MPI worker to log the results
    local_logger.info('Initializing logger')
    logger = None
    traj_logger = None
    if hvd.rank() == 0 and not disable_logging:
        logger = SimpleLogger(log_par.base_path + '/log.txt')
        traj_logger = SimpleLogger(log_par.base_path + '/traj_log.txt')

    ########################
    # START THE EXPERIMENT #
    ########################

    local_logger.info('Starting experiment')
    checkpoint_tracker = CheckpointTracker(log_par, expl)
    prev_checkpoint = None
    merged_dict = {}
    sil_trajectories = []
    if screenshot_merge[0:9] == 'from_dir:':
        screen_shot_dir = screenshot_merge[9:]
    else:
        screen_shot_dir = f'{log_par.base_path}/screen_shots'

    local_logger.info('Initiate cycle')
    expl.init_cycle()
    local_logger.info('Initiating Cycle done')

    if kwargs['expl_state'] is not None:
        local_logger.info('Performing warm up cycles...')
        expl.start_warm_up()
        for i in range(warm_up_cycles):
            if hvd.rank() == 0:
                local_logger.info(f'Running warm up cycle: {i}')
            expl.run_cycle()
        expl.end_warm_up()
        checkpoint_tracker.n_iters = expl.cycles
        checkpoint_tracker.log_warmup = log_after_warm_up
        local_logger.info('Performing warm up cycles... done')

    while checkpoint_tracker.should_continue():
        # Run one iteration
        if hvd.rank() == 0:
            local_logger.info(f'Running cycle: {checkpoint_tracker.n_iters}')

        checkpoint_tracker.pre_cycle()
        expl.run_cycle()
        checkpoint_tracker.post_cycle()

        write_checkpoint = None
        if hvd.rank() == 0:
            write_checkpoint = checkpoint_tracker.calc_write_checkpoint()
        write_checkpoint = mpi.get_comm_world().bcast(write_checkpoint, root=0)
        checkpoint_tracker.set_should_write_checkpoint(write_checkpoint)

        # Code that should be executed by all workers at a checkpoint generation
        if checkpoint_tracker.should_write_checkpoint():
            local_logger.debug(
                f'Rank: {hvd.rank()} is exchanging screenshots for checkpoint: {expl.frames_compute}'
            )
            screenshots = expl.trajectory_gatherer.env.recursive_getattr(
                'rooms')
            if screenshot_merge == 'mpi':
                screenshots = flatten_lists(
                    mpi.COMM_WORLD.allgather(screenshots))
            merged_dict = {}
            for screenshot_dict in screenshots:
                for key, value in screenshot_dict.items():
                    if key not in merged_dict:
                        merged_dict[key] = value
                    else:
                        after_threshold_screenshot_taken_merged = merged_dict[
                            key][0]
                        after_threshold_screenshot_taken_current = screenshot_dict[
                            key][0]
                        if after_threshold_screenshot_taken_current and not after_threshold_screenshot_taken_merged:
                            merged_dict[key] = value

            if screenshot_merge == 'disk':
                for key, value in merged_dict.items():
                    filename = f'{screen_shot_dir}/{key}_{hvd.rank()}.png'
                    os.makedirs(screen_shot_dir, exist_ok=True)
                    if not os.path.isfile(filename):
                        im = Image.fromarray(value[1])
                        im.save(filename)
                        im_array = imageio.imread(filename)
                        assert (im_array == value[1]).all()

                mpi.COMM_WORLD.barrier()

            local_logger.debug('Merging SIL trajectories')
            sil_trajectories = [expl.prev_selected_traj]
            if hvd.size() > 1:
                sil_trajectories = flatten_lists(
                    mpi.COMM_WORLD.allgather(sil_trajectories))
            local_logger.debug(
                f'Rank: {hvd.rank()} is done merging trajectories for checkpoint: {expl.frames_compute}'
            )

            expl.sync_before_checkpoint()
            local_logger.debug(
                f'Rank: {hvd.rank()} is done synchronizing for checkpoint: {expl.frames_compute}'
            )

        # Code that should be executed only by the master
        if hvd.rank() == 0 and not disable_logging:
            gatherer = expl.trajectory_gatherer
            return_success_rate = -1
            if gatherer.nb_return_goals_chosen > 0:
                return_success_rate = gatherer.nb_return_goals_reached / gatherer.nb_return_goals_chosen
            exploration_success_rate = -1
            if gatherer.nb_exploration_goals_chosen > 0:
                exploration_success_rate = gatherer.nb_exploration_goals_reached / gatherer.nb_exploration_goals_chosen

            cum_success_rate = 0
            for reached in expl.archive.cells_reached_dict.values():
                success_rate = sum(reached) / len(reached)
                cum_success_rate += success_rate
            mean_success_rate = cum_success_rate / len(expl.archive.archive)

            logger.write('it', checkpoint_tracker.n_iters)
            logger.write('score', expl.archive.max_score)
            logger.write('cells', len(expl.archive.archive))
            logger.write('ret_suc', return_success_rate)
            logger.write('exp_suc', exploration_success_rate)
            logger.write('rew_mean', gatherer.reward_mean)
            logger.write('len_mean', gatherer.length_mean)
            logger.write('ep', gatherer.nb_of_episodes)
            logger.write('arch_suc', mean_success_rate)
            logger.write('cum_suc', cum_success_rate)
            logger.write('frames', expl.frames_compute)

            if len(gatherer.loss_values) > 0:
                loss_values = np.mean(gatherer.loss_values, axis=0)
                assert len(loss_values) == len(gatherer.model.loss_names)
                for (loss_value, loss_name) in zip(loss_values,
                                                   gatherer.model.loss_names):
                    logger.write(loss_name, loss_value)

            stored_frames = 0
            for traj in expl.archive.cell_trajectory_manager.full_trajectories.values(
            ):
                stored_frames += len(traj)

            logger.write('sil_frames', stored_frames)

            nb_no_score_cells = len(expl.archive.archive)
            for weight in expl.archive.cell_selector.selector_weights:
                if hasattr(weight, 'max_score_dict'):
                    nb_no_score_cells = len(weight.max_score_dict)
            logger.write('no_score_cells', nb_no_score_cells)

            cells_found_ret = 0
            cells_found_rand = 0
            cells_found_policy = 0
            for cell_key in expl.archive.archive:
                cell_info = expl.archive.archive[cell_key]
                if cell_info.ret_discovered == global_const.EXP_STRAT_NONE:
                    cells_found_ret += 1
                elif cell_info.ret_discovered == global_const.EXP_STRAT_RAND:
                    cells_found_rand += 1
                elif cell_info.ret_discovered == global_const.EXP_STRAT_POLICY:
                    cells_found_policy += 1

            logger.write('cells_found_ret', cells_found_ret)
            logger.write('cells_found_rand', cells_found_rand)
            logger.write('cells_found_policy', cells_found_policy)
            logger.flush()

            traj_manager = expl.archive.cell_trajectory_manager
            new_trajectories = sorted(
                traj_manager.new_trajectories,
                key=lambda t: traj_manager.cell_trajectories[t].frame_finished)
            for traj_id in new_trajectories:
                traj_info = traj_manager.cell_trajectories[traj_id]
                traj_logger.write('it', checkpoint_tracker.n_iters)
                traj_logger.write('frame', traj_info.frame_finished)
                traj_logger.write('exp_strat', traj_info.exp_strat)
                traj_logger.write('exp_new_cells', traj_info.exp_new_cells)
                traj_logger.write('ret_new_cells', traj_info.ret_new_cells)
                traj_logger.write('score', traj_info.score)
                traj_logger.write('total_actions', traj_info.total_actions)
                traj_logger.write('id', traj_info.id)
                traj_logger.flush()

            # Code that should be executed by only the master at a checkpoint generation
            if checkpoint_tracker.should_write_checkpoint():
                local_logger.info(
                    f'Rank: {hvd.rank()} is writing checkpoint: {expl.frames_compute}'
                )
                filename = f'{log_par.base_path}/{expl.frames_compute:0{log_par.n_digits}}'

                # Save pictures
                if len(log_par.save_pictures) > 0:
                    if screenshot_merge == 'disk':
                        for file_name in os.listdir(screen_shot_dir):
                            if file_name.endswith('.png'):
                                room = int(file_name.split('_')[0])
                                if room not in merged_dict:
                                    screen_shot = imageio.imread(
                                        f'{screen_shot_dir}/{file_name}')
                                    merged_dict[room] = (True, screen_shot)

                    elif screenshot_merge[0:9] == 'from_dir:':
                        for file_name in os.listdir(screen_shot_dir):
                            if file_name.endswith('.png'):
                                room = int(file_name.split('.')[0])
                                if room not in merged_dict:
                                    screen_shot = imageio.imread(
                                        f'{screen_shot_dir}/{file_name}')
                                    merged_dict[room] = (True, screen_shot)

                    render_pictures(log_par, expl, filename, prev_checkpoint,
                                    merged_dict, sil_trajectories)

                # Save archive state
                if log_par.save_archive:
                    save_state(expl.get_state(), filename + ARCHIVE_POSTFIX)
                    expl.archive.cell_trajectory_manager.dump(filename +
                                                              TRAJ_POSTFIX)

                # Save model
                if log_par.save_model:
                    expl.trajectory_gatherer.save_model(filename +
                                                        MODEL_POSTFIX)

                # Clean up previous checkpoint.
                if prev_checkpoint:
                    for checkpoint_type in clear_checkpoints:
                        if checkpoint_type in CHECKPOINT_ABBREVIATIONS:
                            postfix = CHECKPOINT_ABBREVIATIONS[checkpoint_type]
                        else:
                            postfix = checkpoint_type
                        with contextlib.suppress(FileNotFoundError):
                            local_logger.debug(
                                f'Removing old checkpoint: {prev_checkpoint + postfix}'
                            )
                            os.remove(prev_checkpoint + postfix)
                prev_checkpoint = filename

                if track_memory:
                    snapshot = tracemalloc.take_snapshot()
                    display_top(snapshot)

                if PROFILER:
                    local_logger.info(
                        f'ITERATION: {checkpoint_tracker.n_iters}')
                    PROFILER.disable()
                    PROFILER.dump_stats(filename + '.stats')
                    PROFILER.enable()

    local_logger.info(f'Rank {hvd.rank()} finished experiment')
    mpi.get_comm_world().barrier()