예제 #1
0
class FilterTrainer(Trainer):
    """ Train a filter """
    def __init__(self, args, filepath=None):
        super(FilterTrainer, self).__init__(args, filepath)

        # load models
        self.mapper = Mapper(self.args).to(self.args.device)
        self.model = Filter(self.args).to(self.args.device)

        self.map_opt = self.optimizer(self.mapper.parameters())
        self.model_opt = self.optimizer(self.model.parameters())

        if filepath:
            loader = torch.load(filepath)
            self.mapper.load_state_dict(loader["mapper"])
            self.model.load_state_dict(loader["filter"])
            self.map_opt.load_state_dict(loader["mapper_optimizer"])
            self.model_opt.load_state_dict(loader["filter_optimizer"])
            print("Loaded Mapper and Filter from: %s" % filepath)
        elif args:
            self.model.init_weights()
        self.criterion = LogBeliefLoss()

    def validate(self, split):
        """
        split: "val_seen" or "val_unseen"
        """
        with torch.no_grad():

            if (split == "val_seen"):
                val_dataloader = self.valseendata
            elif (split == "val_unseen"):
                val_dataloader = self.valunseendata

            total_loss = torch.tensor(0.0)
            normalizer = 0

            for it in range(self.args.validation_iterations):

                # Load minibatch and simulator
                seq, seq_mask, seq_lens, batch = val_dataloader.get_batch()
                self.sim.newEpisode(batch)

                # Initialize the mapper
                xyzhe = self.sim.getXYZHE()
                spatial_map, mask = self.mapper.init_map(xyzhe)

                # Note mm is being validated on the GT path. This is not the path taken by Mapper.
                path_xyzhe, path_len = val_dataloader.path_xyzhe()

                for t in range(self.args.timesteps):
                    rgb, depth, states = self.sim.getPanos()
                    spatial_map, mask, ftm = self.mapper(
                        rgb, depth, states, spatial_map, mask)

                    state = None
                    steps = self.args.max_steps - 1

                    belief = self.mapper.belief_map(
                        path_xyzhe[0], self.args.filter_input_sigma).log()
                    for k in range(steps):
                        input_belief = belief
                        # Train a filter
                        new_belief, obs_likelihood, state, _, _, _ = self.model(
                            k, seq, seq_mask, seq_lens, input_belief,
                            spatial_map, state)
                        belief = new_belief + obs_likelihood
                        # Renormalize
                        belief = belief - belief.reshape(
                            belief.shape[0], -1).logsumexp(
                                dim=1).unsqueeze(1).unsqueeze(1).unsqueeze(1)

                        # Determine target and loss
                        target_heatmap = self.mapper.belief_map(
                            path_xyzhe[k + 1], self.args.filter_heatmap_sigma)
                        # To make loss independent of heading_states, we sum over all heading states (logsumexp) for validation
                        total_loss += self.criterion(
                            belief.logsumexp(dim=1).unsqueeze(1),
                            target_heatmap.sum(dim=1).unsqueeze(1))
                        normalizer += self.args.batch_size

                    # Take action in the sim
                    self.sim.takePseudoSupervisedAction(
                        val_dataloader.get_supervision)

        return total_loss / normalizer

    def logging_loop(self, it):
        """ Logging and checkpointing stuff """

        if it % self.args.validate_every == 0:
            self.mapper.eval()
            self.model.eval()
            loss_val_seen = self.validate("val_seen")
            loss_val_unseen = self.validate("val_unseen")

            self.prev_time, time_taken = utils.time_it(self.prev_time)

            print(
                "Iteration: %d Loss: %f Val Seen Loss: %f Val Unseen Loss: %f Time: %0.2f secs"
                % (it, self.loss.item(), loss_val_seen.item(),
                   loss_val_unseen.item(), time_taken))

            if self.visdom:
                # visdom: X, Y, key, line_name, x_label, y_label, fig_title
                self.visdom.line(it,
                                 self.loss.item(),
                                 "train_loss",
                                 "Train Loss",
                                 "Iterations",
                                 "Loss",
                                 title=" Train Phase")
                self.visdom.line(it,
                                 loss_val_seen.item(),
                                 "val_loss",
                                 "Val Seen Loss",
                                 "Iterations",
                                 "Loss",
                                 title="Val Phase")
                self.visdom.line(it,
                                 loss_val_unseen.item(),
                                 "val_loss",
                                 "Val Unseen Loss",
                                 "Iterations",
                                 "Loss",
                                 title="Val Phase")
            self.mapper.train()
            self.model.train()

        elif it % self.args.log_every == 0:
            self.prev_time, time_taken = utils.time_it(self.prev_time)
            print("Iteration: %d Loss: %f Time: %0.2f secs" %
                  (it, self.loss.item(), time_taken))
            if self.visdom:
                self.visdom.line(it,
                                 self.loss.item(),
                                 "train_loss",
                                 "Train Loss",
                                 "Iterations",
                                 "Loss",
                                 title="Train Phase")

        if it % self.args.checkpoint_every == 0:
            saver = {
                "mapper": self.mapper.state_dict(),
                "filter": self.model.state_dict(),
                "mapper_optimizer": self.map_opt.state_dict(),
                "filter_optimizer": self.model_opt.state_dict(),
                "epoch": it,
                "args": self.args,
            }
            dir = "%s/%s" % (self.args.snapshot_dir, self.args.exp_name)
            if not os.path.exists(dir):
                os.makedirs(dir)
            torch.save(saver, "%s/%s_%d" % (dir, self.args.exp_name, it))
            if self.visdom:
                self.visdom.save()

    def map_to_image(self):
        """ Show where map features are found in the RGB images """
        traindata = self.traindata

        # Load minibatch and simulator
        seq, seq_mask, seq_lens, batch = traindata.get_batch(False)
        self.sim.newEpisode(batch)

        # Initialize the mapper
        xyzhe = self.sim.getXYZHE()
        spatial_map, mask = self.mapper.init_map(xyzhe)

        for t in range(self.args.timesteps):
            rgb, depth, states = self.sim.getPanos()

            # ims = (torch.flip(rgb, [1])).permute(0,2,3,1).cpu().detach().numpy()
            # for n in range(ims.shape[0]):
            #     cv2.imwrite('im %d.png' % n, ims[n])
            spatial_map, mask, ftm = self.mapper(rgb, depth, states,
                                                 spatial_map, mask)

            im_ix = torch.arange(0, ftm.shape[0],
                                 step=self.args.batch_size).to(ftm.device)
            ims = self.feature_sources(rgb[im_ix], ftm[im_ix], 56, 48)
            for n in range(ims.shape[0]):
                cv2.imwrite('im %d-%d.png' % (t, n), ims[n])

            # Take action in the sim
            self.sim.takePseudoSupervisedAction(traindata.get_supervision)

    def train(self):
        """
        Supervised training of the filter
        """

        torch.autograd.set_detect_anomaly(True)
        self.prev_time = time.time()

        # Set models to train phase
        self.mapper.train()
        self.model.train()

        for it in range(self.args.start_epoch, self.args.max_iterations + 1):

            self.map_opt.zero_grad()
            self.model_opt.zero_grad()

            traindata = self.traindata
            # Load minibatch and simulator
            seq, seq_mask, seq_lens, batch = traindata.get_batch()
            self.sim.newEpisode(batch)
            if self.args.debug_mode:
                debug_path = traindata.get_path()

            # Initialize the mapper
            xyzhe = self.sim.getXYZHE()

            spatial_map, mask = self.mapper.init_map(xyzhe)

            # Note model is being trained on the GT path. This is not the path taken by Mapper.
            path_xyzhe, path_len = traindata.path_xyzhe()

            loss = 0
            normalizer = 0

            for t in range(self.args.timesteps):
                rgb, depth, states = self.sim.getPanos()

                spatial_map, mask, ftm = self.mapper(rgb, depth, states,
                                                     spatial_map, mask)
                del states
                del depth
                del ftm
                if not self.args.debug_mode:
                    del rgb

                state = None
                steps = self.args.max_steps - 1

                belief = self.mapper.belief_map(
                    path_xyzhe[0], self.args.filter_input_sigma).log()
                for k in range(steps):
                    input_belief = belief
                    # Train a filter
                    new_belief, obs_likelihood, state, _, _, _ = self.model(
                        k, seq, seq_mask, seq_lens, input_belief, spatial_map,
                        state)
                    belief = new_belief + obs_likelihood
                    # Renormalize
                    belief = belief - belief.reshape(
                        belief.shape[0], -1).logsumexp(
                            dim=1).unsqueeze(1).unsqueeze(1).unsqueeze(1)

                    # Determine target and loss
                    target_heatmap = self.mapper.belief_map(
                        path_xyzhe[k + 1], self.args.filter_heatmap_sigma)
                    loss += self.criterion(belief, target_heatmap)
                    normalizer += self.args.batch_size

                    if self.args.debug_mode:
                        input_belief = F.interpolate(
                            input_belief.exp(),
                            scale_factor=self.args.belief_downsample_factor,
                            mode='bilinear',
                            align_corners=False).sum(dim=1).unsqueeze(1)
                        belief_up = F.interpolate(
                            belief.exp(),
                            scale_factor=self.args.belief_downsample_factor,
                            mode='bilinear',
                            align_corners=False).sum(dim=1).unsqueeze(1)
                        target_heatmap_up = F.interpolate(
                            target_heatmap,
                            scale_factor=self.args.belief_downsample_factor,
                            mode='bilinear',
                            align_corners=False).sum(dim=1).unsqueeze(1)
                        debug_map = self.mapper.debug_maps(debug_path)

                        self.visual_debug(t, rgb, mask, debug_map,
                                          input_belief, belief_up,
                                          target_heatmap_up)

                trunc = self.args.truncate_after
                if (t % trunc) == (trunc - 1) or t + 1 == self.args.timesteps:
                    self.loss = loss / normalizer
                    self.loss.backward()
                    self.loss.detach_()
                    loss = 0
                    normalizer = 0
                    spatial_map.detach_()
                    if state is not None:
                        state = self.model.detach_state(state)
                        state = None  # Recalc to get gradients from later in the decoding - #TODO keep this part of the graph without re-running the forward part

                # Take action in the sim
                self.sim.takePseudoSupervisedAction(traindata.get_supervision)

            del spatial_map
            del mask

            self.map_opt.step()
            self.model_opt.step()
            self.logging_loop(it)
예제 #2
0
class PVNBaselineTrainer(Trainer):
    """ Train the PVN Baseline """

    def __init__(self, args):
        super(PVNBaselineTrainer, self).__init__(args, filepath=None)

        self.mapper = Mapper(self.args).to(self.args.device)
        self.navigator = UnetNavigator(self.args).to(self.args.device)
        self.navigator.init_weights()

        self.goal_criterion = MaskedMSELoss()
        self.path_criterion = MaskedMSELoss()

    def validate(self, split):
        """
        split: "val_seen" or "val_unseen"
        """

        with torch.no_grad():

            if(split=="val_seen"):
                val_dataloader = self.valseendata
            elif(split=="val_unseen"):
                val_dataloader = self.valunseendata

            total_loss = torch.tensor(0.0)

            for it in range(self.args.validation_iterations):

                # Load minibatch and simulator
                seq, seq_mask, seq_lens, batch = val_dataloader.get_batch()
                self.sim.newEpisode(batch)

                if self.args.multi_maps:
                    all_spatial_maps = []
                    all_masks = []

                # Run the mapper for one step
                xyzhe = self.sim.getXYZHE()
                spatial_map,mask = self.mapper.init_map(xyzhe)
                for t in range(self.args.timesteps):
                    rgb,depth,states = self.sim.getPanos()
                    spatial_map,mask,ftm = self.mapper(rgb,depth,states,spatial_map,mask)
                    if self.args.multi_maps:
                        all_spatial_maps.append(spatial_map.unsqueeze(1))
                        all_masks.append(mask.unsqueeze(1))
                    if self.args.timesteps != 1:
                        self.sim.takePseudoSupervisedAction(val_dataloader.get_supervision)

                if self.args.multi_maps:
                    spatial_map = torch.cat(all_spatial_maps, dim=1).flatten(0, 1)
                    mask = torch.cat(all_masks, dim=1).flatten(0, 1)

                    seq = seq.unsqueeze(1).expand(-1, self.args.timesteps, -1).flatten(0, 1)
                    seq_lens = seq_lens.unsqueeze(1).expand(-1, self.args.timesteps).flatten(0, 1)

                # Predict with unet
                pred = self.navigator(seq, seq_lens, spatial_map)
                goal_pred = pred[:,0,:,:]
                path_pred = pred[:,1,:,:]

                goal_map = self.mapper.heatmap(val_dataloader.goal_coords(), self.args.goal_heatmap_sigma)
                path_map = self.mapper.heatmap(val_dataloader.path_coords(), self.args.path_heatmap_sigma)

                goal_map_max, _ = torch.max(goal_map, dim=2, keepdim=True)
                goal_map_max, _ = torch.max(goal_map_max, dim=1, keepdim=True)
                goal_map /= goal_map_max

                path_map_max, _ = torch.max(path_map, dim=2, keepdim=True)
                path_map_max, _ = torch.max(path_map_max, dim=1, keepdim=True)
                path_map /= path_map_max

                if self.args.belief_downsample_factor > 1:
                    goal_map = nn.functional.avg_pool2d(goal_map, kernel_size=self.args.belief_downsample_factor, stride=self.args.belief_downsample_factor)
                    path_map = nn.functional.avg_pool2d(path_map, kernel_size=self.args.belief_downsample_factor, stride=self.args.belief_downsample_factor)
                    mask = nn.functional.max_pool2d(mask, kernel_size=self.args.belief_downsample_factor, stride=self.args.belief_downsample_factor)

                if self.args.multi_maps:
                    goal_map = goal_map.unsqueeze(1).expand(-1, self.args.timesteps, -1, -1).flatten(0, 1)
                    path_map = path_map.unsqueeze(1).expand(-1, self.args.timesteps, -1, -1).flatten(0, 1)

                loss = self.goal_criterion(goal_pred,goal_map) + self.path_criterion(path_pred,path_map)

                total_loss += (1.0 / self.args.validation_iterations) * loss

        return total_loss

    def logging_loop(self, it):
        """ Logging and checkpointing stuff """

        if it % self.args.validate_every == 0:

            self.mapper.eval()
            self.navigator.eval()

            loss_val_seen = self.validate("val_seen")
            loss_val_unseen = self.validate("val_unseen")

            self.prev_time, time_taken = utils.time_it(self.prev_time)

            print("Iteration: %d Loss: %f Val Seen Loss: %f Val Unseen Loss: %f Time: %0.2f secs"
                    %(it, self.loss.item(), loss_val_seen.item(), loss_val_unseen.item(), time_taken))

            if self.visdom:
                # visdom: X, Y, key, line_name, x_label, y_label, fig_title
                self.visdom.line(it, self.loss.item(), "train_loss", "Train Loss", "Iterations", "Loss", title=" Train Phase")
                self.visdom.line(it, loss_val_seen.item(), "val_loss", "Val Seen Loss", "Iterations", "Loss", title="Val Phase")
                self.visdom.line(it, loss_val_unseen.item(), "val_loss", "Val Unseen Loss", "Iterations", "Loss", title="Val Phase")

            self.mapper.train()
            self.navigator.train()

        elif it % self.args.log_every == 0:
            self.prev_time, time_taken = utils.time_it(self.prev_time)
            print("Iteration: %d Loss: %f Time: %0.2f secs" % (it, self.loss.item(), time_taken))

            if self.visdom:
                self.visdom.line(it, self.loss.item(), "train_loss", "Train Loss", "Iterations", "Loss", title="Train Phase")

        if it % self.args.checkpoint_every == 0:
            saver = {"mapper": self.mapper.state_dict(),
                     "navigator": self.navigator.state_dict(),
                     "args": self.args}
            dir = "%s/%s" % (self.args.snapshot_dir, self.args.exp_name)
            if not os.path.exists(dir):
                os.makedirs(dir)
            torch.save(
                saver, "%s/%s_%d" % (dir, self.args.exp_name, it)
            )
            if self.visdom:
                self.visdom.save()

    def plot_grad_flow(self, named_parameters):
        ave_grads = []
        layers = []

        for n, p in named_parameters:
            if(p.requires_grad) and ("bias" not in n):
                layers.append(n)
                ave_grads.append(p.grad.abs().mean())
        plt.plot(ave_grads, alpha=0.3, color="b")
        plt.hlines(0, 0, len(ave_grads)+1, linewidth=1, color="k" )
        plt.xticks(range(0,len(ave_grads), 1), layers, rotation="vertical")
        plt.xlim(xmin=0, xmax=len(ave_grads))
        plt.xlabel("Layers")
        plt.ylabel("average gradient")
        plt.title("Gradient flow")
        plt.grid(True)
        plt.savefig("gradient.png")

    def train(self):
        """
        Supervised training of the mapper and navigator
        """

        self.prev_time = time.time()

        map_opt = optim.Adam(self.mapper.parameters(), lr=self.args.learning_rate,
            weight_decay=self.args.weight_decay)
        nav_opt = optim.Adam(self.navigator.parameters(), lr=self.args.learning_rate,
            weight_decay=self.args.weight_decay)

        # Set models to train phase
        self.mapper.train()
        self.navigator.train()

        for it in range(1, self.args.max_iterations + 1):
            map_opt.zero_grad()
            nav_opt.zero_grad()

            # Load minibatch and simulator
            seq, seq_mask, seq_lens, batch = self.traindata.get_batch()
            self.sim.newEpisode(batch)

            if self.args.multi_maps:
                all_spatial_maps = []
                all_masks = []

            # Run the mapper for multiple timesteps and keep each map
            xyzhe = self.sim.getXYZHE()
            spatial_map,mask = self.mapper.init_map(xyzhe)
            for t in range(self.args.timesteps):
                rgb,depth,states = self.sim.getPanos()
                spatial_map,mask,ftm = self.mapper(rgb,depth,states,spatial_map,mask)
                if self.args.multi_maps:
                    all_spatial_maps.append(spatial_map.unsqueeze(1))
                    all_masks.append(mask.unsqueeze(1))
                if self.args.timesteps != 1:
                    self.sim.takePseudoSupervisedAction(self.traindata.get_supervision)

            if self.args.multi_maps:
                spatial_map = torch.cat(all_spatial_maps, dim=1).flatten(0, 1)
                mask = torch.cat(all_masks, dim=1).flatten(0, 1)

                seq = seq.unsqueeze(1).expand(-1, self.args.timesteps, -1).flatten(0, 1)
                seq_lens = seq_lens.unsqueeze(1).expand(-1, self.args.timesteps).flatten(0, 1)

            # Predict with unet
            pred = self.navigator(seq, seq_lens, spatial_map)

            goal_pred = pred[:,0,:,:]
            path_pred = pred[:,1,:,:]

            goal_map = self.mapper.heatmap(self.traindata.goal_coords(), self.args.goal_heatmap_sigma)
            path_map = self.mapper.heatmap(self.traindata.path_coords(), self.args.path_heatmap_sigma)

            goal_map_max, _ = torch.max(goal_map, dim=2, keepdim=True)
            goal_map_max, _ = torch.max(goal_map_max, dim=1, keepdim=True)
            goal_map /= goal_map_max

            path_map_max, _ = torch.max(path_map, dim=2, keepdim=True)
            path_map_max, _ = torch.max(path_map_max, dim=1, keepdim=True)
            path_map /= path_map_max

            if self.args.belief_downsample_factor > 1:
                goal_map = nn.functional.avg_pool2d(goal_map, kernel_size=self.args.belief_downsample_factor, stride=self.args.belief_downsample_factor)
                path_map = nn.functional.avg_pool2d(path_map, kernel_size=self.args.belief_downsample_factor, stride=self.args.belief_downsample_factor)
                mask = nn.functional.max_pool2d(mask, kernel_size=self.args.belief_downsample_factor, stride=self.args.belief_downsample_factor)

            if self.args.multi_maps:
                goal_map = goal_map.unsqueeze(1).expand(-1, self.args.timesteps, -1, -1).flatten(0, 1)
                path_map = path_map.unsqueeze(1).expand(-1, self.args.timesteps, -1, -1).flatten(0, 1)

            self.loss = self.goal_criterion(goal_pred,goal_map) + self.path_criterion(path_pred,path_map)

            self.loss.backward()
            map_opt.step()
            nav_opt.step()
            self.logging_loop(it)
예제 #3
0
class PolicyTrainer(Trainer):
    """ Train a filter and a policy """
    def __init__(self, args, filepath=None):
        super(PolicyTrainer, self).__init__(args, filepath, load_sim=False)
        self.sim = PanoSimulatorWithGraph(self.args)

        # load models
        self.mapper = Mapper(self.args).to(self.args.device)
        self.model = Filter(self.args).to(self.args.device)
        self.policy = ReactivePolicy(self.args).to(self.args.device)

        if filepath:
            loader = torch.load(filepath)
            self.mapper.load_state_dict(loader["mapper"])
            self.model.load_state_dict(loader["filter"])
            self.policy.load_state_dict(loader["policy"])
            print("Loaded Mapper, Filter and Policy from: %s" % filepath)
        elif args:
            self.model.init_weights()
            self.policy.init_weights()
        self.belief_criterion = LogBeliefLoss()
        self.policy_criterion = torch.nn.NLLLoss(
            ignore_index=self.args.action_ignore_index, reduction='none')

    def validate(self, split):
        """
        split: "val_seen" or "val_unseen"
        """
        vln_eval = Evaluation(split)
        self.sim.record_traj(True)
        with torch.no_grad():

            if (split == "val_seen"):
                val_dataloader = self.valseendata
            elif (split == "val_unseen"):
                val_dataloader = self.valunseendata
            elif (split == "train"):
                val_dataloader = self.traindata

            for it in range(self.args.validation_iterations):

                # Load minibatch and simulator
                seq, seq_mask, seq_lens, batch = val_dataloader.get_batch()
                self.sim.newEpisode(batch)

                # Initialize the mapper
                xyzhe = self.sim.getXYZHE()
                spatial_map, mask = self.mapper.init_map(xyzhe)

                ended = torch.zeros(self.args.batch_size,
                                    device=self.args.device).byte()
                for t in range(self.args.timesteps):
                    if self.args.policy_gt_belief:
                        path_xyzhe, path_len = val_dataloader.path_xyzhe()
                    else:
                        rgb, depth, states = self.sim.getPanos()
                        spatial_map, mask, ftm = self.mapper(
                            rgb, depth, states, spatial_map, mask)
                        del states
                        del depth
                        del ftm
                        del rgb

                    features, _, _ = self.sim.getGraphNodes()
                    P = features.shape[0]  # num graph nodes
                    belief_features = torch.empty(P,
                                                  self.args.max_steps,
                                                  device=self.args.device)

                    state = None
                    steps = self.args.max_steps - 1

                    belief = self.mapper.belief_map(
                        xyzhe, self.args.filter_input_sigma).log()
                    sigma = self.args.filter_heatmap_sigma
                    gridcellsize = self.args.gridcellsize * self.args.belief_downsample_factor
                    belief_features[:, 0] = belief_at_nodes(
                        belief.exp(), features[:, :4], sigma, gridcellsize)
                    for k in range(steps):
                        if self.args.policy_gt_belief:
                            target_heatmap = self.mapper.belief_map(
                                path_xyzhe[k + 1],
                                self.args.filter_heatmap_sigma)
                            belief_features[:, k + 1] = belief_at_nodes(
                                target_heatmap, features[:, :4], sigma,
                                gridcellsize)
                        else:
                            input_belief = belief
                            # Train a filter
                            new_belief, obs_likelihood, state, _, _, _ = self.model(
                                k, seq, seq_mask, seq_lens, input_belief,
                                spatial_map, state)
                            belief = new_belief + obs_likelihood
                            # Renormalize
                            belief = belief - belief.reshape(
                                belief.shape[0],
                                -1).logsumexp(dim=1).unsqueeze(1).unsqueeze(
                                    1).unsqueeze(1)

                            belief_features[:, k + 1] = belief_at_nodes(
                                belief.exp(), features[:, :4], sigma,
                                gridcellsize)

                    # Probs from policy
                    aug_features = torch.cat([features, belief_features],
                                             dim=1)
                    log_prob = self.policy(aug_features)

                    # Take argmax action in the sim
                    _, action_idx = log_prob.exp().max(dim=1)
                    ended |= self.sim.takeMultiStepAction(action_idx)

                    if ended.all():
                        break

        # Eval
        scores = vln_eval.score(self.sim.traj, check_all_trajs=False)
        self.sim.record_traj(False)
        return scores['result'][0][split]

    def logging_loop(self, it):
        """ Logging and checkpointing stuff """

        if it % self.args.validate_every == 0:
            self.mapper.eval()
            self.model.eval()
            self.policy.eval()
            scores_train = self.validate("train")
            scores_val_seen = self.validate("val_seen")
            scores_val_unseen = self.validate("val_unseen")

            self.prev_time, time_taken = utils.time_it(self.prev_time)

            print(
                "Iteration: %d Loss: %f Train Success: %f Val Seen Success: %f Val Unseen Success: %f Time: %0.2f secs"
                % (it, self.loss.item(), scores_train['success'],
                   scores_val_seen['success'], scores_val_unseen['success'],
                   time_taken))

            if self.visdom:
                # visdom: X, Y, key, line_name, x_label, y_label, fig_title
                self.visdom.line(it,
                                 self.loss.item(),
                                 "train_loss",
                                 "Train Loss",
                                 "Iterations",
                                 "Loss",
                                 title=" Train Phase")

                units = {
                    'length': 'm',
                    'error': 'm',
                    'oracle success': '%',
                    'success': '%',
                    'spl': '%'
                }
                sub = self.args.validation_iterations * self.args.batch_size
                for metric, score in scores_train.items():
                    m = metric.title()
                    self.visdom.line(it,
                                     score,
                                     metric,
                                     "Train (%d)" % sub,
                                     "Iterations",
                                     units[metric],
                                     title=m)
                for metric, score in scores_val_seen.items():
                    m = metric.title()
                    self.visdom.line(it,
                                     score,
                                     metric,
                                     "Val Seen (%d)" % sub,
                                     "Iterations",
                                     units[metric],
                                     title=m)
                for metric, score in scores_val_unseen.items():
                    m = metric.title()
                    self.visdom.line(it,
                                     score,
                                     metric,
                                     "Val Unseen (%d)" % sub,
                                     "Iterations",
                                     units[metric],
                                     title=m)
            self.mapper.train()
            self.model.train()
            self.policy.train()

        elif it % self.args.log_every == 0:
            self.prev_time, time_taken = utils.time_it(self.prev_time)
            print("Iteration: %d Loss: %f Time: %0.2f secs" %
                  (it, self.loss.item(), time_taken))
            if self.visdom:
                self.visdom.line(it,
                                 self.loss.item(),
                                 "train_loss",
                                 "Train Loss",
                                 "Iterations",
                                 "Loss",
                                 title="Train Phase")

        if it % self.args.checkpoint_every == 0:
            saver = {
                "mapper": self.mapper.state_dict(),
                "args": self.args,
                "filter": self.model.state_dict(),
                "policy": self.policy.state_dict()
            }
            dir = "%s/%s" % (self.args.snapshot_dir, self.args.exp_name)
            if not os.path.exists(dir):
                os.makedirs(dir)
            torch.save(saver, "%s/%s_%d" % (dir, self.args.exp_name, it))
            if self.visdom:
                self.visdom.save()

    def train(self):
        """
        Supervised training of the mapper, filter and policy
        """

        torch.autograd.set_detect_anomaly(True)
        self.prev_time = time.time()

        map_opt = self.optimizer(self.mapper.parameters())
        model_opt = self.optimizer(self.model.parameters())
        policy_opt = self.optimizer(self.policy.parameters())

        # Set models to train phase
        self.mapper.train()
        self.model.train()
        self.policy.train()

        for it in range(1, self.args.max_iterations + 1):
            map_opt.zero_grad()
            model_opt.zero_grad()
            policy_opt.zero_grad()

            traindata = self.traindata

            # Load minibatch and simulator
            seq, seq_mask, seq_lens, batch = traindata.get_batch()
            self.sim.newEpisode(batch)

            # Initialize the mapper
            xyzhe = self.sim.getXYZHE()
            spatial_map, mask = self.mapper.init_map(xyzhe)

            # Note Filter model is being trained to predict the GT path. This is not the path taken by Policy.
            path_xyzhe, path_len = traindata.path_xyzhe()
            self.loss = 0
            ended = torch.zeros(self.args.batch_size,
                                device=self.args.device).byte()
            for t in range(self.args.timesteps):
                if not self.args.policy_gt_belief:
                    rgb, depth, states = self.sim.getPanos()
                    spatial_map, mask, ftm = self.mapper(
                        rgb, depth, states, spatial_map, mask)

                    del states
                    del depth
                    del ftm
                    del rgb

                features, _, _ = self.sim.getGraphNodes()
                P = features.shape[0]  # num graph nodes
                belief_features = torch.empty(P,
                                              self.args.max_steps,
                                              device=self.args.device)

                state = None
                steps = self.args.max_steps - 1

                gt_input_belief = None
                belief = self.mapper.belief_map(
                    path_xyzhe[0], self.args.filter_input_sigma).log()
                sigma = self.args.filter_heatmap_sigma
                gridcellsize = self.args.gridcellsize * self.args.belief_downsample_factor
                belief_features[:,
                                0] = belief_at_nodes(belief.exp(),
                                                     features[:, :4], sigma,
                                                     gridcellsize)
                for k in range(steps):
                    target_heatmap = self.mapper.belief_map(
                        path_xyzhe[k + 1], self.args.filter_heatmap_sigma)
                    if self.args.policy_gt_belief:
                        belief_features[:, k + 1] = belief_at_nodes(
                            target_heatmap, features[:, :4], sigma,
                            gridcellsize)
                    else:
                        input_belief = belief
                        # Train a filter
                        if self.args.teacher_force_motion_model:
                            gt_input_belief = self.mapper.belief_map(
                                path_xyzhe[k],
                                self.args.filter_heatmap_sigma).log()
                            new_belief, obs_likelihood, state, _, _, new_gt_belief = self.model(
                                k, seq, seq_mask, seq_lens, input_belief,
                                spatial_map, state, gt_input_belief)
                            belief = new_belief.detach(
                            ) + obs_likelihood  #Don't backprop through belief
                        else:
                            new_belief, obs_likelihood, state, _, _, new_gt_belief = self.model(
                                k, seq, seq_mask, seq_lens, input_belief,
                                spatial_map, state, gt_input_belief)
                            belief = new_belief + obs_likelihood
                        # Renormalize
                        belief = belief - belief.reshape(
                            belief.shape[0], -1).logsumexp(
                                dim=1).unsqueeze(1).unsqueeze(1).unsqueeze(1)

                        # Determine target and loss for the filter
                        belief_loss = self.belief_criterion(
                            belief, target_heatmap,
                            valid=~ended) / (~ended).sum()
                        if self.args.teacher_force_motion_model:  # Separate loss for the motion model
                            belief_loss += self.belief_criterion(
                                new_gt_belief, target_heatmap,
                                valid=~ended) / (~ended).sum()
                        self.loss += belief_loss

                        belief_features[:, k + 1] = belief_at_nodes(
                            belief.exp(), features[:, :4], sigma, gridcellsize)

                # Train a policy
                aug_features = torch.cat([features, belief_features], dim=1)
                log_prob = self.policy(aug_features)

                # Train a policy
                target_idx = traindata.closest_to_goal(self.sim.G)
                policy_loss = self.args.policy_loss_lambda * self.policy_criterion(
                    log_prob, target_idx).sum() / (~ended).sum()
                self.loss += policy_loss

                # Take action in the sim
                if self.args.supervision_prob < 0:
                    supervision_prob = 1.0 - float(
                        it) / self.args.max_iterations
                else:
                    supervision_prob = self.args.supervision_prob
                sampled_action = D.Categorical(log_prob.exp()).sample()
                weights = torch.tensor(
                    [1.0 - supervision_prob, supervision_prob],
                    dtype=torch.float,
                    device=log_prob.device)
                ix = torch.multinomial(weights,
                                       self.args.batch_size,
                                       replacement=True).byte()
                action_idx = torch.where(ix, target_idx, sampled_action)
                ended |= self.sim.takeMultiStepAction(action_idx)

                trunc = self.args.truncate_after
                if (t % trunc) == (
                        trunc -
                        1) or t + 1 == self.args.timesteps or ended.all():
                    self.loss.backward()
                    self.loss.detach_()
                    spatial_map.detach_()
                    if state is not None:
                        state = self.model.detach_state(state)
                        state = None  # Recalc to get gradients from later in the decoding - #TODO keep this part of the graph without re-running the forward part

                if ended.all():
                    break

            del spatial_map
            del mask

            map_opt.step()
            model_opt.step()
            policy_opt.step()
            self.logging_loop(it)