Exemple #1
0
    def train(self):

        # Switch model to train mode
        self.model.train()

        # Check if maxEpochs have elapsed
        if self.curEpoch >= self.maxEpochs:
            print('Max epochs elapsed! Returning ...')
            return

        # Increment iters
        self.iters += 1

        # Variables to store stats
        rotLosses = []
        transLosses = []
        totalLosses = []
        rotLoss_seq = []
        transLoss_seq = []
        totalLoss_seq = []

        # Handle debug mode here
        if self.args.debug is True:
            numTrainIters = self.args.debugIters
        else:
            numTrainIters = len(self.train_set)

        # Initialize a variable to hold the number of sampes in the current batch
        # Here, 'batch' refers to the length of a subsequence that can be processed
        # before performing a 'detach' operation
        elapsedBatches = 0

        # Choose a generator (for iterating over the dataset, based on whether or not the
        # sbatch flag is set to True). If sbatch is True, we're probably running on a cluster
        # and do not want an interactive output. So, could suppress tqdm and print statements
        if self.args.sbatch is True:
            gen = range(numTrainIters)
        else:
            gen = trange(numTrainIters)

        # Run a pass of the dataset
        for i in gen:

            if self.args.profileGPUUsage is True:
                gpu_memory_map = get_gpu_memory_map()
                tqdm.write('GPU usage: ' + str(gpu_memory_map[0]),
                           file=sys.stdout)

            # Get the next frame
            inp, rot_gt, trans_gt, _, _, _, endOfSeq = self.train_set[i]

            # Feed it through the model
            rot_pred, trans_pred = self.model.forward(inp)

            # Compute loss
            # self.loss_rot += self.args.scf * self.loss_fn(rot_pred, rot_gt)
            if self.args.outputParameterization == 'mahalanobis':
                # Compute a mahalanobis norm on the output 6-vector
                # Note that, although we seem to be computing loss only over rotation variables
                # rot_pred and rot_gt are now 6-vectors that also include translation variables.
                self.loss += self.loss_fn(rot_pred, rot_gt,
                                          self.train_set.infoMat)
                tmpLossVar = Variable(
                    torch.mm(
                        rot_pred - rot_gt,
                        torch.mm(self.train_set.infoMat,
                                 (rot_pred - rot_gt).t())),
                    requires_grad=False).detach().cpu().numpy()
                # tmpLossVar = Variable(torch.dist(rot_pred, rot_gt) ** 2, requires_grad = False).detach().cpu().numpy()
                totalLosses.append(tmpLossVar[0])
                totalLoss_seq.append(tmpLossVar[0])
            else:
                curloss_rot = Variable(self.args.scf *
                                       (torch.dist(rot_pred, rot_gt)**2),
                                       requires_grad=False)
                curloss_trans = Variable(torch.dist(trans_pred, trans_gt)**2,
                                         requires_grad=False)
                self.loss_rot += curloss_rot
                self.loss_trans += curloss_trans

                if np.random.normal() < -0.9:
                    tqdm.write('rot: ' + str(rot_pred.data) + ' ' +
                               str(rot_gt.data),
                               file=sys.stdout)
                    tqdm.write('trans: ' + str(trans_pred.data) + ' ' +
                               str(trans_gt.data),
                               file=sys.stdout)

                self.loss += sum([self.args.scf * self.loss_fn(rot_pred, rot_gt), \
                 self.loss_fn(trans_pred, trans_gt)])
                # self.loss = self.loss_fn(rot_pred, rot_gt)

                # # Compute gradients	# ???
                # self.loss = sum([self.args.scf * self.loss_fn(rot_pred, rot_gt), \
                # 	self.loss_fn(trans_pred, trans_gt)])
                # self.loss.backward()
                # # self.model.zero_grad()
                # self.model.detach_LSTM_hidden()

                # Store losses (for further analysis)
                # curloss_rot = (self.args.scf * self.loss_fn(rot_pred, rot_gt)).detach().cpu().numpy()
                # curloss_trans = (self.loss_fn(trans_pred, trans_gt)).detach().cpu().numpy()
                curloss_rot = curloss_rot.detach().cpu().numpy()
                curloss_trans = curloss_trans.detach().cpu().numpy()
                rotLosses.append(curloss_rot)
                transLosses.append(curloss_trans)
                totalLosses.append(curloss_rot + curloss_trans)
                rotLoss_seq.append(curloss_rot)
                transLoss_seq.append(curloss_trans)
                totalLoss_seq.append(curloss_rot + curloss_trans)

            # Handle debug mode here. Force execute the below if statement in the
            # last debug iteration
            if self.args.debug is True:
                if i == numTrainIters - 1:
                    endOfSeq = True

            elapsedBatches += 1

            # if endOfSeq is True:
            if elapsedBatches >= self.args.trainBatch or endOfSeq is True:

                elapsedBatches = 0

                # # L2-Regularization
                # if self.args.gamma > 0.0:
                # 	# Regularization for network weights
                # 	l2_reg = None
                # 	for W in self.model.parameters():
                # 		if l2_reg is None:
                # 			l2_reg = W.norm(2)
                # 		else:
                # 			l2_reg = l2_reg + W.norm(2)
                # 	self.loss = sum([self.weightRegularizer * l2_reg, self.loss])

                # # L1-Regularization
                # if self.args.gamma > 0.0:
                # 	l1_crit = nn.L1Loss(size_average = False)
                # 	reg_loss = None
                # 	for param in self.model.parameters():
                # 		reg_loss += l1_crit(param)
                # 	self.loss = sum([self.gamma * reg_loss, self.loss])

                # Regularize only LSTM(s)
                if self.args.gamma > 0.0:
                    paramsDict = self.model.state_dict()
                    # print(paramsDict.keys())
                    if self.args.numLSTMCells == 1:
                        reg_loss = None
                        reg_loss = paramsDict['lstm1.weight_ih'].norm(2)
                        reg_loss += paramsDict['lstm1.weight_hh'].norm(2)
                        reg_loss += paramsDict['lstm1.bias_ih'].norm(2)
                        reg_loss += paramsDict['lstm1.bias_hh'].norm(2)
                    else:
                        reg_loss = None
                        reg_loss = paramsDict['lstm2.weight_ih'].norm(2)
                        reg_loss += paramsDict['lstm2.weight_hh'].norm(2)
                        reg_loss += paramsDict['lstm2.bias_ih'].norm(2)
                        reg_loss += paramsDict['lstm2.bias_hh'].norm(2)
                        reg_loss += paramsDict['lstm2.weight_ih'].norm(2)
                        reg_loss += paramsDict['lstm2.weight_hh'].norm(2)
                        reg_loss += paramsDict['lstm2.bias_ih'].norm(2)
                        reg_loss += paramsDict['lstm2.bias_hh'].norm(2)
                    self.loss = sum([self.args.gamma * reg_loss, self.loss])

                # Print stats
                if self.args.outputParameterization != 'mahalanobis':
                    tqdm.write('Rot Loss: ' + str(np.mean(rotLoss_seq)) + ' Trans Loss: ' + \
                     str(np.mean(transLoss_seq)), file = sys.stdout)
                else:
                    tqdm.write('Total Loss: ' + str(np.mean(totalLoss_seq)),
                               file=sys.stdout)
                rotLoss_seq = []
                transLoss_seq = []
                totalLoss_seq = []

                # Compute gradients	# ???
                self.loss.backward()

                # Monitor gradients
                l = 0
                # for p in self.model.parameters():
                # 	if l in [j for j in range(18,26)] + [j for j in range(30,34)]:
                # 		print(p.shape, 'GradNorm: ', p.grad.norm())
                # 	l += 1
                paramList = list(
                    filter(lambda p: p.grad is not None,
                           [param for param in self.model.parameters()]))
                totalNorm = sum([(p.grad.data.norm(2.)**2.)
                                 for p in paramList])**(1. / 2)
                tqdm.write('gradNorm: ' + str(totalNorm.item()))

                # Perform gradient clipping, if enabled
                if self.args.gradClip is not None:
                    torch.nn.utils.clip_grad_norm_(self.model.parameters(),
                                                   self.args.gradClip)

                # Update parameters
                self.optimizer.step()

                # If it's the end of sequence, reset hidden states
                if endOfSeq is True:
                    self.model.reset_LSTM_hidden()
                self.model.detach_LSTM_hidden()  # ???

                # Reset loss variables
                self.loss_rot = torch.zeros(1, dtype=torch.float32).cuda()
                self.loss_trans = torch.zeros(1, dtype=torch.float32).cuda()
                self.loss = torch.zeros(1, dtype=torch.float32).cuda()

                # Flush gradient buffers for next forward pass
                self.model.zero_grad()

        # Return loss logs for further analysis
        if self.args.outputParameterization == 'mahalanobis':
            return [], [], totalLosses
        else:
            return rotLosses, transLosses, totalLosses
Exemple #2
0
    def validate(self):

        # Switch model to eval mode
        self.model.eval()

        # Run a pass of the dataset
        traj_pred = None

        # Variables to store stats
        rotLosses = []
        transLosses = []
        totalLosses = []
        rotLoss_seq = []
        transLoss_seq = []
        totalLoss_seq = []

        # Handle debug switch here
        if self.args.debug is True:
            numValIters = self.args.debugIters
        else:
            numValIters = len(self.val_set)

        # Choose a generator (for iterating over the dataset, based on whether or not the
        # sbatch flag is set to True). If sbatch is True, we're probably running on a cluster
        # and do not want an interactive output. So, could suppress tqdm and print statements
        if self.args.sbatch is True:
            gen = range(numValIters)
        else:
            gen = trange(numValIters)

        for i in gen:

            if self.args.profileGPUUsage is True:
                gpu_memory_map = get_gpu_memory_map()
                tqdm.write('GPU usage: ' + str(gpu_memory_map[0]),
                           file=sys.stdout)

            # Get the next frame
            inp, rot_gt, trans_gt, seq, frame1, frame2, endOfSeq = self.val_set[
                i]
            metadata = np.concatenate((np.asarray([seq]), np.asarray([frame1]),
                                       np.asarray([frame2])))
            metadata = np.reshape(metadata, (1, 3))

            # Feed it through the model
            rot_pred, trans_pred = self.model.forward(inp)

            if self.args.outputParameterization == 'mahalanobis':
                if traj_pred is None:
                    traj_pred = np.concatenate(
                        (metadata, rot_pred.data.cpu().numpy()), axis=1)
                else:
                    cur_pred = np.concatenate(
                        (metadata, rot_pred.data.cpu().numpy()), axis=1)
                    traj_pred = np.concatenate((traj_pred, cur_pred), axis=0)
            else:
                if traj_pred is None:
                    traj_pred = np.concatenate((metadata, rot_pred.data.cpu().numpy(), \
                     trans_pred.data.cpu().numpy()), axis = 1)
                else:
                    cur_pred = np.concatenate((metadata, rot_pred.data.cpu().numpy(), \
                     trans_pred.data.cpu().numpy()), axis = 1)
                    traj_pred = np.concatenate((traj_pred, cur_pred), axis=0)

            # Store losses (for further analysis)
            if self.args.outputParameterization == 'mahalanobis':
                # rot_pred and rot_gt are 6-vectors here, and they include translations too
                tmpLossVar = self.loss_fn(
                    rot_pred, rot_gt,
                    self.train_set.infoMat).detach().cpu().numpy()
                totalLosses.append(tmpLossVar[0])
                totalLoss_seq.append(tmpLossVar[0])
            else:
                curloss_rot = (
                    self.args.scf *
                    self.loss_fn(rot_pred, rot_gt)).detach().cpu().numpy()
                curloss_trans = (self.loss_fn(
                    trans_pred, trans_gt)).detach().cpu().numpy()
                rotLosses.append(curloss_rot)
                transLosses.append(curloss_trans)
                totalLosses.append(curloss_rot + curloss_trans)
                rotLoss_seq.append(curloss_rot)
                transLoss_seq.append(curloss_trans)
                totalLoss_seq.append(curloss_rot + curloss_trans)

            # Detach hidden states and outputs of LSTM
            self.model.detach_LSTM_hidden()

            if endOfSeq is True:

                # Print stats
                if self.args.outputParameterization != 'mahalanobis':
                    tqdm.write('Rot Loss: ' + str(np.mean(rotLoss_seq)) + ' Trans Loss: ' + \
                     str(np.mean(transLoss_seq)), file = sys.stdout)
                else:
                    tqdm.write('Total Loss: ' + str(np.mean(totalLoss_seq)),
                               file=sys.stdout)
                rotLoss_seq = []
                transLoss_seq = []
                totalLoss_seq = []

                # Write predicted trajectory to file
                saveFile = os.path.join(self.args.expDir, 'plots', 'traj', str(seq).zfill(2), \
                 'traj_' + str(self.curEpoch).zfill(3) + '.txt')
                np.savetxt(saveFile, traj_pred, newline='\n')

                # Reset variable, to store new trajectory later on
                traj_pred = None

                # Detach LSTM hidden states
                self.model.detach_LSTM_hidden()

                # Reset LSTM hidden states
                self.model.reset_LSTM_hidden()

        # Return loss logs for further analysis
        if self.args.outputParameterization == 'mahalanobis':
            return [], [], totalLosses
        else:
            return rotLosses, transLosses, totalLosses
Exemple #3
0
    def train(self):

        # Switch model to train mode
        self.model.train()

        # Check if maxEpochs have elapsed
        if self.curEpoch >= self.maxEpochs:
            print('Max epochs elapsed! Returning ...')
            return

        # Increment iters
        self.iters += 1

        # Variables to store stats

        r6Losses = []
        poseLosses = []
        totalLosses = []
        r6Loss_seq = []
        poseLoss_seq = []
        totalLoss_seq = []

        # Handle debug mode here
        if self.args.debug is True:
            numTrainIters = self.args.debugIters
        else:
            numTrainIters = len(self.train_set)

        elapsedBatches = 0
        traj_pred = None
        gen = trange(numTrainIters)
        print("gen", gen)
        # assert False
        # Run a pass of the dataset
        for i in gen:
            if self.args.profileGPUUsage is True:
                gpu_memory_map = get_gpu_memory_map()
                tqdm.write('GPU usage: ' + str(gpu_memory_map[0]),
                           file=sys.stdout)

            # Get the next frame
            inp, imu, r6, xyzq, _, _, _, timestamp, endOfSeq = self.train_set[
                i]
            pred_r6 = self.model.forward(inp, imu, xyzq)
            # del inp
            # del imu
            if self.abs_traj is None:
                # TODO : 여기 초기값 잘 부르고 잘 적분해서 계산하고 있는지 확인해야됨.
                self.abs_traj = xyzq.data.cpu()[0][0]
                # Feed it through the model
            numarr = pred_r6.data.cpu().numpy()[0][0]
            # print('start :',self.abs_traj)
            # print('numarr :', numarr)

            self.abs_traj = se3qua.accu(self.abs_traj, numarr)
            # print('abs_traj :', self.abs_traj)

            abs_traj_input = np.expand_dims(self.abs_traj, axis=0)
            abs_traj_input = np.expand_dims(abs_traj_input, axis=0)
            abs_traj_input = Variable(
                torch.from_numpy(abs_traj_input).type(
                    torch.FloatTensor)).cuda()
            # print(abs_traj_input)
            # raise Exception()

            curloss_r6 = Variable(self.args.scf * (torch.dist(pred_r6, r6)**2),
                                  requires_grad=False)
            curloss_xyzq = Variable(torch.dist(abs_traj_input, xyzq)**2,
                                    requires_grad=False)

            curloss_xyzq_trans = Variable(
                self.args.scf * 10 *
                (torch.dist(abs_traj_input[:, :, :3], xyzq[:, :, :3])**2),
                requires_grad=False)
            curloss_xyzq_rot = Variable(torch.dist(abs_traj_input[:, :, 3:],
                                                   xyzq[:, :, 3:])**2,
                                        requires_grad=False)
            self.loss_r6 = curloss_r6
            self.loss_xyzq = curloss_xyzq

            # if np.random.normal() < -0.9:
            #     tqdm.write('r6(pred,gt): ' + str(pred_r6.data)+' '+ str(r6.data) ,file=sys.stdout)
            #     tqdm.write('pose(pred,gt): ' + str(abs_traj_input.data) + ' '+str(xyzq.data), file=sys.stdout)

            self.loss += sum([
                self.args.scf * (self.loss_fn(pred_r6, r6)).item(),
                self.args.scf * 10 *
                self.loss_fn(abs_traj_input[:, :, :3], xyzq[:, :, :3]).item(),
                self.loss_fn(abs_traj_input[:, :, 3:], xyzq[:, :, 3:]).item()
            ])

            curloss_r6 = curloss_r6.detach().cpu().numpy()
            curloss_xyzq = curloss_xyzq.detach().cpu().numpy()
            curloss_xyzq_rot = curloss_xyzq_rot.detach().cpu().numpy()
            curloss_xyzq_trans = curloss_xyzq_trans.detach().cpu().numpy()
            r6Losses.append(curloss_r6)
            r6Loss_seq.append(curloss_r6)
            poseLosses.append(curloss_xyzq_rot + curloss_xyzq_trans)
            poseLoss_seq.append(curloss_xyzq_rot + curloss_xyzq_trans)
            totalLosses.append(curloss_r6 + curloss_xyzq_rot +
                               curloss_xyzq_trans)
            totalLoss_seq.append(curloss_r6 + curloss_xyzq_rot +
                                 curloss_xyzq_trans)
            del curloss_r6
            del curloss_xyzq

            # Handle debug mode here. Force execute the below if statement in the
            # last debug iteration
            if self.args.debug is True:
                if i == numTrainIters - 1:
                    endOfSeq = True

            elapsedBatches += 1

            # if endOfSeq is True:
            if endOfSeq is True:
                elapsedBatches = 0

                # if self.args.gamma > 0.0:
                #     paramsDict = self.model.state_dict()
                #     # print(paramsDict.keys())
                #
                #     if self.args.numLSTMCells == 1:
                #         reg_loss = None
                #         reg_loss = paramsDict['lstm1.weight_ih'].norm(2)
                #         reg_loss += paramsDict['lstm1.weight_hh'].norm(2)
                #         reg_loss += paramsDict['lstm1.bias_ih'].norm(2)
                #         reg_loss += paramsDict['lstm1.bias_hh'].norm(2)
                #     else:
                #         reg_loss = None
                #         # reg_loss = paramsDict['rnnIMU.weight_ih_l0'].norm(2)
                #         # reg_loss += paramsDict['rnnIMU.weight_hh_l0'].norm(2)
                #         # reg_loss += paramsDict['rnnIMU.bias_ih_l0'].norm(2)
                #         # reg_loss += paramsDict['rnnIMU.bias_hh_l0'].norm(2)
                #         # reg_loss += paramsDict['rnnIMU.weight_ih_l1'].norm(2)
                #         # reg_loss += paramsDict['rnnIMU.weight_Hh_l1'].norm(2)
                #         # reg_loss += paramsDict['rnnIMU.bias_ih_l1'].norm(2)
                #         # reg_loss += paramsDict['rnnIMU.bias_Hh_l1'].norm(2)
                #         reg_loss = paramsDict['rnn.weight_ih_l0'].norm(2)
                #         reg_loss += paramsDict['rnn.weight_hh_l0'].norm(2)
                #         reg_loss += paramsDict['rnn.bias_ih_l0'].norm(2)
                #         reg_loss += paramsDict['rnn.bias_hh_l0'].norm(2)
                #         reg_loss += paramsDict['rnn.weight_ih_l1'].norm(2)
                #         reg_loss += paramsDict['rnn.weight_Hh_l1'].norm(2)
                #         reg_loss += paramsDict['rnn.bias_ih_l1'].norm(2)
                #         reg_loss += paramsDict['rnn.bias_Hh_l1'].norm(2)
                #     self.loss = sum([self.args.gamma * reg_loss, self.loss])
                tqdm.write('r6 Loss: ' + str(np.mean(r6Loss_seq)) +
                           'pose Loss' + str(np.mean(poseLoss_seq)),
                           file=sys.stdout)
                r6Loss_seq = []
                poseLoss_seq = []
                totalLoss_seq = []

                # Compute gradients
                self.loss.backward()

                paramList = list(
                    filter(lambda p: p.grad is not None,
                           [param for param in self.model.parameters()]))
                totalNorm = sum([(p.grad.data.norm(2.)**2.)
                                 for p in paramList])**(1. / 2)
                tqdm.write('gradNorm: ' + str(totalNorm.item()))

                # Perform gradient clipping, if enabled
                if self.args.gradClip is not None:
                    torch.nn.utils.clip_grad_norm_(self.model.parameters(),
                                                   self.args.gradClip)

                # Update parameters
                self.optimizer.step()

                # If it's the end of sequence, reset hidden states
                # if endOfSeq is True:
                #     self.model.reset_LSTM_hidden()
                # self.model.detach_LSTM_hidden()  # ???

                # Reset loss variables
                self.loss_r6 = torch.zeros(1, dtype=torch.float32).cuda()
                self.loss_xyzq = torch.zeros(1, dtype=torch.float32).cuda()
                self.loss = torch.zeros(1, dtype=torch.float32).cuda()

                # Flush gradient buffers for next forward pass
                self.model.zero_grad()
                self.abs_traj = None

        return r6Losses, poseLosses, totalLosses
Exemple #4
0
    def validate(self):

        # Switch model to eval mode
        self.model.eval()

        # Run a pass of the dataset
        traj_pred = None
        self.abs_traj = None

        # Variables to store stats
        r6Losses = []
        poseLosses = []
        totalLosses = []
        r6Loss_seq = []
        poseLoss_seq = []
        totalLoss_seq = []

        # Handle debug switch here
        if self.args.debug is True:
            numValIters = self.args.debugIters
        else:
            numValIters = len(self.val_set)

        # Choose a generator (for iterating over the dataset, based on whether or not the
        # sbatch flag is set to True). If sbatch is True, we're probably running on a cluster
        # and do not want an interactive output. So, could suppress tqdm and print statements
        if self.args.sbatch is True:
            gen = range(numValIters)
        else:
            gen = trange(numValIters)

        for i in gen:

            if self.args.profileGPUUsage is True:
                gpu_memory_map = get_gpu_memory_map()
                tqdm.write('GPU usage: ' + str(gpu_memory_map[0]),
                           file=sys.stdout)

            # Get the next frame
            inp, imu, r6, xyzq, seq, frame1, frame2, timestamp, endOfSeq = self.val_set[
                i]

            metadata = np.asarray([timestamp])

            # Feed it through the model
            pred_r6 = self.model.forward(inp, imu, xyzq)
            numarr = pred_r6.data.cpu().detach().numpy()[0][0]

            if self.abs_traj is None:
                self.abs_traj = xyzq.data.cpu().detach()[0][0]
            if traj_pred is None:
                traj_pred = np.concatenate((metadata, self.abs_traj.numpy()),
                                           axis=0)
                traj_pred = np.resize(traj_pred, (1, -1))

            self.abs_traj = se3qua.accu(self.abs_traj, numarr)

            cur_pred = np.concatenate((metadata, self.abs_traj), axis=0)
            traj_pred = np.append(traj_pred,
                                  np.resize(cur_pred, (1, -1)),
                                  axis=0)

            abs_traj_input = np.expand_dims(self.abs_traj, axis=0)
            abs_traj_input = np.expand_dims(abs_traj_input, axis=0)
            abs_traj_input = Variable(
                torch.from_numpy(abs_traj_input).type(
                    torch.FloatTensor)).cuda()

            # Store losses (for further analysis)
            curloss_r6 = Variable(self.args.scf * (torch.dist(pred_r6, r6)**2),
                                  requires_grad=False)
            curloss_xyzq = Variable(self.args.scf *
                                    (torch.dist(abs_traj_input, xyzq)**2),
                                    requires_grad=False)
            curloss_xyzq_trans = Variable(
                self.args.scf * 10 *
                (torch.dist(abs_traj_input[:, :, :3], xyzq[:, :, :3])**2),
                requires_grad=False)
            curloss_xyzq_rot = Variable(torch.dist(abs_traj_input[:, :, 3:],
                                                   xyzq[:, :, 3:])**2,
                                        requires_grad=False)

            curloss_r6 = curloss_r6.detach().cpu().numpy()
            curloss_xyzq = curloss_xyzq.detach().cpu().numpy()
            curloss_xyzq_rot = curloss_xyzq_rot.detach().cpu().numpy()
            curloss_xyzq_trans = curloss_xyzq_trans.detach().cpu().numpy()

            r6Losses.append(curloss_r6)
            r6Loss_seq.append(curloss_r6)
            poseLosses.append(curloss_xyzq_rot + curloss_xyzq_trans)
            poseLoss_seq.append(curloss_xyzq_rot + curloss_xyzq_trans)
            totalLosses.append(curloss_r6 + curloss_xyzq_rot +
                               curloss_xyzq_trans)
            totalLoss_seq.append(curloss_r6 + curloss_xyzq_rot +
                                 curloss_xyzq_trans)
            del curloss_r6
            del curloss_xyzq
            # Detach hidden states and outputs of LSTM
            # self.model.detach_LSTM_hidden()

            if endOfSeq is True:
                r6Loss_seq = []
                poseLoss_seq = []
                totalLoss_seq = []
                # Print stats

                tqdm.write('Total Loss: ' + str(np.mean(totalLoss_seq)),
                           file=sys.stdout)

                # Write predicted trajectory to file
                saveFile = os.path.join(self.args.expDir, 'plots', 'traj', str(seq).zfill(2), \
                                        'traj_' + str(self.curEpoch).zfill(3) + '.txt')
                # TODO : 트래젝토리 저장부분 왜 한개만 저장하고 마지막 저장은 좀 이상하게 (짧게, 그리고 6컬럼만) 되는지 확인
                np.savetxt(saveFile, traj_pred, newline='\n')

                # Reset variable, to store new trajectory later on
                traj_pred = None

                # Detach LSTM hidden states
                # self.model.detach_LSTM_hidden()

                # Reset LSTM hidden states
                # self.model.reset_LSTM_hidden()
                self.abs_traj = None
                self.model.zero_grad()

        return r6Losses, poseLosses, totalLosses
Exemple #5
0
	def train(self):

		# Switch model to train mode
		self.model.train()

		# Check if maxEpochs have elapsed
		if self.curEpoch >= self.maxEpochs:
			print('Max epochs elapsed! Returning ...')
			return

		# Variables to store stats
		rotLosses = []
		transLosses = []
		totalLosses = []
		

		# Handle debug mode here
		if self.args.debug is True:
			numTrainIters = self.args.debugIters
		else:
			numTrainIters = len(self.train_set)

		elapsedFrames = 0

		# Choose a generator (for iterating over the dataset, based on whether or not the 
		# sbatch flag is set to True). If sbatch is True, we're probably running on a cluster
		# and do not want an interactive output. So, could suppress tqdm and print statements
		if self.args.sbatch is True:
			gen = range(numTrainIters)
		else:
			gen = trange(numTrainIters)

		# Store input and label tensors
		
		inputTensor = None
		labelTensor_trans = None
		labelTensor_rot = None

		# Run a pass of the dataset
		for i in gen:

			if self.args.profileGPUUsage is True:
				gpu_memory_map = get_gpu_memory_map()
				tqdm.write('GPU usage: ' + str(gpu_memory_map[0]), file = sys.stdout)

			# Get the next frame
			inp, rot_gt, trans_gt, seq, frame1, frame2, endOfSeq = self.train_set[i]
			if inputTensor is None:
				inputTensor = inp.clone()
				labelTensor_rot = rot_gt.unsqueeze(0).clone()
				labelTensor_trans = trans_gt.unsqueeze(0).clone()
			else:
				inputTensor = torch.cat((inputTensor,inp.clone()),0)
				labelTensor_rot = torch.cat((labelTensor_rot,rot_gt.unsqueeze(0).clone()),0)
				labelTensor_trans = torch.cat((labelTensor_trans, trans_gt.unsqueeze(0).clone()),0)


			# Handle debug mode here. Force execute the below if statement in the
			# last debug iteration
			if self.args.debug is True:
				if i == numTrainIters - 1:
					endOfSeq = True

			elapsedFrames += 1

			
			# if endOfSeq is True:
			if elapsedFrames >= self.args.seqLen or endOfSeq is True:

				# Flush gradient buffers for next forward pass
				self.model.zero_grad()

				
				rot_pred, trans_pred, tmp = self.model.forward(inputTensor)

				
				

				self.loss = sum([100*self.loss_fn(rot_pred, labelTensor_rot),self.loss_fn(trans_pred, labelTensor_trans)])
				curloss_rot = Variable(torch.dist(rot_pred, labelTensor_rot) ** 2, requires_grad = False)
				curloss_trans = Variable(torch.dist(trans_pred, labelTensor_trans) ** 2 , requires_grad = False)

				# re initialize
				inputTensor = None
				labelTensor_trans = None
				labelTensor_rot = None
				elapsedFrames = 0

				paramsDict = self.model.state_dict()
				reg_loss_R = None
				reg_loss_R = paramsDict['LSTM_R.weight_ih_l0'].norm(2)
				reg_loss_R += paramsDict['LSTM_R.weight_hh_l0'].norm(2)
				reg_loss_R += paramsDict['LSTM_R.bias_ih_l0'].norm(2)
				reg_loss_R += paramsDict['LSTM_R.bias_hh_l0'].norm(2)
				if self.args.numLSTMCells==2:
					reg_loss_R = paramsDict['LSTM_R.weight_ih_l1'].norm(2)
					reg_loss_R += paramsDict['LSTM_R.weight_hh_l1'].norm(2)
					reg_loss_R += paramsDict['LSTM_R.bias_ih_l1'].norm(2)
					reg_loss_R += paramsDict['LSTM_R.bias_hh_l1'].norm(2)
				
				reg_loss_T = None
				reg_loss_T = paramsDict['LSTM_T.weight_ih_l0'].norm(2)
				reg_loss_T += paramsDict['LSTM_T.weight_hh_l0'].norm(2)
				reg_loss_T += paramsDict['LSTM_T.bias_ih_l0'].norm(2)
				reg_loss_T += paramsDict['LSTM_T.bias_hh_l0'].norm(2)
				if self.args.numLSTMCells==2:
					reg_loss_T = paramsDict['LSTM_T.weight_ih_l1'].norm(2)
					reg_loss_T += paramsDict['LSTM_T.weight_hh_l1'].norm(2)
					reg_loss_T += paramsDict['LSTM_T.bias_ih_l1'].norm(2)
					reg_loss_T += paramsDict['LSTM_T.bias_hh_l1'].norm(2)
					
				totalregLoss = sum([reg_loss_R ,reg_loss_T])
				self.loss = sum([self.args.gamma * totalregLoss, self.loss])
				
				# Compute gradients
				self.loss.backward()

				# Rotation Grad norm 
				paramIt=0;
				rotgradNorm=0
				rotParameters=[]
				for p in self.model.parameters():
					paramIt+=1;
					if paramIt in range(19,27):
						rotParameters.append(p)
						rotgradNorm+=(p.grad.data.norm(2.) ** 2.) 
				rotgradNorm = rotgradNorm ** (1. / 2)
								
				# Translation Grad norm 
				paramIt=0;
				transgradNorm=0
				transParameters=[]
				for p in self.model.parameters():
					paramIt+=1;
					if paramIt in range(27,35):
						
						transParameters.append(p)
						transgradNorm+=(p.grad.data.norm(2.)**2.)
				transgradNorm = transgradNorm ** (1./2)

				tqdm.write('Before clipping, Rotation gradNorm: ' + str(rotgradNorm) + ' Translation gradNorm: ' + str(transgradNorm))

				# Perform gradient clipping, if enabled
				if self.args.gradClip is not None:
					
					torch.nn.utils.clip_grad_norm_(rotParameters, self.args.gradClip)
					torch.nn.utils.clip_grad_norm_(transParameters, self.args.gradClip)
					paramIt=0;
					rotgradNorm=0
					for p in self.model.parameters():
						paramIt+=1;
						if paramIt in range(19,27) :
							rotgradNorm+=(p.grad.data.norm(2.) ** 2.) 
					rotgradNorm = rotgradNorm ** (1. / 2)
								
				
					paramIt=0;
					transgradNorm=0
					for p in self.model.parameters():
						paramIt+=1;
						if paramIt in range(27,35):
							transgradNorm+=(p.grad.data.norm(2.)**2.)
					transgradNorm = transgradNorm ** (1./2)

					tqdm.write('After clipping, Rotation gradNorm: ' + str(rotgradNorm) + ' Translation gradNorm: ' + str(transgradNorm))


				# Update parameters
				self.optimizer.step()
	
				curloss_rot = curloss_rot.detach().cpu().numpy()	
				curloss_trans = curloss_trans.detach().cpu().numpy()
				rotLosses.append(curloss_rot)
				transLosses.append(curloss_trans)
				totalLosses.append(curloss_rot + curloss_trans)

				# Print stats
				tqdm.write('Rot Loss: ' + str(np.mean(rotLosses)) + ' Trans Loss: ' + \
					str(np.mean(transLosses)), file = sys.stdout)
				tqdm.write('Total Loss: ' + str(np.mean(totalLosses)), file = sys.stdout)
				# If it's the end of sequence, reset hidden states
				if endOfSeq is True:
					self.model.reset_LSTM_hidden()
				self.model.detach_LSTM_hidden()
	
				# Reset loss variables
				self.loss = torch.zeros(1, dtype = torch.float32).cuda()
		# Return loss logs for further analysis
		return rotLosses, transLosses, totalLosses
Exemple #6
0
	def validate(self):
		
		self.model.eval()
		# Run a pass of the dataset
		traj_pred = None
		
		# Variables to store stats
		rotLosses = []
		transLosses = []
		totalLosses = []
		
		# Handle debug switch here
		if self.args.debug is True:
			numValIters = self.args.debugIters
		else:
			numValIters = len(self.val_set)

		elapsedFrames=0;
		if self.args.sbatch is True:
			gen = range(numValIters)
		else:
			gen = trange(numValIters)

		inputTensor = None
		labelTensor_rot = None
		labelTensor_trans = None

		for i in gen:

			if self.args.profileGPUUsage is True:
				gpu_memory_map = get_gpu_memory_map()
				tqdm.write('GPU usage: ' + str(gpu_memory_map[0]), file = sys.stdout)

			# Get the next frame
			inp, rot_gt, trans_gt, seq, frame1, frame2, endOfSeq = self.val_set[i]
			if inputTensor is None:
				inputTensor = inp.clone()
				labelTensor_rot = rot_gt.unsqueeze(0).clone()
				labelTensor_trans = trans_gt.unsqueeze(0).clone()
			else:
				inputTensor = torch.cat((inputTensor,inp.clone()),0)
				labelTensor_rot = torch.cat((labelTensor_rot,rot_gt.unsqueeze(0).clone()),0)
				labelTensor_trans = torch.cat((labelTensor_trans, trans_gt.unsqueeze(0).clone()),0)
			
			
			elapsedFrames+=1

			if elapsedFrames>=self.args.seqLen or endOfSeq is True:

				
				# Feed it through the model
				rot_pred, trans_pred,_ = self.model.forward(inputTensor)

				curloss_rot = Variable(torch.dist(rot_pred, labelTensor_rot) ** 2, requires_grad = False)
				curloss_trans = Variable(torch.dist(trans_pred, labelTensor_trans) ** 2 , requires_grad = False)
				
				inputTensor = None
				labelTensor_trans = None
				labelTensor_rot = None
				elapsedFrames=0;
			
				if traj_pred is None:
					traj_pred = np.concatenate((rot_pred.data.cpu().numpy().squeeze(1), \
						trans_pred.data.cpu().numpy().squeeze(1)), axis = 1)
				else:
					
					cur_pred = np.concatenate((rot_pred.data.cpu().numpy().squeeze(1), \
						trans_pred.data.cpu().numpy().squeeze(1)), axis = 1)
					
					traj_pred = np.concatenate((traj_pred, cur_pred), axis = 0)

				rotLosses.append(curloss_rot)
				transLosses.append(curloss_trans)
				totalLosses.append(curloss_rot + curloss_trans)

				# Deattach for the next forward pass
				self.model.detach_LSTM_hidden()

				

				if endOfSeq is True:
					# Print stats
					tqdm.write('Rot Loss: ' + str(np.mean(rotLosses)) + ' Trans Loss: ' + \
						str(np.mean(transLosses)), file = sys.stdout)
					tqdm.write('Total Loss: ' + str(np.mean(totalLosses)), file = sys.stdout)
					# Write predicted trajectory to file
					saveFile = os.path.join(self.args.expDir, 'plots', 'traj', str(seq).zfill(2), \
						'traj_' + str(self.curEpoch).zfill(3) + '.txt')
					np.savetxt(saveFile, traj_pred, newline = '\n')
				
					# Reset variable, to store new trajectory later on
					traj_pred = None
		
					# Reset LSTM hidden states
					self.model.reset_LSTM_hidden()
					# Deattach for the next forward pass
					self.model.detach_LSTM_hidden()

					rotLosses = []
					transLosses = []
					totalLosses = []




		

		# Return loss logs for further analysis
		return rotLosses, transLosses, totalLosses