Exemplo n.º 1
0
class PVNBaselineTrainer(Trainer):
    """ Train the PVN Baseline """

    def __init__(self, args):
        super(PVNBaselineTrainer, self).__init__(args, filepath=None)

        self.mapper = Mapper(self.args).to(self.args.device)
        self.navigator = UnetNavigator(self.args).to(self.args.device)
        self.navigator.init_weights()

        self.goal_criterion = MaskedMSELoss()
        self.path_criterion = MaskedMSELoss()

    def validate(self, split):
        """
        split: "val_seen" or "val_unseen"
        """

        with torch.no_grad():

            if(split=="val_seen"):
                val_dataloader = self.valseendata
            elif(split=="val_unseen"):
                val_dataloader = self.valunseendata

            total_loss = torch.tensor(0.0)

            for it in range(self.args.validation_iterations):

                # Load minibatch and simulator
                seq, seq_mask, seq_lens, batch = val_dataloader.get_batch()
                self.sim.newEpisode(batch)

                if self.args.multi_maps:
                    all_spatial_maps = []
                    all_masks = []

                # Run the mapper for one step
                xyzhe = self.sim.getXYZHE()
                spatial_map,mask = self.mapper.init_map(xyzhe)
                for t in range(self.args.timesteps):
                    rgb,depth,states = self.sim.getPanos()
                    spatial_map,mask,ftm = self.mapper(rgb,depth,states,spatial_map,mask)
                    if self.args.multi_maps:
                        all_spatial_maps.append(spatial_map.unsqueeze(1))
                        all_masks.append(mask.unsqueeze(1))
                    if self.args.timesteps != 1:
                        self.sim.takePseudoSupervisedAction(val_dataloader.get_supervision)

                if self.args.multi_maps:
                    spatial_map = torch.cat(all_spatial_maps, dim=1).flatten(0, 1)
                    mask = torch.cat(all_masks, dim=1).flatten(0, 1)

                    seq = seq.unsqueeze(1).expand(-1, self.args.timesteps, -1).flatten(0, 1)
                    seq_lens = seq_lens.unsqueeze(1).expand(-1, self.args.timesteps).flatten(0, 1)

                # Predict with unet
                pred = self.navigator(seq, seq_lens, spatial_map)
                goal_pred = pred[:,0,:,:]
                path_pred = pred[:,1,:,:]

                goal_map = self.mapper.heatmap(val_dataloader.goal_coords(), self.args.goal_heatmap_sigma)
                path_map = self.mapper.heatmap(val_dataloader.path_coords(), self.args.path_heatmap_sigma)

                goal_map_max, _ = torch.max(goal_map, dim=2, keepdim=True)
                goal_map_max, _ = torch.max(goal_map_max, dim=1, keepdim=True)
                goal_map /= goal_map_max

                path_map_max, _ = torch.max(path_map, dim=2, keepdim=True)
                path_map_max, _ = torch.max(path_map_max, dim=1, keepdim=True)
                path_map /= path_map_max

                if self.args.belief_downsample_factor > 1:
                    goal_map = nn.functional.avg_pool2d(goal_map, kernel_size=self.args.belief_downsample_factor, stride=self.args.belief_downsample_factor)
                    path_map = nn.functional.avg_pool2d(path_map, kernel_size=self.args.belief_downsample_factor, stride=self.args.belief_downsample_factor)
                    mask = nn.functional.max_pool2d(mask, kernel_size=self.args.belief_downsample_factor, stride=self.args.belief_downsample_factor)

                if self.args.multi_maps:
                    goal_map = goal_map.unsqueeze(1).expand(-1, self.args.timesteps, -1, -1).flatten(0, 1)
                    path_map = path_map.unsqueeze(1).expand(-1, self.args.timesteps, -1, -1).flatten(0, 1)

                loss = self.goal_criterion(goal_pred,goal_map) + self.path_criterion(path_pred,path_map)

                total_loss += (1.0 / self.args.validation_iterations) * loss

        return total_loss

    def logging_loop(self, it):
        """ Logging and checkpointing stuff """

        if it % self.args.validate_every == 0:

            self.mapper.eval()
            self.navigator.eval()

            loss_val_seen = self.validate("val_seen")
            loss_val_unseen = self.validate("val_unseen")

            self.prev_time, time_taken = utils.time_it(self.prev_time)

            print("Iteration: %d Loss: %f Val Seen Loss: %f Val Unseen Loss: %f Time: %0.2f secs"
                    %(it, self.loss.item(), loss_val_seen.item(), loss_val_unseen.item(), time_taken))

            if self.visdom:
                # visdom: X, Y, key, line_name, x_label, y_label, fig_title
                self.visdom.line(it, self.loss.item(), "train_loss", "Train Loss", "Iterations", "Loss", title=" Train Phase")
                self.visdom.line(it, loss_val_seen.item(), "val_loss", "Val Seen Loss", "Iterations", "Loss", title="Val Phase")
                self.visdom.line(it, loss_val_unseen.item(), "val_loss", "Val Unseen Loss", "Iterations", "Loss", title="Val Phase")

            self.mapper.train()
            self.navigator.train()

        elif it % self.args.log_every == 0:
            self.prev_time, time_taken = utils.time_it(self.prev_time)
            print("Iteration: %d Loss: %f Time: %0.2f secs" % (it, self.loss.item(), time_taken))

            if self.visdom:
                self.visdom.line(it, self.loss.item(), "train_loss", "Train Loss", "Iterations", "Loss", title="Train Phase")

        if it % self.args.checkpoint_every == 0:
            saver = {"mapper": self.mapper.state_dict(),
                     "navigator": self.navigator.state_dict(),
                     "args": self.args}
            dir = "%s/%s" % (self.args.snapshot_dir, self.args.exp_name)
            if not os.path.exists(dir):
                os.makedirs(dir)
            torch.save(
                saver, "%s/%s_%d" % (dir, self.args.exp_name, it)
            )
            if self.visdom:
                self.visdom.save()

    def plot_grad_flow(self, named_parameters):
        ave_grads = []
        layers = []

        for n, p in named_parameters:
            if(p.requires_grad) and ("bias" not in n):
                layers.append(n)
                ave_grads.append(p.grad.abs().mean())
        plt.plot(ave_grads, alpha=0.3, color="b")
        plt.hlines(0, 0, len(ave_grads)+1, linewidth=1, color="k" )
        plt.xticks(range(0,len(ave_grads), 1), layers, rotation="vertical")
        plt.xlim(xmin=0, xmax=len(ave_grads))
        plt.xlabel("Layers")
        plt.ylabel("average gradient")
        plt.title("Gradient flow")
        plt.grid(True)
        plt.savefig("gradient.png")

    def train(self):
        """
        Supervised training of the mapper and navigator
        """

        self.prev_time = time.time()

        map_opt = optim.Adam(self.mapper.parameters(), lr=self.args.learning_rate,
            weight_decay=self.args.weight_decay)
        nav_opt = optim.Adam(self.navigator.parameters(), lr=self.args.learning_rate,
            weight_decay=self.args.weight_decay)

        # Set models to train phase
        self.mapper.train()
        self.navigator.train()

        for it in range(1, self.args.max_iterations + 1):
            map_opt.zero_grad()
            nav_opt.zero_grad()

            # Load minibatch and simulator
            seq, seq_mask, seq_lens, batch = self.traindata.get_batch()
            self.sim.newEpisode(batch)

            if self.args.multi_maps:
                all_spatial_maps = []
                all_masks = []

            # Run the mapper for multiple timesteps and keep each map
            xyzhe = self.sim.getXYZHE()
            spatial_map,mask = self.mapper.init_map(xyzhe)
            for t in range(self.args.timesteps):
                rgb,depth,states = self.sim.getPanos()
                spatial_map,mask,ftm = self.mapper(rgb,depth,states,spatial_map,mask)
                if self.args.multi_maps:
                    all_spatial_maps.append(spatial_map.unsqueeze(1))
                    all_masks.append(mask.unsqueeze(1))
                if self.args.timesteps != 1:
                    self.sim.takePseudoSupervisedAction(self.traindata.get_supervision)

            if self.args.multi_maps:
                spatial_map = torch.cat(all_spatial_maps, dim=1).flatten(0, 1)
                mask = torch.cat(all_masks, dim=1).flatten(0, 1)

                seq = seq.unsqueeze(1).expand(-1, self.args.timesteps, -1).flatten(0, 1)
                seq_lens = seq_lens.unsqueeze(1).expand(-1, self.args.timesteps).flatten(0, 1)

            # Predict with unet
            pred = self.navigator(seq, seq_lens, spatial_map)

            goal_pred = pred[:,0,:,:]
            path_pred = pred[:,1,:,:]

            goal_map = self.mapper.heatmap(self.traindata.goal_coords(), self.args.goal_heatmap_sigma)
            path_map = self.mapper.heatmap(self.traindata.path_coords(), self.args.path_heatmap_sigma)

            goal_map_max, _ = torch.max(goal_map, dim=2, keepdim=True)
            goal_map_max, _ = torch.max(goal_map_max, dim=1, keepdim=True)
            goal_map /= goal_map_max

            path_map_max, _ = torch.max(path_map, dim=2, keepdim=True)
            path_map_max, _ = torch.max(path_map_max, dim=1, keepdim=True)
            path_map /= path_map_max

            if self.args.belief_downsample_factor > 1:
                goal_map = nn.functional.avg_pool2d(goal_map, kernel_size=self.args.belief_downsample_factor, stride=self.args.belief_downsample_factor)
                path_map = nn.functional.avg_pool2d(path_map, kernel_size=self.args.belief_downsample_factor, stride=self.args.belief_downsample_factor)
                mask = nn.functional.max_pool2d(mask, kernel_size=self.args.belief_downsample_factor, stride=self.args.belief_downsample_factor)

            if self.args.multi_maps:
                goal_map = goal_map.unsqueeze(1).expand(-1, self.args.timesteps, -1, -1).flatten(0, 1)
                path_map = path_map.unsqueeze(1).expand(-1, self.args.timesteps, -1, -1).flatten(0, 1)

            self.loss = self.goal_criterion(goal_pred,goal_map) + self.path_criterion(path_pred,path_map)

            self.loss.backward()
            map_opt.step()
            nav_opt.step()
            self.logging_loop(it)
Exemplo n.º 2
0
    def evaluate(self, split):

        self.set_models_eval()

        with torch.no_grad():
            if split == "val_seen":
                self.dataloader = self.valseendata
            elif split == "val_unseen":
                self.dataloader = self.valunseendata

            iterations = int(
                math.ceil(
                    len(self.dataloader.data) / float(self.args.batch_size)))
            last_batch_valid_idx = len(
                self.dataloader.data) - (iterations - 1) * self.args.batch_size

            timestep_nav_error = [[] for i in range(self.args.timesteps)]
            timestep_success_rate = [[] for i in range(self.args.timesteps)]
            average_nav_error = []
            average_success_rate = []

            timestep_map_coverage = [[] for i in range(self.args.timesteps)]
            timestep_goal_seen = [[] for i in range(self.args.timesteps)]
            average_map_coverage = []
            average_goal_seen = []

            mapper = Mapper(self.args).to(self.args.device)

            for it in tqdm(range(iterations),
                           desc="Evaluation Progress for %s split" % split):

                valid_batch_len = last_batch_valid_idx if it == iterations - 1 else self.args.batch_size

                seq, seq_mask, seq_lens, batch = self.dataloader.get_batch()

                self.sim.newEpisode(batch)
                self.floorplan_images = utils.get_floorplan_images(
                    self.sim.getState(),
                    self.floorplan,
                    self.args.map_range_x,
                    self.args.map_range_y,
                    scale_factor=self.args.debug_scale)

                xyzhe = self.sim.getXYZHE()
                mapper.init_map(xyzhe)
                goal_pos = self.dataloader.get_goal_coords_on_map_grid(
                    mapper.map_center)

                pred_goal_map, pred_path_map, mask = self.get_predictions(
                    seq, seq_mask, seq_lens, batch, xyzhe,
                    self.simulator_next_action)

                # shape: (batch_size, 2)
                self.goal_map = mapper.heatmap(self.dataloader.goal_coords(),
                                               self.args.goal_heatmap_sigma)
                self.path_map = mapper.heatmap(self.dataloader.path_coords(),
                                               self.args.path_heatmap_sigma)

                if self.args.multi_maps:
                    # shape: (batch_size*timesteps, 2)
                    self.goal_map = self.goal_map.unsqueeze(1).expand(
                        -1, self.args.timesteps, -1, -1).flatten(0, 1)
                    self.path_map = self.path_map.unsqueeze(1).expand(
                        -1, self.args.timesteps, -1, -1).flatten(0, 1)

                # shape: (batch_size*timesteps, 2)
                goal_pred_argmax = utils.compute_argmax(pred_goal_map).flip(1)

                # shape: (batch_size, timesteps, 2)
                goal_pred_argmax = goal_pred_argmax.reshape(
                    -1, self.args.timesteps, 2)
                goal_pred_xy = self.dataloader.convert_map_pixels_to_xy_coords(
                    goal_pred_argmax,
                    mapper.map_center,
                    multi_timestep_input=True)

                # shape: (batch_size, 2)
                goal_target_xy = self.dataloader.goal_coords()
                # shape: (batch_size, timesteps, 2)
                goal_target_xy = goal_target_xy.unsqueeze(1).expand(
                    -1, self.args.timesteps, -1)

                # shape: (batch_size, timesteps, map_range_y, map_range_x)
                b_t_mask = mask.reshape(-1, self.args.timesteps,
                                        self.args.map_range_y,
                                        self.args.map_range_x)

                batch_timestep_map_coverage, batch_average_map_coverage = \
                    metrics.map_coverage(b_t_mask)
                batch_timestep_goal_seen, batch_average_goal_seen = \
                    metrics.goal_seen_rate(b_t_mask, goal_pos, self.args)
                batch_timestep_nav_error, batch_average_nav_error = \
                    metrics.nav_error(goal_target_xy, goal_pred_xy, self.args)
                batch_timestep_success_rate, batch_average_success_rate = \
                    metrics.success_rate(goal_target_xy, goal_pred_xy, self.args)

                for n in range(valid_batch_len):
                    average_nav_error.append(batch_average_nav_error[n].item())
                    average_success_rate.append(
                        batch_average_success_rate[n].item())
                    average_map_coverage.append(
                        batch_average_map_coverage[n].item())
                    average_goal_seen.append(batch_average_goal_seen[n].item())

                    for t in range(batch_timestep_map_coverage.shape[1]):
                        timestep_nav_error[t].append(
                            batch_timestep_nav_error[n][t].item())
                        timestep_success_rate[t].append(
                            batch_timestep_success_rate[n][t].item())
                        timestep_map_coverage[t].append(
                            batch_timestep_map_coverage[n][t].item())
                        timestep_goal_seen[t].append(
                            batch_timestep_goal_seen[n][t].item())

            self.eval_logging(split, timestep_nav_error, average_nav_error,
                              timestep_success_rate, average_success_rate,
                              timestep_map_coverage, average_map_coverage,
                              timestep_goal_seen, average_goal_seen)