def __init__(self, args): super(PVNBaselineTrainer, self).__init__(args, filepath=None) self.mapper = Mapper(self.args).to(self.args.device) self.navigator = UnetNavigator(self.args).to(self.args.device) self.navigator.init_weights() self.goal_criterion = MaskedMSELoss() self.path_criterion = MaskedMSELoss()
def handle_story(): if request.method == 'POST': mapper = Mapper() result = mapper.map(request.form['content']) sys.stdout.flush() # return the new story URI # this is a stub for now return "/story/1\n" else: return "list stories\n"
class PVNBaselineEvaluator(Evaluator): def __init__(self, args, filepath): super(PVNBaselineEvaluator, self).__init__(args, filepath) def init_models(self): self.args.belief_downsample_factor = 1 self.mapper = Mapper(self.args).to(self.args.device) self.navigator = UnetNavigator(self.args).to(self.args.device) def load_model_weights(self): self.mapper.load_state_dict(self.loader["mapper"]) self.navigator.load_state_dict(self.loader["navigator"]) def set_models_eval(self): self.mapper.eval() self.navigator.eval() def get_predictions(self, seq, seq_mask, seq_lens, batch, xyzhe, simulator_next_action): if self.args.multi_maps: all_spatial_maps = [] all_masks = [] spatial_map, mask = self.mapper.init_map(xyzhe) for t in range(self.args.timesteps): rgb, depth, states = self.sim.getPanos() spatial_map, mask, ftm = self.mapper(rgb, depth, states, spatial_map, mask) if self.args.multi_maps: all_spatial_maps.append(spatial_map.unsqueeze(1)) all_masks.append(mask) if self.args.timesteps != 1: simulator_next_action() if self.args.multi_maps: spatial_map = torch.cat(all_spatial_maps, dim=1).flatten(0, 1) mask = torch.cat(all_masks, dim=1).flatten(0, 1) seq = seq.unsqueeze(1).expand(-1, self.args.timesteps, -1).flatten(0, 1) seq_lens = seq_lens.unsqueeze(1).expand(-1, self.args.timesteps).flatten(0, 1) # Predict with unet pred = self.navigator(seq, seq_lens, spatial_map) # shape: (batch_size*timesteps, 2) goal_pred = pred[:, 0, :, :] path_pred = pred[:, 1, :, :] return goal_pred, path_pred, mask
class HandCodedBaselineEvaluator(Evaluator): def __init__(self, args, filepath=None): args.eval = True super(HandCodedBaselineEvaluator, self).__init__(args, filepath=filepath) radius_in_pixels = int(self.args.handcoded_radius / self.args.gridcellsize) self.avg_dist_goal_mask = utils.get_gaussian_blurred_map_mask( self.args.map_range_y, self.args.map_range_x, radius_in_pixels, (self.args.blur_kernel_x, self.args.blur_kernel_y)) def init_models(self): self.args.belief_downsample_factor = 1 self.mapper = Mapper(self.args).to(self.args.device) def load_model_weights(self): pass def set_models_eval(self): self.mapper.eval() def get_predictions(self, seq, seq_mask, seq_lens, batch, xyzhe, simulator_next_action): all_masks = [] spatial_map, mask = self.mapper.init_map(xyzhe) for t in range(self.args.timesteps): rgb, depth, states = self.sim.getPanos() spatial_map, mask, ftm = self.mapper(rgb, depth, states, spatial_map, mask) all_masks.append(mask) if self.args.timesteps != 1: simulator_next_action() mask = torch.cat(all_masks, dim=1).flatten(0, 1) goal_pred = mask * self.avg_dist_goal_mask path_pred = None return goal_pred, path_pred, mask
def __init__(self, args, filepath=None): super(FilterTrainer, self).__init__(args, filepath) # load models self.mapper = Mapper(self.args).to(self.args.device) self.model = Filter(self.args).to(self.args.device) self.map_opt = self.optimizer(self.mapper.parameters()) self.model_opt = self.optimizer(self.model.parameters()) if filepath: loader = torch.load(filepath) self.mapper.load_state_dict(loader["mapper"]) self.model.load_state_dict(loader["filter"]) self.map_opt.load_state_dict(loader["mapper_optimizer"]) self.model_opt.load_state_dict(loader["filter_optimizer"]) print("Loaded Mapper and Filter from: %s" % filepath) elif args: self.model.init_weights() self.criterion = LogBeliefLoss()
def __init__(self, args, filepath=None): super(PolicyTrainer, self).__init__(args, filepath, load_sim=False) self.sim = PanoSimulatorWithGraph(self.args) # load models self.mapper = Mapper(self.args).to(self.args.device) self.model = Filter(self.args).to(self.args.device) self.policy = ReactivePolicy(self.args).to(self.args.device) if filepath: loader = torch.load(filepath) self.mapper.load_state_dict(loader["mapper"]) self.model.load_state_dict(loader["filter"]) self.policy.load_state_dict(loader["policy"]) print("Loaded Mapper, Filter and Policy from: %s" % filepath) elif args: self.model.init_weights() self.policy.init_weights() self.belief_criterion = LogBeliefLoss() self.policy_criterion = torch.nn.NLLLoss( ignore_index=self.args.action_ignore_index, reduction='none')
def __init__(self, args, filepath): super(PolicyEvaluator, self).__init__(args, filepath, load_sim=False) # load models self.mapper = Mapper(self.args).to(self.args.device) self.model = Filter(self.args).to(self.args.device) self.policy = ReactivePolicy(self.args).to(self.args.device) loader = torch.load(filepath) self.mapper.load_state_dict(loader["mapper"]) self.model.load_state_dict(loader["filter"]) self.policy.load_state_dict(loader["policy"]) print("Loaded Mapper, Filter and Policy from: %s" % filepath) if self.args.viz_folder == "": self.args.viz_folder = "tracker/viz/%s_%d" % (self.args.exp_name, self.args.val_epoch) else: self.args.viz_folder = "%s/%s_%d" % ( self.args.viz_folder, self.args.exp_name, self.args.val_epoch) if not os.path.exists(self.args.viz_folder) and self.args.viz_eval: os.makedirs(self.args.viz_folder)
def init_models(self): self.args.belief_downsample_factor = 1 self.mapper = Mapper(self.args).to(self.args.device) self.navigator = UnetNavigator(self.args).to(self.args.device)
class PVNBaselineTrainer(Trainer): """ Train the PVN Baseline """ def __init__(self, args): super(PVNBaselineTrainer, self).__init__(args, filepath=None) self.mapper = Mapper(self.args).to(self.args.device) self.navigator = UnetNavigator(self.args).to(self.args.device) self.navigator.init_weights() self.goal_criterion = MaskedMSELoss() self.path_criterion = MaskedMSELoss() def validate(self, split): """ split: "val_seen" or "val_unseen" """ with torch.no_grad(): if(split=="val_seen"): val_dataloader = self.valseendata elif(split=="val_unseen"): val_dataloader = self.valunseendata total_loss = torch.tensor(0.0) for it in range(self.args.validation_iterations): # Load minibatch and simulator seq, seq_mask, seq_lens, batch = val_dataloader.get_batch() self.sim.newEpisode(batch) if self.args.multi_maps: all_spatial_maps = [] all_masks = [] # Run the mapper for one step xyzhe = self.sim.getXYZHE() spatial_map,mask = self.mapper.init_map(xyzhe) for t in range(self.args.timesteps): rgb,depth,states = self.sim.getPanos() spatial_map,mask,ftm = self.mapper(rgb,depth,states,spatial_map,mask) if self.args.multi_maps: all_spatial_maps.append(spatial_map.unsqueeze(1)) all_masks.append(mask.unsqueeze(1)) if self.args.timesteps != 1: self.sim.takePseudoSupervisedAction(val_dataloader.get_supervision) if self.args.multi_maps: spatial_map = torch.cat(all_spatial_maps, dim=1).flatten(0, 1) mask = torch.cat(all_masks, dim=1).flatten(0, 1) seq = seq.unsqueeze(1).expand(-1, self.args.timesteps, -1).flatten(0, 1) seq_lens = seq_lens.unsqueeze(1).expand(-1, self.args.timesteps).flatten(0, 1) # Predict with unet pred = self.navigator(seq, seq_lens, spatial_map) goal_pred = pred[:,0,:,:] path_pred = pred[:,1,:,:] goal_map = self.mapper.heatmap(val_dataloader.goal_coords(), self.args.goal_heatmap_sigma) path_map = self.mapper.heatmap(val_dataloader.path_coords(), self.args.path_heatmap_sigma) goal_map_max, _ = torch.max(goal_map, dim=2, keepdim=True) goal_map_max, _ = torch.max(goal_map_max, dim=1, keepdim=True) goal_map /= goal_map_max path_map_max, _ = torch.max(path_map, dim=2, keepdim=True) path_map_max, _ = torch.max(path_map_max, dim=1, keepdim=True) path_map /= path_map_max if self.args.belief_downsample_factor > 1: goal_map = nn.functional.avg_pool2d(goal_map, kernel_size=self.args.belief_downsample_factor, stride=self.args.belief_downsample_factor) path_map = nn.functional.avg_pool2d(path_map, kernel_size=self.args.belief_downsample_factor, stride=self.args.belief_downsample_factor) mask = nn.functional.max_pool2d(mask, kernel_size=self.args.belief_downsample_factor, stride=self.args.belief_downsample_factor) if self.args.multi_maps: goal_map = goal_map.unsqueeze(1).expand(-1, self.args.timesteps, -1, -1).flatten(0, 1) path_map = path_map.unsqueeze(1).expand(-1, self.args.timesteps, -1, -1).flatten(0, 1) loss = self.goal_criterion(goal_pred,goal_map) + self.path_criterion(path_pred,path_map) total_loss += (1.0 / self.args.validation_iterations) * loss return total_loss def logging_loop(self, it): """ Logging and checkpointing stuff """ if it % self.args.validate_every == 0: self.mapper.eval() self.navigator.eval() loss_val_seen = self.validate("val_seen") loss_val_unseen = self.validate("val_unseen") self.prev_time, time_taken = utils.time_it(self.prev_time) print("Iteration: %d Loss: %f Val Seen Loss: %f Val Unseen Loss: %f Time: %0.2f secs" %(it, self.loss.item(), loss_val_seen.item(), loss_val_unseen.item(), time_taken)) if self.visdom: # visdom: X, Y, key, line_name, x_label, y_label, fig_title self.visdom.line(it, self.loss.item(), "train_loss", "Train Loss", "Iterations", "Loss", title=" Train Phase") self.visdom.line(it, loss_val_seen.item(), "val_loss", "Val Seen Loss", "Iterations", "Loss", title="Val Phase") self.visdom.line(it, loss_val_unseen.item(), "val_loss", "Val Unseen Loss", "Iterations", "Loss", title="Val Phase") self.mapper.train() self.navigator.train() elif it % self.args.log_every == 0: self.prev_time, time_taken = utils.time_it(self.prev_time) print("Iteration: %d Loss: %f Time: %0.2f secs" % (it, self.loss.item(), time_taken)) if self.visdom: self.visdom.line(it, self.loss.item(), "train_loss", "Train Loss", "Iterations", "Loss", title="Train Phase") if it % self.args.checkpoint_every == 0: saver = {"mapper": self.mapper.state_dict(), "navigator": self.navigator.state_dict(), "args": self.args} dir = "%s/%s" % (self.args.snapshot_dir, self.args.exp_name) if not os.path.exists(dir): os.makedirs(dir) torch.save( saver, "%s/%s_%d" % (dir, self.args.exp_name, it) ) if self.visdom: self.visdom.save() def plot_grad_flow(self, named_parameters): ave_grads = [] layers = [] for n, p in named_parameters: if(p.requires_grad) and ("bias" not in n): layers.append(n) ave_grads.append(p.grad.abs().mean()) plt.plot(ave_grads, alpha=0.3, color="b") plt.hlines(0, 0, len(ave_grads)+1, linewidth=1, color="k" ) plt.xticks(range(0,len(ave_grads), 1), layers, rotation="vertical") plt.xlim(xmin=0, xmax=len(ave_grads)) plt.xlabel("Layers") plt.ylabel("average gradient") plt.title("Gradient flow") plt.grid(True) plt.savefig("gradient.png") def train(self): """ Supervised training of the mapper and navigator """ self.prev_time = time.time() map_opt = optim.Adam(self.mapper.parameters(), lr=self.args.learning_rate, weight_decay=self.args.weight_decay) nav_opt = optim.Adam(self.navigator.parameters(), lr=self.args.learning_rate, weight_decay=self.args.weight_decay) # Set models to train phase self.mapper.train() self.navigator.train() for it in range(1, self.args.max_iterations + 1): map_opt.zero_grad() nav_opt.zero_grad() # Load minibatch and simulator seq, seq_mask, seq_lens, batch = self.traindata.get_batch() self.sim.newEpisode(batch) if self.args.multi_maps: all_spatial_maps = [] all_masks = [] # Run the mapper for multiple timesteps and keep each map xyzhe = self.sim.getXYZHE() spatial_map,mask = self.mapper.init_map(xyzhe) for t in range(self.args.timesteps): rgb,depth,states = self.sim.getPanos() spatial_map,mask,ftm = self.mapper(rgb,depth,states,spatial_map,mask) if self.args.multi_maps: all_spatial_maps.append(spatial_map.unsqueeze(1)) all_masks.append(mask.unsqueeze(1)) if self.args.timesteps != 1: self.sim.takePseudoSupervisedAction(self.traindata.get_supervision) if self.args.multi_maps: spatial_map = torch.cat(all_spatial_maps, dim=1).flatten(0, 1) mask = torch.cat(all_masks, dim=1).flatten(0, 1) seq = seq.unsqueeze(1).expand(-1, self.args.timesteps, -1).flatten(0, 1) seq_lens = seq_lens.unsqueeze(1).expand(-1, self.args.timesteps).flatten(0, 1) # Predict with unet pred = self.navigator(seq, seq_lens, spatial_map) goal_pred = pred[:,0,:,:] path_pred = pred[:,1,:,:] goal_map = self.mapper.heatmap(self.traindata.goal_coords(), self.args.goal_heatmap_sigma) path_map = self.mapper.heatmap(self.traindata.path_coords(), self.args.path_heatmap_sigma) goal_map_max, _ = torch.max(goal_map, dim=2, keepdim=True) goal_map_max, _ = torch.max(goal_map_max, dim=1, keepdim=True) goal_map /= goal_map_max path_map_max, _ = torch.max(path_map, dim=2, keepdim=True) path_map_max, _ = torch.max(path_map_max, dim=1, keepdim=True) path_map /= path_map_max if self.args.belief_downsample_factor > 1: goal_map = nn.functional.avg_pool2d(goal_map, kernel_size=self.args.belief_downsample_factor, stride=self.args.belief_downsample_factor) path_map = nn.functional.avg_pool2d(path_map, kernel_size=self.args.belief_downsample_factor, stride=self.args.belief_downsample_factor) mask = nn.functional.max_pool2d(mask, kernel_size=self.args.belief_downsample_factor, stride=self.args.belief_downsample_factor) if self.args.multi_maps: goal_map = goal_map.unsqueeze(1).expand(-1, self.args.timesteps, -1, -1).flatten(0, 1) path_map = path_map.unsqueeze(1).expand(-1, self.args.timesteps, -1, -1).flatten(0, 1) self.loss = self.goal_criterion(goal_pred,goal_map) + self.path_criterion(path_pred,path_map) self.loss.backward() map_opt.step() nav_opt.step() self.logging_loop(it)
def init_models(self): self.mapper = Mapper(self.args).to(self.args.device) self.model = Filter(self.args).to(self.args.device)
class FilterEvaluator(Evaluator): def __init__(self, args, filepath): super(FilterEvaluator, self).__init__(args, filepath) def init_models(self): self.mapper = Mapper(self.args).to(self.args.device) self.model = Filter(self.args).to(self.args.device) def load_model_weights(self): self.mapper.load_state_dict(self.loader["mapper"]) self.model.load_state_dict(self.loader["filter"]) print("Loaded Mapper and Filter from: %s" % filepath) def set_models_eval(self): self.mapper.eval() self.model.eval() def viz_attention_weights(self, act_att_weights, obs_att_weights, seqs, seq_lens, split, it, save_folder="viz"): # act_att_weights: (N, T, self.args.max_steps-1, torch.max(seq_lens)) # seqs: (N, max_seq_len) # seq_len (N,) batch_size = act_att_weights.shape[0] timesteps = act_att_weights.shape[1] for n in range(batch_size): instruction = self.dataloader.tokenizer.decode_sentence( seqs[n]).split() for t in range(timesteps): act_att = act_att_weights[n][t].cpu().numpy() obs_att = obs_att_weights[n][t].cpu().numpy() valid_act_att = act_att[:, :seq_lens[n]].transpose() valid_obs_att = obs_att[:, :seq_lens[n]].transpose() fig = viz_utils.plot_att_graph(valid_act_att, valid_obs_att, instruction, seq_lens[n].item()) if t == 0: # Currently doesn't use the map, so can just save once fig.savefig("%s/attention-%s-it%d-n%d-t%d.png" % (save_folder, split, it, n, t)) plt.close('all') print("Saved Attention viz: %s/attention-%s-it%d-n-t.png" % (save_folder, split, it)) def eval_viz(self, goal_pred, path_pred, mask, split, it): if self.args.viz_folder == "": save_folder = "tracker/viz/%s_%d" % (self.args.exp_name, self.args.val_epoch) else: save_folder = "%s/%s_%d" % ( self.args.viz_folder, self.args.exp_name, self.args.val_epoch) if not os.path.exists(save_folder): os.mkdir(save_folder) self.viz_attention_weights(self.act_att_weights, self.obs_att_weights, self.seqs, self.seq_lens, split, it, save_folder) all_floorplan_images = [] for im in self.floorplan_images: all_floorplan_images.extend( [im for t in range(self.args.timesteps)]) self.floorplan_images = all_floorplan_images # len: batch_size * timesteps self.goal_map = F.interpolate( self.goal_map.unsqueeze(1), scale_factor=self.args.debug_scale).squeeze(1) self.path_map = F.interpolate( self.path_map.unsqueeze(1), scale_factor=self.args.debug_scale).squeeze(1) goal_pred = F.interpolate( goal_pred.unsqueeze(1), scale_factor=self.args.debug_scale).squeeze(1) path_pred = F.interpolate( path_pred.unsqueeze(1), scale_factor=self.args.debug_scale).squeeze(1) mask = F.interpolate(mask.unsqueeze(1), scale_factor=self.args.debug_scale).squeeze(1) goal_pred = utils.minmax_normalize(goal_pred) path_pred = utils.minmax_normalize(path_pred) map_masks = viz_utils.get_masks(self.floorplan_images, mask, goal_pred) target_images = viz_utils.get_floorplan_with_goal_path_maps( self.floorplan_images, self.goal_map, self.path_map, scale_factor=self.args.debug_scale, target=True) predicted_images = viz_utils.get_floorplan_with_goal_path_maps( self.floorplan_images, goal_pred, None, scale_factor=self.args.debug_scale) path_pred = path_pred.reshape( self.args.batch_size * self.args.timesteps, self.args.max_steps - 1, 3, path_pred.shape[-2], path_pred.shape[-1]) path_pred = (255 * torch.flip(path_pred, [3])).type( torch.ByteTensor).cpu().numpy() new_belief = path_pred[:, :, 0, :, :] obs_likelihood = path_pred[:, :, 1, :, :] belief = path_pred[:, :, 2, :, :] viz_utils.save_floorplans_with_belief_maps( [map_masks, target_images, predicted_images], [new_belief, obs_likelihood, belief], self.args.max_steps - 1, self.args.batch_size, self.args.timesteps, split, it, save_folder=save_folder) def get_predictions(self, seq, seq_mask, seq_lens, batch, xyzhe, simulator_next_action): all_masks = [] spatial_map, mask = self.mapper.init_map(xyzhe) N = seq.shape[0] T = self.args.timesteps self.act_att_weights = torch.zeros(N, T, self.args.max_steps - 1, torch.max(seq_lens)) self.obs_att_weights = torch.zeros(N, T, self.args.max_steps - 1, torch.max(seq_lens)) goal_pred = torch.zeros(N, T, self.args.map_range_y, self.args.map_range_x, device=spatial_map.device) path_pred = torch.zeros(N, T, self.args.max_steps - 1, 3, self.args.map_range_y, self.args.map_range_x, device=spatial_map.device) if self.args.debug_mode: floor_maps = self.floor_maps(self.sim.getState()) for t in range(self.args.timesteps): rgb, depth, states = self.sim.getPanos() spatial_map, mask, ftm = self.mapper(rgb, depth, states, spatial_map, mask) belief = self.mapper.belief_map( xyzhe, self.args.filter_input_sigma).log() if t == 0: belief_pred = torch.zeros(N, T, self.args.max_steps, self.args.heading_states, belief.shape[-2], belief.shape[-1], device=spatial_map.device) belief_pred[:, t, 0, :, :, :] = belief.exp() state = None for k in range(self.args.max_steps - 1): input_belief = belief new_belief, obs_likelihood, state, act_att_weights, obs_att_weights, _ = self.model( k, seq, seq_mask, seq_lens, belief, spatial_map, state) belief = new_belief + obs_likelihood self.act_att_weights[:, t, k, :] = act_att_weights self.obs_att_weights[:, t, k, :] = obs_att_weights # Renormalize belief = belief - belief.reshape( belief.shape[0], -1).logsumexp(dim=1).unsqueeze(1).unsqueeze(1).unsqueeze(1) path_pred[:, t, k, 0, :, :] = F.interpolate( new_belief.exp().sum(dim=1).unsqueeze(1), scale_factor=self.args.belief_downsample_factor).squeeze(1) path_pred[:, t, k, 1, :, :] = F.interpolate( obs_likelihood.exp().sum(dim=1).unsqueeze(1), scale_factor=self.args.belief_downsample_factor).squeeze(1) path_pred[:, t, k, 2, :, :] = F.interpolate( belief.exp().sum(dim=1).unsqueeze(1), scale_factor=self.args.belief_downsample_factor).squeeze(1) belief_pred[:, t, k + 1, :, :, :] = belief.exp() if self.args.debug_mode: base_maps = self.overlay_mask(floor_maps, mask) belief_maps = self.overlay_belief(base_maps, belief_pred[:, t]) cv2.imshow('belief', belief_maps[0]) cv2.waitKey(0) goal_pred[:, t, :, :] = F.interpolate( belief.exp().sum(dim=1).unsqueeze(1), scale_factor=self.args.belief_downsample_factor).squeeze(1) all_masks.append(mask) simulator_next_action() mask = torch.cat(all_masks, dim=1).flatten(0, 1) self.seq_lens = seq_lens self.seqs = seq return goal_pred.flatten(0, 1), path_pred.flatten(0, 3), mask
class FilterTrainer(Trainer): """ Train a filter """ def __init__(self, args, filepath=None): super(FilterTrainer, self).__init__(args, filepath) # load models self.mapper = Mapper(self.args).to(self.args.device) self.model = Filter(self.args).to(self.args.device) self.map_opt = self.optimizer(self.mapper.parameters()) self.model_opt = self.optimizer(self.model.parameters()) if filepath: loader = torch.load(filepath) self.mapper.load_state_dict(loader["mapper"]) self.model.load_state_dict(loader["filter"]) self.map_opt.load_state_dict(loader["mapper_optimizer"]) self.model_opt.load_state_dict(loader["filter_optimizer"]) print("Loaded Mapper and Filter from: %s" % filepath) elif args: self.model.init_weights() self.criterion = LogBeliefLoss() def validate(self, split): """ split: "val_seen" or "val_unseen" """ with torch.no_grad(): if (split == "val_seen"): val_dataloader = self.valseendata elif (split == "val_unseen"): val_dataloader = self.valunseendata total_loss = torch.tensor(0.0) normalizer = 0 for it in range(self.args.validation_iterations): # Load minibatch and simulator seq, seq_mask, seq_lens, batch = val_dataloader.get_batch() self.sim.newEpisode(batch) # Initialize the mapper xyzhe = self.sim.getXYZHE() spatial_map, mask = self.mapper.init_map(xyzhe) # Note mm is being validated on the GT path. This is not the path taken by Mapper. path_xyzhe, path_len = val_dataloader.path_xyzhe() for t in range(self.args.timesteps): rgb, depth, states = self.sim.getPanos() spatial_map, mask, ftm = self.mapper( rgb, depth, states, spatial_map, mask) state = None steps = self.args.max_steps - 1 belief = self.mapper.belief_map( path_xyzhe[0], self.args.filter_input_sigma).log() for k in range(steps): input_belief = belief # Train a filter new_belief, obs_likelihood, state, _, _, _ = self.model( k, seq, seq_mask, seq_lens, input_belief, spatial_map, state) belief = new_belief + obs_likelihood # Renormalize belief = belief - belief.reshape( belief.shape[0], -1).logsumexp( dim=1).unsqueeze(1).unsqueeze(1).unsqueeze(1) # Determine target and loss target_heatmap = self.mapper.belief_map( path_xyzhe[k + 1], self.args.filter_heatmap_sigma) # To make loss independent of heading_states, we sum over all heading states (logsumexp) for validation total_loss += self.criterion( belief.logsumexp(dim=1).unsqueeze(1), target_heatmap.sum(dim=1).unsqueeze(1)) normalizer += self.args.batch_size # Take action in the sim self.sim.takePseudoSupervisedAction( val_dataloader.get_supervision) return total_loss / normalizer def logging_loop(self, it): """ Logging and checkpointing stuff """ if it % self.args.validate_every == 0: self.mapper.eval() self.model.eval() loss_val_seen = self.validate("val_seen") loss_val_unseen = self.validate("val_unseen") self.prev_time, time_taken = utils.time_it(self.prev_time) print( "Iteration: %d Loss: %f Val Seen Loss: %f Val Unseen Loss: %f Time: %0.2f secs" % (it, self.loss.item(), loss_val_seen.item(), loss_val_unseen.item(), time_taken)) if self.visdom: # visdom: X, Y, key, line_name, x_label, y_label, fig_title self.visdom.line(it, self.loss.item(), "train_loss", "Train Loss", "Iterations", "Loss", title=" Train Phase") self.visdom.line(it, loss_val_seen.item(), "val_loss", "Val Seen Loss", "Iterations", "Loss", title="Val Phase") self.visdom.line(it, loss_val_unseen.item(), "val_loss", "Val Unseen Loss", "Iterations", "Loss", title="Val Phase") self.mapper.train() self.model.train() elif it % self.args.log_every == 0: self.prev_time, time_taken = utils.time_it(self.prev_time) print("Iteration: %d Loss: %f Time: %0.2f secs" % (it, self.loss.item(), time_taken)) if self.visdom: self.visdom.line(it, self.loss.item(), "train_loss", "Train Loss", "Iterations", "Loss", title="Train Phase") if it % self.args.checkpoint_every == 0: saver = { "mapper": self.mapper.state_dict(), "filter": self.model.state_dict(), "mapper_optimizer": self.map_opt.state_dict(), "filter_optimizer": self.model_opt.state_dict(), "epoch": it, "args": self.args, } dir = "%s/%s" % (self.args.snapshot_dir, self.args.exp_name) if not os.path.exists(dir): os.makedirs(dir) torch.save(saver, "%s/%s_%d" % (dir, self.args.exp_name, it)) if self.visdom: self.visdom.save() def map_to_image(self): """ Show where map features are found in the RGB images """ traindata = self.traindata # Load minibatch and simulator seq, seq_mask, seq_lens, batch = traindata.get_batch(False) self.sim.newEpisode(batch) # Initialize the mapper xyzhe = self.sim.getXYZHE() spatial_map, mask = self.mapper.init_map(xyzhe) for t in range(self.args.timesteps): rgb, depth, states = self.sim.getPanos() # ims = (torch.flip(rgb, [1])).permute(0,2,3,1).cpu().detach().numpy() # for n in range(ims.shape[0]): # cv2.imwrite('im %d.png' % n, ims[n]) spatial_map, mask, ftm = self.mapper(rgb, depth, states, spatial_map, mask) im_ix = torch.arange(0, ftm.shape[0], step=self.args.batch_size).to(ftm.device) ims = self.feature_sources(rgb[im_ix], ftm[im_ix], 56, 48) for n in range(ims.shape[0]): cv2.imwrite('im %d-%d.png' % (t, n), ims[n]) # Take action in the sim self.sim.takePseudoSupervisedAction(traindata.get_supervision) def train(self): """ Supervised training of the filter """ torch.autograd.set_detect_anomaly(True) self.prev_time = time.time() # Set models to train phase self.mapper.train() self.model.train() for it in range(self.args.start_epoch, self.args.max_iterations + 1): self.map_opt.zero_grad() self.model_opt.zero_grad() traindata = self.traindata # Load minibatch and simulator seq, seq_mask, seq_lens, batch = traindata.get_batch() self.sim.newEpisode(batch) if self.args.debug_mode: debug_path = traindata.get_path() # Initialize the mapper xyzhe = self.sim.getXYZHE() spatial_map, mask = self.mapper.init_map(xyzhe) # Note model is being trained on the GT path. This is not the path taken by Mapper. path_xyzhe, path_len = traindata.path_xyzhe() loss = 0 normalizer = 0 for t in range(self.args.timesteps): rgb, depth, states = self.sim.getPanos() spatial_map, mask, ftm = self.mapper(rgb, depth, states, spatial_map, mask) del states del depth del ftm if not self.args.debug_mode: del rgb state = None steps = self.args.max_steps - 1 belief = self.mapper.belief_map( path_xyzhe[0], self.args.filter_input_sigma).log() for k in range(steps): input_belief = belief # Train a filter new_belief, obs_likelihood, state, _, _, _ = self.model( k, seq, seq_mask, seq_lens, input_belief, spatial_map, state) belief = new_belief + obs_likelihood # Renormalize belief = belief - belief.reshape( belief.shape[0], -1).logsumexp( dim=1).unsqueeze(1).unsqueeze(1).unsqueeze(1) # Determine target and loss target_heatmap = self.mapper.belief_map( path_xyzhe[k + 1], self.args.filter_heatmap_sigma) loss += self.criterion(belief, target_heatmap) normalizer += self.args.batch_size if self.args.debug_mode: input_belief = F.interpolate( input_belief.exp(), scale_factor=self.args.belief_downsample_factor, mode='bilinear', align_corners=False).sum(dim=1).unsqueeze(1) belief_up = F.interpolate( belief.exp(), scale_factor=self.args.belief_downsample_factor, mode='bilinear', align_corners=False).sum(dim=1).unsqueeze(1) target_heatmap_up = F.interpolate( target_heatmap, scale_factor=self.args.belief_downsample_factor, mode='bilinear', align_corners=False).sum(dim=1).unsqueeze(1) debug_map = self.mapper.debug_maps(debug_path) self.visual_debug(t, rgb, mask, debug_map, input_belief, belief_up, target_heatmap_up) trunc = self.args.truncate_after if (t % trunc) == (trunc - 1) or t + 1 == self.args.timesteps: self.loss = loss / normalizer self.loss.backward() self.loss.detach_() loss = 0 normalizer = 0 spatial_map.detach_() if state is not None: state = self.model.detach_state(state) state = None # Recalc to get gradients from later in the decoding - #TODO keep this part of the graph without re-running the forward part # Take action in the sim self.sim.takePseudoSupervisedAction(traindata.get_supervision) del spatial_map del mask self.map_opt.step() self.model_opt.step() self.logging_loop(it)
def evaluate(self, split): self.set_models_eval() with torch.no_grad(): if split == "val_seen": self.dataloader = self.valseendata elif split == "val_unseen": self.dataloader = self.valunseendata iterations = int( math.ceil( len(self.dataloader.data) / float(self.args.batch_size))) last_batch_valid_idx = len( self.dataloader.data) - (iterations - 1) * self.args.batch_size timestep_nav_error = [[] for i in range(self.args.timesteps)] timestep_success_rate = [[] for i in range(self.args.timesteps)] average_nav_error = [] average_success_rate = [] timestep_map_coverage = [[] for i in range(self.args.timesteps)] timestep_goal_seen = [[] for i in range(self.args.timesteps)] average_map_coverage = [] average_goal_seen = [] mapper = Mapper(self.args).to(self.args.device) for it in tqdm(range(iterations), desc="Evaluation Progress for %s split" % split): valid_batch_len = last_batch_valid_idx if it == iterations - 1 else self.args.batch_size seq, seq_mask, seq_lens, batch = self.dataloader.get_batch() self.sim.newEpisode(batch) self.floorplan_images = utils.get_floorplan_images( self.sim.getState(), self.floorplan, self.args.map_range_x, self.args.map_range_y, scale_factor=self.args.debug_scale) xyzhe = self.sim.getXYZHE() mapper.init_map(xyzhe) goal_pos = self.dataloader.get_goal_coords_on_map_grid( mapper.map_center) pred_goal_map, pred_path_map, mask = self.get_predictions( seq, seq_mask, seq_lens, batch, xyzhe, self.simulator_next_action) # shape: (batch_size, 2) self.goal_map = mapper.heatmap(self.dataloader.goal_coords(), self.args.goal_heatmap_sigma) self.path_map = mapper.heatmap(self.dataloader.path_coords(), self.args.path_heatmap_sigma) if self.args.multi_maps: # shape: (batch_size*timesteps, 2) self.goal_map = self.goal_map.unsqueeze(1).expand( -1, self.args.timesteps, -1, -1).flatten(0, 1) self.path_map = self.path_map.unsqueeze(1).expand( -1, self.args.timesteps, -1, -1).flatten(0, 1) # shape: (batch_size*timesteps, 2) goal_pred_argmax = utils.compute_argmax(pred_goal_map).flip(1) # shape: (batch_size, timesteps, 2) goal_pred_argmax = goal_pred_argmax.reshape( -1, self.args.timesteps, 2) goal_pred_xy = self.dataloader.convert_map_pixels_to_xy_coords( goal_pred_argmax, mapper.map_center, multi_timestep_input=True) # shape: (batch_size, 2) goal_target_xy = self.dataloader.goal_coords() # shape: (batch_size, timesteps, 2) goal_target_xy = goal_target_xy.unsqueeze(1).expand( -1, self.args.timesteps, -1) # shape: (batch_size, timesteps, map_range_y, map_range_x) b_t_mask = mask.reshape(-1, self.args.timesteps, self.args.map_range_y, self.args.map_range_x) batch_timestep_map_coverage, batch_average_map_coverage = \ metrics.map_coverage(b_t_mask) batch_timestep_goal_seen, batch_average_goal_seen = \ metrics.goal_seen_rate(b_t_mask, goal_pos, self.args) batch_timestep_nav_error, batch_average_nav_error = \ metrics.nav_error(goal_target_xy, goal_pred_xy, self.args) batch_timestep_success_rate, batch_average_success_rate = \ metrics.success_rate(goal_target_xy, goal_pred_xy, self.args) for n in range(valid_batch_len): average_nav_error.append(batch_average_nav_error[n].item()) average_success_rate.append( batch_average_success_rate[n].item()) average_map_coverage.append( batch_average_map_coverage[n].item()) average_goal_seen.append(batch_average_goal_seen[n].item()) for t in range(batch_timestep_map_coverage.shape[1]): timestep_nav_error[t].append( batch_timestep_nav_error[n][t].item()) timestep_success_rate[t].append( batch_timestep_success_rate[n][t].item()) timestep_map_coverage[t].append( batch_timestep_map_coverage[n][t].item()) timestep_goal_seen[t].append( batch_timestep_goal_seen[n][t].item()) self.eval_logging(split, timestep_nav_error, average_nav_error, timestep_success_rate, average_success_rate, timestep_map_coverage, average_map_coverage, timestep_goal_seen, average_goal_seen)
class PolicyTrainer(Trainer): """ Train a filter and a policy """ def __init__(self, args, filepath=None): super(PolicyTrainer, self).__init__(args, filepath, load_sim=False) self.sim = PanoSimulatorWithGraph(self.args) # load models self.mapper = Mapper(self.args).to(self.args.device) self.model = Filter(self.args).to(self.args.device) self.policy = ReactivePolicy(self.args).to(self.args.device) if filepath: loader = torch.load(filepath) self.mapper.load_state_dict(loader["mapper"]) self.model.load_state_dict(loader["filter"]) self.policy.load_state_dict(loader["policy"]) print("Loaded Mapper, Filter and Policy from: %s" % filepath) elif args: self.model.init_weights() self.policy.init_weights() self.belief_criterion = LogBeliefLoss() self.policy_criterion = torch.nn.NLLLoss( ignore_index=self.args.action_ignore_index, reduction='none') def validate(self, split): """ split: "val_seen" or "val_unseen" """ vln_eval = Evaluation(split) self.sim.record_traj(True) with torch.no_grad(): if (split == "val_seen"): val_dataloader = self.valseendata elif (split == "val_unseen"): val_dataloader = self.valunseendata elif (split == "train"): val_dataloader = self.traindata for it in range(self.args.validation_iterations): # Load minibatch and simulator seq, seq_mask, seq_lens, batch = val_dataloader.get_batch() self.sim.newEpisode(batch) # Initialize the mapper xyzhe = self.sim.getXYZHE() spatial_map, mask = self.mapper.init_map(xyzhe) ended = torch.zeros(self.args.batch_size, device=self.args.device).byte() for t in range(self.args.timesteps): if self.args.policy_gt_belief: path_xyzhe, path_len = val_dataloader.path_xyzhe() else: rgb, depth, states = self.sim.getPanos() spatial_map, mask, ftm = self.mapper( rgb, depth, states, spatial_map, mask) del states del depth del ftm del rgb features, _, _ = self.sim.getGraphNodes() P = features.shape[0] # num graph nodes belief_features = torch.empty(P, self.args.max_steps, device=self.args.device) state = None steps = self.args.max_steps - 1 belief = self.mapper.belief_map( xyzhe, self.args.filter_input_sigma).log() sigma = self.args.filter_heatmap_sigma gridcellsize = self.args.gridcellsize * self.args.belief_downsample_factor belief_features[:, 0] = belief_at_nodes( belief.exp(), features[:, :4], sigma, gridcellsize) for k in range(steps): if self.args.policy_gt_belief: target_heatmap = self.mapper.belief_map( path_xyzhe[k + 1], self.args.filter_heatmap_sigma) belief_features[:, k + 1] = belief_at_nodes( target_heatmap, features[:, :4], sigma, gridcellsize) else: input_belief = belief # Train a filter new_belief, obs_likelihood, state, _, _, _ = self.model( k, seq, seq_mask, seq_lens, input_belief, spatial_map, state) belief = new_belief + obs_likelihood # Renormalize belief = belief - belief.reshape( belief.shape[0], -1).logsumexp(dim=1).unsqueeze(1).unsqueeze( 1).unsqueeze(1) belief_features[:, k + 1] = belief_at_nodes( belief.exp(), features[:, :4], sigma, gridcellsize) # Probs from policy aug_features = torch.cat([features, belief_features], dim=1) log_prob = self.policy(aug_features) # Take argmax action in the sim _, action_idx = log_prob.exp().max(dim=1) ended |= self.sim.takeMultiStepAction(action_idx) if ended.all(): break # Eval scores = vln_eval.score(self.sim.traj, check_all_trajs=False) self.sim.record_traj(False) return scores['result'][0][split] def logging_loop(self, it): """ Logging and checkpointing stuff """ if it % self.args.validate_every == 0: self.mapper.eval() self.model.eval() self.policy.eval() scores_train = self.validate("train") scores_val_seen = self.validate("val_seen") scores_val_unseen = self.validate("val_unseen") self.prev_time, time_taken = utils.time_it(self.prev_time) print( "Iteration: %d Loss: %f Train Success: %f Val Seen Success: %f Val Unseen Success: %f Time: %0.2f secs" % (it, self.loss.item(), scores_train['success'], scores_val_seen['success'], scores_val_unseen['success'], time_taken)) if self.visdom: # visdom: X, Y, key, line_name, x_label, y_label, fig_title self.visdom.line(it, self.loss.item(), "train_loss", "Train Loss", "Iterations", "Loss", title=" Train Phase") units = { 'length': 'm', 'error': 'm', 'oracle success': '%', 'success': '%', 'spl': '%' } sub = self.args.validation_iterations * self.args.batch_size for metric, score in scores_train.items(): m = metric.title() self.visdom.line(it, score, metric, "Train (%d)" % sub, "Iterations", units[metric], title=m) for metric, score in scores_val_seen.items(): m = metric.title() self.visdom.line(it, score, metric, "Val Seen (%d)" % sub, "Iterations", units[metric], title=m) for metric, score in scores_val_unseen.items(): m = metric.title() self.visdom.line(it, score, metric, "Val Unseen (%d)" % sub, "Iterations", units[metric], title=m) self.mapper.train() self.model.train() self.policy.train() elif it % self.args.log_every == 0: self.prev_time, time_taken = utils.time_it(self.prev_time) print("Iteration: %d Loss: %f Time: %0.2f secs" % (it, self.loss.item(), time_taken)) if self.visdom: self.visdom.line(it, self.loss.item(), "train_loss", "Train Loss", "Iterations", "Loss", title="Train Phase") if it % self.args.checkpoint_every == 0: saver = { "mapper": self.mapper.state_dict(), "args": self.args, "filter": self.model.state_dict(), "policy": self.policy.state_dict() } dir = "%s/%s" % (self.args.snapshot_dir, self.args.exp_name) if not os.path.exists(dir): os.makedirs(dir) torch.save(saver, "%s/%s_%d" % (dir, self.args.exp_name, it)) if self.visdom: self.visdom.save() def train(self): """ Supervised training of the mapper, filter and policy """ torch.autograd.set_detect_anomaly(True) self.prev_time = time.time() map_opt = self.optimizer(self.mapper.parameters()) model_opt = self.optimizer(self.model.parameters()) policy_opt = self.optimizer(self.policy.parameters()) # Set models to train phase self.mapper.train() self.model.train() self.policy.train() for it in range(1, self.args.max_iterations + 1): map_opt.zero_grad() model_opt.zero_grad() policy_opt.zero_grad() traindata = self.traindata # Load minibatch and simulator seq, seq_mask, seq_lens, batch = traindata.get_batch() self.sim.newEpisode(batch) # Initialize the mapper xyzhe = self.sim.getXYZHE() spatial_map, mask = self.mapper.init_map(xyzhe) # Note Filter model is being trained to predict the GT path. This is not the path taken by Policy. path_xyzhe, path_len = traindata.path_xyzhe() self.loss = 0 ended = torch.zeros(self.args.batch_size, device=self.args.device).byte() for t in range(self.args.timesteps): if not self.args.policy_gt_belief: rgb, depth, states = self.sim.getPanos() spatial_map, mask, ftm = self.mapper( rgb, depth, states, spatial_map, mask) del states del depth del ftm del rgb features, _, _ = self.sim.getGraphNodes() P = features.shape[0] # num graph nodes belief_features = torch.empty(P, self.args.max_steps, device=self.args.device) state = None steps = self.args.max_steps - 1 gt_input_belief = None belief = self.mapper.belief_map( path_xyzhe[0], self.args.filter_input_sigma).log() sigma = self.args.filter_heatmap_sigma gridcellsize = self.args.gridcellsize * self.args.belief_downsample_factor belief_features[:, 0] = belief_at_nodes(belief.exp(), features[:, :4], sigma, gridcellsize) for k in range(steps): target_heatmap = self.mapper.belief_map( path_xyzhe[k + 1], self.args.filter_heatmap_sigma) if self.args.policy_gt_belief: belief_features[:, k + 1] = belief_at_nodes( target_heatmap, features[:, :4], sigma, gridcellsize) else: input_belief = belief # Train a filter if self.args.teacher_force_motion_model: gt_input_belief = self.mapper.belief_map( path_xyzhe[k], self.args.filter_heatmap_sigma).log() new_belief, obs_likelihood, state, _, _, new_gt_belief = self.model( k, seq, seq_mask, seq_lens, input_belief, spatial_map, state, gt_input_belief) belief = new_belief.detach( ) + obs_likelihood #Don't backprop through belief else: new_belief, obs_likelihood, state, _, _, new_gt_belief = self.model( k, seq, seq_mask, seq_lens, input_belief, spatial_map, state, gt_input_belief) belief = new_belief + obs_likelihood # Renormalize belief = belief - belief.reshape( belief.shape[0], -1).logsumexp( dim=1).unsqueeze(1).unsqueeze(1).unsqueeze(1) # Determine target and loss for the filter belief_loss = self.belief_criterion( belief, target_heatmap, valid=~ended) / (~ended).sum() if self.args.teacher_force_motion_model: # Separate loss for the motion model belief_loss += self.belief_criterion( new_gt_belief, target_heatmap, valid=~ended) / (~ended).sum() self.loss += belief_loss belief_features[:, k + 1] = belief_at_nodes( belief.exp(), features[:, :4], sigma, gridcellsize) # Train a policy aug_features = torch.cat([features, belief_features], dim=1) log_prob = self.policy(aug_features) # Train a policy target_idx = traindata.closest_to_goal(self.sim.G) policy_loss = self.args.policy_loss_lambda * self.policy_criterion( log_prob, target_idx).sum() / (~ended).sum() self.loss += policy_loss # Take action in the sim if self.args.supervision_prob < 0: supervision_prob = 1.0 - float( it) / self.args.max_iterations else: supervision_prob = self.args.supervision_prob sampled_action = D.Categorical(log_prob.exp()).sample() weights = torch.tensor( [1.0 - supervision_prob, supervision_prob], dtype=torch.float, device=log_prob.device) ix = torch.multinomial(weights, self.args.batch_size, replacement=True).byte() action_idx = torch.where(ix, target_idx, sampled_action) ended |= self.sim.takeMultiStepAction(action_idx) trunc = self.args.truncate_after if (t % trunc) == ( trunc - 1) or t + 1 == self.args.timesteps or ended.all(): self.loss.backward() self.loss.detach_() spatial_map.detach_() if state is not None: state = self.model.detach_state(state) state = None # Recalc to get gradients from later in the decoding - #TODO keep this part of the graph without re-running the forward part if ended.all(): break del spatial_map del mask map_opt.step() model_opt.step() policy_opt.step() self.logging_loop(it)
class PolicyEvaluator(Trainer): def __init__(self, args, filepath): super(PolicyEvaluator, self).__init__(args, filepath, load_sim=False) # load models self.mapper = Mapper(self.args).to(self.args.device) self.model = Filter(self.args).to(self.args.device) self.policy = ReactivePolicy(self.args).to(self.args.device) loader = torch.load(filepath) self.mapper.load_state_dict(loader["mapper"]) self.model.load_state_dict(loader["filter"]) self.policy.load_state_dict(loader["policy"]) print("Loaded Mapper, Filter and Policy from: %s" % filepath) if self.args.viz_folder == "": self.args.viz_folder = "tracker/viz/%s_%d" % (self.args.exp_name, self.args.val_epoch) else: self.args.viz_folder = "%s/%s_%d" % ( self.args.viz_folder, self.args.exp_name, self.args.val_epoch) if not os.path.exists(self.args.viz_folder) and self.args.viz_eval: os.makedirs(self.args.viz_folder) def evaluate(self, split): sim = PanoSimulatorWithGraph(self.args, disable_rendering=True) sim.record_traj(True) self.mapper.eval() self.model.eval() self.policy.eval() vln_eval = Evaluation(split) with torch.no_grad(): if split == "val_seen": self.dataloader = self.valseendata elif split == "val_unseen": self.dataloader = self.valunseendata else: self.dataloader = DataLoader(self.args, splits=[split]) iterations = int( math.ceil( len(self.dataloader.data) / float(self.args.batch_size))) for it in tqdm(range(iterations), desc="Evaluation Progress for %s split" % split): # Load minibatch and simulator seq, seq_mask, seq_lens, batch = self.dataloader.get_batch() sim.newEpisode(batch) # Initialize the mapper xyzhe = sim.getXYZHE() spatial_map, mask = self.mapper.init_map(xyzhe) if self.args.viz_eval and it < self.args.viz_iterations: floor_maps = self.floor_maps(sim.getState()) ended = torch.zeros(self.args.batch_size, device=self.args.device).byte() viz_counter = 0 for t in range(self.args.timesteps): if self.args.policy_gt_belief: path_xyzhe, path_len = self.dataloader.path_xyzhe() else: rgb, depth, states = sim.getPanos() spatial_map, mask, ftm = self.mapper( rgb, depth, states, spatial_map, mask) del states del ftm features, _, _ = sim.getGraphNodes() P = features.shape[0] # num graph nodes belief_features = torch.empty(P, self.args.max_steps, device=self.args.device) state = None steps = self.args.max_steps - 1 belief = self.mapper.belief_map( xyzhe, self.args.filter_input_sigma).log() if t == 0: belief_pred = torch.zeros(seq.shape[0], self.args.max_steps, self.args.heading_states, belief.shape[-2], belief.shape[-1], device=spatial_map.device) belief_pred[:, 0, :, :, :] = belief.exp() sigma = self.args.filter_heatmap_sigma gridcellsize = self.args.gridcellsize * self.args.belief_downsample_factor belief_features[:, 0] = belief_at_nodes( belief.exp(), features[:, :4], sigma, gridcellsize) act_att_weights = torch.zeros(seq.shape[0], steps, torch.max(seq_lens)) obs_att_weights = torch.zeros(seq.shape[0], steps, torch.max(seq_lens)) for k in range(steps): if self.args.policy_gt_belief: target_heatmap = self.mapper.belief_map( path_xyzhe[k + 1], self.args.filter_heatmap_sigma) belief_features[:, k + 1] = belief_at_nodes( target_heatmap, features[:, :4], sigma, gridcellsize) belief_pred[:, k + 1, :, :, :] = target_heatmap else: input_belief = belief # Train a filter new_belief, obs_likelihood, state, act_att, obs_att, _ = self.model( k, seq, seq_mask, seq_lens, input_belief, spatial_map, state) belief = new_belief + obs_likelihood # Renormalize belief = belief - belief.reshape( belief.shape[0], -1).logsumexp(dim=1).unsqueeze(1).unsqueeze( 1).unsqueeze(1) belief_pred[:, k + 1, :, :, :] = belief.exp() belief_features[:, k + 1] = belief_at_nodes( belief.exp(), features[:, :4], sigma, gridcellsize) act_att_weights[:, k] = act_att obs_att_weights[:, k] = obs_att # Probs from policy aug_features = torch.cat([features, belief_features], dim=1) log_prob = self.policy(aug_features) # Take argmax action in the sim _, action_idx = log_prob.exp().max(dim=1) if self.args.viz_eval and it < self.args.viz_iterations: num_cam_views = self.args.num_pano_views * self.args.num_pano_sweeps rgb = rgb.permute(0, 2, 3, 1).reshape(num_cam_views, self.args.batch_size, rgb.shape[-2], rgb.shape[-1], 3) depth = depth.expand(-1, 3, -1, -1).permute( 0, 2, 3, 1).reshape(num_cam_views, self.args.batch_size, depth.shape[-2], depth.shape[-1], 3) # Save attention over instruction if t == 0: att_ims = [] belief_ims = [[] for n in range(self.args.batch_size)] for n in range(seq.shape[0]): instruction = self.dataloader.tokenizer.decode_sentence( seq[n]).split() act_att = act_att_weights[n].cpu().numpy() obs_att = obs_att_weights[n].cpu().numpy() valid_act_att = act_att[:, : seq_lens[n]].transpose( ) valid_obs_att = obs_att[:, : seq_lens[n]].transpose( ) fig = viz_utils.plot_att_graph( valid_act_att, valid_obs_att, instruction, seq_lens[n].item(), black_background=True) if self.args.viz_gif: att_ims.append( viz_utils.figure_to_rgb(fig)) else: fig.savefig( "%s/attention-%s-it%d-n%d.png" % (self.args.viz_folder, split, it, n), facecolor=fig.get_facecolor(), transparent=True) plt.close('all') for k in range(3): if self.args.policy_gt_belief: viz = floor_maps else: viz = self.overlay_mask(floor_maps, mask) if k >= 1: viz = self.overlay_belief(viz, belief_pred) viz = self.overlay_goal( viz, self.dataloader.goal_coords() + self.mapper.map_center[:, :2]) if k >= 2: viz = self.overlay_local_graph( viz, features, action_idx) else: viz = self.overlay_local_graph(viz, features) for n in range(len(viz)): if not ended[n]: if self.args.viz_gif: image = viz[n] * 255 # Add attention image on left min_val = image.shape[ 0] // 2 - att_ims[n].shape[0] // 2 max_val = min_val + att_ims[n].shape[0] image = np.flip( image[min_val:max_val, min_val:max_val, :], 2) image = np.concatenate( [att_ims[n], image], axis=1) # Add rgb images at bottom new_width = int(image.shape[-2] / float(rgb.shape[0])) new_height = int(new_width * rgb.shape[-3] / float(rgb.shape[-2])) rgb_ims = [ cv2.resize( rgb[i, n].cpu().detach().numpy(), (new_width, new_height)) for i in range(rgb.shape[0]) ] rgb_ims = np.concatenate(rgb_ims, axis=1) # Add depth images at bottom depth_ims = [ cv2.resize( depth[i, n].cpu().detach( ).numpy() / 200.0, (new_width, new_height)) for i in range(depth.shape[0]) ] depth_ims = np.concatenate(depth_ims, axis=1) image = np.concatenate( [image, rgb_ims, depth_ims], axis=0) belief_ims[n].append( image.astype(np.uint8)) else: filename = '%s/belief-%s-it%d-n%d-t%d_%d.png' % ( self.args.viz_folder, split, it, n, t, viz_counter) cv2.imwrite(filename, viz[n] * 255) viz_counter += 1 ended |= sim.takeMultiStepAction(action_idx) if ended.all(): break if self.args.viz_gif and self.args.viz_eval and it < self.args.viz_iterations: import imageio for n in range(self.args.batch_size): filename = '%s/%s-%s-it%d-n%d' % ( self.args.viz_folder, split, batch[n]['instr_id'], it, n) if not os.path.exists(filename): os.makedirs(filename) with imageio.get_writer(filename + '.gif', mode='I', format='GIF-PIL', subrectangles=True, fps=1) as writer: for i, image in enumerate(belief_ims[n]): writer.append_data(image) cv2.imwrite("%s/%04d.png" % (filename, i), np.flip(image, 2)) # Eval out_dir = "%s/%s" % (args.result_dir, args.exp_name) if not os.path.exists(out_dir): os.makedirs(out_dir) output_file = "%s/%s_%s_%d.json" % (out_dir, split, args.exp_name, args.val_epoch) with open(output_file, 'w') as f: json.dump(sim.traj, f) scores = vln_eval.score(output_file) print(scores) with open(output_file.replace('.json', '_scores.json'), 'w') as f: json.dump(scores, f)