def init_dnc(self):
        if not os.path.isfile("Data/sequence_width.txt"):
            self.read_data()  # To set the sequence width
        else:
            self.sequence_width = pickle.load(
                open("Data/sequence_width.txt",
                     'rb'))  # To set the sequence width

        self.machine = DNC_Module(self.sequence_width, self.sequence_width,
                                  self.controller_size, self.controller_layers,
                                  self.num_read_heads, self.num_write_heads,
                                  self.memory_N, self.memory_M)
 def init_dnc(self):
     self.machine = DNC_Module(self.num_inputs, self.num_outputs,
                               self.controller_size, self.controller_layers,
                               self.num_read_heads, self.num_write_heads,
                               self.memory_N, self.memory_M)
class task_NER():
    def __init__(self):
        self.name = "NER_task_bieos"

        # Controller Params
        self.controller_size = 128
        self.controller_layers = 1

        # Head Params
        self.num_read_heads = 2
        self.num_write_heads = 2

        # Processor Params
        self.num_inputs = 200  # Length of Embeddings
        self.num_outputs = 13  # Class size

        # Memory Params
        self.memory_N = 128
        self.memory_M = 128

        # Training Params
        self.num_batches = -1
        self.save_batch = 500  # Saving model after every save_batch number of batches
        self.batch_size = 10
        self.num_epoch = 1

        # Optimizer Params
        self.adam_lr = 1e-4
        self.adam_betas = (0.9, 0.999)
        self.adam_eps = 1e-8

        # Handles
        self.machine = None
        self.loss = None
        self.optimizer = None

        # Class Dictionaries
        self.labelDict = None  # Label Dictionary - Labels to Index
        self.reverseDict = None  # Inverse Label Dictionary - Index to Labels

        # File Paths
        self.concept_path_train = "/media/ramkabir/PC Data/ASU Data/Semester 3/BMNLP/Projects/medical_data/train_data/concept"  # Path to train concept files
        self.text_path_train = "/media/ramkabir/PC Data/ASU Data/Semester 3/BMNLP/Projects/medical_data/train_data/txt"  # Path to train text summaries
        self.concept_path_test = "/media/ramkabir/PC Data/ASU Data/Semester 3/BMNLP/Projects/medical_data/test_data/concept"  # Path to test concept files
        self.text_path_test = "/media/ramkabir/PC Data/ASU Data/Semester 3/BMNLP/Projects/medical_data/test_data/txt"  # Path to test text summaries
        self.save_path = "/media/ramkabir/PC Data/ASU Data/Semester 3/BMNLP/Projects/medical_data/cleaned_files"  # Save path
        self.embed_dic_path = "/media/ramkabir/PC Data/ASU Data/Semester 3/BMNLP/Projects/medical_data/embeddings/bio_embedding_dictionary.dat"  # Word2Vec embeddings Dictionary path
        self.random_vec = "/media/ramkabir/PC Data/ASU Data/Semester 3/BMNLP/Projects/medical_data/embeddings/random_vec.dat"  # Path to random embedding (Used to create new vectors)

        # Miscellaneous
        self.padding_symbol = np.full((self.num_inputs),
                                      0.01)  # Padding symbol embedding

    def get_task_name(self):
        return self.name

    def init_dnc(self):
        self.machine = DNC_Module(self.num_inputs, self.num_outputs,
                                  self.controller_size, self.controller_layers,
                                  self.num_read_heads, self.num_write_heads,
                                  self.memory_N, self.memory_M)

    def init_loss(self):
        self.loss = nn.CrossEntropyLoss(
            reduction='mean'
        )  # Cross Entropy Loss -> Softmax Activation + Cross Entropy Loss

    def init_optimizer(self):
        self.optimizer = optim.Adam(self.machine.parameters(),
                                    lr=self.adam_lr,
                                    betas=self.adam_betas,
                                    eps=self.adam_eps)

    def calc_loss(self, Y_pred, Y):
        # Y: dim -> (sequence_len x batch_size)
        # Y_pred: dim -> (sequence_len x batch_size x num_outputs)
        loss_vec = torch.empty(Y.shape[0], dtype=torch.float32)
        for i in range(Y_pred.shape[0]):
            loss_vec[i] = self.loss(Y_pred[i], Y[i])
        return torch.mean(loss_vec)

    def calc_cost(self, Y_pred, Y):  # Calculates % Cost
        # Y: dim -> (sequence_len x batch_size)
        # Y_pred: dim -> (sequence_len x batch_size x sequence_width)
        return torch.sum(((F.softmax(Y_pred, dim=2).max(2)[1]) == Y).type(
            torch.long)).item(), Y.shape[0] * Y.shape[1]

    def print_word(self,
                   token_class):  # Prints the Class name from Class number
        word = self.reverseDict[token_class]
        print(word + "\n")

    def clip_grads(self):  # Clipping gradients for stability
        """Gradient clipping to the range [10, 10]."""
        parameters = list(
            filter(lambda p: p.grad is not None, self.machine.parameters()))
        for p in parameters:
            p.grad.data.clamp_(-10, 10)

    def initialize_labels(
        self
    ):  # Initializing label dictionaries for Labels->IDX and IDX->Labels
        self.labelDict = {}  # Label Dictionary - Labels to Index
        self.reverseDict = {}  # Inverse Label Dictionary - Index to Labels

        # Using BIEOS labelling scheme
        self.labelDict['b-problem'] = 0  # Problem - Beginning
        self.labelDict['i-problem'] = 1  # Problem - Inside
        self.labelDict['e-problem'] = 2  # Problem - End
        self.labelDict['s-problem'] = 3  # Problem - Single
        self.labelDict['b-test'] = 4  # Test - Beginning
        self.labelDict['i-test'] = 5  # Test - Inside
        self.labelDict['e-test'] = 6  # Test - End
        self.labelDict['s-test'] = 7  # Test - Single
        self.labelDict['b-treatment'] = 8  # Treatment - Beginning
        self.labelDict['i-treatment'] = 9  # Treatment - Inside
        self.labelDict['e-treatment'] = 10  # Treatment - End
        self.labelDict['s-treatment'] = 11  # Treatment - Single
        self.labelDict['o'] = 12  # Outside Token

        # Making Inverse Label Dictionary
        for k in self.labelDict.keys():
            self.reverseDict[self.labelDict[k]] = k

        # Saving the diictionaries into a file
        self.save_data([self.labelDict, self.reverseDict],
                       os.path.join(self.save_path, "label_dicts_bieos.dat"))

    def parse_concepts(
            self, file_path
    ):  # Parses the concept file to extract concepts and labels
        conceptList = []  # Stores all the Concept in the File

        f = open(file_path)  # Opening and reading a concept file
        content = f.readlines()  # Reading all the lines in the concept file
        f.close()  # Closing the concept file

        for x in content:  # Reading each line in the concept file
            dic = {}

            # Cleaning and extracting the entities, labels and their positions in the corresponding medical summaries
            x = re.sub('\n', ' ', x)
            x = re.sub(r'\ +', ' ', x)
            x = x.strip().split('||')

            temp1, label = x[0].split(' '), x[1].split('=')[1][1:-1]

            temp1[0] = temp1[0][3:]
            temp1[-3] = temp1[-3][0:-1]
            entity = temp1[0:-2]

            if len(entity) > 1:
                lab = ['i'] * len(entity)
                lab[0] = 'b'
                lab[-1] = 'e'
                lab = [l + "-" + label for l in lab]
            elif len(entity) == 1:
                lab = ["s" + "-" + label]
            else:
                print("Data in File: " + file_path +
                      ", not in expected format..")
                exit()

            noLab = [self.labelDict[l] for l in lab]
            sLine, sCol = int(temp1[-2].split(":")[0]), int(
                temp1[-2].split(":")[1])
            eLine, eCol = int(temp1[-1].split(":")[0]), int(
                temp1[-1].split(":")[1])
            '''
            # Printing the information
            print("------------------------------------------------------------")
            print("Entity: " + str(entity))
            print("Entity Label: " + label)
            print("Labels - BIEOS form: " + str(lab))
            print("Labels  Index: " + str(noLab))
            print("Start Line: " + str(sLine) + ", Start Column: " + str(sCol))
            print("End Line: " + str(eLine) + ", End Column: " + str(eCol))
            print("------------------------------------------------------------")
            '''

            # Storing the information as a dictionary
            dic['entity'] = entity  # Entity Name (In the form of list of words)
            dic['label'] = label  # Common Label
            dic['BIEOS_labels'] = lab  # List of BIEOS labels for each word
            dic['label_index'] = noLab  # Labels in the index form
            dic['start_line'] = sLine  # Start line of the concept in the corresponding text summaries
            dic['start_word_no'] = sCol  # Starting word number of the concept in the corresponding start line
            dic['end_line'] = eLine  # End line of the concept in the corresponding text summaries
            dic['end_word_no'] = eCol  # Ending word number of the concept in the corresponding end line

            # Appending the concept dictionary to the list
            conceptList.append(dic)

        return conceptList  # Returning the all the concepts in the current file in the form of dictionary list

    def parse_summary(self, file_path):  # Parses the Text summaries
        file_lines = []  # Stores the lins of files in the list form
        tags = [
        ]  # Stores corresponding labels for each word in the file (Default label: 'o' [Outside])
        default_label = len(
            self.labelDict
        ) - 1  # default_label is "12" (Corresponding to 'Other' entity)
        # counter = 1                           # Temporary variable used during print

        f = open(file_path)  # Opening and reading a concept file
        content = f.readlines()  # Reading all the lines in the concept file
        f.close()

        for x in content:
            x = re.sub('\n', ' ', x)
            x = re.sub(r'\ +', ' ', x)
            file_lines.append(
                x.strip().split(" ")
            )  # Spliting the lines into word list and Appending each of them in the file list
            tags.append(
                [default_label] * len(file_lines[-1])
            )  # Assigining the default_label to all the words in a line
            '''
            # Printing the information
            print("------------------------------------------------------------")
            print("File Lines No: " + str(counter))
            print(file_lines[-1])
            print("\nCorresponding labels:")
            print(tags[-1])
            print("------------------------------------------------------------")
            counter += 1
            '''
            assert len(tags[-1]) == len(
                file_lines[-1]
            ), "Line length is not matching labels length..."  # Sanity Check
        return file_lines, tags

    def modify_labels(
        self, conceptList, tags
    ):  # Modifies the default labels of each word in text files with the true labels from the concept files
        for e in conceptList:  # Iterating over all the dictionary elements in the Concept List
            if e['start_line'] == e[
                    'end_line']:  # Checking whether concept is spanning over a single line or multiple line in the summary
                tags[e['start_line'] - 1][e['start_word_no']:e['end_word_no'] +
                                          1] = e['label_index'][:]
            else:
                start = e['start_line']
                end = e['end_line']
                beg = 0
                for i in range(
                        start, end + 1
                ):  # Distributing labels over multiple lines in the text summaries
                    if i == start:
                        tags[i - 1][e['start_word_no']:] = e['label_index'][
                            0:len(tags[i - 1]) - e['start_word_no']]
                        beg = len(tags[i - 1]) - e['start_word_no']
                    elif i == end:
                        tags[i - 1][0:e['end_word_no'] +
                                    1] = e['label_index'][beg:]
                    else:
                        tags[i - 1][:] = e['label_index'][beg:beg +
                                                          len(tags[i - 1])]
                        beg = beg + len(tags[i - 1])
        return tags

    def print_data(self, file, file_lines, tags):  # Prints the given data
        counter = 1

        print("\n************ Printing details of the file: " + file +
              " ************\n")
        for x in file_lines:
            print(
                "------------------------------------------------------------")
            print("File Lines No: " + str(counter))
            print(x)
            print("\nCorresponding labels:")
            print([self.reverseDict[i] for i in tags[counter - 1]])
            print("\nCorresponding Label Indices:")
            print(tags[counter - 1])
            print(
                "------------------------------------------------------------")
            counter += 1

    def save_data(self, obj_list,
                  s_path):  # Saves the file into the binary file using Pickle
        # Note: The 'obj_list' must be a list and none other than that
        pickle.dump(tuple(obj_list), open(s_path, 'wb'))

    def acquire_data(
        self, task
    ):  # Read all the concept files to get concepts and labels, proces them and save them
        data = {
        }  # Dictionary to store all the data objects (conceptList, file_lines, tags) each indexed by file name

        if task == 'train':  # Determining the task type to assign the data path accordingly
            t_path = self.text_path_train
            c_path = self.concept_path_train
        else:
            t_path = self.text_path_test
            c_path = self.concept_path_test

        for f in os.listdir(t_path):
            f1 = f.split('.')[0] + ".con"
            if os.path.isfile(os.path.join(c_path, f1)):
                conceptList = self.parse_concepts(
                    os.path.join(c_path, f1)
                )  # Parsing concepts and labels from the corresponding concept file
                file_lines, tags = self.parse_summary(
                    os.path.join(t_path, f)
                )  # Parses the document summaries to get the written notes
                tags = self.modify_labels(
                    conceptList, tags
                )  # Modifies he default labels to each word with the true labels from the concept files
                data[f1] = [conceptList, file_lines,
                            tags]  # Storing each object in dictionary
                # self.print_data(f, file_lines, tags)                              # Printing the details
        return data

    def structure_data(
            self, data_dict):  # Structures the data in proper trainable form
        final_line_list = [
        ]  # Stores words of all the files in separate sub-lists
        final_tag_list = [
        ]  # Stores tags of all the files in separate sub-lists

        for k in data_dict.keys(
        ):  # Extracting data from each pre-processed file in dictionary
            file_lines = data_dict[k][1]  # Extracting story
            tags = data_dict[k][2]  # Extracting corresponding labels

            # Creating empty lists
            temp1 = []
            temp2 = []

            # Merging all the lines in file into a single list. Same for corresponding labels
            for i in range(len(file_lines)):
                temp1.extend(file_lines[i])
                temp2.extend(tags[i])

            assert len(temp1) == len(
                temp2
            ), "Word length not matching Label length for story in " + str(
                k)  # Sanity Check

            final_line_list.append(temp1)
            final_tag_list.append(temp2)

        assert len(final_line_list) == len(
            final_tag_list
        ), "Number of stories not matching number of labels list"  # Sanity Check
        return final_line_list, final_tag_list

    def padding(
        self, line_list, tag_list
    ):  # Pads stories with padding symbol to make them of same length
        diff = 0
        max_len = 0
        outside_class = len(
            self.labelDict) - 1  # Classifying padding symbol as "outside" term

        # Calculating Max Summary Length
        for i in range(len(line_list)):
            if len(line_list[i]) > max_len:
                max_len = len(line_list[i])

        for i in range(len(line_list)):
            diff = max_len - len(line_list[i])
            line_list[i].extend([self.padding_symbol] * diff)
            tag_list[i].extend([outside_class] * diff)
            assert (len(line_list[i]) == max_len) and (len(
                line_list[i]) == len(
                    tag_list[i])), "Padding unsuccessful"  # Sanity check
        return np.asarray(line_list), np.asarray(
            tag_list
        )  # Making NumPy array of size (batch_size x story_length x word size) and (batch_size x story_length x 1) respectively

    def embed_input(self, line_list):  # Converts words to vector embeddings
        final_list = []  # Stores embedded words
        summary = None  # Temp variable
        word = None  # Temp variable
        temp = None  # Temp variable

        embed_dic = pickle.load(
            open(self.embed_dic_path,
                 'rb'))  # Loading word2vec dictionary using Pickle
        r_embed = pickle.load(open(self.random_vec,
                                   'rb'))  # Loading Random embedding

        for i in range(len(line_list)):  # Iterating over all the summaries
            summary = line_list[i]
            final_list.append([])  # Reserving space for curent summary

            for j in range(len(summary)):
                word = summary[j].lower()
                if word in embed_dic:  # Checking for existence of word in dictionary
                    final_list[-1].append(embed_dic[word])
                else:
                    temp = r_embed[:]  # Copying the values of the list
                    random.shuffle(
                        temp
                    )  # Randomly shuffling the word embedding to make it unique
                    temp = np.asarray(
                        temp, dtype=np.float32)  # Converting to NumPy array
                    final_list[-1].append(temp)
        return final_list

    def prepare_data(self, task='train'):  # Preparing all the data necessary
        line_list, tag_list = None, None
        '''
        line_list is the list of rows, where each row is a list of all the words in a medical summary
        Similar is the case for tag_list, except, it stores labels for each words
        '''

        if not os.path.exists(self.save_path):
            os.mkdir(
                self.save_path
            )  # Creating a new directory if it does not exist else reading previously saved data

        if not os.path.exists(
                os.path.join(self.save_path, "label_dicts_bieos.dat")):
            self.initialize_labels()  # Initialize label to index dictionaries
        else:
            self.labelDict, self.reverseDict = pickle.load(
                open(os.path.join(self.save_path, "label_dicts_bieos.dat"),
                     'rb'))  # Loading Label dictionaries

        if not os.path.exists(
                os.path.join(self.save_path,
                             "object_dict_bieos_" + str(task) + ".dat")):
            data_dict = self.acquire_data(task)  # Read data from file
            line_list, tag_list = self.structure_data(
                data_dict)  # Structures the data into proper form
            line_list = self.embed_input(
                line_list)  # Embeds input data (words) into embeddings : Left
            self.save_data([line_list, tag_list],
                           os.path.join(
                               self.save_path,
                               "object_dict_bieos_" + str(task) + ".dat"))
        else:
            line_list, tag_list = pickle.load(
                open(
                    os.path.join(self.save_path,
                                 "object_dict_bieos_" + str(task) + ".dat"),
                    'rb'))  # Loading Data dictionary
        return line_list, tag_list

    def get_data(self, task='train'):
        line_list, tag_list = self.prepare_data(task)

        # Shuffling stories
        story_idx = list(range(0, len(line_list)))
        random.shuffle(story_idx)

        num_batch = int(len(story_idx) / self.batch_size)
        self.num_batches = num_batch

        # Out Data
        x_out = []
        y_out = []

        counter = 1

        for i in story_idx:
            if num_batch <= 0:
                break

            x_out.append(line_list[i])
            y_out.append(tag_list[i])

            if counter % self.batch_size == 0:
                counter = 0

                # Padding and converting labels to one hot vectors
                x_out_pad, y_out_pad = self.padding(x_out, y_out)
                x_out_array = torch.tensor(
                    x_out_pad.swapaxes(0, 1), dtype=torch.float32
                )  # Converting from (batch_size x story_length x word size) to (story_length x batch_size x word size)
                y_out_array = torch.tensor(
                    y_out_pad.swapaxes(0, 1), dtype=torch.long
                )  # Converting from (batch_size x story_length x 1) to (story_length x batch_size x 1)

                x_out = []
                y_out = []
                num_batch -= 1

                yield (self.num_batches - num_batch), x_out_array, y_out_array
            counter += 1

    def train_model(self):
        # Here, the model is optimized using Cross Entropy Loss, however, it is evaluated using Number of error bits in predction and actual labels (cost)
        loss_list = []
        seq_length = []
        last_batch = 0

        for j in range(self.num_epoch):
            for batch_num, X, Y in self.get_data(task='train'):
                self.optimizer.zero_grad(
                )  # Making old gradients zero before calculating the fresh ones
                self.machine.initialization(
                    self.batch_size)  # Initializing states
                Y_out = torch.empty(
                    (X.shape[0], X.shape[1], self.num_outputs),
                    dtype=torch.float32
                )  # dim: (seq_len x batch_size x num_output)

                # Feeding the DNC network all the data first and then predicting output
                # by giving zero vector as input and previous read states and hidden vector
                # and thus training vector this way to give outputs matching the labels

                embeddings = self.machine.backward_prediction(
                    X
                )  # Creating embeddings from data for backward calculation
                temp_size = X.shape[0]

                for i in range(temp_size):
                    Y_out[i, :, :], _ = self.machine(
                        X[i],
                        embeddings[temp_size - i -
                                   1])  # Passing Embeddings from backwards

                loss = self.calc_loss(Y_out, Y)
                loss.backward()
                self.clip_grads()
                self.optimizer.step()

                corr, tot = self.calc_cost(Y_out, Y)

                loss_list += [loss.item()]
                seq_length += [Y.shape[0]]

                if (batch_num % self.save_batch) == 0:
                    self.save_model(j, batch_num)

                last_batch = batch_num
                print("Epoch: " + str(j) + "/" + str(self.num_epoch) +
                      ", Batch: " + str(batch_num) + "/" +
                      str(self.num_batches) + ", Loss: " + str(loss.item()) +
                      ", Batch Accuracy: " +
                      str((float(corr) / float(tot)) * 100.0) + " %")
            self.save_model(j, last_batch)

    def test_model(self):  # Testing the model
        correct = 0
        total = 0
        print("\n")

        for batch_num, X, Y in self.get_data(task='test'):
            self.machine.initialization(self.batch_size)  # Initializing states
            Y_out = torch.empty((X.shape[0], X.shape[1], self.num_outputs),
                                dtype=torch.float32
                                )  # dim: (seq_len x batch_size x num_output)

            # Feeding the DNC network all the data first and then predicting output
            # by giving zero vector as input and previous read states and hidden vector
            # and thus training vector this way to give outputs matching the labels

            embeddings = self.machine.backward_prediction(
                X)  # Creating embeddings from data for backward calculation
            temp_size = X.shape[0]

            for i in range(temp_size):
                Y_out[i, :, :], _ = self.machine(X[i],
                                                 embeddings[temp_size - i - 1])

            corr, tot = self.calc_cost(Y_out, Y)

            correct += corr
            total += tot
            print("Test Example " + str(batch_num) + "/" +
                  str(self.num_batches) + " processed, Batch Accuracy: " +
                  str((float(corr) / float(tot)) * 100.0) + " %")

        accuracy = (float(correct) / float(total)) * 100.0
        print("\nOverall Accuracy: " + str(accuracy) + " %")
        return accuracy  # in %

    def save_model(self, curr_epoch, curr_batch):
        # Here 'start_epoch' and 'start_batch' params below are the 'epoch' and 'batch' number from which to start training after next model loading
        # Note: It is recommended to start from the 'start_epoch' and not 'start_epoch' + 'start_batch', because batches are formed randomly
        if not os.path.exists("Saved_Models/" + self.name):
            os.mkdir("Saved_Models/" + self.name)
        state_dic = {
            'task_name': self.name,
            'start_epoch': curr_epoch + 1,
            'start_batch': curr_batch + 1,
            'state_dict': self.machine.state_dict(),
            'optimizer_dic': self.optimizer.state_dict()
        }
        filename = "Saved_Models/" + self.name + "/" + self.name + "_" + str(
            curr_epoch) + "_" + str(curr_batch) + "_saved_model.pth.tar"
        torch.save(state_dic, filename)

    def load_model(self, option, epoch, batch):
        path = "Saved_Models/" + self.name + "/" + self.name + "_" + str(
            epoch) + "_" + str(batch) + "_saved_model.pth.tar"
        if option == 1:  # Loading for training
            checkpoint = torch.load(path)
            self.machine.load_state_dict(checkpoint['state_dict'])
            self.optimizer.load_state_dict(checkpoint['optimizer_dic'])
        else:  # Loading for testing
            checkpoint = torch.load(path)
            self.machine.load_state_dict(checkpoint['state_dict'])
            self.machine.eval()
class task_babi():
    def __init__(self):
        self.name = "bAbI_task"
        self.controller_size = 128
        self.controller_layers = 1
        self.num_read_heads = 2
        self.num_write_heads = 2
        self.sequence_width = -1  # Length of each word
        self.sequence_len = -1  # Word length of each story
        self.memory_N = 128
        self.memory_M = 128
        self.num_batches = -1
        self.num_epoch = 1
        self.batch_size = 1
        self.adam_lr = 1e-4
        self.adam_betas = (0.9, 0.999)
        self.adam_eps = 1e-8
        self.machine = None
        self.loss = None
        self.optimizer = None
        self.ind_to_word = None
        self.data_dir = "Data/bAbI/en-10k"  # Data directory

    def get_task_name(self):
        return self.name

    def init_dnc(self):
        if not os.path.isfile("Data/sequence_width.txt"):
            self.read_data()  # To set the sequence width
        else:
            self.sequence_width = pickle.load(
                open("Data/sequence_width.txt",
                     'rb'))  # To set the sequence width

        self.machine = DNC_Module(self.sequence_width, self.sequence_width,
                                  self.controller_size, self.controller_layers,
                                  self.num_read_heads, self.num_write_heads,
                                  self.memory_N, self.memory_M)

    def init_loss(self):
        self.loss = nn.CrossEntropyLoss(
            reduction='none'
        )  # Cross Entropy Loss -> Sigmoid Activation + Cross Entropy Loss

    def init_optimizer(self):
        self.optimizer = optim.Adam(self.machine.parameters(),
                                    lr=self.adam_lr,
                                    betas=self.adam_betas,
                                    eps=self.adam_eps)

    def calc_loss(self, Y_pred, Y, mask):
        # Y: dim -> (sequence_len x batch_size)
        # Y_pred: dim -> (sequence_len x batch_size x sequence_width)
        # mask: dim -> (sequence_len x batch_size)

        loss_vec = torch.empty(Y.shape, dtype=torch.float32)

        for i in range(Y_pred.shape[0]):
            loss_vec[i, :] = self.loss(Y_pred[i], Y[i])

        return torch.sum(loss_vec * mask) / torch.sum(mask)

    def calc_cost(self, Y_pred, Y, mask):  # Calculates % Cost
        # Y: dim -> (sequence_len x batch_size)
        # Y_pred: dim -> (sequence_len x batch_size x sequence_width)
        # mask: dim -> (sequence_len x batch_size)
        return torch.sum(
            ((F.softmax(Y_pred, dim=2).max(2)[1]) == Y).type(torch.long) *
            mask.type(torch.long)).item(), torch.sum(mask).item()

    def print_word(self, word_vec):  # Prints the word from word vector
        # "word_vect" dimension : (1 x sequence_width)
        idx = np.argmax(word_vec, axis=1)
        word = self.ind_to_word[idx]
        print(word + "\n")

    def to_one_hot(self, story):  # Converts a vector into one hot form
        out_token = []

        I = np.eye(self.sequence_width)
        for idx in story:
            out_token.append(I[int(idx)])

        if len(out_token) > self.sequence_len:
            self.sequence_len = len(out_token)
        return out_token

    def padding_labels(
        self, stories
    ):  # Making separate funcion to pad labels because, labels will not be in one-hot vector form due to the requirements of PyTorch Cross Entropy Loss Function
        padded_stories = []

        for story in stories:
            if len(story) < self.sequence_len:
                li = [1 for i in range(self.sequence_len - len(story))]
                story.extend(li)
            padded_stories.append(np.asarray(story, dtype=np.long))
        return padded_stories

    def padding(
        self, stories
    ):  # Pads padding element to make all the stories of equal length
        padded_stories = []

        for story in stories:
            if len(story) < self.sequence_len:
                li = self.to_one_hot(np.ones(self.sequence_len - len(story)))
                story.extend(li)
            padded_stories.append(np.asarray(story, dtype=np.float32))
        return padded_stories

    def flatten_if_list(
        self, l
    ):  # Merges all the list within a list with the outer list elements. Example: [you', '?', ['-']] -> ['you', '?', '-']
        newl = []
        for elem in l:
            if isinstance(
                    elem,
                    list):  # Checking whether the element is 'list' or not
                newl.extend(
                    elem
                )  # input.extend(li_2) method appends all the elements of 'li_2' list into 'input' list
            else:
                newl.append(elem)
        return newl

    def structure_data(self, x, y):  # Prepares data for bAbI task
        # Preparing  Data
        keys = list(x.keys())
        random.shuffle(keys)  # Randomly Shuffling the key list

        inp_story = []
        out_story = []

        for key in keys:
            inp_story.extend(x[key])
            out_story.extend(y[key])

        story_idx = list(range(0, len(inp_story)))
        random.shuffle(story_idx)

        # Here I am breaking stories into different files because A single list can't store all the stories
        num_batch = int(len(story_idx) / self.batch_size)
        self.num_batches = num_batch
        counter = 1

        # Out Data
        x_out = []
        y_out = []
        mask_inp = [
        ]  # Will be used for making the mask to make "non amswer" output words from DNC irrelevent

        for i in story_idx:
            if num_batch <= 0:
                break

            x_out.append(self.to_one_hot(inp_story[i]))
            y_out.append(out_story[i])
            mask_inp.append(
                inp_story[i])  # Appending input story For making the mask

            if counter % self.batch_size == 0:
                # Resetting Counter
                counter = 0

                # Padding
                x_out_array = torch.tensor(
                    np.asarray(self.padding(x_out)).swapaxes(0, 1)
                )  # Converting from (batch_size x story_length x word size) to (story_length x batch_size x word size)
                y_out_array = torch.tensor(
                    np.asarray(self.padding_labels(y_out)).swapaxes(0, 1),
                    dtype=torch.long
                )  # Converting from (batch_size x story_length x word size) to (story_length x batch_size x word size)
                m_inp_array = torch.tensor(
                    np.asarray(self.padding_labels(mask_inp)).swapaxes(0, 1),
                    dtype=torch.long
                )  # Converting from (batch_size x story_length x word size) to (story_length x batch_size x word size)

                # Renewing List and updating batch number
                x_out = []
                y_out = []
                mask_inp = []
                num_batch -= 1

                yield (self.num_batches -
                       num_batch), x_out_array, y_out_array, (
                           m_inp_array == 0).float()
            counter += 1

    def read_data(self):  # Reading and Cleaning data from the file
        storage_file = "Data/cleaned_data_bAbI_" + self.data_dir.split(
            '/')[2] + ".txt"

        if not os.path.isfile(storage_file):
            output_symbol = "-"  # Indicates an expectation of output to the DNC
            newstory_delimiter = " NEWSTORY "  # To separate stories
            pad_symbol = "*"  # Padding symbol

            file_paths = []

            word_to_ind = {
                output_symbol: 0,
                pad_symbol: 1
            }  # Dictionary to store indices of all the word in the bAbI dataset. Predefined symbols already stored
            all_input_stories = {}
            all_output_stories = {}

            # Making list of all the files in the data directory
            for f in os.listdir(self.data_dir):
                f_path = os.path.join(self.data_dir, f)
                if os.path.isfile(f_path):
                    file_paths.append(f_path)

            # Processing the text files
            for file_path in file_paths:
                # print(file_path)
                # Cleaning the text
                file = open(file_path).read().lower()
                file = re.sub("\n1 ", newstory_delimiter,
                              file)  # Adding a delimeter between two stories
                file = re.sub(r"\d+|\n|\t", " ",
                              file)  # Removing all numbers, newlines and tabs
                file = re.sub("([?.])", r" \1",
                              file)  # Adding a space before all punctuations
                stories = file.split(newstory_delimiter
                                     )  # Splitting whole text into the stories

                input_stories = [
                ]  # Stores the stories into the index form, where each word has unique index
                output_stories = []

                # Tokenizing the text
                for i, story in enumerate(stories):
                    input_tokens = story.split(
                    )  # Input stories are meant for inputting to the DNC
                    output_tokens = story.split(
                    )  # Output stories works as labels

                    for i, token in enumerate(
                            input_tokens
                    ):  # This when encountered "?", replaces answers with "-" sign in the input for the DNC
                        if token == "?":
                            output_tokens[i + 1] = output_tokens[i +
                                                                 1].split(",")
                            input_tokens[i + 1] = [
                                output_symbol
                                for _ in range(len(output_tokens[i + 1]))
                            ]

                    input_tokens = self.flatten_if_list(input_tokens)
                    output_tokens = self.flatten_if_list(output_tokens)

                    # Calculating index of all the words
                    for token in output_tokens:
                        if token not in word_to_ind:
                            word_to_ind[token] = len(word_to_ind)

                    input_stories.append([
                        word_to_ind[elem] for elem in input_tokens
                    ])  # Storing each story into a list of word indices form
                    output_stories.append(
                        [word_to_ind[elem] for elem in output_tokens])

                all_input_stories[
                    file_path] = input_stories  # Storing all the stories for each file
                all_output_stories[file_path] = output_stories

            # Dumping all the cleaned data into a file
            pickle.dump((word_to_ind, all_input_stories, all_output_stories),
                        open(storage_file, 'wb'))
            pickle.dump(len(word_to_ind), open("Data/sequence_width.txt",
                                               'wb'))
            self.sequence_width = len(
                word_to_ind)  # Vector length of one hot vector
        else:
            word_to_ind, all_input_stories, all_output_stories = pickle.load(
                open(storage_file, 'rb'))
        return word_to_ind, all_input_stories, all_output_stories

    def get_training_data(self):  # Data directory
        word_to_ind, all_input_stories, all_output_stories = self.read_data()
        self.ind_to_word = {
            ind: word
            for word, ind in word_to_ind.items()
        }  # Reverse Index to Word dictionary to show final output

        # Separating Test and Train Data
        x_train_stories = {
            k: v
            for k, v in all_input_stories.items() if k[-9:] == "train.txt"
        }
        y_train_stories = {
            k: v
            for k, v in all_output_stories.items() if k[-9:] == "train.txt"
        }
        return self.structure_data(
            x_train_stories, y_train_stories
        )  # dim: x_train, y_train -> A list of (sequence_len x sequence_width) sized stories

    def get_test_data(
        self
    ):  # Sample data for Testing                                                    # Data directory
        _, all_input_stories, all_output_stories = self.read_data()

        # Separating Test and Train Data
        x_test_stories = {
            k: v
            for k, v in all_input_stories.items() if k[-8:] == "test.txt"
        }
        y_test_stories = {
            k: v
            for k, v in all_output_stories.items() if k[-8:] == "test.txt"
        }
        return self.structure_data(
            x_test_stories, y_test_stories
        )  # dim: x_test, y_test -> A list of (sequence_len x sequence_width) sized stories

    def test_model(self):  # Testing the model
        correct = 0
        total = 0
        print("\n")

        for batch_num, X, Y, mask in self.get_test_data():
            self.machine.initialization(self.batch_size)  # Initializing states
            Y_out = torch.zeros(X.shape)

            # Feeding the DNC network all the data first and then predicting output
            # by giving zero vector as input and previous read states and hidden vector
            # and thus training vector this way to give outputs matching the labels

            embeddings = self.machine.backward_prediction(
                X)  # Creating embeddings from data for backward calculation
            temp_size = X.shape[0]

            for i in range(temp_size):
                Y_out[i, :, :], _ = self.machine(X[i],
                                                 embeddings[temp_size - i - 1])

            corr, tot = self.calc_cost(Y_out, Y, mask)

            correct += corr
            total += tot
            print("Test Example " + str(batch_num) + "/" +
                  str(self.num_batches) + " processed, Batch Accuracy: " +
                  str((float(corr) / float(tot)) * 100.0) + " %")

        accuracy = (float(correct) / float(total)) * 100.0
        print("\nOverall Accuracy: " + str(accuracy) + " %")
        return accuracy  # in %

    def clip_grads(self):  # Clipping gradients for stability
        """Gradient clipping to the range [10, 10]."""
        parameters = list(
            filter(lambda p: p.grad is not None, self.machine.parameters()))
        for p in parameters:
            p.grad.data.clamp_(-10, 10)

    def train_model(self):
        # Here, the model is optimized using Cross Entropy Loss, however, it is evaluated using Number of error bits in predction and actual labels (cost)
        loss_list = []
        seq_length = []
        save_batch = 500
        last_batch = 0

        for j in range(self.num_epoch):
            for batch_num, X, Y, mask in self.get_training_data():
                self.optimizer.zero_grad(
                )  # Making old gradients zero before calculating the fresh ones
                self.machine.initialization(
                    self.batch_size)  # Initializing states
                Y_out = torch.zeros(X.shape)

                # Feeding the DNC network all the data first and then predicting output
                # by giving zero vector as input and previous read states and hidden vector
                # and thus training vector this way to give outputs matching the labels

                embeddings = self.machine.backward_prediction(
                    X
                )  # Creating embeddings from data for backward calculation
                temp_size = X.shape[0]

                for i in range(temp_size):
                    Y_out[i, :, :], _ = self.machine(
                        X[i],
                        embeddings[temp_size - i -
                                   1])  # Passing Embeddings from backwards

                loss = self.calc_loss(Y_out, Y, mask)
                loss.backward()
                self.clip_grads()
                self.optimizer.step()

                loss_list += [loss.item()]
                seq_length += [Y.shape[0]]

                if (batch_num % save_batch) == 0:
                    self.save_model(j, batch_num)

                last_batch = batch_num
                print("Epoch: " + str(j) + "/" + str(self.num_epoch) +
                      ", Batch: " + str(batch_num) + "/" +
                      str(self.num_batches) + ", Loss: " + str(loss.item()))
            self.save_model(j, last_batch)

    def save_model(self, curr_epoch, curr_batch):
        # Here 'start_epoch' and 'start_batch' params below are the 'epoch' and 'batch' number from which to start training after next model loading
        # Note: It is recommended to start from the 'start_epoch' and not 'start_epoch' + 'start_batch', because batches are formed randomly

        if not os.path.exists("Saved_Models/" + self.name):
            os.mkdir("Saved_Models/" + self.name)
        state_dic = {
            'task_name': self.name,
            'start_epoch': curr_epoch + 1,
            'start_batch': curr_batch + 1,
            'state_dict': self.machine.state_dict(),
            'optimizer_dic': self.optimizer.state_dict()
        }
        filename = "Saved_Models/" + self.name + "/" + self.name + "_" + str(
            curr_epoch) + "_" + str(curr_batch) + "_saved_model.pth.tar"
        torch.save(state_dic, filename)

    def load_model(self, option, epoch, batch):
        path = "Saved_Models/" + self.name + "/" + self.name + "_" + str(
            epoch) + "_" + str(batch) + "_saved_model.pth.tar"
        if option == 1:  # Loading for training
            checkpoint = torch.load(path)
            self.machine.load_state_dict(checkpoint['state_dict'])
            self.optimizer.load_state_dict(checkpoint['optimizer_dic'])
        else:  # Loading for testing
            checkpoint = torch.load(path)
            self.machine.load_state_dict(checkpoint['state_dict'])
            self.machine.eval()
 def init_dnc(self):
     self.machine = DNC_Module(self.sequence_width + 1, self.sequence_width,
                               self.controller_size, self.controller_layers,
                               self.num_read_heads, self.num_write_heads,
                               self.memory_N, self.memory_M)
class task_copy():
    def __init__(self):
        self.name = "copy_task"
        self.controller_size = 100
        self.controller_layers = 1
        self.num_read_heads = 2
        self.num_write_heads = 1
        self.sequence_width = 8
        self.sequence_min_len = 5
        self.sequence_max_len = 20  # Default: 20
        self.memory_N = 128
        self.memory_M = 20
        self.num_batches = 10000
        self.batch_size = 1
        self.rmsprop_lr = 1e-4
        self.rmsprop_momentum = 0.9
        self.rmsprop_alpha = 0.95
        self.machine = None
        self.loss = None
        self.optimizer = None

    def get_task_name(self):
        return self.name

    def _data_maker(self, num_batches, batch_size, seq_width, min_len,
                    max_len):  # Generates data for copy task
        # The input data vector will be of size (num_data_rows x batch_size x num_data_columns)
        #
        # 1 1 1 0 1  | 1 1 0 1 0 | 1 1 1 0 1 | 1 0 1 1 1
        # 0 0 1 0 1  | 0 1 0 1 1 | 0 1 0 0 1 | 0 0 1 1 0
        # 0 1 1 0 1  | 1 1 0 0 0 | 1 0 1 0 1 | 0 0 1 1 0
        #
        # Above is the example of data. num_data_rows = 3, num_data_columns = 5, batch_size = 4
        #
        # At a time we will give each row strip to the NTM for prediction as shown below. Therefore input size for one interaction will be (batch_size x num_data_columns)
        #
        # 1 1 1 0 1  | 1 1 0 1 0 | 1 1 1 0 1 | 1 0 1 1 1

        for batch_num in range(num_batches + 1):
            # All batches have the same sequence length
            seq_len = random.randint(min_len, max_len)
            seq = np.random.binomial(
                1, 0.5, (seq_len, batch_size, seq_width)
            )  # Here, seq_len = num_data_rows and seq_width = num_data_columns
            seq = torch.from_numpy(seq)

            # The input includes an additional channel used for the delimiter
            inp = torch.zeros(seq_len + 1, batch_size, seq_width + 1)
            inp[:seq_len, :, :seq_width] = seq
            inp[seq_len, :,
                seq_width] = 1.0  # delimiter in our control channel
            outp = seq.clone()

            yield batch_num + 1, inp.float(), outp.float()

    def init_dnc(self):
        self.machine = DNC_Module(self.sequence_width + 1, self.sequence_width,
                                  self.controller_size, self.controller_layers,
                                  self.num_read_heads, self.num_write_heads,
                                  self.memory_N, self.memory_M)

    def init_loss(self):
        self.loss = nn.BCEWithLogitsLoss(
        )  # Binary Cross Entropy Loss -> Sigmoid Activation + Cross Entropy Loss

    def init_optimizer(self):
        self.optimizer = optim.RMSprop(self.machine.parameters(),
                                       momentum=self.rmsprop_momentum,
                                       alpha=self.rmsprop_alpha,
                                       lr=self.rmsprop_lr)

    def calc_loss(self, Y_pred, Y):
        return self.loss(Y_pred, Y)

    def get_sample_data(self):  # Sample data for Testing
        batch_size = 1
        seq_len = random.randint(self.sequence_min_len, self.sequence_max_len)
        seq = np.random.binomial(
            1, 0.5, (seq_len, batch_size, self.sequence_width)
        )  # Here, seq_len = num_data_rows and seq_width = num_data_columns
        seq = torch.from_numpy(seq)

        # The input includes an additional channel used for the delimiter
        inp = torch.zeros(seq_len + 1, batch_size, self.sequence_width + 1)
        inp[:seq_len, :, :self.sequence_width] = seq
        inp[seq_len, :,
            self.sequence_width] = 1.0  # delimiter in our control channel
        outp = seq.clone()
        return inp.float(), outp.float()

    def calc_cost(self, Y_out, Y, batch_size):
        y_out_binarized = torch.sigmoid(Y_out).clone().data
        y_out_binarized.apply_(lambda x: 0 if x < 0.5 else 1)

        cost = torch.sum(torch.abs(y_out_binarized - Y.data))
        return cost.item() / batch_size

    def get_training_data(self):
        return self._data_maker(self.num_batches, self.batch_size,
                                self.sequence_width, self.sequence_min_len,
                                self.sequence_max_len)

    def test_model(self):
        self.machine.initialization(self.batch_size)  # Initializing states
        X, Y = self.get_sample_data()
        Y_out = torch.zeros(Y.shape)

        # Feeding the DNC network all the data first and then predicting output
        # by giving zero vector as input and previous read states and hidden vector
        # and thus training vector this way to give outputs matching the labels

        for i in range(X.shape[0]):
            self.machine(X[i])

        for i in range(Y.shape[0]):
            Y_out[i, :, :], _ = self.machine()

        loss = self.calc_loss(Y_out, Y)
        cost = self.calc_cost(
            Y_out, Y, self.batch_size
        )  # The cost is the number of error bits per sequence

        print("\n\nTest Data - Loss: " + str(loss.item()) + ", Cost: " +
              str(cost))

        X.squeeze(1)
        Y.squeeze(1)
        Y_out = torch.sigmoid(Y_out.squeeze(1))

        print("\n------Input---------\n")
        print(X.data)
        print("\n------Labels---------\n")
        print(Y.data)
        print("\n------Output---------")
        print((Y_out.data).apply_(lambda x: 0 if x < 0.5 else 1))
        print("\n")

        return loss.item(), cost, X, Y, Y_out

    def clip_grads(self):  # Clipping gradients for stability
        """Gradient clipping to the range [10, 10]."""
        parameters = list(
            filter(lambda p: p.grad is not None, self.machine.parameters()))
        for p in parameters:
            p.grad.data.clamp_(-10, 10)

    def train_model(self):
        # Here, the model is optimized using BCE Loss, however, it is evaluated using Number of error bits in predction and actual labels (cost)
        loss_list = []
        cost_list = []
        seq_length = []

        if (self.num_batches / 10) > 0:
            model_save_interval = (self.num_batches / 10)
        else:
            model_save_interval = 1

        for batch_num, X, Y in self.get_training_data():

            if batch_num > self.num_batches:
                break

            self.optimizer.zero_grad(
            )  # Making old gradients zero before calculating the fresh ones
            self.machine.initialization(self.batch_size)  # Initializing states
            Y_out = torch.zeros(Y.shape)

            # Feeding the NTM network all the data first and then predicting output
            # by giving zero vector as input and previous read states and hidden vector
            # and thus training vector this way to give outputs matching the labels

            for i in range(X.shape[0]):
                self.machine(X[i])

            for i in range(Y.shape[0]):
                Y_out[i, :, :], _ = self.machine()

            loss = self.calc_loss(Y_out, Y)
            loss.backward()
            self.clip_grads()
            self.optimizer.step()

            # The cost is the number of error bits per sequence
            cost = self.calc_cost(Y_out, Y, self.batch_size)

            loss_list += [loss.item()]
            cost_list += [cost]
            seq_length += [Y.shape[0]]

            if batch_num % model_save_interval == 0:
                self.save_model(batch_num)

            print("Batch: " + str(batch_num) + "/" + str(self.num_batches) +
                  ", Loss: " + str(loss.item()) + ", Cost: " + str(cost) +
                  ", Sequence Length: " + str(Y.shape[0]))

    def save_model(self, curr_epoch):
        if not os.path.exists("Saved_Models/" + self.name):
            os.mkdir("Saved_Models/" + self.name)
        state_dic = {
            'task_name': self.name,
            'start_epoch': curr_epoch + 1,
            'state_dict': self.machine.state_dict(),
            'optimizer_dic': self.optimizer.state_dict()
        }
        filename = "Saved_Models/" + self.name + "/" + self.name + "_" + str(
            curr_epoch) + "_saved_model.pth.tar"
        torch.save(state_dic, filename)

    def load_model(self, option, epoch):
        path = "Saved_Models/" + self.name + "/" + self.name + "_" + str(
            epoch) + "_saved_model.pth.tar"
        if option == 1:  # Loading for training
            checkpoint = torch.load(path)
            self.machine.load_state_dict(checkpoint['state_dict'])
            self.optimizer.load_state_dict(checkpoint['optimizer_dic'])
        else:  # Loading for testing
            checkpoint = torch.load(path)
            self.machine.load_state_dict(checkpoint['state_dict'])
            self.machine.eval()