Beispiel #1
0
def main():
    # Setup Logging
    log_dir = "{}/models/{}/".format(args.dump_location, args.exp_name)
    dump_dir = "{}/dump/{}/".format(args.dump_location, args.exp_name)

    if not os.path.exists(log_dir):
        os.makedirs(log_dir)

    if not os.path.exists("{}/images/".format(dump_dir)):
        os.makedirs("{}/images/".format(dump_dir))

    logging.basicConfig(filename=log_dir + 'train.log', level=logging.INFO)
    print("Dumping at {}".format(log_dir))
    print(args)
    logging.info(args)

    # Logging and loss variables
    num_scenes = args.num_processes
    num_episodes = int(args.num_episodes)
    device = args.device = torch.device("cuda:0" if args.cuda else "cpu")
    policy_loss = 0

    best_cost = 100000
    costs = deque(maxlen=1000)
    exp_costs = deque(maxlen=1000)
    pose_costs = deque(maxlen=1000)

    g_masks = torch.ones(num_scenes).float().to(device)
    l_masks = torch.zeros(num_scenes).float().to(device)

    best_local_loss = np.inf
    best_g_reward = -np.inf

    if args.eval:
        traj_lengths = args.max_episode_length // args.num_local_steps
        explored_area_log = np.zeros((num_scenes, num_episodes, traj_lengths))
        explored_ratio_log = np.zeros((num_scenes, num_episodes, traj_lengths))

    g_episode_rewards = deque(maxlen=1000)

    l_action_losses = deque(maxlen=1000)

    g_value_losses = deque(maxlen=1000)
    g_action_losses = deque(maxlen=1000)
    g_dist_entropies = deque(maxlen=1000)

    per_step_g_rewards = deque(maxlen=1000)

    g_process_rewards = np.zeros((num_scenes))

    # Starting environments
    torch.set_num_threads(1)
    envs = make_vec_envs(args)
    obs, infos = envs.reset()

    # Initialize map variables
    ### Full map consists of 4 channels containing the following:
    ### 1. Obstacle Map
    ### 2. Exploread Area
    ### 3. Current Agent Location
    ### 4. Past Agent Locations

    torch.set_grad_enabled(False)

    # Calculating full and local map sizes
    map_size = args.map_size_cm // args.map_resolution
    full_w, full_h = map_size, map_size
    local_w, local_h = int(full_w / args.global_downscaling), \
                       int(full_h / args.global_downscaling)

    # Initializing full and local map
    full_map = torch.zeros(num_scenes, 4, full_w, full_h).float().to(device)
    local_map = torch.zeros(num_scenes, 4, local_w, local_h).float().to(device)

    # Initial full and local pose
    full_pose = torch.zeros(num_scenes, 3).float().to(device)
    local_pose = torch.zeros(num_scenes, 3).float().to(device)

    # Origin of local map
    origins = np.zeros((num_scenes, 3))

    # Local Map Boundaries
    lmb = np.zeros((num_scenes, 4)).astype(int)

    ### Planner pose inputs has 7 dimensions
    ### 1-3 store continuous global agent location
    ### 4-7 store local map boundaries
    planner_pose_inputs = np.zeros((num_scenes, 7))

    def init_map_and_pose():
        full_map.fill_(0.)
        full_pose.fill_(0.)
        full_pose[:, :2] = args.map_size_cm / 100.0 / 2.0

        locs = full_pose.cpu().numpy()
        planner_pose_inputs[:, :3] = locs
        for e in range(num_scenes):
            r, c = locs[e, 1], locs[e, 0]
            loc_r, loc_c = [
                int(r * 100.0 / args.map_resolution),
                int(c * 100.0 / args.map_resolution)
            ]

            full_map[e, 2:, loc_r - 1:loc_r + 2, loc_c - 1:loc_c + 2] = 1.0

            lmb[e] = get_local_map_boundaries(
                (loc_r, loc_c), (local_w, local_h), (full_w, full_h))

            planner_pose_inputs[e, 3:] = lmb[e]
            origins[e] = [
                lmb[e][2] * args.map_resolution / 100.0,
                lmb[e][0] * args.map_resolution / 100.0, 0.
            ]

        for e in range(num_scenes):
            local_map[e] = full_map[e, :, lmb[e, 0]:lmb[e, 1],
                                    lmb[e, 2]:lmb[e, 3]]
            local_pose[e] = full_pose[e] - \
                            torch.from_numpy(origins[e]).to(device).float()

    init_map_and_pose()

    # Global policy observation space
    g_observation_space = gym.spaces.Box(0,
                                         1, (8, local_w, local_h),
                                         dtype='uint8')

    # Global policy action space
    g_action_space = gym.spaces.Box(low=0.0,
                                    high=1.0,
                                    shape=(2, ),
                                    dtype=np.float32)

    # Local policy observation space
    l_observation_space = gym.spaces.Box(
        0, 255, (3, args.frame_width, args.frame_width), dtype='uint8')

    # Local and Global policy recurrent layer sizes
    l_hidden_size = args.local_hidden_size
    g_hidden_size = args.global_hidden_size

    # slam
    nslam_module = Neural_SLAM_Module(args).to(device)
    slam_optimizer = get_optimizer(nslam_module.parameters(),
                                   args.slam_optimizer)

    # Global policy
    g_policy = RL_Policy(g_observation_space.shape,
                         g_action_space,
                         base_kwargs={
                             'recurrent': args.use_recurrent_global,
                             'hidden_size': g_hidden_size,
                             'downscaling': args.global_downscaling
                         }).to(device)
    g_agent = algo.PPO(g_policy,
                       args.clip_param,
                       args.ppo_epoch,
                       args.num_mini_batch,
                       args.value_loss_coef,
                       args.entropy_coef,
                       lr=args.global_lr,
                       eps=args.eps,
                       max_grad_norm=args.max_grad_norm)

    # Local policy
    l_policy = Local_IL_Policy(
        l_observation_space.shape,
        envs.action_space.n,
        recurrent=args.use_recurrent_local,
        hidden_size=l_hidden_size,
        deterministic=args.use_deterministic_local).to(device)
    local_optimizer = get_optimizer(l_policy.parameters(),
                                    args.local_optimizer)

    # Storage
    g_rollouts = GlobalRolloutStorage(args.num_global_steps, num_scenes,
                                      g_observation_space.shape,
                                      g_action_space, g_policy.rec_state_size,
                                      1).to(device)

    slam_memory = FIFOMemory(args.slam_memory_size)

    # Loading model
    if args.load_slam != "0":
        print("Loading slam {}".format(args.load_slam))
        state_dict = torch.load(args.load_slam,
                                map_location=lambda storage, loc: storage)
        nslam_module.load_state_dict(state_dict)

    if not args.train_slam:
        nslam_module.eval()

    if args.load_global != "0":
        print("Loading global {}".format(args.load_global))
        state_dict = torch.load(args.load_global,
                                map_location=lambda storage, loc: storage)
        g_policy.load_state_dict(state_dict)

    if not args.train_global:
        g_policy.eval()

    if args.load_local != "0":
        print("Loading local {}".format(args.load_local))
        state_dict = torch.load(args.load_local,
                                map_location=lambda storage, loc: storage)
        l_policy.load_state_dict(state_dict)

    if not args.train_local:
        l_policy.eval()

    # Predict map from frame 1:
    poses = torch.from_numpy(
        np.asarray([
            infos[env_idx]['sensor_pose'] for env_idx in range(num_scenes)
        ])).float().to(device)

    _, _, local_map[:, 0, :, :], local_map[:, 1, :, :], _, local_pose = \
        nslam_module(obs, obs, poses, local_map[:, 0, :, :],
                     local_map[:, 1, :, :], local_pose)

    # Compute Global policy input
    locs = local_pose.cpu().numpy()
    global_input = torch.zeros(num_scenes, 8, local_w, local_h)
    global_orientation = torch.zeros(num_scenes, 1).long()

    for e in range(num_scenes):
        r, c = locs[e, 1], locs[e, 0]
        loc_r, loc_c = [
            int(r * 100.0 / args.map_resolution),
            int(c * 100.0 / args.map_resolution)
        ]

        local_map[e, 2:, loc_r - 1:loc_r + 2, loc_c - 1:loc_c + 2] = 1.
        global_orientation[e] = int((locs[e, 2] + 180.0) / 5.)

    global_input[:, 0:4, :, :] = local_map.detach()
    global_input[:, 4:, :, :] = nn.MaxPool2d(args.global_downscaling)(full_map)

    g_rollouts.obs[0].copy_(global_input)
    g_rollouts.extras[0].copy_(global_orientation)

    # Run Global Policy (global_goals = Long-Term Goal)
    g_value, g_action, g_action_log_prob, g_rec_states = \
        g_policy.act(
            g_rollouts.obs[0],
            g_rollouts.rec_states[0],
            g_rollouts.masks[0],
            extras=g_rollouts.extras[0],
            deterministic=False
        )

    cpu_actions = nn.Sigmoid()(g_action).cpu().numpy()
    global_goals = [[int(action[0] * local_w),
                     int(action[1] * local_h)] for action in cpu_actions]

    # Compute planner inputs
    planner_inputs = [{} for e in range(num_scenes)]
    for e, p_input in enumerate(planner_inputs):
        p_input['goal'] = global_goals[e]
        p_input['map_pred'] = global_input[e, 0, :, :].detach().cpu().numpy()
        p_input['exp_pred'] = global_input[e, 1, :, :].detach().cpu().numpy()
        p_input['pose_pred'] = planner_pose_inputs[e]

    # Output stores local goals as well as the the ground-truth action
    output = envs.get_short_term_goal(planner_inputs)

    last_obs = obs.detach()
    local_rec_states = torch.zeros(num_scenes, l_hidden_size).to(device)
    start = time.time()

    total_num_steps = -1
    g_reward = 0

    torch.set_grad_enabled(False)

    for ep_num in range(num_episodes):
        for step in range(args.max_episode_length):
            total_num_steps += 1

            g_step = (step // args.num_local_steps) % args.num_global_steps
            eval_g_step = step // args.num_local_steps + 1
            l_step = step % args.num_local_steps

            # ------------------------------------------------------------------
            # Local Policy
            del last_obs
            last_obs = obs.detach()
            local_masks = l_masks
            local_goals = output[:, :-1].to(device).long()

            if args.train_local:
                torch.set_grad_enabled(True)

            action, action_prob, local_rec_states = l_policy(
                obs,
                local_rec_states,
                local_masks,
                extras=local_goals,
            )

            if args.train_local:
                action_target = output[:, -1].long().to(device)
                policy_loss += nn.CrossEntropyLoss()(action_prob,
                                                     action_target)
                torch.set_grad_enabled(False)
            l_action = action.cpu()
            # ------------------------------------------------------------------

            # ------------------------------------------------------------------
            # Env step
            obs, rew, done, infos = envs.step(l_action)

            l_masks = torch.FloatTensor([0 if x else 1
                                         for x in done]).to(device)
            g_masks *= l_masks
            # ------------------------------------------------------------------

            # ------------------------------------------------------------------
            # Reinitialize variables when episode ends
            if step == args.max_episode_length - 1:  # Last episode step
                init_map_and_pose()
                del last_obs
                last_obs = obs.detach()
            # ------------------------------------------------------------------

            # ------------------------------------------------------------------
            # Neural SLAM Module
            if args.train_slam:
                # Add frames to memory
                for env_idx in range(num_scenes):
                    env_obs = obs[env_idx].to("cpu")
                    env_poses = torch.from_numpy(
                        np.asarray(
                            infos[env_idx]['sensor_pose'])).float().to("cpu")
                    env_gt_fp_projs = torch.from_numpy(
                        np.asarray(infos[env_idx]['fp_proj'])).unsqueeze(
                            0).float().to("cpu")
                    env_gt_fp_explored = torch.from_numpy(
                        np.asarray(infos[env_idx]['fp_explored'])).unsqueeze(
                            0).float().to("cpu")
                    env_gt_pose_err = torch.from_numpy(
                        np.asarray(
                            infos[env_idx]['pose_err'])).float().to("cpu")
                    slam_memory.push(
                        (last_obs[env_idx].cpu(), env_obs, env_poses),
                        (env_gt_fp_projs, env_gt_fp_explored, env_gt_pose_err))

            poses = torch.from_numpy(
                np.asarray([
                    infos[env_idx]['sensor_pose']
                    for env_idx in range(num_scenes)
                ])).float().to(device)

            _, _, local_map[:, 0, :, :], local_map[:, 1, :, :], _, local_pose = \
                nslam_module(last_obs, obs, poses, local_map[:, 0, :, :],
                             local_map[:, 1, :, :], local_pose, build_maps=True)

            locs = local_pose.cpu().numpy()
            planner_pose_inputs[:, :3] = locs + origins
            local_map[:,
                      2, :, :].fill_(0.)  # Resetting current location channel
            for e in range(num_scenes):
                r, c = locs[e, 1], locs[e, 0]
                loc_r, loc_c = [
                    int(r * 100.0 / args.map_resolution),
                    int(c * 100.0 / args.map_resolution)
                ]

                local_map[e, 2:, loc_r - 2:loc_r + 3, loc_c - 2:loc_c + 3] = 1.
            # ------------------------------------------------------------------

            # ------------------------------------------------------------------
            # Global Policy
            if l_step == args.num_local_steps - 1:
                # For every global step, update the full and local maps
                for e in range(num_scenes):
                    full_map[e, :, lmb[e, 0]:lmb[e, 1], lmb[e, 2]:lmb[e, 3]] = \
                        local_map[e]
                    full_pose[e] = local_pose[e] + \
                                   torch.from_numpy(origins[e]).to(device).float()

                    locs = full_pose[e].cpu().numpy()
                    r, c = locs[1], locs[0]
                    loc_r, loc_c = [
                        int(r * 100.0 / args.map_resolution),
                        int(c * 100.0 / args.map_resolution)
                    ]

                    lmb[e] = get_local_map_boundaries(
                        (loc_r, loc_c), (local_w, local_h), (full_w, full_h))

                    planner_pose_inputs[e, 3:] = lmb[e]
                    origins[e] = [
                        lmb[e][2] * args.map_resolution / 100.0,
                        lmb[e][0] * args.map_resolution / 100.0, 0.
                    ]

                    local_map[e] = full_map[e, :, lmb[e, 0]:lmb[e, 1],
                                            lmb[e, 2]:lmb[e, 3]]
                    local_pose[e] = full_pose[e] - \
                                    torch.from_numpy(origins[e]).to(device).float()

                locs = local_pose.cpu().numpy()
                for e in range(num_scenes):
                    global_orientation[e] = int((locs[e, 2] + 180.0) / 5.)
                global_input[:, 0:4, :, :] = local_map
                global_input[:, 4:, :, :] = \
                    nn.MaxPool2d(args.global_downscaling)(full_map)

                if False:
                    for i in range(4):
                        ax[i].clear()
                        ax[i].set_yticks([])
                        ax[i].set_xticks([])
                        ax[i].set_yticklabels([])
                        ax[i].set_xticklabels([])
                        ax[i].imshow(global_input.cpu().numpy()[0, 4 + i])
                    plt.gcf().canvas.flush_events()
                    # plt.pause(0.1)
                    fig.canvas.start_event_loop(0.001)
                    plt.gcf().canvas.flush_events()

                # Get exploration reward and metrics
                g_reward = torch.from_numpy(
                    np.asarray([
                        infos[env_idx]['exp_reward']
                        for env_idx in range(num_scenes)
                    ])).float().to(device)

                if args.eval:
                    g_reward = g_reward * 50.0  # Convert reward to area in m2

                g_process_rewards += g_reward.cpu().numpy()
                g_total_rewards = g_process_rewards * \
                                  (1 - g_masks.cpu().numpy())
                g_process_rewards *= g_masks.cpu().numpy()
                per_step_g_rewards.append(np.mean(g_reward.cpu().numpy()))

                if np.sum(g_total_rewards) != 0:
                    for tr in g_total_rewards:
                        g_episode_rewards.append(tr) if tr != 0 else None

                if args.eval:
                    exp_ratio = torch.from_numpy(
                        np.asarray([
                            infos[env_idx]['exp_ratio']
                            for env_idx in range(num_scenes)
                        ])).float()

                    for e in range(num_scenes):
                        explored_area_log[e, ep_num, eval_g_step - 1] = \
                            explored_area_log[e, ep_num, eval_g_step - 2] + \
                            g_reward[e].cpu().numpy()
                        explored_ratio_log[e, ep_num, eval_g_step - 1] = \
                            explored_ratio_log[e, ep_num, eval_g_step - 2] + \
                            exp_ratio[e].cpu().numpy()

                # Add samples to global policy storage
                g_rollouts.insert(global_input, g_rec_states, g_action,
                                  g_action_log_prob, g_value, g_reward,
                                  g_masks, global_orientation)

                # Sample long-term goal from global policy
                g_value, g_action, g_action_log_prob, g_rec_states = \
                    g_policy.act(
                        g_rollouts.obs[g_step + 1],
                        g_rollouts.rec_states[g_step + 1],
                        g_rollouts.masks[g_step + 1],
                        extras=g_rollouts.extras[g_step + 1],
                        deterministic=False
                    )
                cpu_actions = nn.Sigmoid()(g_action).cpu().numpy()
                global_goals = [[
                    int(action[0] * local_w),
                    int(action[1] * local_h)
                ] for action in cpu_actions]

                g_reward = 0
                g_masks = torch.ones(num_scenes).float().to(device)
            # ------------------------------------------------------------------

            # ------------------------------------------------------------------
            # Get short term goal
            planner_inputs = [{} for e in range(num_scenes)]
            for e, p_input in enumerate(planner_inputs):
                p_input['map_pred'] = local_map[e, 0, :, :].cpu().numpy()
                p_input['exp_pred'] = local_map[e, 1, :, :].cpu().numpy()
                p_input['pose_pred'] = planner_pose_inputs[e]
                p_input['goal'] = global_goals[e]

            output = envs.get_short_term_goal(planner_inputs)
            # ------------------------------------------------------------------

            ### TRAINING
            torch.set_grad_enabled(True)
            # ------------------------------------------------------------------
            # Train Neural SLAM Module
            if args.train_slam and len(slam_memory) > args.slam_batch_size:
                for _ in range(args.slam_iterations):
                    inputs, outputs = slam_memory.sample(args.slam_batch_size)
                    b_obs_last, b_obs, b_poses = inputs
                    gt_fp_projs, gt_fp_explored, gt_pose_err = outputs

                    b_obs = b_obs.to(device)
                    b_obs_last = b_obs_last.to(device)
                    b_poses = b_poses.to(device)

                    gt_fp_projs = gt_fp_projs.to(device)
                    gt_fp_explored = gt_fp_explored.to(device)
                    gt_pose_err = gt_pose_err.to(device)

                    b_proj_pred, b_fp_exp_pred, _, _, b_pose_err_pred, _ = \
                        nslam_module(b_obs_last, b_obs, b_poses,
                                     None, None, None,
                                     build_maps=False)
                    loss = 0
                    if args.proj_loss_coeff > 0:
                        proj_loss = F.binary_cross_entropy(
                            b_proj_pred, gt_fp_projs)
                        costs.append(proj_loss.item())
                        loss += args.proj_loss_coeff * proj_loss

                    if args.exp_loss_coeff > 0:
                        exp_loss = F.binary_cross_entropy(
                            b_fp_exp_pred, gt_fp_explored)
                        exp_costs.append(exp_loss.item())
                        loss += args.exp_loss_coeff * exp_loss

                    if args.pose_loss_coeff > 0:
                        pose_loss = torch.nn.MSELoss()(b_pose_err_pred,
                                                       gt_pose_err)
                        pose_costs.append(args.pose_loss_coeff *
                                          pose_loss.item())
                        loss += args.pose_loss_coeff * pose_loss

                    if args.train_slam:
                        slam_optimizer.zero_grad()
                        loss.backward()
                        slam_optimizer.step()

                    del b_obs_last, b_obs, b_poses
                    del gt_fp_projs, gt_fp_explored, gt_pose_err
                    del b_proj_pred, b_fp_exp_pred, b_pose_err_pred

            # ------------------------------------------------------------------

            # ------------------------------------------------------------------
            # Train Local Policy
            if (l_step + 1) % args.local_policy_update_freq == 0 \
                    and args.train_local:
                local_optimizer.zero_grad()
                policy_loss.backward()
                local_optimizer.step()
                l_action_losses.append(policy_loss.item())
                policy_loss = 0
                local_rec_states = local_rec_states.detach_()
            # ------------------------------------------------------------------

            # ------------------------------------------------------------------
            # Train Global Policy
            if g_step % args.num_global_steps == args.num_global_steps - 1 \
                    and l_step == args.num_local_steps - 1:
                if args.train_global:
                    g_next_value = g_policy.get_value(
                        g_rollouts.obs[-1],
                        g_rollouts.rec_states[-1],
                        g_rollouts.masks[-1],
                        extras=g_rollouts.extras[-1]).detach()

                    g_rollouts.compute_returns(g_next_value, args.use_gae,
                                               args.gamma, args.tau)
                    g_value_loss, g_action_loss, g_dist_entropy = \
                        g_agent.update(g_rollouts)
                    g_value_losses.append(g_value_loss)
                    g_action_losses.append(g_action_loss)
                    g_dist_entropies.append(g_dist_entropy)
                g_rollouts.after_update()
            # ------------------------------------------------------------------

            # Finish Training
            torch.set_grad_enabled(False)
            # ------------------------------------------------------------------

            # ------------------------------------------------------------------
            # Logging
            if total_num_steps % args.log_interval == 0:
                end = time.time()
                time_elapsed = time.gmtime(end - start)
                log = " ".join([
                    "Time: {0:0=2d}d".format(time_elapsed.tm_mday - 1),
                    "{},".format(time.strftime("%Hh %Mm %Ss", time_elapsed)),
                    "num timesteps {},".format(total_num_steps *
                                               num_scenes),
                    "FPS {},".format(int(total_num_steps * num_scenes \
                                         / (end - start)))
                ])

                log += "\n\tRewards:"

                if len(g_episode_rewards) > 0:
                    log += " ".join([
                        " Global step mean/med rew:",
                        "{:.4f}/{:.4f},".format(np.mean(per_step_g_rewards),
                                                np.median(per_step_g_rewards)),
                        " Global eps mean/med/min/max eps rew:",
                        "{:.3f}/{:.3f}/{:.3f}/{:.3f},".format(
                            np.mean(g_episode_rewards),
                            np.median(g_episode_rewards),
                            np.min(g_episode_rewards),
                            np.max(g_episode_rewards))
                    ])

                log += "\n\tLosses:"

                if args.train_local and len(l_action_losses) > 0:
                    log += " ".join([
                        " Local Loss:",
                        "{:.3f},".format(np.mean(l_action_losses))
                    ])

                if args.train_global and len(g_value_losses) > 0:
                    log += " ".join([
                        " Global Loss value/action/dist:",
                        "{:.3f}/{:.3f}/{:.3f},".format(
                            np.mean(g_value_losses), np.mean(g_action_losses),
                            np.mean(g_dist_entropies))
                    ])

                if args.train_slam and len(costs) > 0:
                    log += " ".join([
                        " SLAM Loss proj/exp/pose:"
                        "{:.4f}/{:.4f}/{:.4f}".format(np.mean(costs),
                                                      np.mean(exp_costs),
                                                      np.mean(pose_costs))
                    ])

                print(log)
                logging.info(log)
            # ------------------------------------------------------------------

            # ------------------------------------------------------------------
            # Save best models
            if (total_num_steps * num_scenes) % args.save_interval < \
                    num_scenes:

                # Save Neural SLAM Model
                if len(costs) >= 1000 and np.mean(costs) < best_cost \
                        and not args.eval:
                    best_cost = np.mean(costs)
                    torch.save(nslam_module.state_dict(),
                               os.path.join(log_dir, "model_best.slam"))

                # Save Local Policy Model
                if len(l_action_losses) >= 100 and \
                        (np.mean(l_action_losses) <= best_local_loss) \
                        and not args.eval:
                    torch.save(l_policy.state_dict(),
                               os.path.join(log_dir, "model_best.local"))

                    best_local_loss = np.mean(l_action_losses)

                # Save Global Policy Model
                if len(g_episode_rewards) >= 100 and \
                        (np.mean(g_episode_rewards) >= best_g_reward) \
                        and not args.eval:
                    torch.save(g_policy.state_dict(),
                               os.path.join(log_dir, "model_best.global"))
                    best_g_reward = np.mean(g_episode_rewards)

            # Save periodic models
            if (total_num_steps * num_scenes) % args.save_periodic < \
                    num_scenes:
                step = total_num_steps * num_scenes
                if args.train_slam:
                    torch.save(
                        nslam_module.state_dict(),
                        os.path.join(dump_dir,
                                     "periodic_{}.slam".format(step)))
                if args.train_local:
                    torch.save(
                        l_policy.state_dict(),
                        os.path.join(dump_dir,
                                     "periodic_{}.local".format(step)))
                if args.train_global:
                    torch.save(
                        g_policy.state_dict(),
                        os.path.join(dump_dir,
                                     "periodic_{}.global".format(step)))
            # ------------------------------------------------------------------

    # Print and save model performance numbers during evaluation
    if args.eval:
        logfile = open("{}/explored_area.txt".format(dump_dir), "w+")
        for e in range(num_scenes):
            for i in range(explored_area_log[e].shape[0]):
                logfile.write(str(explored_area_log[e, i]) + "\n")
                logfile.flush()

        logfile.close()

        logfile = open("{}/explored_ratio.txt".format(dump_dir), "w+")
        for e in range(num_scenes):
            for i in range(explored_ratio_log[e].shape[0]):
                logfile.write(str(explored_ratio_log[e, i]) + "\n")
                logfile.flush()

        logfile.close()

        log = "Final Exp Area: \n"
        for i in range(explored_area_log.shape[2]):
            log += "{:.5f}, ".format(np.mean(explored_area_log[:, :, i]))

        log += "\nFinal Exp Ratio: \n"
        for i in range(explored_ratio_log.shape[2]):
            log += "{:.5f}, ".format(np.mean(explored_ratio_log[:, :, i]))

        print(log)
        logging.info(log)
Beispiel #2
0
def main():
    print("---------------------")
    print("Actions")
    print("STOP", HabitatSimActions.STOP)
    print("FORWARD", HabitatSimActions.MOVE_FORWARD)
    print("LEFT", HabitatSimActions.TURN_LEFT)
    print("RIGHT", HabitatSimActions.TURN_RIGHT)

    log_dir = "{}/models/{}/".format(args.dump_location, args.exp_name)
    dump_dir = "{}/dump/{}/".format(args.dump_location, args.exp_name)
    tb_dir = log_dir + "tensorboard"
    if not os.path.exists(tb_dir): os.makedirs(tb_dir)
    if not os.path.exists(log_dir):
        os.makedirs(log_dir)
    if not os.path.exists("{}/images/".format(dump_dir)):
        os.makedirs("{}/images/".format(dump_dir))
    logging.basicConfig(
        filename=log_dir + 'train.log',
        level=logging.INFO)
    print("Dumping at {}".format(log_dir))
    print("Arguments starting with ", args)
    logging.info(args)
    device = args.device = torch.device("cuda:0" if args.cuda else "cpu")
    # Logging and loss variables
    num_scenes = args.num_processes
    num_episodes = int(args.num_episodes)

    # setting up rewards and losses
    # policy_loss = 0
    best_cost = float('inf')
    costs = deque(maxlen=1000)
    exp_costs = deque(maxlen=1000)
    pose_costs = deque(maxlen=1000)
    l_masks = torch.zeros(num_scenes).float().to(device)
    # best_local_loss = np.inf
    # if args.eval:
    #     traj_lengths = args.max_episode_length // args.num_local_steps
    # l_action_losses = deque(maxlen=1000)
    print("Setup rewards")

    print("starting envrionments ...")
    # Starting environments
    torch.set_num_threads(1)
    envs = make_vec_envs(args)
    obs, infos = envs.reset()
    print("environments reset")

    # show_gpu_usage()
    # Initialize map variables
    ### Full map consists of 4 channels containing the following:
    ### 1. Obstacle Map
    ### 2. Exploread Area
    ### 3. Current Agent Location
    ### 4. Past Agent Locations
    print("creating maps and poses ")
    torch.set_grad_enabled(False)
    # Calculating full and local map sizes
    map_size = args.map_size_cm // args.map_resolution
    full_w, full_h = map_size, map_size
    local_w, local_h = int(full_w / args.global_downscaling), \
                       int(full_h / args.global_downscaling)
    # Initializing full and local map
    full_map = torch.zeros(num_scenes, 4, full_w, full_h).float().to(device)
    local_map = torch.zeros(num_scenes, 4, local_w, local_h).float().to(device)
    # Initial full and local pose
    full_pose = torch.zeros(num_scenes, 3).float().to(device)
    local_pose = torch.zeros(num_scenes, 3).float().to(device)
    # Origin of local map
    origins = np.zeros((num_scenes, 3))
    # Local Map Boundaries
    lmb = np.zeros((num_scenes, 4)).astype(int)
    ### Planner pose inputs has 7 dimensions
    ### 1-3 store continuous global agent location
    ### 4-7 store local map boundaries
    planner_pose_inputs = np.zeros((num_scenes, 7))

    # show_gpu_usage()
    start_full_pose = np.zeros(3)
    start_full_pose[:2] = args.map_size_cm / 100.0 / 2.0

    def init_map_and_pose():
        full_map.fill_(0.)
        full_pose.fill_(0.)
        full_pose[:, :2] = args.map_size_cm / 100.0 / 2.0

        full_pose_np = full_pose.cpu().numpy()
        planner_pose_inputs[:, :3] = full_pose_np
        for e in range(num_scenes):
            r, c = full_pose_np[e, 1], full_pose_np[e, 0]
            loc_r, loc_c = [int(r * 100.0 / args.map_resolution),
                            int(c * 100.0 / args.map_resolution)]

            full_map[e, 2:, loc_r - 1:loc_r + 2, loc_c - 1:loc_c + 2] = 1.0

            lmb[e] = get_local_map_boundaries((loc_r, loc_c),
                                              (local_w, local_h),
                                              (full_w, full_h))

            planner_pose_inputs[e, 3:] = lmb[e]
            origins[e] = [lmb[e][2] * args.map_resolution / 100.0,
                          lmb[e][0] * args.map_resolution / 100.0, 0.]
        for e in range(num_scenes):
            local_map[e] = full_map[e, :, lmb[e, 0]:lmb[e, 1], lmb[e, 2]:lmb[e, 3]]
            local_pose[e] = full_pose[e] - \
                            torch.from_numpy(origins[e]).to(device).float()

    init_map_and_pose()
    print("maps and poses intialized")

    print("defining architecture")
    # slam
    nslam_module = Neural_SLAM_Module(args).to(device)
    slam_optimizer = get_optimizer(nslam_module.parameters(), args.slam_optimizer)
    slam_memory = FIFOMemory(args.slam_memory_size)

    # # Local policy
    # print("policy observation space", envs.observation_space.spaces['rgb'])
    # print("policy action space ", envs.action_space)
    # l_observation_space = gym.spaces.Box(0, 255,
    #                                      (3,
    #                                       args.frame_width,
    #                                       args.frame_width), dtype='uint8')
    # # todo change this to use envs.observation_space.spaces['rgb'].shape later
    # l_policy = Local_IL_Policy(l_observation_space.shape, envs.action_space.n,
    #                            recurrent=args.use_recurrent_local,
    #                            hidden_size=args.local_hidden_size,
    #                            deterministic=args.use_deterministic_local).to(device)
    # local_optimizer = get_optimizer(l_policy.parameters(), args.local_optimizer)
    # show_gpu_usage()

    print("loading model weights")
    # Loading model
    if args.load_slam != "0":
        print("Loading slam {}".format(args.load_slam))
        state_dict = torch.load(args.load_slam,
                                map_location=lambda storage, loc: storage)
        nslam_module.load_state_dict(state_dict)
    if not args.train_slam:
        nslam_module.eval()

    #     if args.load_local != "0":
    #         print("Loading local {}".format(args.load_local))
    #         state_dict = torch.load(args.load_local,
    #                                 map_location=lambda storage, loc: storage)
    #         l_policy.load_state_dict(state_dict)
    #     if not args.train_local:
    #         l_policy.eval()

    print("predicting first pose and initializing maps")
    # if not (args.use_gt_pose and args.use_gt_map):
    # delta_pose is the expected change in pose when action is applied at
    # the current pose in the absence of noise.
    # initially no action is applied so this is zero.
    delta_poses = torch.from_numpy(np.zeros(local_pose.shape)).float().to(device)
    # initial estimate for local pose and local map from first observation,
    # initialized (zero) pose and map
    _, _, local_map[:, 0, :, :], local_map[:, 1, :, :], _, local_pose = \
        nslam_module(obs, obs, delta_poses, local_map[:, 0, :, :],
                     local_map[:, 1, :, :], local_pose)
    # if args.use_gt_pose:
    #     # todo update local_pose here
    #     full_pose = envs.get_gt_pose()
    #     for e in range(num_scenes):
    #         local_pose[e] = full_pose[e] - \
    #                         torch.from_numpy(origins[e]).to(device).float()
    # if args.use_gt_map:
    #     full_map = envs.get_gt_map()
    #     for e in range(num_scenes):
    #         local_map[e] = full_map[e, :, lmb[e, 0]:lmb[e, 1], lmb[e, 2]:lmb[e, 3]]
    print("slam module returned pose and maps")

    # NOT NEEDED : 4/29
    local_pose_np = local_pose.cpu().numpy()
    # update local map for each scene - input for planner
    for e in range(num_scenes):
        r, c = local_pose_np[e, 1], local_pose_np[e, 0]
        loc_r, loc_c = [int(r * 100.0 / args.map_resolution),
                        int(c * 100.0 / args.map_resolution)]
        local_map[e, 2:, loc_r - 1:loc_r + 2, loc_c - 1:loc_c + 2] = 1.

    #     # todo get goal from env here
    global_goals = envs.get_goal_coords().int()

    # Compute planner inputs
    planner_inputs = [{} for e in range(num_scenes)]
    for e, p_input in enumerate(planner_inputs):
        p_input['goal'] = global_goals[e].detach().cpu().numpy()
        p_input['map_pred'] = local_map[e, 0, :, :].detach().cpu().numpy()
        p_input['exp_pred'] = local_map[e, 1, :, :].detach().cpu().numpy()
        p_input['pose_pred'] = planner_pose_inputs[e]

    # Output stores local goals as well as the the ground-truth action
    planner_out = envs.get_short_term_goal(planner_inputs)
    # planner output contains:
    # Distance to short term goal - positive discretized number
    # angle to short term goal -  angle -180 to 180 but in buckets of 5 degrees so multiply by 5 to ge true angle
    # GT action - action to be taken according to planner (int)

    # going to step through the episodes, so cache previous information
    last_obs = obs.detach()
    local_rec_states = torch.zeros(num_scenes, args.local_hidden_size).to(device)
    start = time.time()
    total_num_steps = -1
    torch.set_grad_enabled(False)

    print("starting episodes")
    with TensorboardWriter(
            tb_dir, flush_secs=60
    ) as writer:
        for itr_counter, ep_num in enumerate(range(num_episodes)):
            print("------------------------------------------------------")
            print("Episode", ep_num)

            # if itr_counter >= 20:
            #     print("DONE WE FIXED IT")
            #     die()
            # for step in range(args.max_episode_length):
            step_bar = tqdm(range(args.max_episode_length))
            for step in step_bar:
                # print("------------------------------------------------------")
                # print("episode ", ep_num, "step ", step)
                total_num_steps += 1
                l_step = step % args.num_local_steps

                # Local Policy
                # ------------------------------------------------------------------
                # cache previous information
                del last_obs
                last_obs = obs.detach()
                #             if not args.use_optimal_policy and not args.use_shortest_path_gt:
                #                 local_masks = l_masks
                #                 local_goals = planner_out[:, :-1].to(device).long()

                #                 if args.train_local:
                #                     torch.set_grad_enabled(True)

                #                 # local policy "step"
                #                 action, action_prob, local_rec_states = l_policy(
                #                     obs,
                #                     local_rec_states,
                #                     local_masks,
                #                     extras=local_goals,
                #                 )

                #                 if args.train_local:
                #                     action_target = planner_out[:, -1].long().to(device)
                #                     # doubt: this is probably wrong? one is action probability and the other is action
                #                     policy_loss += nn.CrossEntropyLoss()(action_prob, action_target)
                #                     torch.set_grad_enabled(False)
                #                 l_action = action.cpu()
                #             else:
                #                 if args.use_optimal_policy:
                #                     l_action = planner_out[:, -1]
                #                 else:
                #                     l_action = envs.get_optimal_gt_action()

                l_action = envs.get_optimal_action(start_full_pose, full_pose).cpu()
                # if step > 10:
                #     l_action = torch.tensor([HabitatSimActions.STOP])

                # ------------------------------------------------------------------
                # ------------------------------------------------------------------
                # Env step
                # print("stepping with action ", l_action)
                # try:
                obs, rew, done, infos = envs.step(l_action)

                # ------------------------------------------------------------------
                # Reinitialize variables when episode ends
                # doubt what if episode ends before max_episode_length?
                # maybe add (or done ) here?
                if l_action == HabitatSimActions.STOP or step == args.max_episode_length - 1:
                    print("l_action", l_action)
                    init_map_and_pose()
                    del last_obs
                    last_obs = obs.detach()
                    print("Reinitialize since at end of episode ")
                    obs, infos = envs.reset()

                # except:
                #     print("can't do that")
                #     print(l_action)
                #     init_map_and_pose()
                #     del last_obs
                #     last_obs = obs.detach()
                #     print("Reinitialize since at end of episode ")
                #     break
                # step_bar.set_description("rew, done, info-sensor_pose, pose_err (stepping) {}, {}, {}, {}".format(rew, done, infos[0]['sensor_pose'], infos[0]['pose_err']))
                if total_num_steps % args.log_interval == 0 and False:
                    print("rew, done, info-sensor_pose, pose_err after stepping ", rew, done, infos[0]['sensor_pose'],
                          infos[0]['pose_err'])
                # l_masks = torch.FloatTensor([0 if x else 1
                #                              for x in done]).to(device)

                # ------------------------------------------------------------------
                # # ------------------------------------------------------------------
                # # Reinitialize variables when episode ends
                # # doubt what if episode ends before max_episode_length?
                # # maybe add (or done ) here?
                # if step == args.max_episode_length - 1 or l_action == HabitatSimActions.STOP:  # Last episode step
                #     init_map_and_pose()
                #     del last_obs
                #     last_obs = obs.detach()
                #     print("Reinitialize since at end of episode ")
                #     break

                # ------------------------------------------------------------------
                # ------------------------------------------------------------------
                # Neural SLAM Module
                delta_poses_np = np.zeros(local_pose_np.shape)
                if args.train_slam:
                    # Add frames to memory
                    for env_idx in range(num_scenes):
                        env_obs = obs[env_idx].to("cpu")
                        env_poses = torch.from_numpy(np.asarray(
                            delta_poses_np[env_idx]
                        )).float().to("cpu")
                        env_gt_fp_projs = torch.from_numpy(np.asarray(
                            infos[env_idx]['fp_proj']
                        )).unsqueeze(0).float().to("cpu")
                        env_gt_fp_explored = torch.from_numpy(np.asarray(
                            infos[env_idx]['fp_explored']
                        )).unsqueeze(0).float().to("cpu")
                        # TODO change pose err here
                        env_gt_pose_err = torch.from_numpy(np.asarray(
                            infos[env_idx]['pose_err']
                        )).float().to("cpu")
                        slam_memory.push(
                            (last_obs[env_idx].cpu(), env_obs, env_poses),
                            (env_gt_fp_projs, env_gt_fp_explored, env_gt_pose_err))
                        delta_poses_np[env_idx] = get_delta_pose(local_pose_np[env_idx], l_action[env_idx])
                delta_poses = torch.from_numpy(delta_poses_np).float().to(device)
                # print("delta pose from SLAM ", delta_poses)
                _, _, local_map[:, 0, :, :], local_map[:, 1, :, :], _, local_pose = \
                    nslam_module(last_obs, obs, delta_poses, local_map[:, 0, :, :],
                                 local_map[:, 1, :, :], local_pose, build_maps=True)
                # print("updated local pose from SLAM ", local_pose)
                # if args.use_gt_pose:
                #     # todo update local_pose here
                #     full_pose = envs.get_gt_pose()
                #     for e in range(num_scenes):
                #         local_pose[e] = full_pose[e] - \
                #                         torch.from_numpy(origins[e]).to(device).float()
                #     print("updated local pose from gt ", local_pose)
                # if args.use_gt_map:
                #     full_map = envs.get_gt_map()
                #     for e in range(num_scenes):
                #         local_map[e] = full_map[e, :, lmb[e, 0]:lmb[e, 1], lmb[e, 2]:lmb[e, 3]]
                #     print("updated local map from gt")
                local_pose_np = local_pose.cpu().numpy()
                planner_pose_inputs[:, :3] = local_pose_np + origins
                local_map[:, 2, :, :].fill_(0.)  # Resetting current location channel
                for e in range(num_scenes):
                    r, c = local_pose_np[e, 1], local_pose_np[e, 0]
                    loc_r, loc_c = [int(r * 100.0 / args.map_resolution),
                                    int(c * 100.0 / args.map_resolution)]
                    local_map[e, 2:, loc_r - 2:loc_r + 3, loc_c - 2:loc_c + 3] = 1.
                if l_step == args.num_local_steps - 1:
                    # For every global step, update the full and local maps
                    for e in range(num_scenes):
                        full_map[e, :, lmb[e, 0]:lmb[e, 1], lmb[e, 2]:lmb[e, 3]] = \
                            local_map[e]
                        full_pose[e] = local_pose[e] + \
                                       torch.from_numpy(origins[e]).to(device).float()

                        full_pose_np = full_pose[e].cpu().numpy()
                        r, c = full_pose_np[1], full_pose_np[0]
                        loc_r, loc_c = [int(r * 100.0 / args.map_resolution),
                                        int(c * 100.0 / args.map_resolution)]

                        lmb[e] = get_local_map_boundaries((loc_r, loc_c),
                                                          (local_w, local_h),
                                                          (full_w, full_h))

                        planner_pose_inputs[e, 3:] = lmb[e]
                        origins[e] = [lmb[e][2] * args.map_resolution / 100.0,
                                      lmb[e][0] * args.map_resolution / 100.0, 0.]

                        local_map[e] = full_map[e, :,
                                       lmb[e, 0]:lmb[e, 1], lmb[e, 2]:lmb[e, 3]]
                        local_pose[e] = full_pose[e] - \
                                        torch.from_numpy(origins[e]).to(device).float()

                local_pose_np = local_pose.cpu().numpy()
                planner_pose_inputs[:, :3] = local_pose_np + origins
                local_map[:, 2, :, :].fill_(0.)  # Resetting current location channel
                for e in range(num_scenes):
                    r, c = local_pose_np[e, 1], local_pose_np[e, 0]
                    loc_r, loc_c = [int(r * 100.0 / args.map_resolution),
                                    int(c * 100.0 / args.map_resolution)]
                    local_map[e, 2:, loc_r - 2:loc_r + 3, loc_c - 2:loc_c + 3] = 1.

                planner_inputs = [{} for e in range(num_scenes)]
                for e, p_input in enumerate(planner_inputs):
                    p_input['map_pred'] = local_map[e, 0, :, :].cpu().numpy()
                    p_input['exp_pred'] = local_map[e, 1, :, :].cpu().numpy()
                    p_input['pose_pred'] = planner_pose_inputs[e]
                    p_input['goal'] = global_goals[e].cpu().numpy()
                planner_out = envs.get_short_term_goal(planner_inputs)

                ### TRAINING
                torch.set_grad_enabled(True)
                # ------------------------------------------------------------------
                # Train Neural SLAM Module
                if args.train_slam and len(slam_memory) > args.slam_batch_size:
                    for _ in range(args.slam_iterations):
                        inputs, outputs = slam_memory.sample(args.slam_batch_size)
                        b_obs_last, b_obs, b_poses = inputs
                        gt_fp_projs, gt_fp_explored, gt_pose_err = outputs

                        b_obs = b_obs.to(device)
                        b_obs_last = b_obs_last.to(device)
                        b_poses = b_poses.to(device)

                        gt_fp_projs = gt_fp_projs.to(device)
                        gt_fp_explored = gt_fp_explored.to(device)
                        gt_pose_err = gt_pose_err.to(device)

                        b_proj_pred, b_fp_exp_pred, _, _, b_pose_err_pred, _ = \
                            nslam_module(b_obs_last, b_obs, b_poses,
                                         None, None, None,
                                         build_maps=False)
                        loss = 0
                        if args.proj_loss_coeff > 0:
                            proj_loss = F.binary_cross_entropy(b_proj_pred,
                                                               gt_fp_projs)
                            costs.append(proj_loss.item())
                            loss += args.proj_loss_coeff * proj_loss

                        if args.exp_loss_coeff > 0:
                            exp_loss = F.binary_cross_entropy(b_fp_exp_pred,
                                                              gt_fp_explored)
                            exp_costs.append(exp_loss.item())
                            loss += args.exp_loss_coeff * exp_loss

                        if args.pose_loss_coeff > 0:
                            pose_loss = torch.nn.MSELoss()(b_pose_err_pred,
                                                           gt_pose_err)
                            pose_costs.append(args.pose_loss_coeff *
                                              pose_loss.item())
                            loss += args.pose_loss_coeff * pose_loss

                        if args.train_slam:
                            slam_optimizer.zero_grad()
                            loss.backward()
                            slam_optimizer.step()

                        del b_obs_last, b_obs, b_poses
                        del gt_fp_projs, gt_fp_explored, gt_pose_err
                        del b_proj_pred, b_fp_exp_pred, b_pose_err_pred

                # ------------------------------------------------------------------

                # ------------------------------------------------------------------
                # Train Local Policy
                # if (l_step + 1) % args.local_policy_update_freq == 0 \
                #         and args.train_local:
                #     local_optimizer.zero_grad()
                #     policy_loss.backward()
                #     local_optimizer.step()
                #     l_action_losses.append(policy_loss.item())
                #     policy_loss = 0
                #     local_rec_states = local_rec_states.detach_()
                # ------------------------------------------------------------------

                # Finish Training
                torch.set_grad_enabled(False)
                # ------------------------------------------------------------------

                # ------------------------------------------------------------------
                # Logging
                writer.add_scalar("SLAM_Loss_Proj", np.mean(costs), total_num_steps)
                writer.add_scalar("SLAM_Loss_Exp", np.mean(exp_costs), total_num_steps)
                writer.add_scalar("SLAM_Loss_Pose", np.mean(pose_costs), total_num_steps)

                gettime = lambda: str(datetime.now()).split('.')[0]
                if total_num_steps % args.log_interval == 0:
                    end = time.time()
                    time_elapsed = time.gmtime(end - start)
                    log = " ".join([
                        "Time: {0:0=2d}d".format(time_elapsed.tm_mday - 1),
                        "{},".format(time.strftime("%Hh %Mm %Ss", time_elapsed)),
                        gettime(),
                        "num timesteps {},".format(total_num_steps *
                                                   num_scenes),
                        "FPS {},".format(int(total_num_steps * num_scenes \
                                             / (end - start)))
                    ])

                    log += "\n\tLosses:"

                    # if args.train_local and len(l_action_losses) > 0:
                    #     log += " ".join([
                    #         " Local Loss:",
                    #         "{:.3f},".format(
                    #             np.mean(l_action_losses))
                    #     ])

                    if args.train_slam and len(costs) > 0:
                        log += " ".join([
                            " SLAM Loss proj/exp/pose:"
                            "{:.4f}/{:.4f}/{:.4f}".format(
                                np.mean(costs),
                                np.mean(exp_costs),
                                np.mean(pose_costs))
                        ])

                    print(log)
                    logging.info(log)
                # ------------------------------------------------------------------

                # ------------------------------------------------------------------
                # Save best models
                if (total_num_steps * num_scenes) % args.save_interval < \
                        num_scenes:

                    # Save Neural SLAM Model
                    if len(costs) >= 1000 and np.mean(costs) < best_cost \
                            and not args.eval:
                        print("Saved best model")
                        best_cost = np.mean(costs)
                        torch.save(nslam_module.state_dict(),
                                   os.path.join(log_dir, "model_best.slam"))

                    # Save Local Policy Model
                    # if len(l_action_losses) >= 100 and \
                    #         (np.mean(l_action_losses) <= best_local_loss) \
                    #         and not args.eval:
                    #     torch.save(l_policy.state_dict(),
                    #                os.path.join(log_dir, "model_best.local"))
                    #
                    #     best_local_loss = np.mean(l_action_losses)

                # Save periodic models
                if (total_num_steps * num_scenes) % args.save_periodic < \
                        num_scenes:
                    step = total_num_steps * num_scenes
                    if args.train_slam:
                        torch.save(nslam_module.state_dict(),
                                   os.path.join(dump_dir,
                                                "periodic_{}.slam".format(step)))
                    # if args.train_local:
                    #     torch.save(l_policy.state_dict(),
                    #                os.path.join(dump_dir,
                    #                             "periodic_{}.local".format(step)))
                # ------------------------------------------------------------------

                if l_action == HabitatSimActions.STOP:  # Last episode step
                    break

    # Print and save model performance numbers during evaluation
    if args.eval:
        logfile = open("{}/explored_area.txt".format(dump_dir), "w+")
        for e in range(num_scenes):
            for i in range(explored_area_log[e].shape[0]):
                logfile.write(str(explored_area_log[e, i]) + "\n")
                logfile.flush()

        logfile.close()

        logfile = open("{}/explored_ratio.txt".format(dump_dir), "w+")
        for e in range(num_scenes):
            for i in range(explored_ratio_log[e].shape[0]):
                logfile.write(str(explored_ratio_log[e, i]) + "\n")
                logfile.flush()

        logfile.close()

        log = "Final Exp Area: \n"
        for i in range(explored_area_log.shape[2]):
            log += "{:.5f}, ".format(
                np.mean(explored_area_log[:, :, i]))

        log += "\nFinal Exp Ratio: \n"
        for i in range(explored_ratio_log.shape[2]):
            log += "{:.5f}, ".format(
                np.mean(explored_ratio_log[:, :, i]))

        print(log)
        logging.info(log)
Beispiel #3
0
def test():

    ##########################################################
    # # Realsense test
    # pipeline = rs.pipeline()
    # config = rs.config()
    # config.enable_stream(rs.stream.depth, 640, 480, rs.format.z16, 30)
    # config.enable_stream(rs.stream.color, 640, 480, rs.format.bgr8, 30)
    # pipeline.start(config)

    # frames = pipeline.wait_for_frames()
    # color_frame = frames.get_color_frame()
    # img = np.asanyarray(color_frame.get_data())
    # img = cv2.resize(img, dsize=(256, 256), interpolation=cv2.INTER_CUBIC)
    # cv2.namedWindow('RealSense', cv2.WINDOW_AUTOSIZE)
    # cv2.imshow('RealSense', img)
    # cv2.waitKey(1)
    ##########################################################

    device = args.device = torch.device("cuda:0" if args.cuda else "cpu")

    # Setup Logging
    log_dir = "{}/models/{}/".format(args.dump_location, args.exp_name)
    dump_dir = "{}/dump/{}/".format(args.dump_location, args.exp_name)

    if not os.path.exists(log_dir):
        os.makedirs(log_dir)

    if not os.path.exists("{}/images/".format(dump_dir)):
        os.makedirs("{}/images/".format(dump_dir))

    logging.basicConfig(filename=log_dir + 'train.log', level=logging.INFO)
    print("Dumping at {}".format(log_dir))
    logging.info(args)

    # Logging and loss variables
    num_scenes = args.num_processes
    num_episodes = int(args.num_episodes)
    device = args.device = torch.device("cuda:0" if args.cuda else "cpu")
    policy_loss = 0

    best_cost = 100000
    costs = deque(maxlen=1000)
    exp_costs = deque(maxlen=1000)
    pose_costs = deque(maxlen=1000)

    g_masks = torch.ones(num_scenes).float().to(device)
    l_masks = torch.zeros(num_scenes).float().to(device)

    best_local_loss = np.inf
    best_g_reward = -np.inf

    if args.eval:
        traj_lengths = args.max_episode_length // args.num_local_steps
        explored_area_log = np.zeros((num_scenes, num_episodes, traj_lengths))
        explored_ratio_log = np.zeros((num_scenes, num_episodes, traj_lengths))

    g_episode_rewards = deque(maxlen=1000)

    l_action_losses = deque(maxlen=1000)

    g_value_losses = deque(maxlen=1000)
    g_action_losses = deque(maxlen=1000)
    g_dist_entropies = deque(maxlen=1000)

    per_step_g_rewards = deque(maxlen=1000)

    g_process_rewards = np.zeros((num_scenes))

    # Starting environments
    torch.set_num_threads(1)
    envs = make_vec_envs(args)
    obs, infos = envs.reset()

    # Initialize map variables
    ### Full map consists of 4 channels containing the following:
    ### 1. Obstacle Map
    ### 2. Exploread Area
    ### 3. Current Agent Location
    ### 4. Past Agent Locations

    torch.set_grad_enabled(False)

    # Calculating full and local map sizes
    map_size = args.map_size_cm // args.map_resolution
    full_w, full_h = map_size, map_size

    local_w, local_h = int(full_w / args.global_downscaling), \
                       int(full_h / args.global_downscaling)

    # Initializing full and local map
    full_map = torch.zeros(num_scenes, 4, full_w, full_h).float().to(device)
    local_map = torch.zeros(num_scenes, 4, local_w, local_h).float().to(device)

    # Initial full and local pose
    full_pose = torch.zeros(num_scenes, 3).float().to(device)
    local_pose = torch.zeros(num_scenes, 3).float().to(device)

    # Origin of local map
    origins = np.zeros((num_scenes, 3))

    # Local Map Boundaries
    lmb = np.zeros((num_scenes, 4)).astype(int)

    ### Planner pose inputs The global agent location
    ### 4-7 store local map boundaries
    planner_pose_inputs = np.zeros((num_scenes, 7))

    # Initialize full_map and full_pose
    def init_map_and_pose():
        full_map.fill_(0.)
        full_pose.fill_(0.)
        full_pose[:, :2] = args.map_size_cm / 100.0 / 2.0

        locs = full_pose.cpu().numpy()
        planner_pose_inputs[:, :3] = locs
        for e in range(num_scenes):
            r, c = locs[e, 1], locs[e, 0]
            loc_r, loc_c = [
                int(r * 100.0 / args.map_resolution),
                int(c * 100.0 / args.map_resolution)
            ]

            full_map[e, 2:, loc_r - 1:loc_r + 2, loc_c - 1:loc_c + 2] = 1.0

            lmb[e] = get_local_map_boundaries(
                (loc_r, loc_c), (local_w, local_h), (full_w, full_h))

            planner_pose_inputs[e, 3:] = lmb[e]
            origins[e] = [
                lmb[e][2] * args.map_resolution / 100.0,
                lmb[e][0] * args.map_resolution / 100.0, 0.
            ]

        for e in range(num_scenes):
            local_map[e] = full_map[e, :, lmb[e, 0]:lmb[e, 1],
                                    lmb[e, 2]:lmb[e, 3]]
            local_pose[e] = full_pose[e] - \
                            torch.from_numpy(origins[e]).to(device).float()

    init_map_and_pose()
    # Global policy observation space
    g_observation_space = gym.spaces.Box(0,
                                         1, (8, local_w, local_h),
                                         dtype='uint8')

    # Global policy action space
    g_action_space = gym.spaces.Box(low=0.0,
                                    high=1.0,
                                    shape=(2, ),
                                    dtype=np.float32)

    # Local policy observation space
    l_observation_space = gym.spaces.Box(
        0, 255, (3, args.frame_width, args.frame_width), dtype='uint8')

    # Local and Global policy recurrent layer sizes
    l_hidden_size = args.local_hidden_size
    g_hidden_size = args.global_hidden_size

    # slam
    nslam_module = Neural_SLAM_Module(args).to(device)
    slam_optimizer = get_optimizer(nslam_module.parameters(),
                                   args.slam_optimizer)

    # Global policy
    # obse_space.shape= [8, 500, 500]
    # act_space= Box shape (2,)
    # g_hidden_size = 256
    g_policy = RL_Policy(g_observation_space.shape,
                         g_action_space,
                         base_kwargs={
                             'recurrent': args.use_recurrent_global,
                             'hidden_size': g_hidden_size,
                             'downscaling': args.global_downscaling
                         }).to(device)
    g_agent = algo.PPO(g_policy,
                       args.clip_param,
                       args.ppo_epoch,
                       args.num_mini_batch,
                       args.value_loss_coef,
                       args.entropy_coef,
                       lr=args.global_lr,
                       eps=args.eps,
                       max_grad_norm=args.max_grad_norm)

    # Local policy
    l_policy = Local_IL_Policy(
        l_observation_space.shape,
        envs.action_space.n,
        recurrent=args.use_recurrent_local,
        hidden_size=l_hidden_size,
        deterministic=args.use_deterministic_local).to(device)
    local_optimizer = get_optimizer(l_policy.parameters(),
                                    args.local_optimizer)

    # Storage
    g_rollouts = GlobalRolloutStorage(args.num_global_steps, num_scenes,
                                      g_observation_space.shape,
                                      g_action_space, g_policy.rec_state_size,
                                      1).to(device)

    slam_memory = FIFOMemory(args.slam_memory_size)
    '''

    '''

    # Loading model
    if args.load_slam != "0":
        print("Loading slam {}".format(args.load_slam))
        state_dict = torch.load(args.load_slam,
                                map_location=lambda storage, loc: storage)
        nslam_module.load_state_dict(state_dict)

    if not args.train_slam:
        nslam_module.eval()

    if args.load_global != "0":
        print("Loading global {}".format(args.load_global))
        state_dict = torch.load(args.load_global,
                                map_location=lambda storage, loc: storage)
        g_policy.load_state_dict(state_dict)

    if not args.train_global:
        g_policy.eval()

    if args.load_local != "0":
        print("Loading local {}".format(args.load_local))
        state_dict = torch.load(args.load_local,
                                map_location=lambda storage, loc: storage)
        l_policy.load_state_dict(state_dict)

    if not args.train_local:
        l_policy.eval()

    # /////////////////////////////////////////////////////////////// TESTING
    from matplotlib import image
    if args.testing:
        test_images = {}
        for i in range(5):
            for j in range(12):
                img_pth = 'imgs/robots_rs/test_{}_{}.jpg'.format(i + 1, j)
                img = image.imread(img_pth)
                test_images[(i + 1, j)] = np.array(img)

        poses_array = []
        for i in range(8):
            poses_array.append(np.array([[0.3, 0.0, 0.0], [0.3, 0.0, 0.0]]))
        for i in range(4):
            poses_array.append(
                np.array([[0.0, 0.0, -0.24587], [0.0, 0.0, -0.27587]]))

        # index from 1 to 5
        test_1_idx = 3
        test_2_idx = 1
        # image1_1 = image.imread('imgs/robots_rs/img_128_6.jpg')
        # image1_2 = image.imread('imgs/robots_rs/img_128_7.jpg')
        # image1_3 = image.imread('imgs/robots_rs/img_128_8.jpg')
        # image2_1 = image.imread('imgs/robots_rs/img_128_30.jpg')
        # image2_2 = image.imread('imgs/robots_rs/img_128_31.jpg')
        # image2_3 = image.imread('imgs/robots_rs/img_128_32.jpg')
        # # image_data = np.asarray(image)
        # # plt.imshow(image)
        # # plt.show()
        # image_data_1_1 = np.array(image1_1)
        # image_data_1_2 = np.array(image1_2)
        # image_data_1_3 = np.array(image1_3)
        # image_data_2_1 = np.array(image2_1)
        # image_data_2_2 = np.array(image2_2)
        # image_data_2_3 = np.array(image2_3)
        # image_data_1_all = np.array([image_data_1_1, image_data_2_1])
        # image_data_2_all = np.array([image_data_1_2, image_data_2_2])
        # image_data_3_all = np.array([image_data_1_3, image_data_2_3])
        image_data_all = np.array(
            [test_images[(test_1_idx, 0)], test_images[(test_2_idx, 0)]])
        obs = torch.from_numpy(image_data_all).float().to(device)
        obs = obs.permute((0, 3, 1, 2)).contiguous()

        # print(f"New obs: {obs}")
        print(f"New obs size: {obs.size()}")
    # /////////////////////////////////////////////////////////////// TESTING

    # Predict map from frame 1:
    poses = torch.from_numpy(
        np.asarray([
            infos[env_idx]['sensor_pose'] for env_idx in range(num_scenes)
        ])).float().to(device)

    _, _, local_map[:, 0, :, :], local_map[:, 1, :, :], _, local_pose = \
        nslam_module(obs, obs, poses, local_map[:, 0, :, :],
                     local_map[:, 1, :, :], local_pose)

    # print(f"\n\n local_map shape: {local_map.shape}")
    # print(f"\n obs shape: {obs.shape}")
    # print(f"\n poses shape: {poses.shape}")

    # Compute Global policy input
    locs = local_pose.cpu().numpy()

    global_input = torch.zeros(num_scenes, 8, local_w, local_h)
    global_orientation = torch.zeros(num_scenes, 1).long()

    for e in range(num_scenes):
        r, c = locs[e, 1], locs[e, 0]
        loc_r, loc_c = [
            int(r * 100.0 / args.map_resolution),
            int(c * 100.0 / args.map_resolution)
        ]

        local_map[e, 2:, loc_r - 1:loc_r + 2, loc_c - 1:loc_c + 2] = 1.
        global_orientation[e] = int((locs[e, 2] + 180.0) / 5.)

    global_input[:, 0:4, :, :] = local_map.detach()
    global_input[:, 4:, :, :] = nn.MaxPool2d(args.global_downscaling)(full_map)

    g_rollouts.obs[0].copy_(global_input)
    g_rollouts.extras[0].copy_(global_orientation)

    # Run Global Policy (global_goals = Long-Term Goal)
    g_value, g_action, g_action_log_prob, g_rec_states = \
        g_policy.act(
            g_rollouts.obs[0],
            g_rollouts.rec_states[0],
            g_rollouts.masks[0],
            extras=g_rollouts.extras[0],
            deterministic=False
        )

    cpu_actions = nn.Sigmoid()(g_action).cpu().numpy()
    global_goals = [[int(action[0] * local_w),
                     int(action[1] * local_h)] for action in cpu_actions]

    # Compute planner inputs
    planner_inputs = [{} for e in range(num_scenes)]
    for e, p_input in enumerate(planner_inputs):
        p_input['goal'] = global_goals[e]
        p_input['map_pred'] = global_input[e, 0, :, :].detach().cpu().numpy()
        p_input['exp_pred'] = global_input[e, 1, :, :].detach().cpu().numpy()
        p_input['pose_pred'] = planner_pose_inputs[e]

    # Output stores local goals as well as the the ground-truth action
    output = envs.get_short_term_goal(planner_inputs)

    last_obs = obs.detach()
    local_rec_states = torch.zeros(num_scenes, l_hidden_size).to(device)
    start = time.time()

    total_num_steps = -1
    g_reward = 0

    torch.set_grad_enabled(False)

    # fig, axis = plt.subplots(1,3)
    fig, axis = plt.subplots(2, 3)
    # a = [[1, 0, 1], [1, 0, 1], [1, 0, 1]]
    # plt.imshow(a)

    for ep_num in range(num_episodes):
        for step in range(args.max_episode_length):

            total_num_steps += 1

            g_step = (step // args.num_local_steps) % args.num_global_steps
            eval_g_step = step // args.num_local_steps + 1
            l_step = step % args.num_local_steps

            # ------------------------------------------------------------------
            # Local Policy
            del last_obs
            last_obs = obs.detach()
            local_masks = l_masks
            local_goals = output[:, :-1].to(device).long()

            if args.train_local:
                torch.set_grad_enabled(True)

            action, action_prob, local_rec_states = l_policy(
                obs,
                local_rec_states,
                local_masks,
                extras=local_goals,
            )

            if args.train_local:
                action_target = output[:, -1].long().to(device)
                policy_loss += nn.CrossEntropyLoss()(action_prob,
                                                     action_target)
                torch.set_grad_enabled(False)
            l_action = action.cpu()
            # ------------------------------------------------------------------

            # ------------------------------------------------------------------
            print(f"l_action: {l_action}")
            print(f"l_action size: {l_action.size()}")
            # Env step
            obs, rew, done, infos = envs.step(l_action)

            # ////////////////////////////////////////////////////////////////// TESTING
            # obs_all = _process_obs_for_display(obs)
            # _ims = [transform_rgb_bgr(obs_all[0]), transform_rgb_bgr(obs_all[1])]

            # ax1.imshow(_ims[0])
            # ax2.imshow(_ims[1])
            # plt.savefig(f"imgs/img_0_{step}.png")
            # # plt.clf()

            # ////////////////////////////////////////////////////////////////// TESTING

            l_masks = torch.FloatTensor([0 if x else 1
                                         for x in done]).to(device)
            g_masks *= l_masks
            # ------------------------------------------------------------------

            # ------------------------------------------------------------------
            # Reinitialize variables when episode ends
            if step == args.max_episode_length - 1:  # Last episode step
                print("Final step")
                init_map_and_pose()
                del last_obs
                last_obs = obs.detach()
            # ------------------------------------------------------------------

            # ------------------------------------------------------------------
            # Neural SLAM Module
            if args.train_slam:
                # Add frames to memory
                for env_idx in range(num_scenes):
                    env_obs = obs[env_idx].to("cpu")
                    env_poses = torch.from_numpy(
                        np.asarray(
                            infos[env_idx]['sensor_pose'])).float().to("cpu")
                    env_gt_fp_projs = torch.from_numpy(
                        np.asarray(infos[env_idx]['fp_proj'])).unsqueeze(
                            0).float().to("cpu")
                    env_gt_fp_explored = torch.from_numpy(
                        np.asarray(infos[env_idx]['fp_explored'])).unsqueeze(
                            0).float().to("cpu")
                    env_gt_pose_err = torch.from_numpy(
                        np.asarray(
                            infos[env_idx]['pose_err'])).float().to("cpu")
                    slam_memory.push(
                        (last_obs[env_idx].cpu(), env_obs, env_poses),
                        (env_gt_fp_projs, env_gt_fp_explored, env_gt_pose_err))

            poses = torch.from_numpy(
                np.asarray([
                    infos[env_idx]['sensor_pose']
                    for env_idx in range(num_scenes)
                ])).float().to(device)

            # ///////////////////////////////////////////////////////////////// TESTING
            if args.testing:
                # obs = torch.from_numpy(obs_).float().to(self.device)
                # obs_cpu = obs.detach().cpu().numpy()
                # last_obs_cpu = last_obs.detach().cpu().numpy()
                # print(f"obs shape: {obs_cpu.shape}")
                # print(f"last_obs shape: {last_obs_cpu.shape}")

                original_obs = obs
                original_last_obs = last_obs
                original_poses = poses

                print(f"step: {step}")
                last_obs = torch.from_numpy(image_data_all).float().to(device)
                last_obs = last_obs.permute((0, 3, 1, 2)).contiguous()
                image_data_all = np.array([
                    test_images[(test_1_idx, step + 1)],
                    test_images[(test_2_idx, step + 1)]
                ])
                obs = torch.from_numpy(image_data_all).float().to(device)
                obs = obs.permute((0, 3, 1, 2)).contiguous()
                _poses = poses_array[step]
                poses = torch.from_numpy(_poses).float().to(device)
                # if step == 0:
                #     print(f"step: {step}")
                #     last_obs = torch.from_numpy(image_data_1_all).float().to(device)
                #     last_obs = last_obs.permute((0, 3, 1, 2)).contiguous()
                #     obs = torch.from_numpy(image_data_2_all).float().to(device)
                #     obs = obs.permute((0, 3, 1, 2)).contiguous()
                #     _poses = np.array([[0.2, 0.0, 0.0], [0.2, 0.0, 0.0]])
                #     poses = torch.from_numpy(_poses).float().to(device)
                # elif step == 1:
                #     print(f"step: {step}")
                #     last_obs = torch.from_numpy(image_data_2_all).float().to(device)
                #     last_obs = last_obs.permute((0, 3, 1, 2)).contiguous()
                #     obs = torch.from_numpy(image_data_3_all).float().to(device)
                #     obs = obs.permute((0, 3, 1, 2)).contiguous()
                #     _poses = np.array([[0.4, 0.0, 0.0], [0.2, 0.0, 0.17587]])
                #     poses = torch.from_numpy(_poses).float().to(device)
                # _poses = np.asarray([infos[env_idx]['sensor_pose'] for env_idx in range(num_scenes)])
                # print(f"New poses: {_poses}")
                # last_obs = torch.from_numpy(image_data_1_1).float().to(device)
                # obs = torch.from_numpy(image_data_1_2).float().to(device)

                # print(f"Original obs: {original_obs}")
                # print(f"Original obs shape: {original_obs.size()}")
                # print(f"Obs: {obs}")
                # print(f"Obs shape: {obs.size()}")
                # print(f"Original last_obs: {original_last_obs}")
                # print(f"Original last_obs shape: {original_last_obs.size()}")
                # print(f"last_obs: {last_obs}")
                # print(f"Last_obs shape: {last_obs.size()}")
                # print(f"Original poses: {original_poses}")
                # print(f"Original poses shape: {original_poses.size()}")
                print(f"Local poses : {local_pose}")
            # ///////////////////////////////////////////////////////////////// TESTING


            _, _, local_map[:, 0, :, :], local_map[:, 1, :, :], _, local_pose = \
                nslam_module(last_obs, obs, poses, local_map[:, 0, :, :],
                             local_map[:, 1, :, :], local_pose, build_maps=True)

            locs = local_pose.cpu().numpy()
            planner_pose_inputs[:, :3] = locs + origins
            local_map[:,
                      2, :, :].fill_(0.)  # Resetting current location channel
            for e in range(num_scenes):
                r, c = locs[e, 1], locs[e, 0]
                loc_r, loc_c = [
                    int(r * 100.0 / args.map_resolution),
                    int(c * 100.0 / args.map_resolution)
                ]

                local_map[e, 2:, loc_r - 2:loc_r + 3, loc_c - 2:loc_c + 3] = 1.

            # //////////////////////////////////////////////////////////////////
            if args.testing:
                local_map_draw = local_map

                if step % 1 == 0:
                    obs_all = _process_obs_for_display(obs)
                    _ims = [
                        transform_rgb_bgr(obs_all[0]),
                        transform_rgb_bgr(obs_all[1])
                    ]

                    imgs_1 = local_map_draw[0, :, :, :].cpu().numpy()
                    imgs_2 = local_map_draw[1, :, :, :].cpu().numpy()

                    # axis[1].imshow(imgs_1[0], cmap='gray')
                    # axis[2].imshow(imgs_1[1], cmap='gray')
                    # axis[0].imshow(_ims[0])
                    axis[0][1].imshow(imgs_1[0], cmap='gray')
                    axis[0][2].imshow(imgs_1[1], cmap='gray')
                    axis[0][0].imshow(_ims[0])
                    axis[1][1].imshow(imgs_2[0], cmap='gray')
                    axis[1][2].imshow(imgs_2[1], cmap='gray')
                    axis[1][0].imshow(_ims[1])
                    plt.savefig(f"imgs/test_{step}.png")

                obs = original_obs
                last_obs = original_last_obs
                poses = original_poses
            # //////////////////////////////////////////////////////////////////

            # ------------------------------------------------------------------

            # ------------------------------------------------------------------
            # Global Policy
            if l_step == args.num_local_steps - 1:
                # For every global step, update the full and local maps
                for e in range(num_scenes):
                    full_map[e, :, lmb[e, 0]:lmb[e, 1], lmb[e, 2]:lmb[e, 3]] = \
                        local_map[e]
                    full_pose[e] = local_pose[e] + \
                                   torch.from_numpy(origins[e]).to(device).float()

                    locs = full_pose[e].cpu().numpy()
                    r, c = locs[1], locs[0]
                    loc_r, loc_c = [
                        int(r * 100.0 / args.map_resolution),
                        int(c * 100.0 / args.map_resolution)
                    ]

                    lmb[e] = get_local_map_boundaries(
                        (loc_r, loc_c), (local_w, local_h), (full_w, full_h))

                    planner_pose_inputs[e, 3:] = lmb[e]
                    origins[e] = [
                        lmb[e][2] * args.map_resolution / 100.0,
                        lmb[e][0] * args.map_resolution / 100.0, 0.
                    ]

                    local_map[e] = full_map[e, :, lmb[e, 0]:lmb[e, 1],
                                            lmb[e, 2]:lmb[e, 3]]
                    local_pose[e] = full_pose[e] - \
                                    torch.from_numpy(origins[e]).to(device).float()

                locs = local_pose.cpu().numpy()
                for e in range(num_scenes):
                    global_orientation[e] = int((locs[e, 2] + 180.0) / 5.)
                global_input[:, 0:4, :, :] = local_map
                global_input[:, 4:, :, :] = \
                    nn.MaxPool2d(args.global_downscaling)(full_map)

                if False:
                    for i in range(4):
                        ax[i].clear()
                        ax[i].set_yticks([])
                        ax[i].set_xticks([])
                        ax[i].set_yticklabels([])
                        ax[i].set_xticklabels([])
                        ax[i].imshow(global_input.cpu().numpy()[0, 4 + i])
                    plt.gcf().canvas.flush_events()
                    # plt.pause(0.1)
                    fig.canvas.start_event_loop(0.001)
                    plt.gcf().canvas.flush_events()

                # Get exploration reward and metrics
                g_reward = torch.from_numpy(
                    np.asarray([
                        infos[env_idx]['exp_reward']
                        for env_idx in range(num_scenes)
                    ])).float().to(device)

                if args.eval:
                    g_reward = g_reward * 50.0  # Convert reward to area in m2

                g_process_rewards += g_reward.cpu().numpy()
                g_total_rewards = g_process_rewards * \
                                  (1 - g_masks.cpu().numpy())
                g_process_rewards *= g_masks.cpu().numpy()
                per_step_g_rewards.append(np.mean(g_reward.cpu().numpy()))

                if np.sum(g_total_rewards) != 0:
                    for tr in g_total_rewards:
                        g_episode_rewards.append(tr) if tr != 0 else None

                if args.eval:
                    exp_ratio = torch.from_numpy(
                        np.asarray([
                            infos[env_idx]['exp_ratio']
                            for env_idx in range(num_scenes)
                        ])).float()

                    for e in range(num_scenes):
                        explored_area_log[e, ep_num, eval_g_step - 1] = \
                            explored_area_log[e, ep_num, eval_g_step - 2] + \
                            g_reward[e].cpu().numpy()
                        explored_ratio_log[e, ep_num, eval_g_step - 1] = \
                            explored_ratio_log[e, ep_num, eval_g_step - 2] + \
                            exp_ratio[e].cpu().numpy()

                # Add samples to global policy storage
                g_rollouts.insert(global_input, g_rec_states, g_action,
                                  g_action_log_prob, g_value, g_reward,
                                  g_masks, global_orientation)

                # Sample long-term goal from global policy
                g_value, g_action, g_action_log_prob, g_rec_states = \
                    g_policy.act(
                        g_rollouts.obs[g_step + 1],
                        g_rollouts.rec_states[g_step + 1],
                        g_rollouts.masks[g_step + 1],
                        extras=g_rollouts.extras[g_step + 1],
                        deterministic=False
                    )
                cpu_actions = nn.Sigmoid()(g_action).cpu().numpy()
                global_goals = [[
                    int(action[0] * local_w),
                    int(action[1] * local_h)
                ] for action in cpu_actions]

                g_reward = 0
                g_masks = torch.ones(num_scenes).float().to(device)
            # ------------------------------------------------------------------

            # ------------------------------------------------------------------
            # Get short term goal
            planner_inputs = [{} for e in range(num_scenes)]
            for e, p_input in enumerate(planner_inputs):
                p_input['map_pred'] = local_map[e, 0, :, :].cpu().numpy()
                p_input['exp_pred'] = local_map[e, 1, :, :].cpu().numpy()
                p_input['pose_pred'] = planner_pose_inputs[e]
                p_input['goal'] = global_goals[e]

            output = envs.get_short_term_goal(planner_inputs)

            # print(f"\n output (short term goal): {output}\n")

            # ------------------------------------------------------------------

            ### TRAINING
            torch.set_grad_enabled(True)
            # ------------------------------------------------------------------
            # Train Neural SLAM Module
            if args.train_slam and len(slam_memory) > args.slam_batch_size:
                for _ in range(args.slam_iterations):
                    inputs, outputs = slam_memory.sample(args.slam_batch_size)
                    b_obs_last, b_obs, b_poses = inputs
                    gt_fp_projs, gt_fp_explored, gt_pose_err = outputs

                    b_obs = b_obs.to(device)
                    b_obs_last = b_obs_last.to(device)
                    b_poses = b_poses.to(device)

                    gt_fp_projs = gt_fp_projs.to(device)
                    gt_fp_explored = gt_fp_explored.to(device)
                    gt_pose_err = gt_pose_err.to(device)

                    b_proj_pred, b_fp_exp_pred, _, _, b_pose_err_pred, _ = \
                        nslam_module(b_obs_last, b_obs, b_poses,
                                     None, None, None,
                                     build_maps=False)
                    loss = 0
                    if args.proj_loss_coeff > 0:
                        proj_loss = F.binary_cross_entropy(
                            b_proj_pred, gt_fp_projs)
                        costs.append(proj_loss.item())
                        loss += args.proj_loss_coeff * proj_loss

                    if args.exp_loss_coeff > 0:
                        exp_loss = F.binary_cross_entropy(
                            b_fp_exp_pred, gt_fp_explored)
                        exp_costs.append(exp_loss.item())
                        loss += args.exp_loss_coeff * exp_loss

                    if args.pose_loss_coeff > 0:
                        pose_loss = torch.nn.MSELoss()(b_pose_err_pred,
                                                       gt_pose_err)
                        pose_costs.append(args.pose_loss_coeff *
                                          pose_loss.item())
                        loss += args.pose_loss_coeff * pose_loss

                    if args.train_slam:
                        slam_optimizer.zero_grad()
                        loss.backward()
                        slam_optimizer.step()

                    del b_obs_last, b_obs, b_poses
                    del gt_fp_projs, gt_fp_explored, gt_pose_err
                    del b_proj_pred, b_fp_exp_pred, b_pose_err_pred

            # ------------------------------------------------------------------

            # ------------------------------------------------------------------
            # Train Local Policy
            if (l_step + 1) % args.local_policy_update_freq == 0 \
                    and args.train_local:
                local_optimizer.zero_grad()
                policy_loss.backward()
                local_optimizer.step()
                l_action_losses.append(policy_loss.item())
                policy_loss = 0
                local_rec_states = local_rec_states.detach_()
            # ------------------------------------------------------------------

            # ------------------------------------------------------------------
            # Train Global Policy
            if g_step % args.num_global_steps == args.num_global_steps - 1 \
                    and l_step == args.num_local_steps - 1:
                if args.train_global:
                    g_next_value = g_policy.get_value(
                        g_rollouts.obs[-1],
                        g_rollouts.rec_states[-1],
                        g_rollouts.masks[-1],
                        extras=g_rollouts.extras[-1]).detach()

                    g_rollouts.compute_returns(g_next_value, args.use_gae,
                                               args.gamma, args.tau)
                    g_value_loss, g_action_loss, g_dist_entropy = \
                        g_agent.update(g_rollouts)
                    g_value_losses.append(g_value_loss)
                    g_action_losses.append(g_action_loss)
                    g_dist_entropies.append(g_dist_entropy)
                g_rollouts.after_update()
            # ------------------------------------------------------------------

            # Finish Training
            torch.set_grad_enabled(False)
            # ------------------------------------------------------------------

            # ------------------------------------------------------------------
            # Logging
            if total_num_steps % args.log_interval == 0:
                end = time.time()
                time_elapsed = time.gmtime(end - start)
                log = " ".join([
                    "Time: {0:0=2d}d".format(time_elapsed.tm_mday - 1),
                    "{},".format(time.strftime("%Hh %Mm %Ss", time_elapsed)),
                    "num timesteps {},".format(total_num_steps *
                                               num_scenes),
                    "FPS {},".format(int(total_num_steps * num_scenes \
                                         / (end - start)))
                ])

                log += "\n\tRewards:"

                if len(g_episode_rewards) > 0:
                    log += " ".join([
                        " Global step mean/med rew:",
                        "{:.4f}/{:.4f},".format(np.mean(per_step_g_rewards),
                                                np.median(per_step_g_rewards)),
                        " Global eps mean/med/min/max eps rew:",
                        "{:.3f}/{:.3f}/{:.3f}/{:.3f},".format(
                            np.mean(g_episode_rewards),
                            np.median(g_episode_rewards),
                            np.min(g_episode_rewards),
                            np.max(g_episode_rewards))
                    ])

                log += "\n\tLosses:"

                if args.train_local and len(l_action_losses) > 0:
                    log += " ".join([
                        " Local Loss:",
                        "{:.3f},".format(np.mean(l_action_losses))
                    ])

                if args.train_global and len(g_value_losses) > 0:
                    log += " ".join([
                        " Global Loss value/action/dist:",
                        "{:.3f}/{:.3f}/{:.3f},".format(
                            np.mean(g_value_losses), np.mean(g_action_losses),
                            np.mean(g_dist_entropies))
                    ])

                if args.train_slam and len(costs) > 0:
                    log += " ".join([
                        " SLAM Loss proj/exp/pose:"
                        "{:.4f}/{:.4f}/{:.4f}".format(np.mean(costs),
                                                      np.mean(exp_costs),
                                                      np.mean(pose_costs))
                    ])

                print(log)
                logging.info(log)
            # ------------------------------------------------------------------

            # ------------------------------------------------------------------
            # Save best models
            if (total_num_steps * num_scenes) % args.save_interval < \
                    num_scenes:

                # Save Neural SLAM Model
                if len(costs) >= 1000 and np.mean(costs) < best_cost \
                        and not args.eval:
                    best_cost = np.mean(costs)
                    torch.save(nslam_module.state_dict(),
                               os.path.join(log_dir, "model_best.slam"))

                # Save Local Policy Model
                if len(l_action_losses) >= 100 and \
                        (np.mean(l_action_losses) <= best_local_loss) \
                        and not args.eval:
                    torch.save(l_policy.state_dict(),
                               os.path.join(log_dir, "model_best.local"))

                    best_local_loss = np.mean(l_action_losses)

                # Save Global Policy Model
                if len(g_episode_rewards) >= 100 and \
                        (np.mean(g_episode_rewards) >= best_g_reward) \
                        and not args.eval:
                    torch.save(g_policy.state_dict(),
                               os.path.join(log_dir, "model_best.global"))
                    best_g_reward = np.mean(g_episode_rewards)

            # Save periodic models
            if (total_num_steps * num_scenes) % args.save_periodic < \
                    num_scenes:
                step = total_num_steps * num_scenes
                if args.train_slam:
                    torch.save(
                        nslam_module.state_dict(),
                        os.path.join(dump_dir,
                                     "periodic_{}.slam".format(step)))
                if args.train_local:
                    torch.save(
                        l_policy.state_dict(),
                        os.path.join(dump_dir,
                                     "periodic_{}.local".format(step)))
                if args.train_global:
                    torch.save(
                        g_policy.state_dict(),
                        os.path.join(dump_dir,
                                     "periodic_{}.global".format(step)))
            # ------------------------------------------------------------------
    print("Finishing Epsiods")

    # Print and save model performance numbers during evaluation
    if args.eval:
        logfile = open("{}/explored_area.txt".format(dump_dir), "w+")
        for e in range(num_scenes):
            for i in range(explored_area_log[e].shape[0]):
                logfile.write(str(explored_area_log[e, i]) + "\n")
                logfile.flush()

        logfile.close()

        logfile = open("{}/explored_ratio.txt".format(dump_dir), "w+")
        for e in range(num_scenes):
            for i in range(explored_ratio_log[e].shape[0]):
                logfile.write(str(explored_ratio_log[e, i]) + "\n")
                logfile.flush()

        logfile.close()

        log = "Final Exp Area: \n"
        for i in range(explored_area_log.shape[2]):
            log += "{:.5f}, ".format(np.mean(explored_area_log[:, :, i]))

        log += "\nFinal Exp Ratio: \n"
        for i in range(explored_ratio_log.shape[2]):
            log += "{:.5f}, ".format(np.mean(explored_ratio_log[:, :, i]))

        print(log)
        logging.info(log)

    imgs_1 = local_map[0, :, :, :].cpu().numpy()
    imgs_2 = local_map[1, :, :, :].cpu().numpy()

    obs_all = _process_obs_for_display(obs)

    # fig, axis = plt.subplots(1, 3)
    # axis[0].imshow(obs_all[0])
    # axis[1].imshow(imgs_1[0], cmap='gray')
    # axis[2].imshow(imgs_1[1], cmap='gray')
    return

    cv2.imshow("Camer", transform_rgb_bgr(obs_all[0]))
    cv2.imshow("Proj", imgs_1[0])
    cv2.imshow("Map", imgs_1[1])

    cv2.imshow("Camer2", transform_rgb_bgr(obs_all[1]))
    cv2.imshow("Proj2", imgs_2[0])
    cv2.imshow("Map2", imgs_2[1])

    action = 1
    while action != 4:
        k = cv2.waitKey(0)
        if k == 119:
            action = 1
            action_2 = 1
        elif k == 100:
            action = 3
            action_2 = 1
        elif k == 97:
            action = 2
            action_2 = 2
        elif k == 102:
            action = 4
            break
        else:
            action = 1

        last_obs = obs.detach()

        obs, rew, done, infos = envs.step(
            torch.from_numpy(np.array([action, action_2])))

        obs_all = _process_obs_for_display(obs)
        cv2.imshow("Camer", transform_rgb_bgr(obs_all[0]))
        cv2.imshow("Camer2", transform_rgb_bgr(obs_all[1]))

        poses = torch.from_numpy(
            np.asarray([
                infos[env_idx]['sensor_pose'] for env_idx in range(num_scenes)
            ])).float().to(device)

        _, _, local_map[:, 0, :, :], local_map[:, 1, :, :], _, local_pose = \
            nslam_module(last_obs, obs, poses, local_map[:, 0, :, :],
                            local_map[:, 1, :, :], local_pose, build_maps=True)

        imgs_1 = local_map[0, :, :, :].cpu().numpy()
        imgs_2 = local_map[1, :, :, :].cpu().numpy()
        cv2.imshow("Proj", imgs_1[0])
        cv2.imshow("Map", imgs_1[1])
        cv2.imshow("Proj2", imgs_2[0])
        cv2.imshow("Map2", imgs_2[1])

    # plt.show()

    print("\n\nDone\n\n")