class TestEngine:
    def __init__(self):
        self.testSet = ['s_11']
        self.datapath = '/data/Guha/GR/Dataset.old/'
        self.test_dataset = IMUDataset(self.datapath, self.testSet)
        self.use_cuda = True
        self.testModel = BiRNN().cuda()
        #self.testModel = BiLSTM().cuda()

        modelPath = '/data/Guha/GR/model/9/validation.pth.tar'
        self.base = '/data/Guha/GR/Output/TestSet/8/'
        with open(modelPath, 'rb') as tar:
            checkpoint = torch.load(tar)
            model_weights = checkpoint['state_dict']
            #epoch_loss = checkpoint['validation_loss']
        self.testModel.load_state_dict(model_weights)

    def test(self):
        # initialize hidden and cell state  at each new batch
        hidden = torch.zeros(cfg.n_layers * 2,
                             1,
                             cfg.hid_dim,
                             dtype=torch.double).cuda()
        cell = torch.zeros(cfg.n_layers * 2,
                           1,
                           cfg.hid_dim,
                           dtype=torch.double).cuda()
        # loop through all the files

        #for f in self.test_dataset.files:
        f = '/data/Guha/GR/Dataset/CMU/02_02_05.npz'
        self.test_dataset.readfile(f)
        input = torch.FloatTensor(self.test_dataset.input)
        input = torch.unsqueeze(input, 0)
        target = torch.FloatTensor(self.test_dataset.target)

        if self.use_cuda:
            # input = [input.cuda()]
            input = input.cuda()
            self.testModel.cuda()

        self.testModel.eval()
        # bilstm
        # prediction,_,_ = self.testModel(input,hidden,cell)

        # birnn
        prediction = self.testModel(input)

        prediction = prediction.detach().reshape_as(target).cpu()
        loss = self._loss_impl(prediction, target)
        # Renormalize prediction
        prediction = prediction.numpy().reshape(-1, 15, 4)
        seq_len = prediction.shape[0]
        norms = np.linalg.norm(prediction, axis=2)
        prediction = np.asarray([
            prediction[k, j, :] / norms[0, 0]
            for k, j in itertools.product(range(seq_len), range(15))
        ])
        # save GT and prediction
        #np.savez_compressed(self.base + f.split('/')[-1], target=target.cpu().numpy(), predictions=prediction)
        print(f, '------------', loss.item())


##################### read one file and return input and target values

    def readfile(self, file):
        data_dict = np.load(file, encoding='latin1')
        sample_pose = data_dict['pose'].reshape(-1, 15, 3, 3)
        sample_ori = data_dict['ori']
        sample_acc = data_dict['acc']
        seq_len = sample_pose.shape[0]

        #################### convert orientation matrices to quaternion ###############
        ori_quat = np.asarray([
            Quaternion(matrix=sample_ori[k, j, :, :]).elements
            for k, j in itertools.product(range(seq_len), range(5))
        ])
        ori_quat = ori_quat.reshape(-1, 5 * 4)

        #################### convert orientation matrices to euler ###############
        # ori_euler = np.asarray([transforms3d.euler.mat2euler(sample_ori[k, j, :, :]) for k, j in
        #                         itertools.product(range(seq_len), range(5))])
        # ori_euler = ori_euler.reshape(-1, 5, 3)
        # ori_euler = ori_euler[:, :, 0:2].reshape(-1, 5 * 2)

        #################### convert pose matrices to quaternion ###############
        pose_quat = np.asarray([
            Quaternion(matrix=sample_pose[k, j, :, :]).elements
            for k, j in itertools.product(range(seq_len), range(15))
        ])

        pose_quat = pose_quat.reshape(-1, 15, 4)
        #################### standardize acceleration #################
        ################# To normalize acceleration ###################
        # imu_dip = dict(
        #     np.load('/data/Guha/GR/code/dip18/train_and_eval/data/dipIMU/imu_own_validation.npz', encoding='latin1'))
        # data_stats = imu_dip.get('statistics').tolist()
        # acc_stats = data_stats['acceleration']
        # sample_acc = sample_acc.reshape(-1, 5 * 3)
        # sample_acc = (sample_acc - acc_stats['mean_channel']) / acc_stats['std_channel']
        #
        # concat = np.concatenate((ori_quat, sample_acc), axis=1)

        self.input = ori_quat
        self.target = pose_quat

    def testWindow(self, len_past, len_future):
        # initialize hidden and cell state  at each new batch
        hidden = torch.zeros(cfg.n_layers * 2,
                             1,
                             cfg.hid_dim,
                             dtype=torch.double).cuda()
        cell = torch.zeros(cfg.n_layers * 2,
                           1,
                           cfg.hid_dim,
                           dtype=torch.double).cuda()
        loss_file = open('/data/Guha/GR/Output/loss/loss_9_H36.txt', 'w')
        # loop through all the files
        for ct, f in enumerate(self.test_dataset.files):
            #f = '/data/Guha/GR/Dataset/DIP_IMU2/test/s_10_05.npz'
            #f = '/data/Guha/GR/Dataset.old/AMASS_Transition/mazen_c3dairkick_jumpinplace.npz'
            self.readfile(f)
            input = self.input
            target = self.target
            seq_len = input.shape[0]
            predictions = []
            # loop over all frames in input. take the window to predict each timestep t
            for step in range(seq_len):
                start_idx = max(step - len_past, 0)
                end_idx = min(step + len_future + 1, seq_len)
                in_window = input[start_idx:end_idx]
                in_window = torch.FloatTensor(in_window).unsqueeze(0).cuda()
                # target_window = target[start_idx:end_idx]

                self.testModel.eval()
                # bilstm
                #output,_,_ = self.testModel(in_window,hidden,cell)
                # birnn
                output = self.testModel(in_window)
                prediction_step = min(step, len_past)
                pred = output[:, prediction_step:prediction_step +
                              1].detach().cpu().numpy().reshape(15, 4)
                predictions.append(pred)

            ################## Renormalize prediction
            predictions = np.asarray(predictions)
            norms = np.linalg.norm(predictions, axis=2)
            predictions = np.asarray([
                predictions[k, j, :] / norms[0, 0]
                for k, j in itertools.product(range(seq_len), range(15))
            ])
            predictions = predictions.reshape(seq_len, 15, 4)
            ################### convert to euler
            target_euler = np.asarray([
                transforms3d.euler.quat2euler(target[k, j])
                for k, j in itertools.product(range(seq_len), range(15))
            ])
            target_euler = (target_euler * 180) / np.pi
            pred_euler = np.asarray([
                transforms3d.euler.quat2euler(predictions[k, j])
                for k, j in itertools.product(range(seq_len), range(15))
            ])
            pred_euler = (pred_euler * 180) / np.pi
            ##################calculate loss
            loss = self.loss_impl(target_euler.reshape(-1, 15, 3),
                                  pred_euler.reshape(-1, 15, 3))
            #loss_file.write('{}-- {}\n'.format(f,loss))
            #print(f+'-------'+str(loss))
            print(f + '-------' + str(loss))
            loss_file.write('{}\n'.format(loss))
            # save GT and prediction
            #np.savez_compressed(self.base + f.split('/')[-1], target=target, predictions=predictions)
            #print(f)
            if (ct == 30):
                break
        loss_file.close()

    def loss_impl(self, predicted, expected):
        error = predicted - expected
        error_norm = np.linalg.norm(error, axis=2)
        error_per_joint = np.mean(error_norm, axis=1)
        error_per_frame_per_joint = np.mean(error_per_joint, axis=0)
        return error_per_frame_per_joint
class TrainingEngine:
    def __init__(self):
        self.trainset = ['AMASS_ACCAD', 'AMASS_BioMotion', 'AMASS_CMU_Kitchen', 'AMASS_Eyes', 'AMASS_MIXAMO',
                         'AMASS_SSM', 'AMASS_Transition', 'CMU', 'H36']
        self.testSet = ['AMASS_HDM05', 'HEva', 'JointLimit']
        self.datapath = '/data/Guha/GR/Dataset'
        self.dataset = IMUDataset(self.datapath,self.trainset)
        self.use_cuda =True
        self.modelPath = '/data/Guha/GR/model/13/'
        self.model = BiRNN().cuda()
        self.mseloss = nn.MSELoss()
        baseModelPath = '/data/Guha/GR/model/13/epoch_1.pth.tar'

        with open(baseModelPath, 'rb') as tar:
            checkpoint = torch.load(tar)
            model_weights = checkpoint['state_dict']
        self.model.load_state_dict(model_weights)

    def train(self,n_epochs):
        f = open(self.modelPath+'model_details','w')
        f.write(str(self.model))
        f.write('\n')

        np.random.seed(1234)
        lr = 0.001
        gradient_clip = 0.1
        optimizer = optim.Adam(self.model.parameters(),lr=lr)

        print('Training for %d epochs' % (n_epochs))
        no_of_trainbatch = int(len(self.dataset.files) / cfg.batch_len)

        print('batch size--> %d, Seq len--> %d, no of batches--> %d' % (cfg.batch_len, cfg.seq_len, no_of_trainbatch))
        f.write('batch size--> %d, Seq len--> %d, no of batches--> %d \n' % (cfg.batch_len, cfg.seq_len, no_of_trainbatch))

        min_batch_loss = 0.0
        min_valid_loss = 0.0

        try:
            for epoch in range(1,n_epochs):
                epoch_loss = []
                start_time = time()
                self.dataset.loadfiles(self.datapath,self.trainset)
                ####################### training #######################
                # while(len(self.dataset.files) > 0):
                #     # Pick a random chunk from each sequence
                #     self.dataset.createbatch_no_replacement()
                #     inputs = torch.FloatTensor(self.dataset.input)
                #     outputs = torch.FloatTensor(self.dataset.target)
                #
                #     if self.use_cuda:
                #         inputs = inputs.cuda()
                #         outputs = outputs.cuda()
                #         self.model.cuda()
                #
                #     chunk_in = list(torch.split(inputs, cfg.seq_len))[:-1]
                #     chunk_out = list(torch.split(outputs, cfg.seq_len))[:-1]
                #     random.shuffle(chunk_in)
                #     random.shuffle(chunk_out)
                #     chunk_in = torch.stack(chunk_in, dim=0)
                #     chunk_out = torch.stack(chunk_out, dim=0)
                #     print('no of chunks %d \n' % (len(chunk_in)))
                #     f.write('no of chunks %d  \n' % (len(chunk_in)))
                #     self.model.train()
                #     optimizer.zero_grad()
                #     predictions = self.model(chunk_in)
                #
                #     loss = self._loss_impl(predictions, chunk_out)
                #     loss.backward()
                #     nn.utils.clip_grad_norm_(self.model.parameters(), gradient_clip)
                #     optimizer.step()
                #     loss.detach()
                #
                #     epoch_loss.append(loss.item())
                #     if (min_batch_loss == 0 or loss < min_batch_loss):
                #         min_batch_loss = loss
                #         print ('training loss %f ' % (loss.item()))
                #         f.write('training loss %f \n' % (loss.item()))
                #
                # epoch_loss = torch.mean(torch.FloatTensor(epoch_loss))
                # # we save the model after each epoch : epoch_{}.pth.tar
                # state = {
                #     'epoch': epoch + 1,
                #     'state_dict': self.model.state_dict(),
                #     'epoch_loss': epoch_loss
                # }
                # torch.save(state, self.modelPath + 'epoch_{}.pth.tar'.format(epoch+1))

                ####################### Validation #######################
                valid_meanloss = []
                valid_maxloss = []
                data_to_plot = []
                self.model.eval()

                for d in self.testSet:
                    self.dataset.loadfiles(self.datapath, [d])
                    dset_loss = []
                    for f in self.dataset.files:
                        self.dataset.readfile(f)
                        input = torch.FloatTensor(self.dataset.input)
                        input = torch.unsqueeze(input, 0)
                        target = torch.FloatTensor(self.dataset.target)
                        if self.use_cuda:
                            input = input.cuda()

                        prediction = self.model(input)
                        prediction = prediction.detach().reshape_as(target).cpu()
                        loss = torch.norm((prediction-target),2,1)
                        mean_loss = torch.mean(loss)
                        max_loss = torch.max(loss)
                        dset_loss.extend(loss.numpy())
                        valid_meanloss.append(mean_loss)
                        valid_maxloss.append(max_loss)
                        print(
                            'mean loss %f,  max loss %f, \n' % (
                                 mean_loss, max_loss))

                    data_to_plot.append(dset_loss)

                # save box plots of three dataset
                fig = plt.figure('epoch: '+str(epoch))
                # Create an axes instance
                ax = fig.add_subplot(111)
                # Create the boxplot
                ax.boxplot(data_to_plot)
                ax.set_xticklabels(self.testSet)
                # Save the figure
                fig.savefig(self.modelPath+'epoch: '+str(epoch)+'.png', bbox_inches='tight')

                mean_valid_loss = torch.mean(torch.FloatTensor(valid_meanloss))
                max_valid_loss = torch.max(torch.FloatTensor(valid_maxloss))
                # we save the model if current validation loss is less than prev : validation.pth.tar
                if (min_valid_loss == 0 or mean_valid_loss < min_valid_loss):
                    min_valid_loss = mean_valid_loss
                    state = {
                        'epoch': epoch + 1,
                        'state_dict': self.model.state_dict(),
                        'validation_loss': mean_valid_loss
                    }
                    torch.save(state, self.modelPath + 'validation.pth.tar')

                # logging to track
                print ('epoch No %d, epoch loss %d , validation mean loss %d, validation max loss %d, Time taken %d \n' % (
                epoch + 1, epoch_loss, mean_valid_loss.item(),max_valid_loss.item(), start_time - time()))
                f.write('epoch No %d, epoch loss %d , validation mean loss %d, validation max loss %d, Time taken %d \n' % (
                epoch + 1, epoch_loss, mean_valid_loss.item(), max_valid_loss.item(), start_time - time()))

            f.close()
        except KeyboardInterrupt:
            print('Training aborted.')

    def _loss_impl(self, predicted, expected):
        L1 = predicted - expected
        batch_size = predicted.shape[0]
        dist = torch.sum(torch.norm(L1, 2, 2))
        return  dist/ batch_size