def __init__(self, config=None): super().__init__(config) self.actor_critic = None self.agent = None self.envs = None if config is not None: logger.info(f"config: {config}")
def run_exp(exp_config: str, run_type: str, opts=None) -> None: r"""Runs experiment given mode and config Args: exp_config: path to config file. run_type: "train" or "eval. opts: list of strings of additional config options. Returns: None. """ config = get_config(exp_config, opts) logger.info(f"config: {config}") logger.add_filehandler(config.LOG_FILE) random.seed(config.TASK_CONFIG.SEED) np.random.seed(config.TASK_CONFIG.SEED) torch.manual_seed(config.TASK_CONFIG.SEED) torch.backends.cudnn.benchmark = True if run_type == "eval" and config.EVAL.EVAL_NONLEARNING: evaluate_agent(config) return trainer_init = baseline_registry.get_trainer(config.TRAINER_NAME) assert trainer_init is not None, f"{config.TRAINER_NAME} is not supported" trainer = trainer_init(config) if run_type == "train": trainer.train() elif run_type == "eval": trainer.eval()
def _setup_auxiliary_tasks(self, aux_cfg, ppo_cfg, task_cfg, observation_space=None, is_eval=False): aux_task_strings = [task.lower() for task in aux_cfg.tasks] # Differentiate instances of tasks by adding letters aux_counts = {} for i, x in enumerate(aux_task_strings): if x in aux_counts: aux_task_strings[i] = f"{aux_task_strings[i]}_{aux_counts[x]}" aux_counts[x] += 1 else: aux_counts[x] = 1 logger.info(f"Auxiliary tasks: {aux_task_strings}") num_recurrent_memories = 1 if ppo_cfg.policy in MULTIPLE_BELIEF_CLASSES: num_recurrent_memories = len(aux_cfg.tasks) init_aux_tasks = [] if not is_eval: for task in aux_cfg.tasks: task_class = get_aux_task_class(task) aux_module = task_class( ppo_cfg, aux_cfg[task], task_cfg, self.device, observation_space=observation_space).to(self.device) init_aux_tasks.append(aux_module) return init_aux_tasks, num_recurrent_memories, aux_task_strings
def _setup_eval_config(self, checkpoint_config: Config) -> Config: r"""Sets up and returns a merged config for evaluation. Config object saved from checkpoint is merged into config file specified at evaluation time with the following overwrite priority: eval_opts > ckpt_opts > eval_cfg > ckpt_cfg If the saved config is outdated, only the eval config is returned. Args: checkpoint_config: saved config from checkpoint. Returns: Config: merged config for eval. """ config = self.config.clone() ckpt_cmd_opts = checkpoint_config.CMD_TRAILING_OPTS eval_cmd_opts = config.CMD_TRAILING_OPTS try: config.merge_from_other_cfg(checkpoint_config) config.merge_from_other_cfg(self.config) config.merge_from_list(ckpt_cmd_opts) config.merge_from_list(eval_cmd_opts) except KeyError: logger.info("Saved config is outdated, using solely eval config") config = self.config.clone() config.merge_from_list(eval_cmd_opts) config.defrost() if config.TASK_CONFIG.DATASET.SPLIT == "train": config.TASK_CONFIG.DATASET.SPLIT = "val" config.TASK_CONFIG.SIMULATOR.AGENT_0.SENSORS = self.config.SENSORS config.freeze() return config
def __init__(self, config=None): super().__init__(config) self.actor_critic = None self.agent = None self.envs = None self.obs_transforms = [] self.rms = RunningMeanStd() if config is not None: logger.info(f"config: {config}") self.checkpoint_prefix = config.TENSORBOARD_DIR.split('/')[-1] self._static_encoder = False self._encoder = None def get_color_distortion(s=1.0): # s is the strength of color distortion. color_jitter = transforms.ColorJitter(0.8*s, 0.8*s, 0.8*s, 0.4*s) rnd_color_jitter = transforms.RandomApply([color_jitter], p=0.8) rnd_gray = transforms.RandomGrayscale(p=0.2) color_distort = transforms.Compose([ rnd_color_jitter, rnd_gray]) return color_distort self.color_transform = get_color_distortion()
def __init__(self, config=None): if config is not None: self._synchronize_configs(config) super().__init__(config) # Set pytorch random seed for initialization random.seed(config.PYT_RANDOM_SEED) np.random.seed(config.PYT_RANDOM_SEED) torch.manual_seed(config.PYT_RANDOM_SEED) if torch.cuda.is_available(): torch.backends.cudnn.deterministic = True torch.backends.cudnn.benchmark = False # Disable cuDNN to prevent: # RuntimeError: cuDNN error: CUDNN_STATUS_NOT_SUPPORTED. # This error may appear if you passed in a non-contiguous input. torch.backends.cudnn.enabled = False self.mapper = None self.local_actor_critic = None self.ans_net = None self.planner = None self.envs = None if config is not None: logger.info(f"config: {config}")
def _setup_actor_critic_agent(self, config: Config, load_from_ckpt: bool, ckpt_path: str) -> None: r"""Sets up actor critic and agent. Args: config: MODEL config Returns: None """ config.defrost() config.TORCH_GPU_ID = self.config.TORCH_GPU_ID config.freeze() if config.CMA.use: self.actor_critic = CMAPolicy( observation_space=self.envs.observation_spaces[0], action_space=self.envs.action_spaces[0], model_config=config, ) else: self.actor_critic = Seq2SeqPolicy( observation_space=self.envs.observation_spaces[0], action_space=self.envs.action_spaces[0], model_config=config, ) self.actor_critic.to(self.device) self.optimizer = torch.optim.Adam(self.actor_critic.parameters(), lr=self.config.DAGGER.LR) if load_from_ckpt: ckpt_dict = self.load_checkpoint(ckpt_path, map_location="cpu") self.actor_critic.load_state_dict(ckpt_dict["state_dict"]) logger.info(f"Loaded weights from checkpoint: {ckpt_path}") logger.info("Finished setting up actor critic model.")
def load_resume_state(filename_or_config: Union[Config, str], load_ckpt: bool) -> Optional[Any]: r"""Loads the saved resume state :param filename_or_config: The filename of the saved state or the config to construct it. :return: The saved state if the file exists, else none """ if isinstance(filename_or_config, Config): filename = resume_state_filename(filename_or_config) cfg = filename_or_config found_f = None for f in os.listdir(cfg.CHECKPOINT_FOLDER): if cfg.PREFIX in f: found_f = f if found_f is not None: filename = osp.join(cfg.CHECKPOINT_FOLDER, found_f, RESUME_STATE_BASE_NAME + ".pth") elif os.path.isfile(cfg.EVAL_CKPT_PATH_DIR) and load_ckpt: filename = cfg.EVAL_CKPT_PATH_DIR else: filename = filename_or_config print('Trying to resume state from', filename) if not osp.exists(filename): return None if rank0_only(): logger.info(f"Loading resume state: {filename}") return torch.load(filename, map_location="cpu")
def requeue_job(): r"""Requeue the job by calling ``scontrol requeue ${SLURM_JOBID}``""" if SLURM_JOBID is None: return if os.environ['SLURM_PROCID'] == '0' and os.getpid() == MAIN_PID: logger.info(f"Requeueing job {SLURM_JOBID}") subprocess.check_call(shlex.split(f"scontrol requeue {SLURM_JOBID}"))
def __init__(self, config=None): super().__init__(config) self.device = (torch.device("cuda", self.config.TORCH_GPU_ID) if torch.cuda.is_available() else torch.device("cpu")) if config is not None: logger.info(f"config: {config}")
def eval(self) -> None: r"""Main method of trainer evaluation. Calls _eval_checkpoint() that is specified in Trainer class that inherits from BaseRLTrainer or BaseILTrainer Returns: None """ self.device = ( torch.device("cuda", self.config.TORCH_GPU_ID) if torch.cuda.is_available() else torch.device("cpu") ) if "tensorboard" in self.config.VIDEO_OPTION: assert ( len(self.config.TENSORBOARD_DIR) > 0 ), "Must specify a tensorboard directory for video display" os.makedirs(self.config.TENSORBOARD_DIR, exist_ok=True) if "disk" in self.config.VIDEO_OPTION: assert ( len(self.config.VIDEO_DIR) > 0 ), "Must specify a directory for storing videos on disk" with TensorboardWriter( self.config.TENSORBOARD_DIR, flush_secs=self.flush_secs ) as writer: if os.path.isfile(self.config.EVAL_CKPT_PATH_DIR): # evaluate singe checkpoint proposed_index = get_checkpoint_id( self.config.EVAL_CKPT_PATH_DIR ) if proposed_index is not None: ckpt_idx = proposed_index else: ckpt_idx = 0 self._eval_checkpoint( self.config.EVAL_CKPT_PATH_DIR, writer, checkpoint_index=ckpt_idx, ) else: # evaluate multiple checkpoints in order prev_ckpt_ind = -1 while True: current_ckpt = None while current_ckpt is None: current_ckpt = poll_checkpoint_folder( self.config.EVAL_CKPT_PATH_DIR, prev_ckpt_ind ) time.sleep(2) # sleep for 2 secs before polling again logger.info(f"=======current_ckpt: {current_ckpt}=======") prev_ckpt_ind += 1 self._eval_checkpoint( checkpoint_path=current_ckpt, writer=writer, checkpoint_index=prev_ckpt_ind, )
def eval(self, eval_ckpt=None, log_diagnostics=[], output_dir=".", label="eval") -> None: r"""Main method of trainer evaluation. Calls _eval_checkpoint() that is specified in Trainer class that inherits from BaseRLTrainer Returns: None """ self.device = (torch.device("cuda", self.config.TORCH_GPU_ID) if torch.cuda.is_available() else torch.device("cpu")) if "tensorboard" in self.config.VIDEO_OPTION: assert (len(self.config.TENSORBOARD_DIR) > 0 ), "Must specify a tensorboard directory for video display" if "disk" in self.config.VIDEO_OPTION: assert (len(self.config.VIDEO_DIR) > 0), "Must specify a directory for storing videos on disk" with TensorboardWriter(self.config.TENSORBOARD_DIR, flush_secs=self.flush_secs) as writer: if eval_ckpt is not None: # evaluate a single checkpoint from path ckpt_index = os.path.split(eval_ckpt)[1].split(".")[-2] self._eval_checkpoint(eval_ckpt, writer, checkpoint_index=ckpt_index, log_diagnostics=log_diagnostics, output_dir=output_dir, label=label) else: if os.path.isfile(self.config.EVAL_CKPT_PATH_DIR): # evaluate singe checkpoint # parse checkpoint from filename ckpt_index = self.config.EVAL_CKPT_PATH_DIR.split('.')[-2] self._eval_checkpoint(self.config.EVAL_CKPT_PATH_DIR, writer, checkpoint_index=ckpt_index) else: # evaluate multiple checkpoints in order prev_ckpt_ind = -1 while True: current_ckpt = None while current_ckpt is None: current_ckpt = poll_checkpoint_folder( self.config.EVAL_CKPT_PATH_DIR, prev_ckpt_ind) time.sleep( 2) # sleep for 2 secs before polling again logger.info( f"=======current_ckpt: {current_ckpt}=======") prev_ckpt_ind += 1 self._eval_checkpoint( checkpoint_path=current_ckpt, writer=writer, checkpoint_index=prev_ckpt_ind, )
def create_tar_archive(archive_path: str, dataset_path: str) -> None: """Creates tar archive of dataset and returns status code. Used in VQA trainer's webdataset. """ logger.info("[ Creating tar archive. This will take a few minutes. ]") with tarfile.open(archive_path, "w:gz") as tar: for file in sorted(os.listdir(dataset_path)): tar.add(os.path.join(dataset_path, file))
def __init__(self, config=None): super().__init__(config) self.actor_critic = None self.agent = None self.envs = None self.device = None self.video_option = [] if config is not None: logger.info(f"config: {config}")
def _setup_actor_critic_agent(self, ppo_cfg: Config) -> None: r"""Sets up actor critic and agent for PPO. Args: ppo_cfg: config node with relevant params Returns: None """ logger.add_filehandler(self.config.LOG_FILE) model_cfg = self.config.MODEL model_cfg.defrost() model_cfg.TORCH_GPU_ID = self.config.TORCH_GPU_ID model_cfg.freeze() assert model_cfg.POLICY in SUPPORTED_POLICIES, \ f"{model_cfg.POLICY} not in {SUPPORTED_POLICIES}" if model_cfg.POLICY == "seq2seq": self.actor_critic = Seq2SeqPolicy( observation_space=self.envs.observation_spaces[0], action_space=self.envs.action_spaces[0], model_config=self.config.MODEL, ) elif model_cfg.POLICY == "cma": self.actor_critic = CMAPolicy( observation_space=self.envs.observation_spaces[0], action_space=self.envs.action_spaces[0], model_config=self.config.MODEL, ) self.actor_critic.to(self.device) self.agent = PPO( actor_critic=self.actor_critic, clip_param=ppo_cfg.clip_param, ppo_epoch=ppo_cfg.ppo_epoch, num_mini_batch=ppo_cfg.num_mini_batch, value_loss_coef=ppo_cfg.value_loss_coef, entropy_coef=ppo_cfg.entropy_coef, lr=ppo_cfg.lr, eps=ppo_cfg.eps, max_grad_norm=ppo_cfg.max_grad_norm, use_normalized_advantage=ppo_cfg.use_normalized_advantage, ) if self.config.LOAD_FROM_CKPT: ckpt_dict = self.load_checkpoint(self.config.LOAD_CKPT_PATH, map_location="cpu") self.actor_critic.load_state_dict(ckpt_dict["state_dict_ac"]) self.agent.load_state_dict(ckpt_dict["state_dict_agent"]) logger.info( f"Loaded weights from checkpoint: {self.config.LOAD_CKPT_PATH}" ) logger.info("Finished setting up actor critic model.")
def __init__(self, config=None): super().__init__(config) self.actor_critic = None self.agent = None self.envs = None if config is not None: logger.info(f"config: {config}") self.checkpoint_prefix = config.TENSORBOARD_DIR.split('/')[-1] self._static_encoder = False self._encoder = None
def construct_envs(args): env_configs = [] baseline_configs = [] basic_config = cfg_env(config_paths=args.task_config, opts=args.opts) dataset = make_dataset(basic_config.DATASET.TYPE) scenes = dataset.get_scenes_to_load(basic_config.DATASET) if len(scenes) > 0: random.shuffle(scenes) assert len(scenes) >= args.num_processes, ( "reduce the number of processes as there " "aren't enough number of scenes" ) scene_split_size = int(np.floor(len(scenes) / args.num_processes)) scene_splits = [[] for _ in range(args.num_processes)] for j, s in enumerate(scenes): scene_splits[j % len(scene_splits)].append(s) assert sum(map(len, scene_splits)) == len(scenes) for i in range(args.num_processes): config_env = cfg_env(config_paths=args.task_config, opts=args.opts) config_env.defrost() if len(scenes) > 0: config_env.DATASET.CONTENT_SCENES = scene_splits[i] config_env.SIMULATOR.HABITAT_SIM_V0.GPU_DEVICE_ID = args.sim_gpu_id agent_sensors = args.sensors.strip().split(",") for sensor in agent_sensors: assert sensor in ["RGB_SENSOR", "DEPTH_SENSOR"] config_env.SIMULATOR.AGENT_0.SENSORS = agent_sensors config_env.freeze() env_configs.append(config_env) config_baseline = cfg_baseline() baseline_configs.append(config_baseline) logger.info("config_env: {}".format(config_env)) envs = habitat.VectorEnv( make_env_fn=make_env_fn, env_fn_args=tuple( tuple( zip(env_configs, baseline_configs, range(args.num_processes)) ) ), ) return envs
def __init__(self, config=None): super().__init__(config) self.actor_critic = None self.agent = None self.envs = None self.obs_transforms = [] if config is not None: logger.info(f"config: {config}") self._static_encoder = False self._encoder = None
def requeue_job(): r"""Requeues the job by calling ``scontrol requeue ${SLURM_JOBID}``""" if SLURM_JOBID is None: return if not REQUEUE.is_set(): return distrib.barrier() if distrib.get_rank() == 0: logger.info(f"Requeueing job {SLURM_JOBID}") subprocess.check_call(shlex.split(f"scontrol requeue {SLURM_JOBID}"))
def eval(self, eval_interval=1, prev_ckpt_ind=-1, use_last_ckpt=False) -> None: r"""Main method of trainer evaluation. Calls _eval_checkpoint() that is specified in Trainer class that inherits from BaseRLTrainer Returns: None """ self.device = (torch.device("cuda", self.config.TORCH_GPU_ID) if torch.cuda.is_available() else torch.device("cpu")) if "tensorboard" in self.config.VIDEO_OPTION: assert (len(self.config.TENSORBOARD_DIR) > 0 ), "Must specify a tensorboard directory for video display" if "disk" in self.config.VIDEO_OPTION: assert (len(self.config.VIDEO_DIR) > 0), "Must specify a directory for storing videos on disk" with TensorboardWriter(self.config.TENSORBOARD_DIR, flush_secs=self.flush_secs) as writer: # eval last checkpoint in the folder if use_last_ckpt: models_paths = list( filter(os.path.isfile, glob.glob(self.config.EVAL_CKPT_PATH_DIR + "/*"))) models_paths.sort(key=os.path.getmtime) self.config.defrost() self.config.EVAL_CKPT_PATH_DIR = models_paths[-1] self.config.freeze() if os.path.isfile(self.config.EVAL_CKPT_PATH_DIR): # evaluate singe checkpoint result = self._eval_checkpoint(self.config.EVAL_CKPT_PATH_DIR, writer) return result else: # evaluate multiple checkpoints in order while True: current_ckpt = None while current_ckpt is None: current_ckpt = poll_checkpoint_folder( self.config.EVAL_CKPT_PATH_DIR, prev_ckpt_ind, eval_interval) time.sleep(2) # sleep for 2 secs before polling again logger.info(f"=======current_ckpt: {current_ckpt}=======") prev_ckpt_ind += eval_interval self._eval_checkpoint(checkpoint_path=current_ckpt, writer=writer, checkpoint_index=prev_ckpt_ind)
def requeue_job(): r"""Requeues the job by calling ``scontrol requeue ${SLURM_JOBID}``""" if not is_slurm_batch_job(): return if not REQUEUE.is_set(): return if distrib.is_initialized(): distrib.barrier() if rank0_only(): logger.info(f"Requeueing job {SLURM_JOBID}") subprocess.check_call(["scontrol", "requeue", str(SLURM_JOBID)])
def transform_observation_space(self, observation_space, trans_keys=["rgb", "depth", "semantic"]): size = self._size observation_space = copy.deepcopy(observation_space) if size: for key in observation_space.spaces: if (key in trans_keys and observation_space.spaces[key].shape != size): logger.info("Overwriting CNN input size of %s: %s" % (key, size)) observation_space.spaces[key] = overwrite_gym_box_shape( observation_space.spaces[key], size) self.observation_space = observation_space return observation_space
def __init__(self, config=None): super().__init__(config) self.actor_critic = None self.agent = None self.envs = None if config is not None: logger.info(f"config: {config}") self._static_encoder = False self._encoder = None self.last_observations = None self.last_recurrent_hidden_states = None self.last_prev_actions = None self.last_masks = None
def __init__(self, config=None): if config is not None: self._synchronize_configs(config) super().__init__(config) # Set pytorch random seed for initialization torch.manual_seed(config.PYT_RANDOM_SEED) # initial the mapper and mapper_agent for training object self.mapper = None self.mapper_agent = None self.envs = None if config is not None: logger.info(f"config: {config}")
def __call__(self, args: EvaluatorArgs, prev_ckpt_ind=-1, num_frames=0): logger.info("CUDA_VISIBLE_DEVICES: {}".format( os.environ["CUDA_VISIBLE_DEVICES"])) logger.info("Hostname: {}".format(socket.gethostname())) config = get_config(args.exp_config, args.opts) random.seed(config.TASK_CONFIG.SEED) np.random.seed(config.TASK_CONFIG.SEED) trainer_init = baseline_registry.get_trainer(config.TRAINER_NAME) assert trainer_init is not None, f"{config.TRAINER_NAME} is not supported" self.trainer = trainer_init(config) self.trainer.prev_ckpt_ind = prev_ckpt_ind self.trainer.num_frames = num_frames self.trainer.eval()
def evaluate_agent(config: Config) -> None: split = config.EVAL.SPLIT config.defrost() # turn off RGBD rendering as neither RandomAgent nor HandcraftedAgent use it. config.TASK_CONFIG.SIMULATOR.AGENT_0.SENSORS = [] config.TASK_CONFIG.TASK.SENSORS = [] config.TASK_CONFIG.ENVIRONMENT.ITERATOR_OPTIONS.SHUFFLE = False config.TASK_CONFIG.ENVIRONMENT.ITERATOR_OPTIONS.MAX_SCENE_REPEAT_STEPS = -1 config.TASK_CONFIG.DATASET.SPLIT = split config.TASK_CONFIG.TASK.NDTW.SPLIT = split config.TASK_CONFIG.TASK.SDTW.SPLIT = split config.freeze() env = Env(config=config.TASK_CONFIG) assert config.EVAL.NONLEARNING.AGENT in [ "RandomAgent", "HandcraftedAgent", ], "EVAL.NONLEARNING.AGENT must be either RandomAgent or HandcraftedAgent." if config.EVAL.NONLEARNING.AGENT == "RandomAgent": agent = RandomAgent() else: agent = HandcraftedAgent() stats = defaultdict(float) num_episodes = min(config.EVAL.EPISODE_COUNT, len(env.episodes)) for i in tqdm(range(num_episodes)): obs = env.reset() agent.reset() while not env.episode_over: action = agent.act(obs) obs = env.step(action) for m, v in env.get_metrics().items(): stats[m] += v stats = {k: v / num_episodes for k, v in stats.items()} logger.info(f"Averaged benchmark for {config.EVAL.NONLEARNING.AGENT}:") for stat_key in stats.keys(): logger.info("{}: {:.3f}".format(stat_key, stats[stat_key])) with open(f"stats_{config.EVAL.NONLEARNING.AGENT}_{split}.json", "w") as f: json.dump(stats, f, indent=4)
def construct_envs(args): env_configs = [] baseline_configs = [] basic_config = cfg_env(config_file=args.task_config) scenes = PointNavDatasetV1.get_scenes_to_load(basic_config.DATASET) if len(scenes) > 0: random.shuffle(scenes) assert len(scenes) >= args.num_processes, ( "reduce the number of processes as there " "aren't enough number of scenes") scene_split_size = int(np.floor(len(scenes) / args.num_processes)) for i in range(args.num_processes): config_env = cfg_env(config_file=args.task_config) config_env.defrost() if len(scenes) > 0: config_env.DATASET.POINTNAVV1.CONTENT_SCENES = scenes[ i * scene_split_size:(i + 1) * scene_split_size] config_env.SIMULATOR.HABITAT_SIM_V0.GPU_DEVICE_ID = args.sim_gpu_id agent_sensors = args.sensors.strip().split(",") for sensor in agent_sensors: assert sensor in ["RGB_SENSOR", "DEPTH_SENSOR"] config_env.SIMULATOR.AGENT_0.SENSORS = agent_sensors config_env.freeze() env_configs.append(config_env) config_baseline = cfg_baseline() baseline_configs.append(config_baseline) logger.info("config_env: {}".format(config_env)) envs = habitat.VectorEnv( make_env_fn=make_env_fn, env_fn_args=tuple( tuple(zip(env_configs, baseline_configs, range(args.num_processes)))), ) return envs
def load_resume_state(filename_or_config: Union[Config, str]) -> Optional[Any]: r"""Loads the saved resume state :param filename_or_config: The filename of the saved state or the config to construct it. :return: The saved state if the file exists, else none """ if isinstance(filename_or_config, Config): filename = resume_state_filename(filename_or_config) else: filename = filename_or_config if not osp.exists(filename): return None if rank0_only(): logger.info(f"Loading resume state: {filename}") return torch.load(filename, map_location="cpu")
def __init__(self, config=None): if config is not None: self._synchronize_configs(config) super().__init__(config) # Set pytorch random seed for initialization random.seed(config.PYT_RANDOM_SEED) np.random.seed(config.PYT_RANDOM_SEED) torch.manual_seed(config.PYT_RANDOM_SEED) if torch.cuda.is_available(): torch.backends.cudnn.deterministic = True torch.backends.cudnn.benchmark = False self.mapper = None self.local_actor_critic = None self.ans_net = None self.planner = None self.envs = None if config is not None: logger.info(f"config: {config}")
def reset(self): # print(f'{self.ep} reset {self.step}') # We don't reset state because our rnn accounts for masks, and ignore actions because we don't use actions self.not_done_masks = torch.zeros(1, 1, device=self.device, dtype=torch.bool) self.prev_actions = torch.zeros(1, 1, dtype=torch.long, device=self.device) self.test_recurrent_hidden_states = torch.zeros( self.model_cfg.STATE_ENCODER.num_recurrent_layers, 1, # num_processes self.model_cfg.STATE_ENCODER.hidden_size, device=self.device, ) # self.step = 0 self.ep += 1 logger.info("Episode done: {}".format(self.ep))