class PVNBaselineTrainer(Trainer): """ Train the PVN Baseline """ def __init__(self, args): super(PVNBaselineTrainer, self).__init__(args, filepath=None) self.mapper = Mapper(self.args).to(self.args.device) self.navigator = UnetNavigator(self.args).to(self.args.device) self.navigator.init_weights() self.goal_criterion = MaskedMSELoss() self.path_criterion = MaskedMSELoss() def validate(self, split): """ split: "val_seen" or "val_unseen" """ with torch.no_grad(): if(split=="val_seen"): val_dataloader = self.valseendata elif(split=="val_unseen"): val_dataloader = self.valunseendata total_loss = torch.tensor(0.0) for it in range(self.args.validation_iterations): # Load minibatch and simulator seq, seq_mask, seq_lens, batch = val_dataloader.get_batch() self.sim.newEpisode(batch) if self.args.multi_maps: all_spatial_maps = [] all_masks = [] # Run the mapper for one step xyzhe = self.sim.getXYZHE() spatial_map,mask = self.mapper.init_map(xyzhe) for t in range(self.args.timesteps): rgb,depth,states = self.sim.getPanos() spatial_map,mask,ftm = self.mapper(rgb,depth,states,spatial_map,mask) if self.args.multi_maps: all_spatial_maps.append(spatial_map.unsqueeze(1)) all_masks.append(mask.unsqueeze(1)) if self.args.timesteps != 1: self.sim.takePseudoSupervisedAction(val_dataloader.get_supervision) if self.args.multi_maps: spatial_map = torch.cat(all_spatial_maps, dim=1).flatten(0, 1) mask = torch.cat(all_masks, dim=1).flatten(0, 1) seq = seq.unsqueeze(1).expand(-1, self.args.timesteps, -1).flatten(0, 1) seq_lens = seq_lens.unsqueeze(1).expand(-1, self.args.timesteps).flatten(0, 1) # Predict with unet pred = self.navigator(seq, seq_lens, spatial_map) goal_pred = pred[:,0,:,:] path_pred = pred[:,1,:,:] goal_map = self.mapper.heatmap(val_dataloader.goal_coords(), self.args.goal_heatmap_sigma) path_map = self.mapper.heatmap(val_dataloader.path_coords(), self.args.path_heatmap_sigma) goal_map_max, _ = torch.max(goal_map, dim=2, keepdim=True) goal_map_max, _ = torch.max(goal_map_max, dim=1, keepdim=True) goal_map /= goal_map_max path_map_max, _ = torch.max(path_map, dim=2, keepdim=True) path_map_max, _ = torch.max(path_map_max, dim=1, keepdim=True) path_map /= path_map_max if self.args.belief_downsample_factor > 1: goal_map = nn.functional.avg_pool2d(goal_map, kernel_size=self.args.belief_downsample_factor, stride=self.args.belief_downsample_factor) path_map = nn.functional.avg_pool2d(path_map, kernel_size=self.args.belief_downsample_factor, stride=self.args.belief_downsample_factor) mask = nn.functional.max_pool2d(mask, kernel_size=self.args.belief_downsample_factor, stride=self.args.belief_downsample_factor) if self.args.multi_maps: goal_map = goal_map.unsqueeze(1).expand(-1, self.args.timesteps, -1, -1).flatten(0, 1) path_map = path_map.unsqueeze(1).expand(-1, self.args.timesteps, -1, -1).flatten(0, 1) loss = self.goal_criterion(goal_pred,goal_map) + self.path_criterion(path_pred,path_map) total_loss += (1.0 / self.args.validation_iterations) * loss return total_loss def logging_loop(self, it): """ Logging and checkpointing stuff """ if it % self.args.validate_every == 0: self.mapper.eval() self.navigator.eval() loss_val_seen = self.validate("val_seen") loss_val_unseen = self.validate("val_unseen") self.prev_time, time_taken = utils.time_it(self.prev_time) print("Iteration: %d Loss: %f Val Seen Loss: %f Val Unseen Loss: %f Time: %0.2f secs" %(it, self.loss.item(), loss_val_seen.item(), loss_val_unseen.item(), time_taken)) if self.visdom: # visdom: X, Y, key, line_name, x_label, y_label, fig_title self.visdom.line(it, self.loss.item(), "train_loss", "Train Loss", "Iterations", "Loss", title=" Train Phase") self.visdom.line(it, loss_val_seen.item(), "val_loss", "Val Seen Loss", "Iterations", "Loss", title="Val Phase") self.visdom.line(it, loss_val_unseen.item(), "val_loss", "Val Unseen Loss", "Iterations", "Loss", title="Val Phase") self.mapper.train() self.navigator.train() elif it % self.args.log_every == 0: self.prev_time, time_taken = utils.time_it(self.prev_time) print("Iteration: %d Loss: %f Time: %0.2f secs" % (it, self.loss.item(), time_taken)) if self.visdom: self.visdom.line(it, self.loss.item(), "train_loss", "Train Loss", "Iterations", "Loss", title="Train Phase") if it % self.args.checkpoint_every == 0: saver = {"mapper": self.mapper.state_dict(), "navigator": self.navigator.state_dict(), "args": self.args} dir = "%s/%s" % (self.args.snapshot_dir, self.args.exp_name) if not os.path.exists(dir): os.makedirs(dir) torch.save( saver, "%s/%s_%d" % (dir, self.args.exp_name, it) ) if self.visdom: self.visdom.save() def plot_grad_flow(self, named_parameters): ave_grads = [] layers = [] for n, p in named_parameters: if(p.requires_grad) and ("bias" not in n): layers.append(n) ave_grads.append(p.grad.abs().mean()) plt.plot(ave_grads, alpha=0.3, color="b") plt.hlines(0, 0, len(ave_grads)+1, linewidth=1, color="k" ) plt.xticks(range(0,len(ave_grads), 1), layers, rotation="vertical") plt.xlim(xmin=0, xmax=len(ave_grads)) plt.xlabel("Layers") plt.ylabel("average gradient") plt.title("Gradient flow") plt.grid(True) plt.savefig("gradient.png") def train(self): """ Supervised training of the mapper and navigator """ self.prev_time = time.time() map_opt = optim.Adam(self.mapper.parameters(), lr=self.args.learning_rate, weight_decay=self.args.weight_decay) nav_opt = optim.Adam(self.navigator.parameters(), lr=self.args.learning_rate, weight_decay=self.args.weight_decay) # Set models to train phase self.mapper.train() self.navigator.train() for it in range(1, self.args.max_iterations + 1): map_opt.zero_grad() nav_opt.zero_grad() # Load minibatch and simulator seq, seq_mask, seq_lens, batch = self.traindata.get_batch() self.sim.newEpisode(batch) if self.args.multi_maps: all_spatial_maps = [] all_masks = [] # Run the mapper for multiple timesteps and keep each map xyzhe = self.sim.getXYZHE() spatial_map,mask = self.mapper.init_map(xyzhe) for t in range(self.args.timesteps): rgb,depth,states = self.sim.getPanos() spatial_map,mask,ftm = self.mapper(rgb,depth,states,spatial_map,mask) if self.args.multi_maps: all_spatial_maps.append(spatial_map.unsqueeze(1)) all_masks.append(mask.unsqueeze(1)) if self.args.timesteps != 1: self.sim.takePseudoSupervisedAction(self.traindata.get_supervision) if self.args.multi_maps: spatial_map = torch.cat(all_spatial_maps, dim=1).flatten(0, 1) mask = torch.cat(all_masks, dim=1).flatten(0, 1) seq = seq.unsqueeze(1).expand(-1, self.args.timesteps, -1).flatten(0, 1) seq_lens = seq_lens.unsqueeze(1).expand(-1, self.args.timesteps).flatten(0, 1) # Predict with unet pred = self.navigator(seq, seq_lens, spatial_map) goal_pred = pred[:,0,:,:] path_pred = pred[:,1,:,:] goal_map = self.mapper.heatmap(self.traindata.goal_coords(), self.args.goal_heatmap_sigma) path_map = self.mapper.heatmap(self.traindata.path_coords(), self.args.path_heatmap_sigma) goal_map_max, _ = torch.max(goal_map, dim=2, keepdim=True) goal_map_max, _ = torch.max(goal_map_max, dim=1, keepdim=True) goal_map /= goal_map_max path_map_max, _ = torch.max(path_map, dim=2, keepdim=True) path_map_max, _ = torch.max(path_map_max, dim=1, keepdim=True) path_map /= path_map_max if self.args.belief_downsample_factor > 1: goal_map = nn.functional.avg_pool2d(goal_map, kernel_size=self.args.belief_downsample_factor, stride=self.args.belief_downsample_factor) path_map = nn.functional.avg_pool2d(path_map, kernel_size=self.args.belief_downsample_factor, stride=self.args.belief_downsample_factor) mask = nn.functional.max_pool2d(mask, kernel_size=self.args.belief_downsample_factor, stride=self.args.belief_downsample_factor) if self.args.multi_maps: goal_map = goal_map.unsqueeze(1).expand(-1, self.args.timesteps, -1, -1).flatten(0, 1) path_map = path_map.unsqueeze(1).expand(-1, self.args.timesteps, -1, -1).flatten(0, 1) self.loss = self.goal_criterion(goal_pred,goal_map) + self.path_criterion(path_pred,path_map) self.loss.backward() map_opt.step() nav_opt.step() self.logging_loop(it)
def evaluate(self, split): self.set_models_eval() with torch.no_grad(): if split == "val_seen": self.dataloader = self.valseendata elif split == "val_unseen": self.dataloader = self.valunseendata iterations = int( math.ceil( len(self.dataloader.data) / float(self.args.batch_size))) last_batch_valid_idx = len( self.dataloader.data) - (iterations - 1) * self.args.batch_size timestep_nav_error = [[] for i in range(self.args.timesteps)] timestep_success_rate = [[] for i in range(self.args.timesteps)] average_nav_error = [] average_success_rate = [] timestep_map_coverage = [[] for i in range(self.args.timesteps)] timestep_goal_seen = [[] for i in range(self.args.timesteps)] average_map_coverage = [] average_goal_seen = [] mapper = Mapper(self.args).to(self.args.device) for it in tqdm(range(iterations), desc="Evaluation Progress for %s split" % split): valid_batch_len = last_batch_valid_idx if it == iterations - 1 else self.args.batch_size seq, seq_mask, seq_lens, batch = self.dataloader.get_batch() self.sim.newEpisode(batch) self.floorplan_images = utils.get_floorplan_images( self.sim.getState(), self.floorplan, self.args.map_range_x, self.args.map_range_y, scale_factor=self.args.debug_scale) xyzhe = self.sim.getXYZHE() mapper.init_map(xyzhe) goal_pos = self.dataloader.get_goal_coords_on_map_grid( mapper.map_center) pred_goal_map, pred_path_map, mask = self.get_predictions( seq, seq_mask, seq_lens, batch, xyzhe, self.simulator_next_action) # shape: (batch_size, 2) self.goal_map = mapper.heatmap(self.dataloader.goal_coords(), self.args.goal_heatmap_sigma) self.path_map = mapper.heatmap(self.dataloader.path_coords(), self.args.path_heatmap_sigma) if self.args.multi_maps: # shape: (batch_size*timesteps, 2) self.goal_map = self.goal_map.unsqueeze(1).expand( -1, self.args.timesteps, -1, -1).flatten(0, 1) self.path_map = self.path_map.unsqueeze(1).expand( -1, self.args.timesteps, -1, -1).flatten(0, 1) # shape: (batch_size*timesteps, 2) goal_pred_argmax = utils.compute_argmax(pred_goal_map).flip(1) # shape: (batch_size, timesteps, 2) goal_pred_argmax = goal_pred_argmax.reshape( -1, self.args.timesteps, 2) goal_pred_xy = self.dataloader.convert_map_pixels_to_xy_coords( goal_pred_argmax, mapper.map_center, multi_timestep_input=True) # shape: (batch_size, 2) goal_target_xy = self.dataloader.goal_coords() # shape: (batch_size, timesteps, 2) goal_target_xy = goal_target_xy.unsqueeze(1).expand( -1, self.args.timesteps, -1) # shape: (batch_size, timesteps, map_range_y, map_range_x) b_t_mask = mask.reshape(-1, self.args.timesteps, self.args.map_range_y, self.args.map_range_x) batch_timestep_map_coverage, batch_average_map_coverage = \ metrics.map_coverage(b_t_mask) batch_timestep_goal_seen, batch_average_goal_seen = \ metrics.goal_seen_rate(b_t_mask, goal_pos, self.args) batch_timestep_nav_error, batch_average_nav_error = \ metrics.nav_error(goal_target_xy, goal_pred_xy, self.args) batch_timestep_success_rate, batch_average_success_rate = \ metrics.success_rate(goal_target_xy, goal_pred_xy, self.args) for n in range(valid_batch_len): average_nav_error.append(batch_average_nav_error[n].item()) average_success_rate.append( batch_average_success_rate[n].item()) average_map_coverage.append( batch_average_map_coverage[n].item()) average_goal_seen.append(batch_average_goal_seen[n].item()) for t in range(batch_timestep_map_coverage.shape[1]): timestep_nav_error[t].append( batch_timestep_nav_error[n][t].item()) timestep_success_rate[t].append( batch_timestep_success_rate[n][t].item()) timestep_map_coverage[t].append( batch_timestep_map_coverage[n][t].item()) timestep_goal_seen[t].append( batch_timestep_goal_seen[n][t].item()) self.eval_logging(split, timestep_nav_error, average_nav_error, timestep_success_rate, average_success_rate, timestep_map_coverage, average_map_coverage, timestep_goal_seen, average_goal_seen)