def __init__(self):
     self.trainset = ['H36_S1_S9','AMASS_HDM05']
     self.testset = ['H36_S11']
     self.datapath = '/data/Guha/GR/synthetic60FPS/'
     #self.datapath = '/ds2/synthetic60FPS/synthetic60FPS/'
     self.dataset = RawDataset()
     self.validset = RawDataset()
     self.use_cuda = True
     self.modelPath = '/data/Guha/GR/model/attn/'
     #self.modelPath = '/b_test/suparna/20/'
     self.encoder = Encoder(input_dim=24, enc_units=256).cuda()
     self.decoder = Decoder(output_dim=60, dec_units=256, enc_units=256).cuda()
    def __init__(self):

        self.datapath = '/data/Guha/GR/synthetic60FPS/'
        self.dataset = RawDataset()
        #self.modelPath = '/data/Guha/GR/model/18/'
        self.encoder = Encoder(input_dim=24, enc_units=256).cuda()
        self.decoder = Decoder(output_dim=60, dec_units=256,
                               enc_units=256).cuda()
        baseModelPath = '/data/Guha/GR/model/attn_greedy/epoch_6.pth.tar'
        self.base = '/data/Guha/GR/Output/TestSet/attn/'

        with open(baseModelPath, 'rb') as tar:
            checkpoint = torch.load(tar)
            self.encoder.load_state_dict(checkpoint['encoder_dict'])
            self.decoder.load_state_dict(checkpoint['decoder_dict'])
def attention_lstm(units, steps, input_dims):
    #Functional
    #define layer
    inputs = Input(shape=(steps, input_dims), name='inputs')
    h = Input(shape=(units, ), name='h0')
    c = Input(shape=(units, ), name='c0')

    mask = Masking(mask_value=-1.)
    lstm = LSTM(units,
                activation='relu',
                return_sequences=True,
                return_state=True)

    #build model
    hidden = []
    x = mask(inputs)
    for index in range(steps):
        attention, x1 = Encoder(units, index)(x, states=[h, c])
        _, h1, c1 = lstm(x1, initial_state=[h, c])
        hidden.append(h1)

    hidden = Concatenate(axis=-1)(hidden)
    model = Model(inputs=[inputs, h, c], outputs=hidden)
    #model = Model(inputs=[inputs,h,c],outputs=[attention,x1])
    return model
    def __init__(self):
        self.trainset = ['H36_S1_S9','AMASS_HDM05']
        self.testset = ['H36_S11']
        self.datapath = '/data/Guha/GR/synthetic60FPS/'
        #self.datapath = '/ds2/synthetic60FPS/synthetic60FPS/'
        self.dataset = RawDataset()
        self.validset = RawDataset()
        self.use_cuda = True
        self.modelPath = '/data/Guha/GR/model/attn_greedy/'
        #self.modelPath = '/b_test/suparna/20/'
        self.encoder = Encoder(input_dim=24, enc_units=256).cuda()
        self.decoder = Decoder(output_dim=60, dec_units=256, enc_units=256).cuda()
        # baseModelPath = '/data/Guha/GR/model/19/epoch_3.pth.tar'
        #
        # with open(baseModelPath, 'rb') as tar:
        #     checkpoint = torch.load(tar)
        #     self.encoder.load_state_dict(checkpoint['encoder_dict'])
        #     self.decoder.load_state_dict(checkpoint['decoder_dict'])

        self.tf_decay = .0001
        self.tf_rate = 1
        self.tf_low = 0.1
class TrainingEngine:
    def __init__(self):
        self.trainset = ['H36_S1_S9','AMASS_HDM05']
        self.testset = ['H36_S11']
        self.datapath = '/data/Guha/GR/synthetic60FPS/'
        #self.datapath = '/ds2/synthetic60FPS/synthetic60FPS/'
        self.dataset = RawDataset()
        self.validset = RawDataset()
        self.use_cuda = True
        self.modelPath = '/data/Guha/GR/model/attn_greedy/'
        #self.modelPath = '/b_test/suparna/20/'
        self.encoder = Encoder(input_dim=24, enc_units=256).cuda()
        self.decoder = Decoder(output_dim=60, dec_units=256, enc_units=256).cuda()
        # baseModelPath = '/data/Guha/GR/model/19/epoch_3.pth.tar'
        #
        # with open(baseModelPath, 'rb') as tar:
        #     checkpoint = torch.load(tar)
        #     self.encoder.load_state_dict(checkpoint['encoder_dict'])
        #     self.decoder.load_state_dict(checkpoint['decoder_dict'])

        self.tf_decay = .0001
        self.tf_rate = 1
        self.tf_low = 0.1

    def train(self, n_epochs):
        f = open(self.modelPath + 'model_details', 'w')
        f.write(str(self.encoder))
        f.write('\n')
        f.write(str(self.decoder))
        f.write('\n')
        group_sz = 10
        np.random.seed(1234)
        optimizer = optim.Adam(list(self.encoder.parameters()) + list(self.decoder.parameters()), lr=0.001)

        min_valid_loss = 0.0
        print('Training for {} epochs'.format(n_epochs))
        self.dataset.loadfiles(self.datapath,self.trainset)
        print('total no of files {}'.format(len(self.dataset.files)))
        f.write('total no of files {} \n'.format(len(self.dataset.files)))

        try:
            ################ epoch loop ###################
            epoch_loss = {'train': [], 'validation': []}
            for epoch in range( n_epochs):
                start_time = time()
                ####################### training #######################
                self.encoder.train()
                self.decoder.train()
                train_loss = []
                self.dataset.loadfiles(self.datapath, self.trainset)
                while (len(self.dataset.files) > 0):
                    self.dataset.prepareBatchOfMotion(group_sz)
                    inputs = self.dataset.input
                    outputs = self.dataset.target
                    # divide the data into chunk of seq len
                    chunk_in = list(torch.split(inputs, cfg.seq_len, dim=1))
                    chunk_target = list(torch.split(outputs, cfg.seq_len, dim=1))

                    if (len(chunk_in) == 0):
                        continue
                    print('chunk list size',len(chunk_in))
                    # pass all the chunks through encoder and accumulate c_out and c_hidden in a list
                    enc_output = []
                    enc_hidden = []
                    for c_in in chunk_in:
                        # chunk_in: (batch_sz:10, seq_len: 200, in_dim: 20)
                        #chunk_out: (batch_sz:10, seq_len: 200, out_dim: 60)
                        c_enc_out,c_enc_hidden = self.encoder(c_in)
                        enc_output.append(c_enc_out)
                        enc_hidden.append(c_enc_hidden)


                    # decoder input for the first timestep
                    batch_sz = chunk_in[0].shape[0]
                    tpose = np.array([[1, 0, 0, 0] * 15] * batch_sz)
                    dec_input = torch.FloatTensor(tpose.reshape(batch_sz, 1, 60)).cuda()
                    # pass all chunks to the decoder and predict for each timestep for all chunks sequentially
                    #predictions = []
                    for c_enc_out, c_enc_hidden , c_target in zip(enc_output,enc_hidden,chunk_target):

                        dec_hidden = c_enc_hidden
                        loss = 0.0
                        for t in range(c_target.shape[1]):
                            pred_t, dec_hidden = self.decoder(dec_input, dec_hidden, c_enc_out)
                            ############### use of teacher forcing in greedy method #############
                            # Return random floats in the half-open interval [0.0, 1.0)
                            rand = np.random.random()
                            # teacher forcing time
                            if (rand > 1 - self.tf_rate):
                                dec_input = c_target[:, t].unsqueeze(1)
                            # use own prediction
                            else:
                                dec_input = pred_t.detach().unsqueeze(1)
                            loss += self._loss_impl(c_target[:,t], pred_t)
                            #predictions.append(pred_t.detach())

                        optimizer.zero_grad()
                        loss.backward()
                        optimizer.step()
                        train_loss.append(loss.item() / (t+1) )

                    # decrease teacher forcing over time.
                    if ((self.tf_rate - self.tf_decay) > self.tf_low):
                        self.tf_rate -= self.tf_decay
                    else:
                         self.tf_rate = self.tf_low

                train_loss = torch.mean(torch.FloatTensor(train_loss))
                epoch_loss['train'].append(train_loss)
                # we save the model after each epoch : epoch_{}.pth.tar
                state = {
                    'epoch': epoch + 1,
                    'encoder_dict': self.encoder.state_dict(),
                    'decoder_dict': self.decoder.state_dict(),
                    'epoch_loss': train_loss
                }
                torch.save(state, self.modelPath + 'epoch_{}.pth.tar'.format(epoch + 1))
                debug_string = 'epoch No {}, epoch loss {}, Time taken {} \n'.format(
                    epoch + 1, epoch_loss, start_time - time()
                )
                print(debug_string)
                f.write(debug_string)
                f.write('\n')
                ####################### Validation #######################
                self.validset.loadfiles(self.datapath, self.testset)
                valid_loss = []
                while (len(self.validset.files) > 0):
                    self.validset.prepareBatchOfMotion(10)
                    inputs = self.validset.input
                    outputs = self.validset.target
                    # divide the data into chunk of seq len
                    chunk_in = list(torch.split(inputs, cfg.seq_len, dim=1))
                    chunk_target = list(torch.split(outputs, cfg.seq_len, dim=1))
                    if (len(chunk_in) == 0):
                        continue
                    # pass all the chunks through encoder and accumulate c_out and c_hidden in a list
                    enc_output = []
                    enc_hidden = []
                    for c_in in chunk_in:
                        # chunk_in: (batch_sz:10, seq_len: 200, in_dim: 20)
                        # chunk_out: (batch_sz:10, seq_len: 200, out_dim: 60)
                        c_enc_out, c_enc_hidden = self.encoder(c_in)
                        enc_output.append(c_enc_out)
                        enc_hidden.append(c_enc_hidden)

                    # decoder input for the first timestep
                    batch_sz = chunk_in[0].shape[0]
                    tpose = np.array([[1, 0, 0, 0] * 15] * batch_sz)
                    dec_input = torch.FloatTensor(tpose.reshape(batch_sz, 1, 60)).cuda()

                    # pass all chunks to the decoder and predict for each timestep for all chunks sequentially
                    # predictions = []
                    for c_enc_out, c_enc_hidden, c_target in zip(enc_output, enc_hidden, chunk_target):
                        dec_hidden = c_enc_hidden
                        loss = 0.0
                        for t in range(c_target.shape[1]):
                            pred_t, dec_hidden = self.decoder(dec_input, dec_hidden, c_enc_out)
                            ############### use of teacher forcing in greedy method #############
                            # Return random floats in the half-open interval [0.0, 1.0)
                            rand = np.random.random()
                            # teacher forcing time
                            if (rand > 1 - self.tf_rate):
                                dec_input = c_target[:, t].unsqueeze(1)
                            # use own prediction
                            else:
                                dec_input = pred_t.detach().unsqueeze(1)
                            loss += self._loss_impl(c_target[:, t], pred_t)
                            # predictions.append(pred_t.detach())
                        valid_loss.append(loss.item() / (t + 1))

                valid_loss = torch.mean(torch.FloatTensor(valid_loss))
                epoch_loss['validation'].append(valid_loss)
                # we save the model if current validation loss is less than prev : validation.pth.tar
                if (min_valid_loss == 0 or valid_loss < min_valid_loss):
                    min_valid_loss = valid_loss
                    state = {
                        'epoch': epoch + 1,
                        'encoder_dict': self.encoder.state_dict(),
                        'decoder_dict': self.decoder.state_dict(),
                        'validation_loss': valid_loss
                    }
                    torch.save(state, self.modelPath + 'validation.pth.tar')

                # logging to track
                debug_string = 'epoch No {}, validation loss {}, Time taken {} \n'.format(
                    epoch + 1, valid_loss, start_time - time()
                )
                print(debug_string)
                f.write(debug_string)
                f.write('\n')

            f.write('{}'.format(epoch_loss))
            f.close()
        except KeyboardInterrupt:
            state = {
                'epoch': epoch + 1,
                'encoder_dict': self.encoder.state_dict(),
                'decoder_dict': self.decoder.state_dict(),

            }
            torch.save(state, self.modelPath + 'error.pth.tar')
            print('Training aborted.')

    def _loss_impl(self, predicted, expected):
        L1 = predicted - expected
        return torch.mean((torch.norm(L1, 2, 1)))
class TestEngine:
    def __init__(self):

        self.datapath = '/data/Guha/GR/synthetic60FPS/'
        self.dataset = RawDataset()
        #self.modelPath = '/data/Guha/GR/model/18/'
        self.encoder = Encoder(input_dim=24, enc_units=256).cuda()
        self.decoder = Decoder(output_dim=60, dec_units=256,
                               enc_units=256).cuda()
        baseModelPath = '/data/Guha/GR/model/attn_greedy/epoch_6.pth.tar'
        self.base = '/data/Guha/GR/Output/TestSet/attn/'

        with open(baseModelPath, 'rb') as tar:
            checkpoint = torch.load(tar)
            self.encoder.load_state_dict(checkpoint['encoder_dict'])
            self.decoder.load_state_dict(checkpoint['decoder_dict'])

    def test(self):
        try:
            dset = [
                'AMASS_ACCAD', 'AMASS_BioMotion', 'AMASS_CMU_Kitchen', 'CMU',
                'HEva', 'H36'
            ]

            ####################### Validation #######################
            self.dataset.loadfiles(self.datapath, ['H36'])
            valid_loss = []
            for file in self.dataset.files:
                # Pick a random sequence
                # file = '/data/Guha/GR/Dataset/DFKI/walking_1.npz'
                # file = '/data/Guha/GR/DIPIMUandOthers/DIP_IMU_and_Others/DIP_IMU/s_01/01.pkl'

                ################# test on synthetic data #########################
                chunk_in, chunk_target = self.dataset.readfile(
                    file, cfg.seq_len)

                ################# test on DIP_IMU data #########################
                #input_batches, output_batches = self.dataset.readDIPfile(file,cfg.seq_len)

                ################# test on DFKI data #########################
                #input_batches = self.dataset.readDFKIfile(file, cfg.seq_len)

                self.encoder.eval()
                self.decoder.eval()
                if (len(chunk_in) == 0):
                    continue
                print('chunk list size', len(chunk_in))
                # pass all the chunks through encoder and accumulate c_out and c_hidden in a list
                enc_output = []
                enc_hidden = []
                for c_in in chunk_in:
                    # chunk_in: (batch_sz:10, seq_len: 200, in_dim: 20)
                    # chunk_out: (batch_sz:10, seq_len: 200, out_dim: 60)
                    c_in = c_in.unsqueeze(0)
                    c_enc_out, c_enc_hidden = self.encoder(c_in)
                    enc_output.append(c_enc_out)
                    enc_hidden.append(c_enc_hidden)

                # decoder input for the first timestep
                batch_sz = 1
                tpose = np.array([[1, 0, 0, 0] * 15] * batch_sz)
                dec_input = torch.FloatTensor(tpose.reshape(batch_sz, 1,
                                                            60)).cuda()

                # ########### for start with Ipose ###################
                # SMPL_MAJOR_JOINTS = [1, 2, 3, 4, 5, 6, 9, 12, 13, 14, 15, 16, 17, 18, 19]
                # ipose = myUtil.Ipose.reshape(-1, 24, 3)[:, SMPL_MAJOR_JOINTS, :]
                # qs = quaternion.from_rotation_vector(ipose)
                # qs = quaternion.as_float_array(qs)
                # dec_input = torch.FloatTensor(qs.reshape(1, 1, 60)).cuda()
                dec_input = chunk_target[0][0, :].reshape(1, 1, 60)

                # pass all chunks to the decoder and predict for each timestep for all chunks sequentially
                predictions = []
                for c_enc_out, c_enc_hidden, c_target in zip(
                        enc_output, enc_hidden, chunk_target):
                    dec_hidden = c_enc_hidden
                    loss = 0.0
                    for t in range(c_target.shape[0]):
                        pred_t, dec_hidden = self.decoder(
                            dec_input, dec_hidden, c_enc_out)
                        dec_input = pred_t.unsqueeze(1)
                        #loss += self._loss_impl(c_target[t], pred_t)
                        predictions.append(pred_t.detach().cpu().numpy())

                target = torch.cat(chunk_target).detach().cpu().numpy()
                predictions = np.asarray(predictions).reshape(-1, 15, 4)
                norms = np.linalg.norm(predictions, axis=2)
                predictions = np.asarray([
                    predictions[k, j, :] / norms[0, 0]
                    for k, j in itertools.product(range(predictions.shape[0]),
                                                  range(15))
                ])
                np.savez_compressed(self.base + file.split('/')[-1],
                                    target=target,
                                    predictions=predictions)
                #np.savez_compressed(self.base + file.split('/')[-1], predictions=pred)
                #break

        except KeyboardInterrupt:
            print('Testing aborted.')

    def _loss_impl(self, predicted, expected):
        L1 = predicted - expected
        return torch.mean((torch.norm(L1, 2, 1)))