def train_loop(cfg, agent, logger): curr_iter, max_iter, eval_iter, eval_batch_sz, batch_sz, save_iter =\ cfg.require("current training iter", "max iter", "eval interval", "eval batch size", "batch size", "save interval") training_cfg = ParamDict({ "policy state dict": agent.policy().getStateDict(), "filter state dict": agent.filter().getStateDict(), "trajectory max step": 64, "batch size": batch_sz, "fixed environment": False, "fixed policy": False, "fixed filter": False }) validate_cfg = ParamDict({ "policy state dict": None, "filter state dict": None, "trajectory max step": 64, "batch size": eval_batch_sz, "fixed environment": False, "fixed policy": True, "fixed filter": True }) for i_iter in range(curr_iter, max_iter): s_time = float(running_time(fmt=False)) """sample new batch and perform TRPO update""" batch_train, info_train = agent.rollout(training_cfg) trpo_step(cfg, batch_train, agent.policy()) e_time = float(running_time(fmt=False)) logger.train() info_train["duration"] = e_time - s_time info_train["epoch"] = i_iter logger(info_train) cfg["current training iter"] = i_iter + 1 cfg["policy state dict"] = training_cfg["policy state dict"] = validate_cfg["policy state dict"] = agent.policy().getStateDict() cfg["filter state dict"] = training_cfg["filter state dict"] = validate_cfg["filter state dict"] = agent.filter().getStateDict() if i_iter % eval_iter == 0: batch_eval, info_eval = agent.rollout(validate_cfg) logger.train(False) info_eval["duration"] = e_time - s_time info_eval["epoch"] = i_iter logger(info_eval) if i_iter != 0 and i_iter % save_iter == 0: file_name = os.path.join(model_dir(cfg), f"I_{i_iter}.pkl") cfg.save(file_name) print(f"Saving current step at {file_name}") file_name = os.path.join(model_dir(cfg), f"final.pkl") cfg.save(file_name) print(f"Total running time: {running_time(fmt=True)}, result saved at {file_name}")
def train_loop(cfg, agent, logger): curr_iter, max_iter, eval_iter, eval_batch_sz, batch_sz, save_iter, demo_loader =\ cfg.require("current training iter", "max iter", "eval interval", "eval batch size", "batch size", "save interval", "demo loader") training_cfg = ParamDict({ "policy state dict": agent.policy().getStateDict(), "filter state dict": agent.filter().getStateDict(), "trajectory max step": 1024, "batch size": batch_sz, "fixed environment": False, "fixed policy": False, "fixed filter": False }) validate_cfg = ParamDict({ "policy state dict": None, "filter state dict": None, "trajectory max step": 1024, "batch size": eval_batch_sz, "fixed environment": False, "fixed policy": True, "fixed filter": True }) # we use the entire demo set without sampling demo_trajectory = demo_loader.generate_all() if demo_trajectory is None: print("Warning: No demo loaded, fall back compatible with TRPO method") else: print("Info: Demo loaded successfully") demo_actions = [] demo_states = [] for p in demo_trajectory: demo_actions.append( torch.as_tensor([t['a'] for t in p], dtype=torch.float32, device=agent.policy().device)) demo_states.append( torch.as_tensor([t['s'] for t in p], dtype=torch.float32, device=agent.policy().device)) demo_states = torch.cat(demo_states, dim=0) demo_actions = torch.cat(demo_actions, dim=0) demo_trajectory = (demo_states, demo_actions) for i_iter in range(curr_iter, max_iter): s_time = float(running_time(fmt=False)) """sample new batch and perform MCPO update""" batch_train, info_train = agent.rollout(training_cfg) demo_batch = None if demo_trajectory is not None: filter_dict = agent.filter().getStateDict() errsum, mean, n_step = filter_dict["zfilter errsum"], filter_dict[ "zfilter mean"], filter_dict["zfilter n_step"] errsum = torch.as_tensor(errsum, dtype=torch.float32, device=agent.policy().device) mean = torch.as_tensor(mean, dtype=torch.float32, device=agent.policy().device) std = torch.sqrt(errsum / (n_step - 1)) if n_step > 1 else mean demo_batch = ((demo_trajectory[0] - mean) / (std + 1e-8), demo_trajectory[1]) mcpo_step(cfg, batch_train, agent.policy(), demo_batch) e_time = float(running_time(fmt=False)) logger.train() info_train["duration"] = e_time - s_time info_train["epoch"] = i_iter logger(info_train) cfg["current training iter"] = i_iter + 1 cfg["policy state dict"] = training_cfg[ "policy state dict"] = validate_cfg[ "policy state dict"] = agent.policy().getStateDict() cfg["filter state dict"] = training_cfg[ "filter state dict"] = validate_cfg[ "filter state dict"] = agent.filter().getStateDict() if i_iter % eval_iter == 0: batch_eval, info_eval = agent.rollout(validate_cfg) logger.train(False) info_eval["duration"] = e_time - s_time info_eval["epoch"] = i_iter logger(info_eval) if i_iter != 0 and i_iter % save_iter == 0: file_name = os.path.join(model_dir(cfg), f"I_{i_iter}.pkl") cfg.save(file_name) print(f"Saving current step at {file_name}") file_name = os.path.join(model_dir(cfg), f"final.pkl") cfg.save(file_name) print( f"Total running time: {running_time(fmt=True)}, result saved at {file_name}" )
def train_loop(cfg, agent, logger): curr_iter, max_iter, eval_iter, eval_batch_sz, save_iter, demo_loader =\ cfg.require("current training iter", "max iter", "eval interval", "eval batch size", "save interval", "demo loader") validate_cfg = ParamDict({ "policy state dict": None, "filter state dict": None, "trajectory max step": 64, "batch size": eval_batch_sz, "fixed environment": False, "fixed policy": True, "fixed filter": True }) # we use the entire demo set without sampling demo_trajectory = demo_loader.generate_all() if demo_trajectory is None: raise FileNotFoundError( "Demo file not exists or cannot be loaded, abort !") else: print("Info: Demo loaded successfully") demo_actions = [] demo_states = [] for p in demo_trajectory: demo_actions.append( torch.as_tensor([t['a'] for t in p], dtype=torch.float32, device=agent.policy().device)) demo_states.append( torch.as_tensor([t['s'] for t in p], dtype=torch.float32, device=agent.policy().device)) demo_states = torch.cat(demo_states, dim=0) demo_actions = torch.cat(demo_actions, dim=0) demo_trajectory = (demo_states, demo_actions) for i_iter in range(curr_iter, max_iter): s_time = float(running_time(fmt=False)) """sample new data batch and perform Behavior Cloning update""" loss = bc_step(cfg, agent.policy(), demo_trajectory) e_time = float(running_time(fmt=False)) cfg["current training iter"] = i_iter + 1 cfg["policy state dict"] = validate_cfg[ "policy state dict"] = agent.policy().getStateDict() cfg["filter state dict"] = validate_cfg[ "filter state dict"] = agent.filter().getStateDict() if i_iter % eval_iter == 0: batch_eval, info_eval = agent.rollout(validate_cfg) logger.train(False) info_eval["duration"] = e_time - s_time info_eval["epoch"] = i_iter info_eval["loss"] = loss logger(info_eval) if i_iter != 0 and i_iter % save_iter == 0: file_name = os.path.join(model_dir(cfg), f"I_{i_iter}.pkl") cfg.save(file_name) print(f"Saving current step at {file_name}") file_name = os.path.join(model_dir(cfg), f"final.pkl") cfg.save(file_name) print( f"Total running time: {running_time(fmt=True)}, result saved at {file_name}" )