def run(): print("CHANGE exp_name TO NOT OVERRIDE PREV. EXPERIMENTS.") config = get_model_args() np.random.seed(config.seed) torch.manual_seed(config.seed) if config.is_train: train(config) test_experiment(config)
def test_experiment(config): deg = config.deg() env = deg.get_env() vobs_dim, dof_dim = deg.obs_space[deg.vis_obv_key], deg.obs_space[deg.dof_obv_key] act_dim = deg.action_space with torch.no_grad(): perception_module = PerceptionModule(vobs_dim, dof_dim, config.visual_state_dim).to(device) visual_goal_encoder = VisualGoalEncoder(config.visual_state_dim, config.goal_dim).to(device) plan_proposer = PlanProposerModule(config.combined_state_dim, config.goal_dim, config.latent_dim).to(device) control_module = ControlModule(act_dim, config.combined_state_dim, config.goal_dim, config.latent_dim).to(device) perception_module.load_state_dict(torch.load(os.path.join(config.models_save_path, 'perception.pth'))) visual_goal_encoder.load_state_dict(torch.load(os.path.join(config.models_save_path, 'visual_goal.pth'))) plan_proposer.load_state_dict(torch.load(os.path.join(config.models_save_path, 'plan_proposer.pth'))) control_module.load_state_dict(torch.load(os.path.join(config.models_save_path, 'control_module.pth'))) obvs = env.reset() for i in range(config.n_test_evals): obvs = env.reset() goal = torch.from_numpy(deg.get_random_goal()).float() goal = goal.reshape(1, goal.shape[2], goal.shape[0], goal.shape[1]) goal = perception_module(goal) goal, _, _, _ = visual_goal_encoder(goal) t, done = 0, False while (not done) and t <= config.max_test_timestep: # TODO : Figure out way to set done tru when goal is reached visual_obv, dof_obv = torch.from_numpy(deg._get_obs(obvs, deg.vis_obv_key).float(), torch.from_numpy(deg._get_obs(obvs, deg.dof_obv_key).float() visual_obv = visual_obv.reshape(1, visual_obv.shape[2], visual_obv.shape[0], visual_obv.shape[1]) dof_obv = dof_obv.reshape(1, dof_obv.shape[0]) state = perception_module(visual_obv, dof_obv) z_p, _, _, _ = plan_proposer(state, goal) action, _ = control_module.step(state, goal, z_p) obvs, _, done, _ = env.step(action[0]) #yield env.render(mode='rgb_array') env.render() t += 1 if __name__ == '__main__': from model_config import get_model_args config = get_model_args() test_experiment(config)
def imitate_play(): model_config = get_model_args() demons_config = get_demons_args() deg = demons_config.deg() env = deg.get_env() vobs_dim, dof_dim = deg.obs_space[deg.vis_obv_key], deg.obs_space[ deg.dof_obv_key] act_dim = deg.action_space with torch.no_grad(): perception_module = PerceptionModule( vobs_dim, dof_dim, model_config.visual_state_dim).to(device) imitation_policy = ImitationPolicy( act_dim, model_config.combined_state_dim, ).to(device) if demons_config.use_model_perception: perception_module.load_state_dict( torch.load( os.path.join(model_config.models_save_path, 'perception.pth'))) else: perception_module.load_state_dict( torch.load( os.path.join(demons_config.models_save_path, 'perception.pth'))) imitation_policy.load_state_dict( torch.load( os.path.join(demons_config.models_save_path, 'imitation_policy.pth'))) for run in range(demons_config.n_gen_traj): obs = env.reset() tr_vobvs, tr_dof, tr_actions = [], [], [] for step in range(demon_config.flush_freq): visual_obv, dof_obv = torch.from_numpy( obvs[deg.vis_obv_key]).float(), torch.from_numpy( obvs[deg.dof_obv_key]).float() visual_obv = visual_obv.reshape(1, visual_obv.shape[2], visual_obv.shape[0], visual_obv.shape[1]) dof_obv = dof_obv.reshape(1, dof_obv.shape[0]) state = perception_module(visual_obv, dof_obv) action, _ = imitation_policy.step(state) if int(step % demons_config.collect_freq) == 0: tr_vobvs.append(visual_obv) tr_dof.append(dof_obv) tr_actions.append(action[0]) obs, _, done, _ = env.step(action[0]) print('Storing Trajectory') trajectory = { deg.vis_obv_key: np.array(tr_vobvs), deg.dof_obv_key: np.array(tr_dof), 'action': np.array(tr_actions) } store_trajectoy(trajectory, 'imitation') trajectory, tr_vobvs, tr_dof, tr_actions = {}, [], [], [] env.close()
loss.backward() optimizer.step() tensorboard_writer.add_scalar('loss/total', loss, epoch * len(dataloader.dataset) + i) if int(epoch % config.save_interval_epoch) == 0: torch.save(perception_module.state_dict(), os.path.join(config.models_save_path, 'perception.pth')) torch.save( visual_goal_encoder.state_dict(), os.path.join(config.models_save_path, 'visual_goal.pth')) torch.save( plan_recognizer.state_dict(), os.path.join(config.models_save_path, 'plan_recognizer.pth')) torch.save( plan_proposer.state_dict(), os.path.join(config.models_save_path, 'plan_proposer.pth')) torch.save( control_policy.state_dict(), os.path.join(config.models_save_path, 'control_policy.pth')) torch.save(optimizer.state_dict(), os.path.join(config.models_save_path, 'optimizer.pth')) if __name__ == '__main__': from model_config import get_model_args config = get_model_args() torch.manual_seed(config.seed) train(config)
def train_imitation(demons_config): model_config = get_model_args() deg = demons_config.deg(get_episode_type='EPISODE_ROBOT_PLAY') vobs_dim, dof_dim = deg.obs_space[deg.vis_obv_key], deg.obs_space[ deg.dof_obv_key] act_dim = deg.action_space tensorboard_writer = SummaryWriter(logdir=demons_config.tensorboard_path) perception_module = PerceptionModule( vobs_dim, dof_dim, model_config.visual_state_dim).to(device) imitation_policy = ImitationPolicy( act_dim, model_config.combined_state_dim, ).to(device) params = list(perception_module.parameters()) + list( imitation_policy.parameters()) print("Number of parameters : {}".format(len(params))) optimizer = torch.optim.Adam(params, lr=model_config.learning_rate) if (demons_config.load_models): # TODO : IMPORTANT - Check if file exist before loading if demons_config.use_model_perception: perception_module.load_state_dict( torch.load( os.path.join(model_config.models_save_path, 'perception.pth'))) else: perception_module.load_state_dict( torch.load( os.path.join(demons_config.models_save_path, 'perception.pth'))) imitation_policy.load_state_dict( torch.load( os.path.join(demons_config.models_save_path, 'imitation_policy.pth'))) optimizer.load_state_dict( torch.load( os.path.join(demons_config.models_save_path, 'optimizer.pth'))) print("Run : tensorboard --logdir={} --host '0.0.0.0' --port 6006".format( demons_config.tensorboard_path)) data_loader = DataLoader(deg.traj_dataset, batch_size=model_config.batch_size, shuffle=True, num_workers=1) max_step_size = len(data_loader.dataset) for epoch in tqdm(range(model_config.max_epochs), desc="Check Tensorboard"): for i, trajectory in enumerate(data_loader): trajectory = { key: trajectory[key].float().to(device) for key in trajectory.keys() } visual_obvs, dof_obs, action = trajectory[ deg.vis_obv_key], trajectory[ deg.dof_obv_key], trajectory['action'] batch_size, seq_len = visual_obvs.shape[0], visual_obvs.shape[1] visual_obvs = visual_obvs.reshape(batch_size * seq_len, vobs_dim[2], vobs_dim[0], vobs_dim[1]) dof_obs = dof_obs.reshape(batch_size * seq_len, dof_dim) actions = trajectory['action'].reshape(batch_size * seq_len, -1) states = perception_module( visual_obvs, dof_obs) # DEBUG : Might raise in-place errors pi, logp_a = imitation_policy(state=states, action=actions) optimizer.zero_grad() loss = -logp_a loss = loss.mean() tensorboard_writer.add_scalar('Clone Loss', loss, epoch * max_step_size + i) loss.backward() optimizer.step() if int(i % model_config.save_interval) == 0: if not demons_config.use_model_perception: torch.save( perception_module.state_dict(), os.path.join(demons_config.models_save_path, 'perception.pth')) torch.save( imitation_policy.state_dict(), os.path.join(demons_config.models_save_path, 'imitation_policy.pth')) torch.save( optimizer.state_dict(), os.path.join(demons_config.models_save_path, 'optimizer.pth'))