def gather(self): self.runner.run() if hasattr(self.runner, 'trunc_lst_mb_sil_valid'): diff = len(self.runner.trunc_lst_mb_sil_valid) - np.sum( self.runner.trunc_lst_mb_sil_valid) self.processed_frames = int(diff * self.frame_skip) sil_frames = self.runner.trunc_lst_mb_sil_valid else: self.processed_frames = self.runner.steps_taken * self.runner.nenv * self.frame_skip sil_frames = np.zeros_like(self.runner.trunc_lst_mb_dones) local_ep_infos = self.runner.epinfos if hvd.size() > 1: ep_infos = flatten_lists(mpi.COMM_WORLD.allgather(local_ep_infos)) else: ep_infos = local_ep_infos self.ep_info_window.extend(ep_infos) if len(ep_infos) >= self.log_window_size: self.ep_infos_to_report = ep_infos else: self.ep_infos_to_report = self.ep_info_window self.nb_return_goals_reached = sum( [ei['reached'] for ei in self.ep_infos_to_report]) self.nb_return_goals_chosen = len(self.ep_infos_to_report) self.nb_exploration_goals_reached = sum([ ei['nb_exploration_goals_reached'] for ei in self.ep_infos_to_report ]) self.nb_exploration_goals_chosen = sum([ ei['nb_exploration_goals_chosen'] for ei in self.ep_infos_to_report ]) self.reward_mean = safemean( [ei['r'] for ei in self.ep_infos_to_report]) self.length_mean = safemean( [ei['l'] for ei in self.ep_infos_to_report]) self.nb_of_episodes += len(ep_infos) self.return_goals_chosen = [ei['goal_chosen'] for ei in local_ep_infos] self.return_goals_reached = [ei['reached'] for ei in local_ep_infos] self.sub_goals = [ei['sub_goal'] for ei in local_ep_infos] self.ent_incs = [ei['inc_ent'] for ei in local_ep_infos] # We do not update the network during the warm up period if not self.freeze_network and not self.warm_up: self._train() return (self.runner.ar_mb_cells, self.runner.ar_mb_game_reward, self.runner.trunc_lst_mb_trajectory_ids, self.runner.trunc_lst_mb_dones, self.runner.trunc_mb_obs, self.runner.trunc_mb_goals, self.runner.trunc_mb_actions, self.runner.trunc_lst_mb_rewards, sil_frames, self.runner.ar_mb_ret_strat, self.runner.ar_mb_traj_index, self.runner.ar_mb_traj_len)
def gather(self): ep_infos = [] for _ in range(self.num_steps): actions = [ self.explorer.get_action(env) for env in self.env.get_envs() ] obs_and_goals, rewards, dones, infos = self.env.step(actions) for info in infos: maybe_ep_info = info.get('episode') if maybe_ep_info: ep_infos.append(maybe_ep_info) if hvd.size() > 1: ep_infos = flatten_lists(mpi.COMM_WORLD.allgather(ep_infos)) trajectories = [ei['trajectory'] for ei in ep_infos] self.ep_info_window.extend(ep_infos) if len(ep_infos) >= 100: self.ep_infos_to_report = ep_infos else: self.ep_infos_to_report = self.ep_info_window self.nb_return_goals_reached = sum( [ei['nb_return_goals_reached'] for ei in self.ep_infos_to_report]) self.nb_return_goals_chosen = sum( [ei['nb_return_goals_chosen'] for ei in self.ep_infos_to_report]) self.nb_exploration_goals_reached = sum([ ei['nb_exploration_goals_reached'] for ei in self.ep_infos_to_report ]) self.nb_exploration_goals_chosen = sum([ ei['nb_exploration_goals_chosen'] for ei in self.ep_infos_to_report ]) self.return_goals_chosen = flatten_lists( [ei['return_goals_chosen'] for ei in ep_infos]) self.return_goals_info_chosen = flatten_lists( [ei['return_goals_info_chosen'] for ei in ep_infos]) self.exploration_goals_chosen = flatten_lists( [ei['exploration_goals_chosen'] for ei in ep_infos]) self.return_goals_reached = flatten_lists( [ei['return_goals_reached'] for ei in ep_infos]) self.exploration_goals_reached = flatten_lists( [ei['exploration_goals_reached'] for ei in ep_infos]) self.restored = flatten_lists([ei['restored'] for ei in ep_infos]) self.reward_mean = safemean( [ei['r'] for ei in self.ep_infos_to_report]) self.length_mean = safemean( [ei['l'] for ei in self.ep_infos_to_report]) self.nb_of_episodes += len(trajectories) return trajectories
def sync_before_checkpoint(self): if self.sil == 'sil' or self.sil == 'replay': # Let everyone in the world know who has which full trajectory owned_by_world = self.get_traj_owners(self.archive.cell_trajectory_manager.cell_trajectories) requests = [] if hvd.rank() == 0: # Rank 0: figure out which trajectories you are missing owned_by_others = [] for traj_id in self.archive.cell_trajectory_manager.cell_trajectories: if not self.archive.cell_trajectory_manager.has_full_trajectory(traj_id): owned_by_others.append(traj_id) # Rank 0: figure out who owns those trajectories owners = [self.get_traj_owner(owned_by_world, other_traj_id) for other_traj_id in owned_by_others] # Rank 0: construct a set of requests requests = [(hvd.rank(), traj_id, owner) for traj_id, owner in zip(owned_by_others, owners)] # Exchange requests requests = mpi.COMM_WORLD.allgather(requests) requests = flatten_lists(requests) self.process_requests(requests)
def _run(**kwargs): # Make sure that, if one worker crashes, the entire MPI process is aborted def handle_exception(exc_type, exc_value, exc_traceback): sys.__excepthook__(exc_type, exc_value, exc_traceback) sys.stderr.flush() if hvd.size() > 1: mpi.COMM_WORLD.Abort(1) sys.excepthook = handle_exception track_memory = kwargs['trace_memory'] disable_logging = bool(kwargs['disable_logging']) warm_up_cycles = kwargs['warm_up_cycles'] log_after_warm_up = kwargs['log_after_warm_up'] screenshot_merge = kwargs['screenshot_merge'] clear_checkpoints = list( filter(None, kwargs['clear_checkpoints'].split(':'))) if 'all' in clear_checkpoints: clear_checkpoints = CHECKPOINT_ABBREVIATIONS.keys() if track_memory: tracemalloc.start(25) kwargs = del_out_of_setup_args(kwargs) expl, log_par = setup(**kwargs) local_logger.info('setup done') # We only need one MPI worker to log the results local_logger.info('Initializing logger') logger = None traj_logger = None if hvd.rank() == 0 and not disable_logging: logger = SimpleLogger(log_par.base_path + '/log.txt') traj_logger = SimpleLogger(log_par.base_path + '/traj_log.txt') ######################## # START THE EXPERIMENT # ######################## local_logger.info('Starting experiment') checkpoint_tracker = CheckpointTracker(log_par, expl) prev_checkpoint = None merged_dict = {} sil_trajectories = [] if screenshot_merge[0:9] == 'from_dir:': screen_shot_dir = screenshot_merge[9:] else: screen_shot_dir = f'{log_par.base_path}/screen_shots' local_logger.info('Initiate cycle') expl.init_cycle() local_logger.info('Initiating Cycle done') if kwargs['expl_state'] is not None: local_logger.info('Performing warm up cycles...') expl.start_warm_up() for i in range(warm_up_cycles): if hvd.rank() == 0: local_logger.info(f'Running warm up cycle: {i}') expl.run_cycle() expl.end_warm_up() checkpoint_tracker.n_iters = expl.cycles checkpoint_tracker.log_warmup = log_after_warm_up local_logger.info('Performing warm up cycles... done') while checkpoint_tracker.should_continue(): # Run one iteration if hvd.rank() == 0: local_logger.info(f'Running cycle: {checkpoint_tracker.n_iters}') checkpoint_tracker.pre_cycle() expl.run_cycle() checkpoint_tracker.post_cycle() write_checkpoint = None if hvd.rank() == 0: write_checkpoint = checkpoint_tracker.calc_write_checkpoint() write_checkpoint = mpi.get_comm_world().bcast(write_checkpoint, root=0) checkpoint_tracker.set_should_write_checkpoint(write_checkpoint) # Code that should be executed by all workers at a checkpoint generation if checkpoint_tracker.should_write_checkpoint(): local_logger.debug( f'Rank: {hvd.rank()} is exchanging screenshots for checkpoint: {expl.frames_compute}' ) screenshots = expl.trajectory_gatherer.env.recursive_getattr( 'rooms') if screenshot_merge == 'mpi': screenshots = flatten_lists( mpi.COMM_WORLD.allgather(screenshots)) merged_dict = {} for screenshot_dict in screenshots: for key, value in screenshot_dict.items(): if key not in merged_dict: merged_dict[key] = value else: after_threshold_screenshot_taken_merged = merged_dict[ key][0] after_threshold_screenshot_taken_current = screenshot_dict[ key][0] if after_threshold_screenshot_taken_current and not after_threshold_screenshot_taken_merged: merged_dict[key] = value if screenshot_merge == 'disk': for key, value in merged_dict.items(): filename = f'{screen_shot_dir}/{key}_{hvd.rank()}.png' os.makedirs(screen_shot_dir, exist_ok=True) if not os.path.isfile(filename): im = Image.fromarray(value[1]) im.save(filename) im_array = imageio.imread(filename) assert (im_array == value[1]).all() mpi.COMM_WORLD.barrier() local_logger.debug('Merging SIL trajectories') sil_trajectories = [expl.prev_selected_traj] if hvd.size() > 1: sil_trajectories = flatten_lists( mpi.COMM_WORLD.allgather(sil_trajectories)) local_logger.debug( f'Rank: {hvd.rank()} is done merging trajectories for checkpoint: {expl.frames_compute}' ) expl.sync_before_checkpoint() local_logger.debug( f'Rank: {hvd.rank()} is done synchronizing for checkpoint: {expl.frames_compute}' ) # Code that should be executed only by the master if hvd.rank() == 0 and not disable_logging: gatherer = expl.trajectory_gatherer return_success_rate = -1 if gatherer.nb_return_goals_chosen > 0: return_success_rate = gatherer.nb_return_goals_reached / gatherer.nb_return_goals_chosen exploration_success_rate = -1 if gatherer.nb_exploration_goals_chosen > 0: exploration_success_rate = gatherer.nb_exploration_goals_reached / gatherer.nb_exploration_goals_chosen cum_success_rate = 0 for reached in expl.archive.cells_reached_dict.values(): success_rate = sum(reached) / len(reached) cum_success_rate += success_rate mean_success_rate = cum_success_rate / len(expl.archive.archive) logger.write('it', checkpoint_tracker.n_iters) logger.write('score', expl.archive.max_score) logger.write('cells', len(expl.archive.archive)) logger.write('ret_suc', return_success_rate) logger.write('exp_suc', exploration_success_rate) logger.write('rew_mean', gatherer.reward_mean) logger.write('len_mean', gatherer.length_mean) logger.write('ep', gatherer.nb_of_episodes) logger.write('arch_suc', mean_success_rate) logger.write('cum_suc', cum_success_rate) logger.write('frames', expl.frames_compute) if len(gatherer.loss_values) > 0: loss_values = np.mean(gatherer.loss_values, axis=0) assert len(loss_values) == len(gatherer.model.loss_names) for (loss_value, loss_name) in zip(loss_values, gatherer.model.loss_names): logger.write(loss_name, loss_value) stored_frames = 0 for traj in expl.archive.cell_trajectory_manager.full_trajectories.values( ): stored_frames += len(traj) logger.write('sil_frames', stored_frames) nb_no_score_cells = len(expl.archive.archive) for weight in expl.archive.cell_selector.selector_weights: if hasattr(weight, 'max_score_dict'): nb_no_score_cells = len(weight.max_score_dict) logger.write('no_score_cells', nb_no_score_cells) cells_found_ret = 0 cells_found_rand = 0 cells_found_policy = 0 for cell_key in expl.archive.archive: cell_info = expl.archive.archive[cell_key] if cell_info.ret_discovered == global_const.EXP_STRAT_NONE: cells_found_ret += 1 elif cell_info.ret_discovered == global_const.EXP_STRAT_RAND: cells_found_rand += 1 elif cell_info.ret_discovered == global_const.EXP_STRAT_POLICY: cells_found_policy += 1 logger.write('cells_found_ret', cells_found_ret) logger.write('cells_found_rand', cells_found_rand) logger.write('cells_found_policy', cells_found_policy) logger.flush() traj_manager = expl.archive.cell_trajectory_manager new_trajectories = sorted( traj_manager.new_trajectories, key=lambda t: traj_manager.cell_trajectories[t].frame_finished) for traj_id in new_trajectories: traj_info = traj_manager.cell_trajectories[traj_id] traj_logger.write('it', checkpoint_tracker.n_iters) traj_logger.write('frame', traj_info.frame_finished) traj_logger.write('exp_strat', traj_info.exp_strat) traj_logger.write('exp_new_cells', traj_info.exp_new_cells) traj_logger.write('ret_new_cells', traj_info.ret_new_cells) traj_logger.write('score', traj_info.score) traj_logger.write('total_actions', traj_info.total_actions) traj_logger.write('id', traj_info.id) traj_logger.flush() # Code that should be executed by only the master at a checkpoint generation if checkpoint_tracker.should_write_checkpoint(): local_logger.info( f'Rank: {hvd.rank()} is writing checkpoint: {expl.frames_compute}' ) filename = f'{log_par.base_path}/{expl.frames_compute:0{log_par.n_digits}}' # Save pictures if len(log_par.save_pictures) > 0: if screenshot_merge == 'disk': for file_name in os.listdir(screen_shot_dir): if file_name.endswith('.png'): room = int(file_name.split('_')[0]) if room not in merged_dict: screen_shot = imageio.imread( f'{screen_shot_dir}/{file_name}') merged_dict[room] = (True, screen_shot) elif screenshot_merge[0:9] == 'from_dir:': for file_name in os.listdir(screen_shot_dir): if file_name.endswith('.png'): room = int(file_name.split('.')[0]) if room not in merged_dict: screen_shot = imageio.imread( f'{screen_shot_dir}/{file_name}') merged_dict[room] = (True, screen_shot) render_pictures(log_par, expl, filename, prev_checkpoint, merged_dict, sil_trajectories) # Save archive state if log_par.save_archive: save_state(expl.get_state(), filename + ARCHIVE_POSTFIX) expl.archive.cell_trajectory_manager.dump(filename + TRAJ_POSTFIX) # Save model if log_par.save_model: expl.trajectory_gatherer.save_model(filename + MODEL_POSTFIX) # Clean up previous checkpoint. if prev_checkpoint: for checkpoint_type in clear_checkpoints: if checkpoint_type in CHECKPOINT_ABBREVIATIONS: postfix = CHECKPOINT_ABBREVIATIONS[checkpoint_type] else: postfix = checkpoint_type with contextlib.suppress(FileNotFoundError): local_logger.debug( f'Removing old checkpoint: {prev_checkpoint + postfix}' ) os.remove(prev_checkpoint + postfix) prev_checkpoint = filename if track_memory: snapshot = tracemalloc.take_snapshot() display_top(snapshot) if PROFILER: local_logger.info( f'ITERATION: {checkpoint_tracker.n_iters}') PROFILER.disable() PROFILER.dump_stats(filename + '.stats') PROFILER.enable() local_logger.info(f'Rank {hvd.rank()} finished experiment') mpi.get_comm_world().barrier()