Example #1
0
def run():
    print("CHANGE exp_name TO NOT OVERRIDE PREV. EXPERIMENTS.")
    config = get_model_args()
    np.random.seed(config.seed)
    torch.manual_seed(config.seed)

    if config.is_train:
        train(config)

    test_experiment(config)
Example #2
0
def test_experiment(config):
    deg = config.deg()
    env = deg.get_env()
    vobs_dim, dof_dim = deg.obs_space[deg.vis_obv_key], deg.obs_space[deg.dof_obv_key] 
    act_dim = deg.action_space

    with torch.no_grad():
        perception_module = PerceptionModule(vobs_dim, dof_dim, config.visual_state_dim).to(device)
        visual_goal_encoder = VisualGoalEncoder(config.visual_state_dim, config.goal_dim).to(device)
        plan_proposer = PlanProposerModule(config.combined_state_dim, config.goal_dim, config.latent_dim).to(device)
        control_module = ControlModule(act_dim, config.combined_state_dim, config.goal_dim, config.latent_dim).to(device)

        perception_module.load_state_dict(torch.load(os.path.join(config.models_save_path, 'perception.pth')))
        visual_goal_encoder.load_state_dict(torch.load(os.path.join(config.models_save_path, 'visual_goal.pth')))
        plan_proposer.load_state_dict(torch.load(os.path.join(config.models_save_path, 'plan_proposer.pth')))
        control_module.load_state_dict(torch.load(os.path.join(config.models_save_path, 'control_module.pth')))

        obvs = env.reset()
        for i in range(config.n_test_evals):
            obvs = env.reset()
            goal = torch.from_numpy(deg.get_random_goal()).float()
            goal = goal.reshape(1, goal.shape[2], goal.shape[0], goal.shape[1])
            goal = perception_module(goal)
            goal, _, _, _ = visual_goal_encoder(goal) 
            t, done = 0, False

            while (not done) and t <= config.max_test_timestep: 
                # TODO : Figure out way to set done tru when goal is reached
                visual_obv, dof_obv = torch.from_numpy(deg._get_obs(obvs, deg.vis_obv_key).float(), torch.from_numpy(deg._get_obs(obvs, deg.dof_obv_key).float()
                visual_obv = visual_obv.reshape(1, visual_obv.shape[2], visual_obv.shape[0], visual_obv.shape[1])
                dof_obv = dof_obv.reshape(1, dof_obv.shape[0])
                state = perception_module(visual_obv, dof_obv)
                z_p, _, _, _ = plan_proposer(state, goal)
                
                action, _ = control_module.step(state, goal, z_p)
                obvs, _, done, _ = env.step(action[0])
                #yield env.render(mode='rgb_array')
                env.render()
                t += 1

if __name__ == '__main__':
    from model_config import get_model_args
    config = get_model_args()
    test_experiment(config)
def imitate_play():
    model_config = get_model_args()
    demons_config = get_demons_args()

    deg = demons_config.deg()
    env = deg.get_env()

    vobs_dim, dof_dim = deg.obs_space[deg.vis_obv_key], deg.obs_space[
        deg.dof_obv_key]
    act_dim = deg.action_space

    with torch.no_grad():
        perception_module = PerceptionModule(
            vobs_dim, dof_dim, model_config.visual_state_dim).to(device)
        imitation_policy = ImitationPolicy(
            act_dim,
            model_config.combined_state_dim,
        ).to(device)

        if demons_config.use_model_perception:
            perception_module.load_state_dict(
                torch.load(
                    os.path.join(model_config.models_save_path,
                                 'perception.pth')))
        else:
            perception_module.load_state_dict(
                torch.load(
                    os.path.join(demons_config.models_save_path,
                                 'perception.pth')))
        imitation_policy.load_state_dict(
            torch.load(
                os.path.join(demons_config.models_save_path,
                             'imitation_policy.pth')))

        for run in range(demons_config.n_gen_traj):
            obs = env.reset()
            tr_vobvs, tr_dof, tr_actions = [], [], []

            for step in range(demon_config.flush_freq):
                visual_obv, dof_obv = torch.from_numpy(
                    obvs[deg.vis_obv_key]).float(), torch.from_numpy(
                        obvs[deg.dof_obv_key]).float()
                visual_obv = visual_obv.reshape(1, visual_obv.shape[2],
                                                visual_obv.shape[0],
                                                visual_obv.shape[1])
                dof_obv = dof_obv.reshape(1, dof_obv.shape[0])
                state = perception_module(visual_obv, dof_obv)

                action, _ = imitation_policy.step(state)

                if int(step % demons_config.collect_freq) == 0:
                    tr_vobvs.append(visual_obv)
                    tr_dof.append(dof_obv)
                    tr_actions.append(action[0])

                obs, _, done, _ = env.step(action[0])

            print('Storing Trajectory')
            trajectory = {
                deg.vis_obv_key: np.array(tr_vobvs),
                deg.dof_obv_key: np.array(tr_dof),
                'action': np.array(tr_actions)
            }
            store_trajectoy(trajectory, 'imitation')
            trajectory, tr_vobvs, tr_dof, tr_actions = {}, [], [], []

        env.close()
            loss.backward()
            optimizer.step()

            tensorboard_writer.add_scalar('loss/total', loss,
                                          epoch * len(dataloader.dataset) + i)

        if int(epoch % config.save_interval_epoch) == 0:
            torch.save(perception_module.state_dict(),
                       os.path.join(config.models_save_path, 'perception.pth'))
            torch.save(
                visual_goal_encoder.state_dict(),
                os.path.join(config.models_save_path, 'visual_goal.pth'))
            torch.save(
                plan_recognizer.state_dict(),
                os.path.join(config.models_save_path, 'plan_recognizer.pth'))
            torch.save(
                plan_proposer.state_dict(),
                os.path.join(config.models_save_path, 'plan_proposer.pth'))
            torch.save(
                control_policy.state_dict(),
                os.path.join(config.models_save_path, 'control_policy.pth'))
            torch.save(optimizer.state_dict(),
                       os.path.join(config.models_save_path, 'optimizer.pth'))


if __name__ == '__main__':
    from model_config import get_model_args
    config = get_model_args()
    torch.manual_seed(config.seed)

    train(config)
def train_imitation(demons_config):
    model_config = get_model_args()

    deg = demons_config.deg(get_episode_type='EPISODE_ROBOT_PLAY')

    vobs_dim, dof_dim = deg.obs_space[deg.vis_obv_key], deg.obs_space[
        deg.dof_obv_key]
    act_dim = deg.action_space

    tensorboard_writer = SummaryWriter(logdir=demons_config.tensorboard_path)
    perception_module = PerceptionModule(
        vobs_dim, dof_dim, model_config.visual_state_dim).to(device)
    imitation_policy = ImitationPolicy(
        act_dim,
        model_config.combined_state_dim,
    ).to(device)

    params = list(perception_module.parameters()) + list(
        imitation_policy.parameters())
    print("Number of parameters : {}".format(len(params)))

    optimizer = torch.optim.Adam(params, lr=model_config.learning_rate)

    if (demons_config.load_models):
        # TODO : IMPORTANT - Check if file exist before loading
        if demons_config.use_model_perception:
            perception_module.load_state_dict(
                torch.load(
                    os.path.join(model_config.models_save_path,
                                 'perception.pth')))
        else:
            perception_module.load_state_dict(
                torch.load(
                    os.path.join(demons_config.models_save_path,
                                 'perception.pth')))
        imitation_policy.load_state_dict(
            torch.load(
                os.path.join(demons_config.models_save_path,
                             'imitation_policy.pth')))
        optimizer.load_state_dict(
            torch.load(
                os.path.join(demons_config.models_save_path, 'optimizer.pth')))

    print("Run : tensorboard --logdir={} --host '0.0.0.0' --port 6006".format(
        demons_config.tensorboard_path))
    data_loader = DataLoader(deg.traj_dataset,
                             batch_size=model_config.batch_size,
                             shuffle=True,
                             num_workers=1)
    max_step_size = len(data_loader.dataset)

    for epoch in tqdm(range(model_config.max_epochs),
                      desc="Check Tensorboard"):
        for i, trajectory in enumerate(data_loader):
            trajectory = {
                key: trajectory[key].float().to(device)
                for key in trajectory.keys()
            }
            visual_obvs, dof_obs, action = trajectory[
                deg.vis_obv_key], trajectory[
                    deg.dof_obv_key], trajectory['action']
            batch_size, seq_len = visual_obvs.shape[0], visual_obvs.shape[1]

            visual_obvs = visual_obvs.reshape(batch_size * seq_len,
                                              vobs_dim[2], vobs_dim[0],
                                              vobs_dim[1])
            dof_obs = dof_obs.reshape(batch_size * seq_len, dof_dim)
            actions = trajectory['action'].reshape(batch_size * seq_len, -1)

            states = perception_module(
                visual_obvs, dof_obs)  # DEBUG : Might raise in-place errors

            pi, logp_a = imitation_policy(state=states, action=actions)

            optimizer.zero_grad()
            loss = -logp_a
            loss = loss.mean()

            tensorboard_writer.add_scalar('Clone Loss', loss,
                                          epoch * max_step_size + i)
            loss.backward()
            optimizer.step()

            if int(i % model_config.save_interval) == 0:
                if not demons_config.use_model_perception:
                    torch.save(
                        perception_module.state_dict(),
                        os.path.join(demons_config.models_save_path,
                                     'perception.pth'))
                torch.save(
                    imitation_policy.state_dict(),
                    os.path.join(demons_config.models_save_path,
                                 'imitation_policy.pth'))
                torch.save(
                    optimizer.state_dict(),
                    os.path.join(demons_config.models_save_path,
                                 'optimizer.pth'))