class Trainer:
    """A Recurrent Attention Model trainer.

    All hyperparameters are provided by the user in the
    config file.
    """
    def __init__(self, config, data_loader):
        """
        Construct a new Trainer instance.

        Args:
            config: object containing command line arguments.
            data_loader: A data iterator.
        """
        self.config = config

        if config.use_gpu and torch.cuda.is_available():
            self.device = torch.device("cuda")
        else:
            self.device = torch.device("cpu")

        # glimpse network params
        self.patch_size = config.patch_size
        self.glimpse_scale = config.glimpse_scale
        self.num_patches = config.num_patches
        self.loc_hidden = config.loc_hidden
        self.glimpse_hidden = config.glimpse_hidden

        # core network params
        self.num_glimpses = config.num_glimpses
        self.hidden_size = config.hidden_size

        # reinforce params
        self.std = config.std
        self.M = config.M

        # data params
        if config.is_train:
            self.train_loader = data_loader[0]
            self.valid_loader = data_loader[1]
            self.num_train = len(self.train_loader.sampler.indices)
            self.num_valid = len(self.valid_loader.sampler.indices)
        else:
            self.test_loader = data_loader
            self.num_test = len(self.test_loader.dataset)
        self.num_classes = 25  #10
        self.num_channels = 1

        # training params
        self.epochs = config.epochs
        self.start_epoch = 0
        self.momentum = config.momentum
        self.lr = config.init_lr

        # misc params
        self.best = config.best
        self.ckpt_dir = config.ckpt_dir
        self.logs_dir = config.logs_dir
        self.best_valid_acc = 0.0
        self.counter = 0
        self.lr_patience = config.lr_patience
        self.train_patience = config.train_patience
        self.use_tensorboard = config.use_tensorboard
        self.resume = config.resume
        self.print_freq = config.print_freq
        self.plot_freq = config.plot_freq
        self.model_name = "ram_{}_{}x{}_{}".format(
            config.num_glimpses,
            config.patch_size,
            config.patch_size,
            config.glimpse_scale,
        )

        self.plot_dir = "./plots/" + self.model_name + "/"
        if not os.path.exists(self.plot_dir):
            os.makedirs(self.plot_dir)

        # configure tensorboard logging
        if self.use_tensorboard:
            tensorboard_dir = self.logs_dir + self.model_name
            print("[*] Saving tensorboard logs to {}".format(tensorboard_dir))
            if not os.path.exists(tensorboard_dir):
                os.makedirs(tensorboard_dir)
            configure(tensorboard_dir)

        # build RAM model
        self.model = RecurrentAttention(
            self.patch_size,
            self.num_patches,
            self.glimpse_scale,
            self.num_channels,
            self.loc_hidden,
            self.glimpse_hidden,
            self.std,
            self.hidden_size,
            self.num_classes,
        )
        self.model.to(self.device)

        # initialize optimizer and scheduler
        self.optimizer = torch.optim.Adam(self.model.parameters(),
                                          lr=self.config.init_lr)
        self.scheduler = ReduceLROnPlateau(self.optimizer,
                                           "min",
                                           patience=self.lr_patience)

    def gmdataset(self):
        import pandas as pd
        #data = pd.read_csv("check.csv", index_col ="name")
        gW = pd.read_csv("goodware.csv", index_col="name")
        mW = pd.read_csv("malware.csv", index_col="name")

        out = mW.append(gW)
        data = out
        out.drop('(BAD)', axis=1, inplace=True)
        out.drop('STD', axis=1, inplace=True)
        out.drop('SHLD', axis=1, inplace=True)
        out.drop('SETLE', axis=1, inplace=True)
        out.drop('SETB', axis=1, inplace=True)
        out.drop('SBB', axis=1, inplace=True)
        out.drop('RDTSC', axis=1, inplace=True)
        out.drop('PUSHF', axis=1, inplace=True)
        out.drop('FSTCW', axis=1, inplace=True)
        out.drop('FDIVP', axis=1, inplace=True)
        out.drop('FILD', axis=1, inplace=True)
        out.drop('RETN', axis=1, inplace=True)
        out.drop('LEA', axis=1, inplace=True)
        out.drop('IMUL', axis=1, inplace=True)

        from sklearn.model_selection import train_test_split
        #print(data['labels'])
        M = data.values
        X = M[:, :-1]
        Y = M[:, -1]
        #print(Y)
        #X_train, X_test, y_train, y_test = train_test_split(X, Y, test_size=0.33, random_state=42)
        #print("HHHHHHH: ",X_train.shape, y_train.shape)
        import numpy as np

        # print(X_train.shape)
        #x_train = np.reshape(X_train, (X_train.shape[0],5, 5,1))
        #padie=np.pad(X_train,  ((0,0),(0,759)), 'constant', constant_values=0)
        padie = np.pad(X, ((0, 0), (0, 1)), 'constant', constant_values=0)
        print(padie.shape)
        x = np.reshape(padie, (padie.shape[0], 1, 4, 4))
        return x, Y

    def alldatacsv(self):
        import numpy as np  # linear algebra
        import pandas as pd  # data processing, CSV file I/O (e.g. pd.read_csv)

        from sklearn.feature_extraction.text import CountVectorizer
        from keras.preprocessing.text import Tokenizer
        from keras.preprocessing.sequence import pad_sequences
        from keras.models import Sequential
        from keras.layers import Dense, Embedding, LSTM, SpatialDropout1D, Dropout
        from sklearn.model_selection import train_test_split
        from keras.utils.np_utils import to_categorical
        import re

        # Input data files are available in the "../input/" directory.
        # For example, running this (by clicking run or pressing Shift+Enter) will list the files in the input directory
        """Only keeping the necessary columns."""

        import pandas as pd
        data = pd.read_csv("AllData.csv")

        data = data.drop(['Unnamed: 0'], axis=1)

        data = data.rename(columns={'Text': 'text', 'Label': 'sentiment'})

        data

        #data = pd.read_csv('../input/Sentiment.csv')
        # Keeping only the neccessary columns
        data = data[['text', 'sentiment']]

        pos = data[data['sentiment'] == 1]
        pos.shape[0]

        #data = data[data.sentiment != "Neutral"]
        data['text'] = data['text'].apply(lambda x: x.lower())
        data['text'] = data['text'].apply(
            (lambda x: re.sub('[^a-zA-z0-9\s]', '', x)))

        #print(data[ data['sentiment'] == 1].size)
        #print(data[ data['sentiment'] == 0].size)

        for idx, row in data.iterrows():
            row[0] = row[0].replace('rt', ' ')

        max_fatures = 2000
        tokenizer = Tokenizer(num_words=max_fatures, split=' ')
        tokenizer.fit_on_texts(data['text'].values)
        X = tokenizer.texts_to_sequences(data['text'].values)
        X = pad_sequences(X)

        data['sentiment']

        pd.DataFrame(data=X[1:, -20000:],
                     index=X[1:, 0])  # 1st row as the column name

        #X=X[0:,-20000:]

        X.shape
        """# Train and Test Dataset Declaration"""

        Y = pd.get_dummies(data['sentiment']).values
        #X_train, X_test, Y_train, Y_test = train_test_split(X,Y, test_size = 0.33, random_state = 78)

        #conu = 0
        #        for x in Y_test:
        #           if x.argmax()== 1:
        #              conu = conu + 1
        #     conu

        import numpy as np
        import matplotlib.pyplot as plt
        #print(X_train.shape)
        #x_train = np.reshape(X_train, (X_train.shape[0],5, 5,1))
        padie = np.pad(X, ((0, 0), (0, 251)), 'constant', constant_values=0)
        padie = padie[:, 459 * 459 - 50176:]
        x = np.reshape(padie, (padie.shape[0], 1, 224, 224))
        #x_test = np.reshape(x_test, (x_test.shape[0],2,2, 1))
        #print(Y)
        sk = pd.DataFrame(data=Y, columns=[0, 1])
        inverted = sk.idxmax(1).values
        ss = np.rint(inverted)
        Y = ss
        #for i in range(0,Y.shape[0]):
        #     if Y[i] == 1:
        #        #print("YESSSSSSSS")
        #       string = "imgs/" + str(i) + ".png"
        #       plt.imsave(string,x[i][0,:,:])
        #       qq = i
        #x = x[:]
        #Y = Y[int(qq-qq/2):]
        #print(type(X))
        return x, Y

    def batadal(self):
        import pandas as pd

        import numpy as np

        fD = pd.read_csv("newDatasets/BATADAL_dataset04.csv", header=None)

        #fD = pd.read_csv("/content/drive/My Drive/newDatasets/BATADAL_dataset02 (1).csv" , header=None)

        test = pd.read_csv("newDatasets/BATADAL_test_dataset.csv", header=None)

        test = test.drop(columns=0)
        test = test.drop([0], axis=0)

        Data = fD.drop(columns=0)

        Data = Data.drop([0], axis=0)

        nData = Data.values

        nData.shape

        testData = test.values

        xData = nData[:, :43]
        yData = nData[:, 43]

        xData.shape

        testData.shape

        xData = np.pad(xData, ((0, 0), (0, 6)), 'constant', constant_values=0)

        testData = np.pad(testData, ((0, 0), (0, 6)),
                          'constant',
                          constant_values=0)

        test[:]

        xData.shape

        xData = xData.reshape(-1, 1, 7, 7)

        testData = testData.reshape(-1, 1, 7, 7)

        from sklearn.model_selection import train_test_split
        X_train, X_test, y_train, y_test = train_test_split(xData,
                                                            yData,
                                                            test_size=0.1,
                                                            random_state=42)
        #print(X_train, y_train)
        #print("JJ: ", type(xData))
        return xData, yData

    def Malimg(self):
        import tensorflow as tf
        import keras
        import numpy as np
        from sklearn.preprocessing import StandardScaler
        from sklearn.model_selection import train_test_split
        dataset = np.load('malimg.npz', allow_pickle=True)

        BATCH_SIZE = 256
        CELL_SIZE = 256
        DROPOUT_RATE = 0.85
        LEARNING_RATE = 1e-3
        NODE_SIZE = [512, 256, 128]
        NUM_LAYERS = 5

        features = dataset['arr'][:, 0]
        features = np.array([feature for feature in features])
        features = np.reshape(
            features,
            (features.shape[0], features.shape[1] * features.shape[2]))
        r, c = features.shape

        print("Number of Samples", r)
        print("Number of Features", c)

        if 1 == 1:
            features = StandardScaler().fit_transform(features)

        labels = dataset['arr'][:, 1]
        labels = np.array([label for label in labels])

        one_hot = np.zeros((labels.shape[0], labels.max() + 1))
        one_hot[np.arange(labels.shape[0]), labels] = 1
        labels = one_hot
        labels[labels == 0] = 0
        num_features = features.shape[1]
        num_classes = labels.shape[1]

        Y = labels
        X = features

        print("Shape of Labels", Y.shape)
        print("Shape of Features", X.shape)
        train_features, test_features, train_labels, test_labels = train_test_split(
            features, labels, test_size=0.1, stratify=labels)  #10% Test size

        train_size = int(train_features.shape[0])
        train_features = train_features[:train_size -
                                        (train_size % BATCH_SIZE)]
        train_labels = train_labels[:train_size - (train_size % BATCH_SIZE)]

        test_size = int(test_features.shape[0])
        test_features = test_features[:test_size - (test_size % BATCH_SIZE)]
        test_labels = test_labels[:test_size - (test_size % BATCH_SIZE)]

        fsize = int(features.shape[0])
        features = features[:fsize - (fsize % BATCH_SIZE)]
        labels = labels[:fsize - (fsize % BATCH_SIZE)]

        r, c = train_features.shape
        print("Number of Training Samples", r)
        print("Number of Training Features", c)

        r, c = test_features.shape
        print("Number of Test Samples", r)
        print("Number of Test Features", c)
        #print(train_labels.shape)

        #print(tf.reshape(test_features[1], [32,32]))

        print(train_features.shape, test_features.shape, train_labels.shape,
              test_labels.shape)

        #print(train_labels)
        train_X = train_features.reshape(-1, 1, 32, 32)
        feat = features.reshape(-1, 1, 32, 32)
        test_X = test_features.reshape(-1, 32, 32, 1)
        Unchanined = X.reshape(-1, 32, 32, 1)
        y_test_non_category = [np.argmax(t) for t in labels]
        print("LABELS", np.asarray(y_test_non_category))
        return feat, np.asarray(y_test_non_category)

    def SWAT(self):
        import pandas as pd
        import numpy as np

        fD = pd.read_excel("newDatasets/SWaT/SWaT_Dataset_Attack_v0.xlsx",
                           header=None)
        Data = fD.drop(columns=0)
        Data = Data.drop([0, 1], axis=0)
        xData = Data.values[:, :51]
        yData = Data.values[:, 51]
        count = 0
        for i in yData:
            if i == 'Normal':
                yData[count] = 0
            else:
                yData[count] = 1
            count = count + 1
        xData.shape
        xData = np.pad(xData, ((0, 0), (0, 13)), 'constant', constant_values=0)
        xData = xData.reshape(xData.shape[0], 8, 8, 1)
        return xData[:200], yData[:200]

    def reset(self):
        h_t = torch.zeros(
            self.batch_size,
            self.hidden_size,
            dtype=torch.float,
            device=self.device,
            requires_grad=True,
        )
        l_t = torch.FloatTensor(self.batch_size, 2).uniform_(-1,
                                                             1).to(self.device)
        l_t.requires_grad = True

        return h_t, l_t

    def train(self):
        """Train the model on the training set.

        A checkpoint of the model is saved after each epoch
        and if the validation accuracy is improved upon,
        a separate ckpt is created for use on the test set.
        """
        # load the most recent checkpoint
        if self.resume:
            self.load_checkpoint(best=False)

        print("\n[*] Train on {} samples, validate on {} samples".format(
            self.num_train, self.num_valid))

        for epoch in range(self.start_epoch, self.epochs):

            print("\nEpoch: {}/{} - LR: {:.6f}".format(
                epoch + 1, self.epochs, self.optimizer.param_groups[0]["lr"]))

            # train for 1 epoch
            train_loss, train_acc = self.train_one_epoch(epoch)

            # evaluate on validation set
            valid_loss, valid_acc = self.validate(epoch)

            # # reduce lr if validation loss plateaus
            self.scheduler.step(-valid_acc)

            is_best = valid_acc > self.best_valid_acc
            msg1 = "train loss: {:.3f} - train acc: {:.3f} "
            msg2 = "- val loss: {:.3f} - val acc: {:.3f} - val err: {:.3f}"
            if is_best:
                self.counter = 0
                msg2 += " [*]"
            msg = msg1 + msg2
            print(
                msg.format(train_loss, train_acc, valid_loss, valid_acc,
                           100 - valid_acc))

            # check for improvement
            if not is_best:
                self.counter += 1
            if self.counter > self.train_patience:
                print("[!] No improvement in a while, stopping training.")
                return
            self.best_valid_acc = max(valid_acc, self.best_valid_acc)
            self.save_checkpoint(
                {
                    "epoch": epoch + 1,
                    "model_state": self.model.state_dict(),
                    "optim_state": self.optimizer.state_dict(),
                    "best_valid_acc": self.best_valid_acc,
                },
                is_best,
            )

    def train_one_epoch(self, epoch):
        """
        Train the model for 1 epoch of the training set.

        An epoch corresponds to one full pass through the entire
        training set in successive mini-batches.

        This is used by train() and should not be called manually.
        """

        import pandas as pd
        import numpy as np

        self.model.train()
        batch_time = AverageMeter()
        losses = AverageMeter()
        accs = AverageMeter()

        tic = time.time()
        with tqdm(total=self.num_train) as pbar:
            for i, (x, y) in enumerate(self.train_loader):
                self.optimizer.zero_grad()

                x, y = x.to(self.device), y.to(self.device)
                x1, y1 = self.SWAT(
                )  #self.gmdataset()#self.batadal()#self.alldatacsv()#self.gmdataset()
                x1 = x1.astype(np.float32)
                y1 = y1.astype(np.float32)

                x, y = torch.from_numpy(x1).float(), torch.from_numpy(
                    y1).long()
                #print("Here", y)
                plot = False
                if (epoch % self.plot_freq == 0) and (i == 0):
                    plot = True

                # initialize location vector and hidden state
                self.batch_size = x.shape[0]
                h_t, l_t = self.reset()

                # save images
                imgs = []
                imgs.append(x[0:9])

                # extract the glimpses
                locs = []
                log_pi = []
                baselines = []
                for t in range(self.num_glimpses - 1):
                    # forward pass through model
                    h_t, l_t, b_t, p = self.model(x, l_t, h_t)

                    # store
                    locs.append(l_t[0:9])
                    baselines.append(b_t)
                    log_pi.append(p)

                # last iteration
                h_t, l_t, b_t, log_probas, p = self.model(x,
                                                          l_t,
                                                          h_t,
                                                          last=True)
                log_pi.append(p)
                baselines.append(b_t)
                locs.append(l_t[0:9])

                # convert list to tensors and reshape
                baselines = torch.stack(baselines).transpose(1, 0)
                log_pi = torch.stack(log_pi).transpose(1, 0)

                # calculate reward
                predicted = torch.max(log_probas, 1)[1]
                R = (predicted.detach() == y).float()
                R = R.unsqueeze(1).repeat(1, self.num_glimpses)

                # compute losses for differentiable modules
                loss_action = F.nll_loss(log_probas, y)
                loss_baseline = F.mse_loss(baselines, R)

                # compute reinforce loss
                # summed over timesteps and averaged across batch
                adjusted_reward = R - baselines.detach()
                loss_reinforce = torch.sum(-log_pi * adjusted_reward, dim=1)
                loss_reinforce = torch.mean(loss_reinforce, dim=0)

                # sum up into a hybrid loss
                loss = loss_action + loss_baseline + loss_reinforce * 0.01

                # compute accuracy
                correct = (predicted == y).float()
                acc = 100 * (correct.sum() / len(y))
                #print("Predicted: ", predicted, "\nTrue", y)

                # store
                losses.update(loss.item(), x.size()[0])
                accs.update(acc.item(), x.size()[0])

                # compute gradients and update SGD
                loss.backward()
                self.optimizer.step()

                # measure elapsed time
                toc = time.time()
                batch_time.update(toc - tic)

                pbar.set_description(
                    ("{:.1f}s - loss: {:.3f} - acc: {:.3f}".format(
                        (toc - tic), loss.item(), acc.item())))
                pbar.update(self.batch_size)

                # dump the glimpses and locs
                if plot:
                    imgs = [g.cpu().data.numpy().squeeze() for g in imgs]
                    locs = [l.cpu().data.numpy() for l in locs]
                    pickle.dump(
                        imgs,
                        open(self.plot_dir + "g_{}.p".format(epoch + 1), "wb"))
                    pickle.dump(
                        locs,
                        open(self.plot_dir + "l_{}.p".format(epoch + 1), "wb"))

                # log to tensorboard
                if self.use_tensorboard:
                    iteration = epoch * len(self.train_loader) + i
                    log_value("train_loss", losses.avg, iteration)
                    log_value("train_acc", accs.avg, iteration)

            return losses.avg, accs.avg

    @torch.no_grad()
    def validate(self, epoch):
        """Evaluate the RAM model on the validation set.
        """
        import torch
        import numpy as np
        losses = AverageMeter()
        accs = AverageMeter()

        for i, (x, y) in enumerate(self.valid_loader):
            x, y = x.to(self.device), y.to(self.device)
            x1, y1 = self.gmdataset(
            )  #self.batadal()#self.alldatacsv()#self.gmdataset()
            x1 = x1.astype(np.float32)
            y1 = y1.astype(np.float32)

            x, y = torch.from_numpy(x1).float(), torch.from_numpy(y1).long()
            # duplicate M times
            x = x.repeat(self.M, 1, 1, 1)

            # initialize location vector and hidden state
            self.batch_size = x.shape[0]
            h_t, l_t = self.reset()

            # extract the glimpses
            log_pi = []
            baselines = []
            for t in range(self.num_glimpses - 1):
                # forward pass through model
                h_t, l_t, b_t, p = self.model(x, l_t, h_t)

                # store
                baselines.append(b_t)
                log_pi.append(p)

            # last iteration
            h_t, l_t, b_t, log_probas, p = self.model(x, l_t, h_t, last=True)
            log_pi.append(p)
            baselines.append(b_t)

            # convert list to tensors and reshape
            baselines = torch.stack(baselines).transpose(1, 0)
            log_pi = torch.stack(log_pi).transpose(1, 0)

            # average
            log_probas = log_probas.view(self.M, -1, log_probas.shape[-1])
            log_probas = torch.mean(log_probas, dim=0)

            baselines = baselines.contiguous().view(self.M, -1,
                                                    baselines.shape[-1])
            baselines = torch.mean(baselines, dim=0)

            log_pi = log_pi.contiguous().view(self.M, -1, log_pi.shape[-1])
            log_pi = torch.mean(log_pi, dim=0)

            # calculate reward
            predicted = torch.max(log_probas, 1)[1]
            R = (predicted.detach() == y).float()
            R = R.unsqueeze(1).repeat(1, self.num_glimpses)

            # compute losses for differentiable modules
            loss_action = F.nll_loss(log_probas, y)
            loss_baseline = F.mse_loss(baselines, R)

            # compute reinforce loss
            adjusted_reward = R - baselines.detach()
            loss_reinforce = torch.sum(-log_pi * adjusted_reward, dim=1)
            loss_reinforce = torch.mean(loss_reinforce, dim=0)

            # sum up into a hybrid loss
            loss = loss_action + loss_baseline + loss_reinforce * 0.01

            # compute accuracy
            correct = (predicted == y).float()
            acc = 100 * (correct.sum() / len(y))
            count = 0
            countFP = 0
            countFN = 0
            countTN = 0
            countTP = 0
            for i in range(len(correct)):
                if (correct[i] == 0):
                    count = count + 1
                if (predicted[i] == 1 and y[i] == 0):  #False Positive
                    countFP = countFP + 1
                if (predicted[i] == 0 and y[i] == 1):  #False Negative
                    countFN = countFN + 1
                if (predicted[i] == 0 and y[i] == 0):  #True Negative
                    countTN = countTN + 1
                if (predicted[i] == 1 and y[i] == 1):  #True Positive
                    countTP = countTP + 1

            print("Total: ", len(correct), "Wrong: ", count)
            print("TP: ", countTP, "TN: ", countTN, "FN: ", countFN, "FP: ",
                  countFP)

            # store
            losses.update(loss.item(), x.size()[0])
            accs.update(acc.item(), x.size()[0])

            # log to tensorboard
            if self.use_tensorboard:
                iteration = epoch * len(self.valid_loader) + i
                log_value("valid_loss", losses.avg, iteration)
                log_value("valid_acc", accs.avg, iteration)

        return losses.avg, accs.avg

    @torch.no_grad()
    def test(self):
        """Test the RAM model.

        This function should only be called at the very
        end once the model has finished training.
        """
        correct = 0

        # load the best checkpoint
        self.load_checkpoint(best=self.best)

        for i, (x, y) in enumerate(self.test_loader):
            x, y = x.to(self.device), y.to(self.device)

            # duplicate M times
            x = x.repeat(self.M, 1, 1, 1)

            # initialize location vector and hidden state
            self.batch_size = x.shape[0]
            h_t, l_t = self.reset()

            # extract the glimpses
            for t in range(self.num_glimpses - 1):
                # forward pass through model
                h_t, l_t, b_t, p = self.model(x, l_t, h_t)

            # last iteration
            h_t, l_t, b_t, log_probas, p = self.model(x, l_t, h_t, last=True)

            log_probas = log_probas.view(self.M, -1, log_probas.shape[-1])
            log_probas = torch.mean(log_probas, dim=0)

            pred = log_probas.data.max(1, keepdim=True)[1]
            correct += pred.eq(y.data.view_as(pred)).cpu().sum()

        perc = (100.0 * correct) / (self.num_test)
        error = 100 - perc
        print("[*] Test Acc: {}/{} ({:.2f}% - {:.2f}%)".format(
            correct, self.num_test, perc, error))

    def save_checkpoint(self, state, is_best):
        """Saves a checkpoint of the model.

        If this model has reached the best validation accuracy thus
        far, a seperate file with the suffix `best` is created.
        """
        filename = self.model_name + "_ckpt.pth.tar"
        ckpt_path = os.path.join(self.ckpt_dir, filename)
        torch.save(state, ckpt_path)
        if is_best:
            filename = self.model_name + "_model_best.pth.tar"
            shutil.copyfile(ckpt_path, os.path.join(self.ckpt_dir, filename))

    def load_checkpoint(self, best=False):
        """Load the best copy of a model.

        This is useful for 2 cases:
        - Resuming training with the most recent model checkpoint.
        - Loading the best validation model to evaluate on the test data.

        Args:
            best: if set to True, loads the best model.
                Use this if you want to evaluate your model
                on the test data. Else, set to False in which
                case the most recent version of the checkpoint
                is used.
        """
        print("[*] Loading model from {}".format(self.ckpt_dir))

        filename = self.model_name + "_ckpt.pth.tar"
        if best:
            filename = self.model_name + "_model_best.pth.tar"
        ckpt_path = os.path.join(self.ckpt_dir, filename)
        ckpt = torch.load(ckpt_path)

        # load variables from checkpoint
        self.start_epoch = ckpt["epoch"]
        self.best_valid_acc = ckpt["best_valid_acc"]
        self.model.load_state_dict(ckpt["model_state"])
        self.optimizer.load_state_dict(ckpt["optim_state"])

        if best:
            print("[*] Loaded {} checkpoint @ epoch {} "
                  "with best valid acc of {:.3f}".format(
                      filename, ckpt["epoch"], ckpt["best_valid_acc"]))
        else:
            print("[*] Loaded {} checkpoint @ epoch {}".format(
                filename, ckpt["epoch"]))
Пример #2
0
class Trainer:
    """A Recurrent Attention Model trainer.

    All hyperparameters are provided by the user in the
    config file.
    """

    def __init__(self, config, data_loader):
        """
        Construct a new Trainer instance.

        Args:
            config: object containing command line arguments.
            data_loader: A data iterator.
        """
        self.config = config

        if config.use_gpu and torch.cuda.is_available():
            self.device = torch.device("cuda")
        else:
            self.device = torch.device("cpu")

        # glimpse network params
        self.patch_size = config.patch_size
        self.glimpse_scale = config.glimpse_scale
        self.num_patches = config.num_patches
        self.loc_hidden = config.loc_hidden
        self.glimpse_hidden = config.glimpse_hidden

        # core network params
        self.num_glimpses = config.num_glimpses
        self.hidden_size = config.hidden_size

        # reinforce params
        self.std = config.std
        self.M = config.M

        # data params
        if config.is_train:
            self.train_loader = data_loader[0]
            self.valid_loader = data_loader[1]
            self.num_train = len(self.train_loader.sampler.indices)
            self.num_valid = len(self.valid_loader.sampler.indices)
        else:
            self.test_loader = data_loader
            self.num_test = len(self.test_loader.dataset)
        self.num_classes = 10
        self.num_channels = 1

        # training params
        self.epochs = config.epochs
        self.start_epoch = 0
        self.momentum = config.momentum
        self.lr = config.init_lr

        # misc params
        self.best = config.best
        self.ckpt_dir = config.ckpt_dir
        self.logs_dir = config.logs_dir
        self.best_valid_acc = 0.0
        self.counter = 0
        self.lr_patience = config.lr_patience
        self.train_patience = config.train_patience
        self.use_tensorboard = config.use_tensorboard
        self.resume = config.resume
        self.print_freq = config.print_freq
        self.plot_freq = config.plot_freq
        self.model_name = "ram_{}_{}x{}_{}".format(
            config.num_glimpses,
            config.patch_size,
            config.patch_size,
            config.glimpse_scale,
        )

        self.plot_dir = "./plots/" + self.model_name + "/"
        if not os.path.exists(self.plot_dir):
            os.makedirs(self.plot_dir)

        # configure tensorboard logging
        if self.use_tensorboard:
            tensorboard_dir = self.logs_dir + self.model_name
            print("[*] Saving tensorboard logs to {}".format(tensorboard_dir))
            if not os.path.exists(tensorboard_dir):
                os.makedirs(tensorboard_dir)
            configure(tensorboard_dir)

        # build RAM model
        self.model = RecurrentAttention(
            self.patch_size,
            self.num_patches,
            self.glimpse_scale,
            self.num_channels,
            self.loc_hidden,
            self.glimpse_hidden,
            self.std,
            self.hidden_size,
            self.num_classes,
        )
        self.model.to(self.device)

        # initialize optimizer and scheduler
        self.optimizer = torch.optim.Adam(
            self.model.parameters(), lr=self.config.init_lr
        )
        self.scheduler = ReduceLROnPlateau(
            self.optimizer, "min", patience=self.lr_patience
        )

    def reset(self):
        h_t = torch.zeros(
            self.batch_size,
            self.hidden_size,
            dtype=torch.float,
            device=self.device,
            requires_grad=True,
        )
        l_t = torch.FloatTensor(self.batch_size, 2).uniform_(-1, 1).to(self.device)
        l_t.requires_grad = True

        return h_t, l_t

    def train(self):
        """Train the model on the training set.

        A checkpoint of the model is saved after each epoch
        and if the validation accuracy is improved upon,
        a separate ckpt is created for use on the test set.
        """
        # load the most recent checkpoint
        if self.resume:
            self.load_checkpoint(best=False)

        print(
            "\n[*] Train on {} samples, validate on {} samples".format(
                self.num_train, self.num_valid
            )
        )

        for epoch in range(self.start_epoch, self.epochs):

            print(
                "\nEpoch: {}/{} - LR: {:.6f}".format(
                    epoch + 1, self.epochs, self.optimizer.param_groups[0]["lr"]
                )
            )

            # train for 1 epoch
            train_loss, train_acc = self.train_one_epoch(epoch)

            # evaluate on validation set
            valid_loss, valid_acc = self.validate(epoch)

            # # reduce lr if validation loss plateaus
            self.scheduler.step(-valid_acc)

            is_best = valid_acc > self.best_valid_acc
            msg1 = "train loss: {:.3f} - train acc: {:.3f} "
            msg2 = "- val loss: {:.3f} - val acc: {:.3f} - val err: {:.3f}"
            if is_best:
                self.counter = 0
                msg2 += " [*]"
            msg = msg1 + msg2
            print(
                msg.format(
                    train_loss, train_acc, valid_loss, valid_acc, 100 - valid_acc
                )
            )

            # check for improvement
            if not is_best:
                self.counter += 1
            if self.counter > self.train_patience:
                print("[!] No improvement in a while, stopping training.")
                return
            self.best_valid_acc = max(valid_acc, self.best_valid_acc)
            self.save_checkpoint(
                {
                    "epoch": epoch + 1,
                    "model_state": self.model.state_dict(),
                    "optim_state": self.optimizer.state_dict(),
                    "best_valid_acc": self.best_valid_acc,
                },
                is_best,
            )

    def train_one_epoch(self, epoch):
        """
        Train the model for 1 epoch of the training set.

        An epoch corresponds to one full pass through the entire
        training set in successive mini-batches.

        This is used by train() and should not be called manually.
        """
        self.model.train()
        batch_time = AverageMeter()
        losses = AverageMeter()
        accs = AverageMeter()

        tic = time.time()
        with tqdm(total=self.num_train) as pbar:
            for i, (x, y) in enumerate(self.train_loader):
                self.optimizer.zero_grad()

                x, y = x.to(self.device), y.to(self.device)

                plot = False
                if (epoch % self.plot_freq == 0) and (i == 0):
                    plot = True

                # initialize location vector and hidden state
                self.batch_size = x.shape[0]
                h_t, l_t = self.reset()

                # save images
                imgs = []
                imgs.append(x[0:9])

                # extract the glimpses
                locs = []
                log_pi = []
                baselines = []
                for t in range(self.num_glimpses - 1):
                    # forward pass through model
                    h_t, l_t, b_t, p = self.model(x, l_t, h_t)

                    # store
                    locs.append(l_t[0:9])
                    baselines.append(b_t)
                    log_pi.append(p)

                # last iteration
                h_t, l_t, b_t, log_probas, p = self.model(x, l_t, h_t, last=True)
                log_pi.append(p)
                baselines.append(b_t)
                locs.append(l_t[0:9])

                # convert list to tensors and reshape
                baselines = torch.stack(baselines).transpose(1, 0)
                log_pi = torch.stack(log_pi).transpose(1, 0)

                # calculate reward
                predicted = torch.max(log_probas, 1)[1]
                R = (predicted.detach() == y).float()
                R = R.unsqueeze(1).repeat(1, self.num_glimpses)

                # compute losses for differentiable modules
                loss_action = F.nll_loss(log_probas, y)
                loss_baseline = F.mse_loss(baselines, R)

                # compute reinforce loss
                # summed over timesteps and averaged across batch
                adjusted_reward = R - baselines.detach()
                loss_reinforce = torch.sum(-log_pi * adjusted_reward, dim=1)
                loss_reinforce = torch.mean(loss_reinforce, dim=0)

                # sum up into a hybrid loss
                loss = loss_action + loss_baseline + loss_reinforce * 0.01

                # compute accuracy
                correct = (predicted == y).float()
                acc = 100 * (correct.sum() / len(y))

                # store
                losses.update(loss.item(), x.size()[0])
                accs.update(acc.item(), x.size()[0])

                # compute gradients and update SGD
                loss.backward()
                self.optimizer.step()

                # measure elapsed time
                toc = time.time()
                batch_time.update(toc - tic)

                pbar.set_description(
                    (
                        "{:.1f}s - loss: {:.3f} - acc: {:.3f}".format(
                            (toc - tic), loss.item(), acc.item()
                        )
                    )
                )
                pbar.update(self.batch_size)

                # dump the glimpses and locs
                if plot:
                    imgs = [g.cpu().data.numpy().squeeze() for g in imgs]
                    locs = [l.cpu().data.numpy() for l in locs]
                    pickle.dump(
                        imgs, open(self.plot_dir + "g_{}.p".format(epoch + 1), "wb")
                    )
                    pickle.dump(
                        locs, open(self.plot_dir + "l_{}.p".format(epoch + 1), "wb")
                    )

                # log to tensorboard
                if self.use_tensorboard:
                    iteration = epoch * len(self.train_loader) + i
                    log_value("train_loss", losses.avg, iteration)
                    log_value("train_acc", accs.avg, iteration)

            return losses.avg, accs.avg

    @torch.no_grad()
    def validate(self, epoch):
        """Evaluate the RAM model on the validation set.
        """
        losses = AverageMeter()
        accs = AverageMeter()

        for i, (x, y) in enumerate(self.valid_loader):
            x, y = x.to(self.device), y.to(self.device)

            # duplicate M times
            x = x.repeat(self.M, 1, 1, 1)

            # initialize location vector and hidden state
            self.batch_size = x.shape[0]
            h_t, l_t = self.reset()

            # extract the glimpses
            log_pi = []
            baselines = []
            for t in range(self.num_glimpses - 1):
                # forward pass through model
                h_t, l_t, b_t, p = self.model(x, l_t, h_t)

                # store
                baselines.append(b_t)
                log_pi.append(p)

            # last iteration
            h_t, l_t, b_t, log_probas, p = self.model(x, l_t, h_t, last=True)
            log_pi.append(p)
            baselines.append(b_t)

            # convert list to tensors and reshape
            baselines = torch.stack(baselines).transpose(1, 0)
            log_pi = torch.stack(log_pi).transpose(1, 0)

            # average
            log_probas = log_probas.view(self.M, -1, log_probas.shape[-1])
            log_probas = torch.mean(log_probas, dim=0)

            baselines = baselines.contiguous().view(self.M, -1, baselines.shape[-1])
            baselines = torch.mean(baselines, dim=0)

            log_pi = log_pi.contiguous().view(self.M, -1, log_pi.shape[-1])
            log_pi = torch.mean(log_pi, dim=0)

            # calculate reward
            predicted = torch.max(log_probas, 1)[1]
            R = (predicted.detach() == y).float()
            R = R.unsqueeze(1).repeat(1, self.num_glimpses)

            # compute losses for differentiable modules
            loss_action = F.nll_loss(log_probas, y)
            loss_baseline = F.mse_loss(baselines, R)

            # compute reinforce loss
            adjusted_reward = R - baselines.detach()
            loss_reinforce = torch.sum(-log_pi * adjusted_reward, dim=1)
            loss_reinforce = torch.mean(loss_reinforce, dim=0)

            # sum up into a hybrid loss
            loss = loss_action + loss_baseline + loss_reinforce * 0.01

            # compute accuracy
            correct = (predicted == y).float()
            acc = 100 * (correct.sum() / len(y))

            # store
            losses.update(loss.item(), x.size()[0])
            accs.update(acc.item(), x.size()[0])

            # log to tensorboard
            if self.use_tensorboard:
                iteration = epoch * len(self.valid_loader) + i
                log_value("valid_loss", losses.avg, iteration)
                log_value("valid_acc", accs.avg, iteration)

        return losses.avg, accs.avg

    @torch.no_grad()
    def test(self):
        """Test the RAM model.

        This function should only be called at the very
        end once the model has finished training.
        """
        correct = 0

        # load the best checkpoint
        self.load_checkpoint(best=self.best)

        for i, (x, y) in enumerate(self.test_loader):
            x, y = x.to(self.device), y.to(self.device)

            # duplicate M times
            x = x.repeat(self.M, 1, 1, 1)

            # initialize location vector and hidden state
            self.batch_size = x.shape[0]
            h_t, l_t = self.reset()

            # extract the glimpses
            for t in range(self.num_glimpses - 1):
                # forward pass through model
                h_t, l_t, b_t, p = self.model(x, l_t, h_t)

            # last iteration
            h_t, l_t, b_t, log_probas, p = self.model(x, l_t, h_t, last=True)

            log_probas = log_probas.view(self.M, -1, log_probas.shape[-1])
            log_probas = torch.mean(log_probas, dim=0)

            pred = log_probas.data.max(1, keepdim=True)[1]
            correct += pred.eq(y.data.view_as(pred)).cpu().sum()

        perc = (100.0 * correct) / (self.num_test)
        error = 100 - perc
        print(
            "[*] Test Acc: {}/{} ({:.2f}% - {:.2f}%)".format(
                correct, self.num_test, perc, error
            )
        )

    def save_checkpoint(self, state, is_best):
        """Saves a checkpoint of the model.

        If this model has reached the best validation accuracy thus
        far, a seperate file with the suffix `best` is created.
        """
        filename = self.model_name + "_ckpt.pth.tar"
        ckpt_path = os.path.join(self.ckpt_dir, filename)
        torch.save(state, ckpt_path)
        if is_best:
            filename = self.model_name + "_model_best.pth.tar"
            shutil.copyfile(ckpt_path, os.path.join(self.ckpt_dir, filename))

    def load_checkpoint(self, best=False):
        """Load the best copy of a model.

        This is useful for 2 cases:
        - Resuming training with the most recent model checkpoint.
        - Loading the best validation model to evaluate on the test data.

        Args:
            best: if set to True, loads the best model.
                Use this if you want to evaluate your model
                on the test data. Else, set to False in which
                case the most recent version of the checkpoint
                is used.
        """
        print("[*] Loading model from {}".format(self.ckpt_dir))

        filename = self.model_name + "_ckpt.pth.tar"
        if best:
            filename = self.model_name + "_model_best.pth.tar"
        ckpt_path = os.path.join(self.ckpt_dir, filename)
        ckpt = torch.load(ckpt_path)

        # load variables from checkpoint
        self.start_epoch = ckpt["epoch"]
        self.best_valid_acc = ckpt["best_valid_acc"]
        self.model.load_state_dict(ckpt["model_state"])
        self.optimizer.load_state_dict(ckpt["optim_state"])

        if best:
            print(
                "[*] Loaded {} checkpoint @ epoch {} "
                "with best valid acc of {:.3f}".format(
                    filename, ckpt["epoch"], ckpt["best_valid_acc"]
                )
            )
        else:
            print("[*] Loaded {} checkpoint @ epoch {}".format(filename, ckpt["epoch"]))