Exemple #1
0
def main():

    # load images
    imgs = []
    paths = [data_dir + './lenna.jpg', data_dir + './cat.jpg']
    for i in range(len(paths)):
        img = img2array(paths[i], desired_size=[512, 512], expand=True)
        imgs.append(torch.from_numpy(img))
    imgs = Variable(torch.cat(imgs))

    B, H, W, C = imgs.shape

    l_t_prev = torch.Tensor(B, 2).uniform_(-1, 1)
    l_t_prev = Variable(l_t_prev)
    h_t_prev = Variable(torch.zeros(B, 256))

    ram = RecurrentAttention(64, 3, 2, 3, 128, 128, 256, 10, 0.11)
    h_t, l_t = ram(imgs, l_t_prev, h_t_prev)

    print("h_t: {}".format(h_t.shape))
    print("l_t: {}".format(l_t.shape))
    def __init__(self, config, data_loader):
        """
        Construct a new Trainer instance.

        Args:
            config: object containing command line arguments.
            data_loader: A data iterator.
        """
        self.config = config

        if config.use_gpu and torch.cuda.is_available():
            self.device = torch.device("cuda")
        else:
            self.device = torch.device("cpu")

        # glimpse network params
        self.patch_size = config.patch_size
        self.glimpse_scale = config.glimpse_scale
        self.num_patches = config.num_patches
        self.loc_hidden = config.loc_hidden
        self.glimpse_hidden = config.glimpse_hidden

        # core network params
        self.num_glimpses = config.num_glimpses
        self.hidden_size = config.hidden_size

        # reinforce params
        self.std = config.std
        self.M = config.M

        # data params
        if config.is_train:
            self.train_loader = data_loader[0]
            self.valid_loader = data_loader[1]
            self.num_train = len(self.train_loader.sampler.indices)
            self.num_valid = len(self.valid_loader.sampler.indices)
        else:
            self.test_loader = data_loader
            self.num_test = len(self.test_loader.dataset)
        self.num_classes = 25  #10
        self.num_channels = 1

        # training params
        self.epochs = config.epochs
        self.start_epoch = 0
        self.momentum = config.momentum
        self.lr = config.init_lr

        # misc params
        self.best = config.best
        self.ckpt_dir = config.ckpt_dir
        self.logs_dir = config.logs_dir
        self.best_valid_acc = 0.0
        self.counter = 0
        self.lr_patience = config.lr_patience
        self.train_patience = config.train_patience
        self.use_tensorboard = config.use_tensorboard
        self.resume = config.resume
        self.print_freq = config.print_freq
        self.plot_freq = config.plot_freq
        self.model_name = "ram_{}_{}x{}_{}".format(
            config.num_glimpses,
            config.patch_size,
            config.patch_size,
            config.glimpse_scale,
        )

        self.plot_dir = "./plots/" + self.model_name + "/"
        if not os.path.exists(self.plot_dir):
            os.makedirs(self.plot_dir)

        # configure tensorboard logging
        if self.use_tensorboard:
            tensorboard_dir = self.logs_dir + self.model_name
            print("[*] Saving tensorboard logs to {}".format(tensorboard_dir))
            if not os.path.exists(tensorboard_dir):
                os.makedirs(tensorboard_dir)
            configure(tensorboard_dir)

        # build RAM model
        self.model = RecurrentAttention(
            self.patch_size,
            self.num_patches,
            self.glimpse_scale,
            self.num_channels,
            self.loc_hidden,
            self.glimpse_hidden,
            self.std,
            self.hidden_size,
            self.num_classes,
        )
        self.model.to(self.device)

        # initialize optimizer and scheduler
        self.optimizer = torch.optim.Adam(self.model.parameters(),
                                          lr=self.config.init_lr)
        self.scheduler = ReduceLROnPlateau(self.optimizer,
                                           "min",
                                           patience=self.lr_patience)
class Trainer:
    """A Recurrent Attention Model trainer.

    All hyperparameters are provided by the user in the
    config file.
    """
    def __init__(self, config, data_loader):
        """
        Construct a new Trainer instance.

        Args:
            config: object containing command line arguments.
            data_loader: A data iterator.
        """
        self.config = config

        if config.use_gpu and torch.cuda.is_available():
            self.device = torch.device("cuda")
        else:
            self.device = torch.device("cpu")

        # glimpse network params
        self.patch_size = config.patch_size
        self.glimpse_scale = config.glimpse_scale
        self.num_patches = config.num_patches
        self.loc_hidden = config.loc_hidden
        self.glimpse_hidden = config.glimpse_hidden

        # core network params
        self.num_glimpses = config.num_glimpses
        self.hidden_size = config.hidden_size

        # reinforce params
        self.std = config.std
        self.M = config.M

        # data params
        if config.is_train:
            self.train_loader = data_loader[0]
            self.valid_loader = data_loader[1]
            self.num_train = len(self.train_loader.sampler.indices)
            self.num_valid = len(self.valid_loader.sampler.indices)
        else:
            self.test_loader = data_loader
            self.num_test = len(self.test_loader.dataset)
        self.num_classes = 25  #10
        self.num_channels = 1

        # training params
        self.epochs = config.epochs
        self.start_epoch = 0
        self.momentum = config.momentum
        self.lr = config.init_lr

        # misc params
        self.best = config.best
        self.ckpt_dir = config.ckpt_dir
        self.logs_dir = config.logs_dir
        self.best_valid_acc = 0.0
        self.counter = 0
        self.lr_patience = config.lr_patience
        self.train_patience = config.train_patience
        self.use_tensorboard = config.use_tensorboard
        self.resume = config.resume
        self.print_freq = config.print_freq
        self.plot_freq = config.plot_freq
        self.model_name = "ram_{}_{}x{}_{}".format(
            config.num_glimpses,
            config.patch_size,
            config.patch_size,
            config.glimpse_scale,
        )

        self.plot_dir = "./plots/" + self.model_name + "/"
        if not os.path.exists(self.plot_dir):
            os.makedirs(self.plot_dir)

        # configure tensorboard logging
        if self.use_tensorboard:
            tensorboard_dir = self.logs_dir + self.model_name
            print("[*] Saving tensorboard logs to {}".format(tensorboard_dir))
            if not os.path.exists(tensorboard_dir):
                os.makedirs(tensorboard_dir)
            configure(tensorboard_dir)

        # build RAM model
        self.model = RecurrentAttention(
            self.patch_size,
            self.num_patches,
            self.glimpse_scale,
            self.num_channels,
            self.loc_hidden,
            self.glimpse_hidden,
            self.std,
            self.hidden_size,
            self.num_classes,
        )
        self.model.to(self.device)

        # initialize optimizer and scheduler
        self.optimizer = torch.optim.Adam(self.model.parameters(),
                                          lr=self.config.init_lr)
        self.scheduler = ReduceLROnPlateau(self.optimizer,
                                           "min",
                                           patience=self.lr_patience)

    def gmdataset(self):
        import pandas as pd
        #data = pd.read_csv("check.csv", index_col ="name")
        gW = pd.read_csv("goodware.csv", index_col="name")
        mW = pd.read_csv("malware.csv", index_col="name")

        out = mW.append(gW)
        data = out
        out.drop('(BAD)', axis=1, inplace=True)
        out.drop('STD', axis=1, inplace=True)
        out.drop('SHLD', axis=1, inplace=True)
        out.drop('SETLE', axis=1, inplace=True)
        out.drop('SETB', axis=1, inplace=True)
        out.drop('SBB', axis=1, inplace=True)
        out.drop('RDTSC', axis=1, inplace=True)
        out.drop('PUSHF', axis=1, inplace=True)
        out.drop('FSTCW', axis=1, inplace=True)
        out.drop('FDIVP', axis=1, inplace=True)
        out.drop('FILD', axis=1, inplace=True)
        out.drop('RETN', axis=1, inplace=True)
        out.drop('LEA', axis=1, inplace=True)
        out.drop('IMUL', axis=1, inplace=True)

        from sklearn.model_selection import train_test_split
        #print(data['labels'])
        M = data.values
        X = M[:, :-1]
        Y = M[:, -1]
        #print(Y)
        #X_train, X_test, y_train, y_test = train_test_split(X, Y, test_size=0.33, random_state=42)
        #print("HHHHHHH: ",X_train.shape, y_train.shape)
        import numpy as np

        # print(X_train.shape)
        #x_train = np.reshape(X_train, (X_train.shape[0],5, 5,1))
        #padie=np.pad(X_train,  ((0,0),(0,759)), 'constant', constant_values=0)
        padie = np.pad(X, ((0, 0), (0, 1)), 'constant', constant_values=0)
        print(padie.shape)
        x = np.reshape(padie, (padie.shape[0], 1, 4, 4))
        return x, Y

    def alldatacsv(self):
        import numpy as np  # linear algebra
        import pandas as pd  # data processing, CSV file I/O (e.g. pd.read_csv)

        from sklearn.feature_extraction.text import CountVectorizer
        from keras.preprocessing.text import Tokenizer
        from keras.preprocessing.sequence import pad_sequences
        from keras.models import Sequential
        from keras.layers import Dense, Embedding, LSTM, SpatialDropout1D, Dropout
        from sklearn.model_selection import train_test_split
        from keras.utils.np_utils import to_categorical
        import re

        # Input data files are available in the "../input/" directory.
        # For example, running this (by clicking run or pressing Shift+Enter) will list the files in the input directory
        """Only keeping the necessary columns."""

        import pandas as pd
        data = pd.read_csv("AllData.csv")

        data = data.drop(['Unnamed: 0'], axis=1)

        data = data.rename(columns={'Text': 'text', 'Label': 'sentiment'})

        data

        #data = pd.read_csv('../input/Sentiment.csv')
        # Keeping only the neccessary columns
        data = data[['text', 'sentiment']]

        pos = data[data['sentiment'] == 1]
        pos.shape[0]

        #data = data[data.sentiment != "Neutral"]
        data['text'] = data['text'].apply(lambda x: x.lower())
        data['text'] = data['text'].apply(
            (lambda x: re.sub('[^a-zA-z0-9\s]', '', x)))

        #print(data[ data['sentiment'] == 1].size)
        #print(data[ data['sentiment'] == 0].size)

        for idx, row in data.iterrows():
            row[0] = row[0].replace('rt', ' ')

        max_fatures = 2000
        tokenizer = Tokenizer(num_words=max_fatures, split=' ')
        tokenizer.fit_on_texts(data['text'].values)
        X = tokenizer.texts_to_sequences(data['text'].values)
        X = pad_sequences(X)

        data['sentiment']

        pd.DataFrame(data=X[1:, -20000:],
                     index=X[1:, 0])  # 1st row as the column name

        #X=X[0:,-20000:]

        X.shape
        """# Train and Test Dataset Declaration"""

        Y = pd.get_dummies(data['sentiment']).values
        #X_train, X_test, Y_train, Y_test = train_test_split(X,Y, test_size = 0.33, random_state = 78)

        #conu = 0
        #        for x in Y_test:
        #           if x.argmax()== 1:
        #              conu = conu + 1
        #     conu

        import numpy as np
        import matplotlib.pyplot as plt
        #print(X_train.shape)
        #x_train = np.reshape(X_train, (X_train.shape[0],5, 5,1))
        padie = np.pad(X, ((0, 0), (0, 251)), 'constant', constant_values=0)
        padie = padie[:, 459 * 459 - 50176:]
        x = np.reshape(padie, (padie.shape[0], 1, 224, 224))
        #x_test = np.reshape(x_test, (x_test.shape[0],2,2, 1))
        #print(Y)
        sk = pd.DataFrame(data=Y, columns=[0, 1])
        inverted = sk.idxmax(1).values
        ss = np.rint(inverted)
        Y = ss
        #for i in range(0,Y.shape[0]):
        #     if Y[i] == 1:
        #        #print("YESSSSSSSS")
        #       string = "imgs/" + str(i) + ".png"
        #       plt.imsave(string,x[i][0,:,:])
        #       qq = i
        #x = x[:]
        #Y = Y[int(qq-qq/2):]
        #print(type(X))
        return x, Y

    def batadal(self):
        import pandas as pd

        import numpy as np

        fD = pd.read_csv("newDatasets/BATADAL_dataset04.csv", header=None)

        #fD = pd.read_csv("/content/drive/My Drive/newDatasets/BATADAL_dataset02 (1).csv" , header=None)

        test = pd.read_csv("newDatasets/BATADAL_test_dataset.csv", header=None)

        test = test.drop(columns=0)
        test = test.drop([0], axis=0)

        Data = fD.drop(columns=0)

        Data = Data.drop([0], axis=0)

        nData = Data.values

        nData.shape

        testData = test.values

        xData = nData[:, :43]
        yData = nData[:, 43]

        xData.shape

        testData.shape

        xData = np.pad(xData, ((0, 0), (0, 6)), 'constant', constant_values=0)

        testData = np.pad(testData, ((0, 0), (0, 6)),
                          'constant',
                          constant_values=0)

        test[:]

        xData.shape

        xData = xData.reshape(-1, 1, 7, 7)

        testData = testData.reshape(-1, 1, 7, 7)

        from sklearn.model_selection import train_test_split
        X_train, X_test, y_train, y_test = train_test_split(xData,
                                                            yData,
                                                            test_size=0.1,
                                                            random_state=42)
        #print(X_train, y_train)
        #print("JJ: ", type(xData))
        return xData, yData

    def Malimg(self):
        import tensorflow as tf
        import keras
        import numpy as np
        from sklearn.preprocessing import StandardScaler
        from sklearn.model_selection import train_test_split
        dataset = np.load('malimg.npz', allow_pickle=True)

        BATCH_SIZE = 256
        CELL_SIZE = 256
        DROPOUT_RATE = 0.85
        LEARNING_RATE = 1e-3
        NODE_SIZE = [512, 256, 128]
        NUM_LAYERS = 5

        features = dataset['arr'][:, 0]
        features = np.array([feature for feature in features])
        features = np.reshape(
            features,
            (features.shape[0], features.shape[1] * features.shape[2]))
        r, c = features.shape

        print("Number of Samples", r)
        print("Number of Features", c)

        if 1 == 1:
            features = StandardScaler().fit_transform(features)

        labels = dataset['arr'][:, 1]
        labels = np.array([label for label in labels])

        one_hot = np.zeros((labels.shape[0], labels.max() + 1))
        one_hot[np.arange(labels.shape[0]), labels] = 1
        labels = one_hot
        labels[labels == 0] = 0
        num_features = features.shape[1]
        num_classes = labels.shape[1]

        Y = labels
        X = features

        print("Shape of Labels", Y.shape)
        print("Shape of Features", X.shape)
        train_features, test_features, train_labels, test_labels = train_test_split(
            features, labels, test_size=0.1, stratify=labels)  #10% Test size

        train_size = int(train_features.shape[0])
        train_features = train_features[:train_size -
                                        (train_size % BATCH_SIZE)]
        train_labels = train_labels[:train_size - (train_size % BATCH_SIZE)]

        test_size = int(test_features.shape[0])
        test_features = test_features[:test_size - (test_size % BATCH_SIZE)]
        test_labels = test_labels[:test_size - (test_size % BATCH_SIZE)]

        fsize = int(features.shape[0])
        features = features[:fsize - (fsize % BATCH_SIZE)]
        labels = labels[:fsize - (fsize % BATCH_SIZE)]

        r, c = train_features.shape
        print("Number of Training Samples", r)
        print("Number of Training Features", c)

        r, c = test_features.shape
        print("Number of Test Samples", r)
        print("Number of Test Features", c)
        #print(train_labels.shape)

        #print(tf.reshape(test_features[1], [32,32]))

        print(train_features.shape, test_features.shape, train_labels.shape,
              test_labels.shape)

        #print(train_labels)
        train_X = train_features.reshape(-1, 1, 32, 32)
        feat = features.reshape(-1, 1, 32, 32)
        test_X = test_features.reshape(-1, 32, 32, 1)
        Unchanined = X.reshape(-1, 32, 32, 1)
        y_test_non_category = [np.argmax(t) for t in labels]
        print("LABELS", np.asarray(y_test_non_category))
        return feat, np.asarray(y_test_non_category)

    def SWAT(self):
        import pandas as pd
        import numpy as np

        fD = pd.read_excel("newDatasets/SWaT/SWaT_Dataset_Attack_v0.xlsx",
                           header=None)
        Data = fD.drop(columns=0)
        Data = Data.drop([0, 1], axis=0)
        xData = Data.values[:, :51]
        yData = Data.values[:, 51]
        count = 0
        for i in yData:
            if i == 'Normal':
                yData[count] = 0
            else:
                yData[count] = 1
            count = count + 1
        xData.shape
        xData = np.pad(xData, ((0, 0), (0, 13)), 'constant', constant_values=0)
        xData = xData.reshape(xData.shape[0], 8, 8, 1)
        return xData[:200], yData[:200]

    def reset(self):
        h_t = torch.zeros(
            self.batch_size,
            self.hidden_size,
            dtype=torch.float,
            device=self.device,
            requires_grad=True,
        )
        l_t = torch.FloatTensor(self.batch_size, 2).uniform_(-1,
                                                             1).to(self.device)
        l_t.requires_grad = True

        return h_t, l_t

    def train(self):
        """Train the model on the training set.

        A checkpoint of the model is saved after each epoch
        and if the validation accuracy is improved upon,
        a separate ckpt is created for use on the test set.
        """
        # load the most recent checkpoint
        if self.resume:
            self.load_checkpoint(best=False)

        print("\n[*] Train on {} samples, validate on {} samples".format(
            self.num_train, self.num_valid))

        for epoch in range(self.start_epoch, self.epochs):

            print("\nEpoch: {}/{} - LR: {:.6f}".format(
                epoch + 1, self.epochs, self.optimizer.param_groups[0]["lr"]))

            # train for 1 epoch
            train_loss, train_acc = self.train_one_epoch(epoch)

            # evaluate on validation set
            valid_loss, valid_acc = self.validate(epoch)

            # # reduce lr if validation loss plateaus
            self.scheduler.step(-valid_acc)

            is_best = valid_acc > self.best_valid_acc
            msg1 = "train loss: {:.3f} - train acc: {:.3f} "
            msg2 = "- val loss: {:.3f} - val acc: {:.3f} - val err: {:.3f}"
            if is_best:
                self.counter = 0
                msg2 += " [*]"
            msg = msg1 + msg2
            print(
                msg.format(train_loss, train_acc, valid_loss, valid_acc,
                           100 - valid_acc))

            # check for improvement
            if not is_best:
                self.counter += 1
            if self.counter > self.train_patience:
                print("[!] No improvement in a while, stopping training.")
                return
            self.best_valid_acc = max(valid_acc, self.best_valid_acc)
            self.save_checkpoint(
                {
                    "epoch": epoch + 1,
                    "model_state": self.model.state_dict(),
                    "optim_state": self.optimizer.state_dict(),
                    "best_valid_acc": self.best_valid_acc,
                },
                is_best,
            )

    def train_one_epoch(self, epoch):
        """
        Train the model for 1 epoch of the training set.

        An epoch corresponds to one full pass through the entire
        training set in successive mini-batches.

        This is used by train() and should not be called manually.
        """

        import pandas as pd
        import numpy as np

        self.model.train()
        batch_time = AverageMeter()
        losses = AverageMeter()
        accs = AverageMeter()

        tic = time.time()
        with tqdm(total=self.num_train) as pbar:
            for i, (x, y) in enumerate(self.train_loader):
                self.optimizer.zero_grad()

                x, y = x.to(self.device), y.to(self.device)
                x1, y1 = self.SWAT(
                )  #self.gmdataset()#self.batadal()#self.alldatacsv()#self.gmdataset()
                x1 = x1.astype(np.float32)
                y1 = y1.astype(np.float32)

                x, y = torch.from_numpy(x1).float(), torch.from_numpy(
                    y1).long()
                #print("Here", y)
                plot = False
                if (epoch % self.plot_freq == 0) and (i == 0):
                    plot = True

                # initialize location vector and hidden state
                self.batch_size = x.shape[0]
                h_t, l_t = self.reset()

                # save images
                imgs = []
                imgs.append(x[0:9])

                # extract the glimpses
                locs = []
                log_pi = []
                baselines = []
                for t in range(self.num_glimpses - 1):
                    # forward pass through model
                    h_t, l_t, b_t, p = self.model(x, l_t, h_t)

                    # store
                    locs.append(l_t[0:9])
                    baselines.append(b_t)
                    log_pi.append(p)

                # last iteration
                h_t, l_t, b_t, log_probas, p = self.model(x,
                                                          l_t,
                                                          h_t,
                                                          last=True)
                log_pi.append(p)
                baselines.append(b_t)
                locs.append(l_t[0:9])

                # convert list to tensors and reshape
                baselines = torch.stack(baselines).transpose(1, 0)
                log_pi = torch.stack(log_pi).transpose(1, 0)

                # calculate reward
                predicted = torch.max(log_probas, 1)[1]
                R = (predicted.detach() == y).float()
                R = R.unsqueeze(1).repeat(1, self.num_glimpses)

                # compute losses for differentiable modules
                loss_action = F.nll_loss(log_probas, y)
                loss_baseline = F.mse_loss(baselines, R)

                # compute reinforce loss
                # summed over timesteps and averaged across batch
                adjusted_reward = R - baselines.detach()
                loss_reinforce = torch.sum(-log_pi * adjusted_reward, dim=1)
                loss_reinforce = torch.mean(loss_reinforce, dim=0)

                # sum up into a hybrid loss
                loss = loss_action + loss_baseline + loss_reinforce * 0.01

                # compute accuracy
                correct = (predicted == y).float()
                acc = 100 * (correct.sum() / len(y))
                #print("Predicted: ", predicted, "\nTrue", y)

                # store
                losses.update(loss.item(), x.size()[0])
                accs.update(acc.item(), x.size()[0])

                # compute gradients and update SGD
                loss.backward()
                self.optimizer.step()

                # measure elapsed time
                toc = time.time()
                batch_time.update(toc - tic)

                pbar.set_description(
                    ("{:.1f}s - loss: {:.3f} - acc: {:.3f}".format(
                        (toc - tic), loss.item(), acc.item())))
                pbar.update(self.batch_size)

                # dump the glimpses and locs
                if plot:
                    imgs = [g.cpu().data.numpy().squeeze() for g in imgs]
                    locs = [l.cpu().data.numpy() for l in locs]
                    pickle.dump(
                        imgs,
                        open(self.plot_dir + "g_{}.p".format(epoch + 1), "wb"))
                    pickle.dump(
                        locs,
                        open(self.plot_dir + "l_{}.p".format(epoch + 1), "wb"))

                # log to tensorboard
                if self.use_tensorboard:
                    iteration = epoch * len(self.train_loader) + i
                    log_value("train_loss", losses.avg, iteration)
                    log_value("train_acc", accs.avg, iteration)

            return losses.avg, accs.avg

    @torch.no_grad()
    def validate(self, epoch):
        """Evaluate the RAM model on the validation set.
        """
        import torch
        import numpy as np
        losses = AverageMeter()
        accs = AverageMeter()

        for i, (x, y) in enumerate(self.valid_loader):
            x, y = x.to(self.device), y.to(self.device)
            x1, y1 = self.gmdataset(
            )  #self.batadal()#self.alldatacsv()#self.gmdataset()
            x1 = x1.astype(np.float32)
            y1 = y1.astype(np.float32)

            x, y = torch.from_numpy(x1).float(), torch.from_numpy(y1).long()
            # duplicate M times
            x = x.repeat(self.M, 1, 1, 1)

            # initialize location vector and hidden state
            self.batch_size = x.shape[0]
            h_t, l_t = self.reset()

            # extract the glimpses
            log_pi = []
            baselines = []
            for t in range(self.num_glimpses - 1):
                # forward pass through model
                h_t, l_t, b_t, p = self.model(x, l_t, h_t)

                # store
                baselines.append(b_t)
                log_pi.append(p)

            # last iteration
            h_t, l_t, b_t, log_probas, p = self.model(x, l_t, h_t, last=True)
            log_pi.append(p)
            baselines.append(b_t)

            # convert list to tensors and reshape
            baselines = torch.stack(baselines).transpose(1, 0)
            log_pi = torch.stack(log_pi).transpose(1, 0)

            # average
            log_probas = log_probas.view(self.M, -1, log_probas.shape[-1])
            log_probas = torch.mean(log_probas, dim=0)

            baselines = baselines.contiguous().view(self.M, -1,
                                                    baselines.shape[-1])
            baselines = torch.mean(baselines, dim=0)

            log_pi = log_pi.contiguous().view(self.M, -1, log_pi.shape[-1])
            log_pi = torch.mean(log_pi, dim=0)

            # calculate reward
            predicted = torch.max(log_probas, 1)[1]
            R = (predicted.detach() == y).float()
            R = R.unsqueeze(1).repeat(1, self.num_glimpses)

            # compute losses for differentiable modules
            loss_action = F.nll_loss(log_probas, y)
            loss_baseline = F.mse_loss(baselines, R)

            # compute reinforce loss
            adjusted_reward = R - baselines.detach()
            loss_reinforce = torch.sum(-log_pi * adjusted_reward, dim=1)
            loss_reinforce = torch.mean(loss_reinforce, dim=0)

            # sum up into a hybrid loss
            loss = loss_action + loss_baseline + loss_reinforce * 0.01

            # compute accuracy
            correct = (predicted == y).float()
            acc = 100 * (correct.sum() / len(y))
            count = 0
            countFP = 0
            countFN = 0
            countTN = 0
            countTP = 0
            for i in range(len(correct)):
                if (correct[i] == 0):
                    count = count + 1
                if (predicted[i] == 1 and y[i] == 0):  #False Positive
                    countFP = countFP + 1
                if (predicted[i] == 0 and y[i] == 1):  #False Negative
                    countFN = countFN + 1
                if (predicted[i] == 0 and y[i] == 0):  #True Negative
                    countTN = countTN + 1
                if (predicted[i] == 1 and y[i] == 1):  #True Positive
                    countTP = countTP + 1

            print("Total: ", len(correct), "Wrong: ", count)
            print("TP: ", countTP, "TN: ", countTN, "FN: ", countFN, "FP: ",
                  countFP)

            # store
            losses.update(loss.item(), x.size()[0])
            accs.update(acc.item(), x.size()[0])

            # log to tensorboard
            if self.use_tensorboard:
                iteration = epoch * len(self.valid_loader) + i
                log_value("valid_loss", losses.avg, iteration)
                log_value("valid_acc", accs.avg, iteration)

        return losses.avg, accs.avg

    @torch.no_grad()
    def test(self):
        """Test the RAM model.

        This function should only be called at the very
        end once the model has finished training.
        """
        correct = 0

        # load the best checkpoint
        self.load_checkpoint(best=self.best)

        for i, (x, y) in enumerate(self.test_loader):
            x, y = x.to(self.device), y.to(self.device)

            # duplicate M times
            x = x.repeat(self.M, 1, 1, 1)

            # initialize location vector and hidden state
            self.batch_size = x.shape[0]
            h_t, l_t = self.reset()

            # extract the glimpses
            for t in range(self.num_glimpses - 1):
                # forward pass through model
                h_t, l_t, b_t, p = self.model(x, l_t, h_t)

            # last iteration
            h_t, l_t, b_t, log_probas, p = self.model(x, l_t, h_t, last=True)

            log_probas = log_probas.view(self.M, -1, log_probas.shape[-1])
            log_probas = torch.mean(log_probas, dim=0)

            pred = log_probas.data.max(1, keepdim=True)[1]
            correct += pred.eq(y.data.view_as(pred)).cpu().sum()

        perc = (100.0 * correct) / (self.num_test)
        error = 100 - perc
        print("[*] Test Acc: {}/{} ({:.2f}% - {:.2f}%)".format(
            correct, self.num_test, perc, error))

    def save_checkpoint(self, state, is_best):
        """Saves a checkpoint of the model.

        If this model has reached the best validation accuracy thus
        far, a seperate file with the suffix `best` is created.
        """
        filename = self.model_name + "_ckpt.pth.tar"
        ckpt_path = os.path.join(self.ckpt_dir, filename)
        torch.save(state, ckpt_path)
        if is_best:
            filename = self.model_name + "_model_best.pth.tar"
            shutil.copyfile(ckpt_path, os.path.join(self.ckpt_dir, filename))

    def load_checkpoint(self, best=False):
        """Load the best copy of a model.

        This is useful for 2 cases:
        - Resuming training with the most recent model checkpoint.
        - Loading the best validation model to evaluate on the test data.

        Args:
            best: if set to True, loads the best model.
                Use this if you want to evaluate your model
                on the test data. Else, set to False in which
                case the most recent version of the checkpoint
                is used.
        """
        print("[*] Loading model from {}".format(self.ckpt_dir))

        filename = self.model_name + "_ckpt.pth.tar"
        if best:
            filename = self.model_name + "_model_best.pth.tar"
        ckpt_path = os.path.join(self.ckpt_dir, filename)
        ckpt = torch.load(ckpt_path)

        # load variables from checkpoint
        self.start_epoch = ckpt["epoch"]
        self.best_valid_acc = ckpt["best_valid_acc"]
        self.model.load_state_dict(ckpt["model_state"])
        self.optimizer.load_state_dict(ckpt["optim_state"])

        if best:
            print("[*] Loaded {} checkpoint @ epoch {} "
                  "with best valid acc of {:.3f}".format(
                      filename, ckpt["epoch"], ckpt["best_valid_acc"]))
        else:
            print("[*] Loaded {} checkpoint @ epoch {}".format(
                filename, ckpt["epoch"]))
Exemple #4
0
class Trainer(object):
    """
    Trainer encapsulates all the logic necessary for
    training the Recurrent Attention Model.

    All hyperparameters are provided by the user in the
    config file.
    """
    def __init__(self, config, data_loader):
        """
        Construct a new Trainer instance.

        Args
        ----
        - config: object containing command line arguments.
        - data_loader: data iterator
        """
        self.config = config
        self.dis_R_thres = config.dis_R_thres

        # glimpse network params
        self.patch_size = config.patch_size
        self.glimpse_scale = config.glimpse_scale
        self.num_patches = config.num_patches
        self.loc_hidden = config.loc_hidden
        self.glimpse_hidden = config.glimpse_hidden

        # core network params
        self.num_glimpses = config.num_glimpses
        self.hidden_size = config.hidden_size

        # reinforce params
        self.std = config.std
        self.M = config.M
        self.device = torch.device(
            "cuda" if torch.cuda.is_available() else "cpu")

        # data params
        if config.is_train:
            self.train_loader = data_loader[0]
            self.valid_loader = data_loader[1]
            self.num_train = len(self.train_loader) * config.batch_size
            self.num_valid = len(self.valid_loader) * config.batch_size
        else:
            self.test_loader = data_loader[1]
            self.num_test = len(self.test_loader.dataset)
        self.num_classes = 10
        self.num_channels = 3

        # training params
        self.epochs = config.epochs
        self.start_epoch = 0
        self.momentum = config.momentum
        self.lr = config.init_lr

        # misc params
        self.use_gpu = config.use_gpu
        self.best = config.best
        self.ckpt_dir = config.ckpt_dir
        self.logs_dir = config.logs_dir
        self.best_valid_acc = 0.
        self.counter = 0
        self.lr_patience = config.lr_patience
        self.train_patience = config.train_patience
        self.use_tensorboard = config.use_tensorboard
        self.resume = config.resume
        self.print_freq = config.print_freq
        self.plot_freq = config.plot_freq
        self.model_name = 'ram_{}_{}x{}_{}'.format(config.num_glimpses,
                                                   config.patch_size,
                                                   config.patch_size,
                                                   config.glimpse_scale)

        self.plot_dir = './plots/' + self.model_name + '/'
        if not os.path.exists(self.plot_dir):
            os.makedirs(self.plot_dir)

        # configure tensorboard logging
        if self.use_tensorboard:
            tensorboard_dir = self.logs_dir + self.model_name
            print('[*] Saving tensorboard logs to {}'.format(tensorboard_dir))
            if not os.path.exists(tensorboard_dir):
                os.makedirs(tensorboard_dir)
            configure(tensorboard_dir)

        # build RAM model
        self.model = RecurrentAttention(
            self.patch_size,
            self.num_patches,
            self.glimpse_scale,
            self.num_channels,
            self.loc_hidden,
            self.glimpse_hidden,
            self.std,
            self.hidden_size,
            self.num_classes,
        )
        if self.use_gpu:
            self.model.cuda()
        # train resnet or not
        # self.model.sensor.feature_extractor.eval()

        print('[*] Number of model parameters: {:,}'.format(
            sum([p.data.nelement() for p in self.model.parameters()])))

        # # initialize optimizer and scheduler
        # self.optimizer = optim.SGD(
        #     self.model.parameters(), lr=self.lr, momentum=self.momentum,
        # )
        # self.scheduler = ReduceLROnPlateau(
        #     self.optimizer, 'min', patience=self.lr_patience
        # )
        self.optimizer = optim.Adam(
            self.model.parameters(),
            lr=3e-4,
        )

    def reset(self):
        """
        Initialize the hidden state of the core network
        and the location vector.

        This is called once every time a new minibatch
        `x` is introduced.
        """
        dtype = (torch.cuda.FloatTensor if self.use_gpu else torch.FloatTensor)

        h_t = torch.zeros(self.batch_size, self.hidden_size)
        h_t = Variable(h_t).type(dtype)

        l_t = torch.Tensor(self.batch_size, 2).uniform_(-1, 1)
        l_t = Variable(l_t).type(dtype)

        return h_t, l_t

    def train(self):
        """
        Train the model on the training set.

        A checkpoint of the model is saved after each epoch
        and if the validation accuracy is improved upon,
        a separate ckpt is created for use on the test set.
        """
        # load the most recent checkpoint
        if self.resume:
            self.load_checkpoint(best=False)

        print("\n[*] Train on {} samples, validate on {} samples".format(
            self.num_train, self.num_valid))

        for epoch in range(self.start_epoch, self.epochs):

            print('\nEpoch: {}/{} - LR: {:.6f}'.format(epoch + 1, self.epochs,
                                                       self.lr))

            # train for 1 epoch
            train_loss, train_acc = self.train_one_epoch(epoch)

            # evaluate on validation set
            valid_loss, valid_acc = self.validate(epoch)
            # train_loss, train_acc = self.train_one_epoch(epoch)

            # # reduce lr if validation loss plateaus
            # self.scheduler.step(valid_loss)

            # is_best = valid_acc > self.best_valid_acc
            is_best = 1
            msg1 = "train loss: {:.3f} - train acc: {:.3f} "
            msg2 = "- val loss: {:.3f} - val acc: {:.3f}"
            if is_best:
                self.counter = 0
                msg2 += " [*]"
            msg = msg1 + msg2
            print(msg.format(train_loss, train_acc, valid_loss, valid_acc))

            # check for improvement
            if not is_best:
                self.counter += 1
            if self.counter > self.train_patience:
                print("[!] No improvement in a while, stopping training.")
                return
            self.best_valid_acc = max(valid_acc, self.best_valid_acc)
            is_best = 1
            self.save_checkpoint(
                {
                    'epoch': epoch + 1,
                    'model_state': self.model.state_dict(),
                    'optim_state': self.optimizer.state_dict(),
                    'best_valid_acc': self.best_valid_acc,
                }, is_best)

    def train_one_epoch(self, epoch):
        """
        Train the model for 1 epoch of the training set.

        An epoch corresponds to one full pass through the entire
        training set in successive mini-batches.

        This is used by train() and should not be called manually.
        """
        batch_time = AverageMeter()
        losses = AverageMeter()
        accs = AverageMeter()

        losses_action = AverageMeter()
        countTotal = 1000

        tic = time.time()
        count = 0
        with tqdm(total=self.num_train) as pbar:
            for i, (x, fixation, y, speeds, courses, scale_gt, indexSeq,
                    frameEnd) in enumerate(self.train_loader):
                if count > countTotal:
                    return losses.avg, accs.avg
                count = count + 1
                y = y.squeeze().float()

                if self.use_gpu:
                    x, y, speeds, courses, scale_gt = x.cuda(), y.cuda(
                    ), speeds.cuda(), courses.cuda(), scale_gt.cuda()
                x, y, speeds, courses, scale_gt = Variable(x), Variable(
                    y), Variable(speeds), Variable(courses), Variable(scale_gt)

                plot = False
                if (epoch % self.plot_freq == 0) and (i == 0):
                    plot = True

                # initialize location vector and hidden state
                self.batch_size = x.shape[0]
                h_t, l_t = self.reset()

                # save images
                imgs = []
                imgs.append(x[0:9])

                # extract the glimpses
                locs = []
                log_pi = []
                baselines = []
                for t in range(self.num_glimpses - 1):
                    # forward pass through model
                    h_t, l_t, b_t, p = self.model(x, speeds, courses, l_t, h_t,
                                                  t)

                    # store
                    locs.append(l_t[0:9])
                    baselines.append(b_t)
                    log_pi.append(p)

                # last iteration
                h_t, l_t, b_t, l_t_final, p, scale = self.model(
                    x,
                    speeds,
                    courses,
                    l_t,
                    h_t,
                    self.num_glimpses - 1,
                    last=True)
                log_pi.append(p)
                baselines.append(b_t)
                locs.append(l_t[0:9])

                # convert list to tensors and reshape
                baselines = torch.stack(baselines).transpose(1, 0)
                log_pi = torch.stack(log_pi).transpose(1, 0)

                # calculate reward
                # predicted = torch.max(log_probas, 1)[1]
                # R = (predicted.detach() == y).float()

                R = torch.zeros(y.shape[0])
                for index in range(y.shape[0]):
                    # get the distance of two locations
                    distance = torch.sqrt(
                        torch.pow(l_t_final[index, 0] - y[index, 0], 2) +
                        torch.pow(l_t_final[index, 1] -
                                  y[index, 1], 2)).float()
                    # R[index] = distance < self.dis_R_thres
                    # temp= distance < self.dis_R_thres
                    R[index] = 1 - distance
                # R = locs
                mean_R = torch.mean(R)
                R = R.unsqueeze(1).repeat(1, self.num_glimpses).to(self.device)

                # compute losses for differentiable modules
                # loss_action = F.nll_loss(log_probas, y)
                loss_action = F.mse_loss(l_t_final, y)
                loss_scale = F.mse_loss(scale, scale_gt)
                # if loss_action.data > 1:
                #     print('loss_action > 1 and l_t_final = {} y = {}'.format(l_t_final.data, y.data))
                loss_baseline = F.mse_loss(baselines, R)

                # compute reinforce loss
                # summed over timesteps and averaged across batch
                adjusted_reward = R - baselines.detach()
                loss_reinforce = torch.sum(-log_pi * adjusted_reward, dim=1)
                loss_reinforce = torch.mean(loss_reinforce, dim=0)

                # sum up into a hybrid loss
                loss = loss_action * 10 + loss_baseline * 10 + loss_reinforce + loss_scale
                # loss =  loss_baseline + loss_reinforce

                # compute accuracy
                # correct = dis.float()
                # acc = 100 * (correct.sum() / len(y))
                dist = distance

                # store
                losses.update(loss.data, x.size()[0])
                accs.update(dist.data, x.size()[0])
                losses_action.update(loss_action.data, x.size()[0])

                # compute gradients and update SGD
                self.optimizer.zero_grad()
                loss.backward()
                self.optimizer.step()

                # measure elapsed time
                toc = time.time()
                batch_time.update(toc - tic)

                print(
                    "Epoch: {} - {}/{} - {:.1f}s - loss: {:.3f} - dis: {:.3f} - loss_action : {:.3f} - loss_scale: {:.3f} "
                    "- mean_R : {:.3f} - mean-baseline: {:.3f} - mean_adjusted_reward: {:.3f} "
                    .format(epoch, count, countTotal, (toc - tic), loss.data,
                            dist.data, loss_action.data, loss_scale.data,
                            mean_R, torch.mean(baselines.data),
                            torch.mean(adjusted_reward.data)))

                # pbar.set_description(
                #     (
                #         "{:.1f}s - loss: {:.3f} - dis: {:.3f} - sum_R : {} - loss_action : {:.3f}".format(
                #             (toc-tic), loss.data, dist.data, sum_R, loss_action.data
                #         )
                #     )
                # )
                # pbar.update(self.batch_size)

                # dump the glimpses and locs
                if plot:
                    if self.use_gpu:
                        imgs = [g.cpu().data.numpy().squeeze() for g in imgs]
                        locs = [l.cpu().data.numpy() for l in locs]
                    else:
                        imgs = [g.data.numpy().squeeze() for g in imgs]
                        locs = [l.data.numpy() for l in locs]
                    pickle.dump(
                        imgs,
                        open(self.plot_dir + "g_{}.p".format(epoch + 1), "wb"))
                    pickle.dump(
                        locs,
                        open(self.plot_dir + "l_{}.p".format(epoch + 1), "wb"))

                # log to tensorboard
                if self.use_tensorboard:
                    iteration = epoch * len(self.train_loader) + i
                    log_value('train_loss', losses.avg, iteration)
                    log_value('train_acc', accs.avg, iteration)

            return losses.avg, accs.avg

    def validate(self, epoch):
        """
        Evaluate the model on the validation set.
        """
        losses = AverageMeter()
        accs = AverageMeter()
        countTotal = 50

        count = 0
        is_blend = 1
        save_dir = os.path.join('logs', '{:02d}'.format(epoch))
        if not os.path.exists(save_dir):
            os.mkdir(save_dir)

        for i, (x, fixs, y, speeds, courses, scale_gt, indexSeq,
                frameEnd) in enumerate(self.valid_loader):
            y = y.squeeze().float()

            if count > countTotal:
                return losses.avg, accs.avg
            count = count + 1

            if self.use_gpu:
                x, y, speeds, courses = x.cuda(), y.cuda(), speeds.cuda(
                ), courses.cuda()
            x, y, speeds, courses = Variable(x), Variable(y), Variable(
                speeds), Variable(courses)

            # duplicate 10 times
            x = x.repeat(self.M, 1, 1, 1, 1)
            # speeds = speeds.repeat(self.M, 1,1,)

            # initialize location vector and hidden state
            self.batch_size = x.shape[0]
            h_t, l_t = self.reset()

            # extract the glimpses
            log_pi = []
            baselines = []
            for t in range(self.num_glimpses - 1):
                # forward pass through model
                h_t, l_t, b_t, p = self.model(x, speeds, courses, l_t, h_t, t)

                # store
                baselines.append(b_t)
                log_pi.append(p)

            # last iteration
            h_t, l_t, b_t, l_t_final, p, scale = self.model(x,
                                                            speeds,
                                                            courses,
                                                            l_t,
                                                            h_t,
                                                            self.num_glimpses -
                                                            1,
                                                            last=True)
            log_pi.append(p)
            baselines.append(b_t)

            # convert list to tensors and reshape
            baselines = torch.stack(baselines).transpose(1, 0)
            log_pi = torch.stack(log_pi).transpose(1, 0)

            # average
            l_t_final = l_t_final.view(self.M, -1, l_t_final.shape[-1])
            l_t_final = torch.mean(l_t_final, dim=0)

            if is_blend:
                for indexBlend in range(x.shape[0]):
                    # img = x[indexBlend, :, -1, : , :].cpu().numpy()
                    # img = np.transpose(img, (1,2,0))*255
                    # img = img[:, :, [2, 1, 0]]
                    pathImg = os.path.join(
                        dreyeve_dir, '{:02d}'.format(indexSeq[indexBlend]),
                        'frames', '{:06d}.jpg'.format(frameEnd[indexBlend]))
                    img = read_image(pathImg, channels_first=False, color=True)
                    # cv2.imwrite( 'temp.jpg', img)
                    pathFix = os.path.join(
                        dreyeve_dir, '{:02d}'.format(indexSeq[indexBlend]),
                        'saliency_fix',
                        '{:06d}.png'.format(frameEnd[indexBlend]))
                    map = read_image(pathFix,
                                     channels_first=False,
                                     color=False)
                    # map = fixs[indexBlend, :,:].cpu().numpy()
                    loc = l_t_final[indexBlend, :].cpu().detach().numpy()
                    loc_gt = y[indexBlend].cpu().numpy()
                    scale_blend = scale[indexBlend].cpu().detach().numpy()
                    scale_gt_blend = scale_gt[indexBlend].cpu().detach().numpy(
                    )
                    # blend = blend_map_with_focus_circle
                    # loc= np.array([0,0])

                    # draw target
                    blend = blend_map_with_focus_rectangle(img,
                                                           map,
                                                           loc,
                                                           scale=scale_blend,
                                                           color=(0, 0, 255))
                    #draw gt
                    if not (np.isnan(loc_gt[0]) or np.isnan(loc_gt[1])):
                        # loc_gt[0]=-0.9
                        # loc_gt[1]=0.2
                        blend = blend_map_with_focus_rectangle(
                            blend,
                            map,
                            loc_gt,
                            scale=scale_gt_blend,
                            color=(0, 255, 0))
                        # blend = blend_map_with_focus_circle(img, map, loc_gt, color=(0, 255, 0))

                    print('scale is {:.3f} and scale_gt is {:.3f}'.format(
                        float(scale_blend), float(scale_gt_blend)))
                    cv2.imwrite(
                        os.path.join(save_dir, '{:06d}.jpg'.format(
                            frameEnd[indexBlend])), blend)

            baselines = baselines.contiguous().view(self.M, -1,
                                                    baselines.shape[-1])
            baselines = torch.mean(baselines, dim=0)

            log_pi = log_pi.contiguous().view(self.M, -1, log_pi.shape[-1])
            log_pi = torch.mean(log_pi, dim=0)

            # # calculate reward
            # predicted = torch.max(log_probas, 1)[1]
            # R = (predicted.detach() == y).float()
            # R = R.unsqueeze(1).repeat(1, self.num_glimpses)
            dis = 0
            R = torch.zeros(y.shape[0])
            for index in range(y.shape[0]):
                # get the distance of two locations
                distance = torch.sqrt(
                    torch.pow(l_t_final[index, 0] - y[index, 0], 2) +
                    torch.pow(l_t_final[index, 1] - y[index, 1], 2))
                dis = dis + distance
                # R[index] = distance < self.dis_R_thres
                R[index] = distance < self.dis_R_thres
            # R = locs
            R = R.unsqueeze(1).repeat(1, self.num_glimpses).to(self.device)

            # compute losses for differentiable modules
            loss_action = F.mse_loss(l_t_final, y)
            loss_baseline = F.mse_loss(baselines, R)

            # compute reinforce loss
            adjusted_reward = R - baselines.detach()
            loss_reinforce = torch.sum(-log_pi * adjusted_reward, dim=1)
            loss_reinforce = torch.mean(loss_reinforce, dim=0)

            # sum up into a hybrid loss
            loss = loss_action * 100 + loss_baseline + loss_reinforce
            # loss =  loss_baseline + loss_reinforce

            # compute accuracy
            # compute accuracy
            correct = dis.float()
            # acc = 100 * (correct.sum() / len(y))
            acc = dis / len(y)

            print('avg dist is {}'.format(acc))

            # store
            losses.update(loss.data, x.size()[0])
            accs.update(acc.data, x.size()[0])

            # log to tensorboard
            if self.use_tensorboard:
                iteration = epoch * len(self.valid_loader) + i
                log_value('valid_loss', losses.avg, iteration)
                log_value('valid_acc', accs.avg, iteration)

        return losses.avg, accs.avg

    def test(self):
        """
        Test the model on the held-out test data.
        This function should only be called at the very
        end once the model has finished training.
        """
        is_output = 1
        if is_output:
            f = open('output.txt', 'w')

        # load the best checkpoint
        self.load_checkpoint(best=self.best)

        for i, (x, fixs, y, speeds, courses, scale_gt, indexSeq,
                frameEnd) in enumerate(self.test_loader):
            y = y.squeeze().float()

            if self.use_gpu:
                x, y, speeds, courses = x.cuda(), y.cuda(), speeds.cuda(
                ), courses.cuda()
            x, y, speeds, courses = Variable(x), Variable(y), Variable(
                speeds), Variable(courses)

            # duplicate 10 times
            x = x.repeat(self.M, 1, 1, 1, 1)
            # speeds = speeds.repeat(self.M, 1,1,)

            # initialize location vector and hidden state
            self.batch_size = x.shape[0]
            h_t, l_t = self.reset()

            # extract the glimpses
            log_pi = []
            baselines = []
            for t in range(self.num_glimpses - 1):
                # forward pass through model
                h_t, l_t, b_t, p = self.model(x, speeds, courses, l_t, h_t, t)

                # store
                baselines.append(b_t)
                log_pi.append(p)

            # last iteration
            h_t, l_t, b_t, l_t_final, p, scale = self.model(x,
                                                            speeds,
                                                            courses,
                                                            l_t,
                                                            h_t,
                                                            self.num_glimpses -
                                                            1,
                                                            last=True)
            log_pi.append(p)
            baselines.append(b_t)

            # convert list to tensors and reshape
            baselines = torch.stack(baselines).transpose(1, 0)
            log_pi = torch.stack(log_pi).transpose(1, 0)

            # average
            l_t_final = l_t_final.view(self.M, -1, l_t_final.shape[-1])
            l_t_final = torch.mean(l_t_final, dim=0)

            if is_output:
                for indexBlend in range(x.shape[0]):
                    loc = l_t_final[indexBlend, :].cpu().detach().numpy()
                    loc_gt = y[indexBlend].cpu().numpy()
                    scale_blend = scale[indexBlend].cpu().detach().numpy()
                    scale_gt_blend = scale_gt[indexBlend].cpu().detach().numpy(
                    )

                    line = '{:02} {:04d} {:.3f} {:.3f} {:.3f} {:.3f} {:.3f} {:.3f}\n'\
                        .format(indexSeq[indexBlend], frameEnd[indexBlend], loc_gt[0], loc_gt[1], loc[0], loc[1], float(scale_gt_blend), float(scale_blend))
                    print('seq: {:02}- frame: {:04d} - loc_gt_h: {:.3f} - loc_gt_w: {:.3f} - loc_h: {:.3f} - loc_w: {:.3f} '
                          '- scale_gt: {:.3f}- scale: {:.3f}'\
                        .format(indexSeq[indexBlend], frameEnd[indexBlend], loc_gt[0], loc_gt[1], loc[0], loc[1], float(scale_gt_blend), float(scale_blend)))
                    f.writelines(line)

            baselines = baselines.contiguous().view(self.M, -1,
                                                    baselines.shape[-1])
            baselines = torch.mean(baselines, dim=0)

            log_pi = log_pi.contiguous().view(self.M, -1, log_pi.shape[-1])
            log_pi = torch.mean(log_pi, dim=0)

            dis = 0
            R = torch.zeros(y.shape[0])
            for index in range(y.shape[0]):
                # get the distance of two locations
                distance = torch.sqrt(
                    torch.pow(l_t_final[index, 0] - y[index, 0], 2) +
                    torch.pow(l_t_final[index, 1] - y[index, 1], 2))
                dis = dis + distance
                # R[index] = distance < self.dis_R_thres
                R[index] = distance < self.dis_R_thres
            # R = locs
            R = R.unsqueeze(1).repeat(1, self.num_glimpses).to(self.device)

            # compute losses for differentiable modules
            loss_action = F.mse_loss(l_t_final, y)
            loss_baseline = F.mse_loss(baselines, R)

            # compute reinforce loss
            adjusted_reward = R - baselines.detach()
            loss_reinforce = torch.sum(-log_pi * adjusted_reward, dim=1)
            loss_reinforce = torch.mean(loss_reinforce, dim=0)

            # sum up into a hybrid loss
            loss = loss_action * 100 + loss_baseline + loss_reinforce
            # loss =  loss_baseline + loss_reinforce

            # compute accuracy
            acc = dis / len(y)

            print('avg dist is {}'.format(acc))

    def save_checkpoint(self, state, is_best):
        """
        Save a copy of the model so that it can be loaded at a future
        date. This function is used when the model is being evaluated
        on the test data.

        If this model has reached the best validation accuracy thus
        far, a seperate file with the suffix `best` is created.
        """
        # print("[*] Saving model to {}".format(self.ckpt_dir))

        filename = self.model_name + '_ckpt.pth.tar'
        ckpt_path = os.path.join(self.ckpt_dir, filename)
        torch.save(state, ckpt_path)

        if is_best:
            filename = self.model_name + '_model_best.pth.tar'
            shutil.copyfile(ckpt_path, os.path.join(self.ckpt_dir, filename))

    def load_checkpoint(self, best=False):
        """
        Load the best copy of a model. This is useful for 2 cases:

        - Resuming training with the most recent model checkpoint.
        - Loading the best validation model to evaluate on the test data.

        Params
        ------
        - best: if set to True, loads the best model. Use this if you want
          to evaluate your model on the test data. Else, set to False in
          which case the most recent version of the checkpoint is used.
        """
        print("[*] Loading model from {}".format(self.ckpt_dir))

        filename = self.model_name + '_ckpt.pth.tar'
        if best:
            filename = self.model_name + '_model_best.pth.tar'
        ckpt_path = os.path.join(self.ckpt_dir, filename)
        ckpt = torch.load(ckpt_path)

        # load variables from checkpoint
        self.start_epoch = ckpt['epoch']
        self.best_valid_acc = ckpt['best_valid_acc']
        self.model.load_state_dict(ckpt['model_state'])
        self.optimizer.load_state_dict(ckpt['optim_state'])

        if best:
            print("[*] Loaded {} checkpoint @ epoch {} "
                  "with best valid acc of {:.3f}".format(
                      filename, ckpt['epoch'], ckpt['best_valid_acc']))
        else:
            print("[*] Loaded {} checkpoint @ epoch {}".format(
                filename, ckpt['epoch']))
    def __init__(self, config, data_loader):
        """
        Construct a new Trainer instance.

        Args
        ----
        - config: object containing command line arguments.
        - data_loader: data iterator
        """
        self.config = config

        # glimpse network params
        self.patch_size = config.patch_size

        # core network params
        self.num_glimpses = config.num_glimpses
        self.hidden_size = config.hidden_size

        # reinforce params
        self.std = config.std
        self.M = config.M

        # data params
        if config.is_train:
            self.train_loader = data_loader[0]
            self.valid_loader = data_loader[1]

            image_tmp, _ = iter(self.train_loader).next()
            self.image_size = (image_tmp.shape[2], image_tmp.shape[3])

            if 'MNIST' in config.dataset_name or config.dataset_name == 'CIFAR':
                self.num_train = len(self.train_loader.sampler.indices)
                self.num_valid = len(self.valid_loader.sampler.indices)
            elif config.dataset_name == 'ImageNet':
                # the ImageNet cannot be sampled, otherwise this part will be wrong.
                self.num_train = 100000  #len(train_dataset) in data_loader.py, wrong: len(self.train_loader)
                self.num_valid = 10000  #len(self.valid_loader)
        else:
            self.test_loader = data_loader
            self.num_test = len(self.test_loader.dataset)

            image_tmp, _ = iter(self.test_loader).next()
            self.image_size = (image_tmp.shape[2], image_tmp.shape[3])

        # assign numer of channels and classes of images in this dataset, maybe there is more robust way
        if 'MNIST' in config.dataset_name:
            self.num_channels = 1
            self.num_classes = 10
        elif config.dataset_name == 'ImageNet':
            self.num_channels = 3
            self.num_classes = 1000
        elif config.dataset_name == 'CIFAR':
            self.num_channels = 3
            self.num_classes = 10

        # training params
        self.epochs = config.epochs
        self.start_epoch = 0
        self.momentum = config.momentum
        self.lr = config.init_lr
        self.loss_fun_baseline = config.loss_fun_baseline
        self.loss_fun_action = config.loss_fun_action
        self.weight_decay = config.weight_decay

        # misc params
        self.use_gpu = config.use_gpu
        self.best = config.best
        self.ckpt_dir = config.ckpt_dir
        self.logs_dir = config.logs_dir
        self.best_valid_acc = 0.
        self.best_train_acc = 0.
        self.counter = 0
        self.lr_patience = config.lr_patience
        self.train_patience = config.train_patience
        self.use_tensorboard = config.use_tensorboard
        self.resume = config.resume
        self.print_freq = config.print_freq
        self.plot_freq = config.plot_freq

        if config.use_gpu:
            self.model_name = 'ram_gpu_{0}_{1}_{2}x{3}_{4}_{5:1.2f}_{6}'.format(
                config.PBSarray_ID, config.num_glimpses, config.patch_size,
                config.patch_size, config.hidden_size, config.std,
                config.dropout)
        else:
            self.model_name = 'ram_{0}_{1}_{2}x{3}_{4}_{5:1.2f}_{6}'.format(
                config.PBSarray_ID, config.num_glimpses, config.patch_size,
                config.patch_size, config.hidden_size, config.std,
                config.dropout)

        self.plot_dir = './plots/' + self.model_name + '/'
        if not os.path.exists(self.plot_dir):
            os.makedirs(self.plot_dir, exist_ok=True)

        # configure tensorboard logging
        if self.use_tensorboard:
            print('[*] Saving tensorboard logs to {}'.format(tensorboard_dir))
            if not os.path.exists(tensorboard_dir):
                os.makedirs(tensorboard_dir)
            configure(tensorboard_dir)
            writer = SummaryWriter(logs_dir=self.logs_dir + self.model_name)

        # build DRAMBUTD model
        self.model = RecurrentAttention(self.patch_size, self.num_channels,
                                        self.image_size, self.std,
                                        self.hidden_size, self.num_classes,
                                        config)
        if self.use_gpu:
            self.model.cuda()

        print('[*] Number of model parameters: {:,}'.format(
            sum([p.data.nelement() for p in self.model.parameters()])))

        # initialize optimizer and scheduler
        if config.optimizer == 'SGD':
            self.optimizer = optim.SGD(self.model.parameters(),
                                       lr=self.lr,
                                       momentum=self.momentum,
                                       weight_decay=self.weight_decay)
        elif config.optimizer == 'ReduceLROnPlateau':
            self.scheduler = ReduceLROnPlateau(self.optimizer,
                                               'min',
                                               patience=self.lr_patience,
                                               weight_decay=self.weight_decay)
        elif config.optimizer == 'Adadelta':
            self.optimizer = optim.Adadelta(self.model.parameters(),
                                            weight_decay=self.weight_decay)
        elif config.optimizer == 'Adam':
            self.optimizer = optim.Adam(self.model.parameters(),
                                        lr=3e-4,
                                        weight_decay=self.weight_decay)
        elif config.optimizer == 'AdaBound':
            self.optimizer = adabound.AdaBound(self.model.parameters(),
                                               lr=3e-4,
                                               final_lr=0.1,
                                               weight_decay=self.weight_decay)
        elif config.optimizer == 'Ranger':
            self.optimizer = Ranger(self.model.parameters(),
                                    weight_decay=self.weight_decay)
class Trainer(object):
    """
    Trainer encapsulates all the logic necessary for
    training the Recurrent Attention Model.

    All hyperparameters are provided by the user in the
    config file.
    """
    def __init__(self, config, data_loader):
        """
        Construct a new Trainer instance.

        Args
        ----
        - config: object containing command line arguments.
        - data_loader: data iterator
        """
        self.config = config

        # glimpse network params
        self.patch_size = config.patch_size

        # core network params
        self.num_glimpses = config.num_glimpses
        self.hidden_size = config.hidden_size

        # reinforce params
        self.std = config.std
        self.M = config.M

        # data params
        if config.is_train:
            self.train_loader = data_loader[0]
            self.valid_loader = data_loader[1]

            image_tmp, _ = iter(self.train_loader).next()
            self.image_size = (image_tmp.shape[2], image_tmp.shape[3])

            if 'MNIST' in config.dataset_name or config.dataset_name == 'CIFAR':
                self.num_train = len(self.train_loader.sampler.indices)
                self.num_valid = len(self.valid_loader.sampler.indices)
            elif config.dataset_name == 'ImageNet':
                # the ImageNet cannot be sampled, otherwise this part will be wrong.
                self.num_train = 100000  #len(train_dataset) in data_loader.py, wrong: len(self.train_loader)
                self.num_valid = 10000  #len(self.valid_loader)
        else:
            self.test_loader = data_loader
            self.num_test = len(self.test_loader.dataset)

            image_tmp, _ = iter(self.test_loader).next()
            self.image_size = (image_tmp.shape[2], image_tmp.shape[3])

        # assign numer of channels and classes of images in this dataset, maybe there is more robust way
        if 'MNIST' in config.dataset_name:
            self.num_channels = 1
            self.num_classes = 10
        elif config.dataset_name == 'ImageNet':
            self.num_channels = 3
            self.num_classes = 1000
        elif config.dataset_name == 'CIFAR':
            self.num_channels = 3
            self.num_classes = 10

        # training params
        self.epochs = config.epochs
        self.start_epoch = 0
        self.momentum = config.momentum
        self.lr = config.init_lr
        self.loss_fun_baseline = config.loss_fun_baseline
        self.loss_fun_action = config.loss_fun_action
        self.weight_decay = config.weight_decay

        # misc params
        self.use_gpu = config.use_gpu
        self.best = config.best
        self.ckpt_dir = config.ckpt_dir
        self.logs_dir = config.logs_dir
        self.best_valid_acc = 0.
        self.best_train_acc = 0.
        self.counter = 0
        self.lr_patience = config.lr_patience
        self.train_patience = config.train_patience
        self.use_tensorboard = config.use_tensorboard
        self.resume = config.resume
        self.print_freq = config.print_freq
        self.plot_freq = config.plot_freq

        if config.use_gpu:
            self.model_name = 'ram_gpu_{0}_{1}_{2}x{3}_{4}_{5:1.2f}_{6}'.format(
                config.PBSarray_ID, config.num_glimpses, config.patch_size,
                config.patch_size, config.hidden_size, config.std,
                config.dropout)
        else:
            self.model_name = 'ram_{0}_{1}_{2}x{3}_{4}_{5:1.2f}_{6}'.format(
                config.PBSarray_ID, config.num_glimpses, config.patch_size,
                config.patch_size, config.hidden_size, config.std,
                config.dropout)

        self.plot_dir = './plots/' + self.model_name + '/'
        if not os.path.exists(self.plot_dir):
            os.makedirs(self.plot_dir, exist_ok=True)

        # configure tensorboard logging
        if self.use_tensorboard:
            print('[*] Saving tensorboard logs to {}'.format(tensorboard_dir))
            if not os.path.exists(tensorboard_dir):
                os.makedirs(tensorboard_dir)
            configure(tensorboard_dir)
            writer = SummaryWriter(logs_dir=self.logs_dir + self.model_name)

        # build DRAMBUTD model
        self.model = RecurrentAttention(self.patch_size, self.num_channels,
                                        self.image_size, self.std,
                                        self.hidden_size, self.num_classes,
                                        config)
        if self.use_gpu:
            self.model.cuda()

        print('[*] Number of model parameters: {:,}'.format(
            sum([p.data.nelement() for p in self.model.parameters()])))

        # initialize optimizer and scheduler
        if config.optimizer == 'SGD':
            self.optimizer = optim.SGD(self.model.parameters(),
                                       lr=self.lr,
                                       momentum=self.momentum,
                                       weight_decay=self.weight_decay)
        elif config.optimizer == 'ReduceLROnPlateau':
            self.scheduler = ReduceLROnPlateau(self.optimizer,
                                               'min',
                                               patience=self.lr_patience,
                                               weight_decay=self.weight_decay)
        elif config.optimizer == 'Adadelta':
            self.optimizer = optim.Adadelta(self.model.parameters(),
                                            weight_decay=self.weight_decay)
        elif config.optimizer == 'Adam':
            self.optimizer = optim.Adam(self.model.parameters(),
                                        lr=3e-4,
                                        weight_decay=self.weight_decay)
        elif config.optimizer == 'AdaBound':
            self.optimizer = adabound.AdaBound(self.model.parameters(),
                                               lr=3e-4,
                                               final_lr=0.1,
                                               weight_decay=self.weight_decay)
        elif config.optimizer == 'Ranger':
            self.optimizer = Ranger(self.model.parameters(),
                                    weight_decay=self.weight_decay)

    def reset(self, x, SM):
        """
        Initialize the hidden state of the core network
        and the location vector.

        This is called once every time a new minibatch
        `x` is introduced.
        """
        dtype = (torch.cuda.FloatTensor if self.use_gpu else torch.FloatTensor)
        #
        h_t2, l_t, SM_local_smooth = self.model.initialize(x, SM)

        # initialize hidden state 1 as 0 vector to avoid the directly classification from context
        h_t1 = torch.zeros(self.batch_size, self.hidden_size).type(dtype)

        cell_state1 = torch.zeros(self.batch_size,
                                  self.hidden_size).type(dtype)

        cell_state2 = torch.zeros(self.batch_size,
                                  self.hidden_size).type(dtype)

        return h_t1, h_t2, l_t, cell_state1, cell_state2, SM_local_smooth

    def train(self):
        """
        Train the model on the training set.

        A checkpoint of the model is saved after each epoch
        and if the validation accuracy is improved upon,
        a separate ckpt is created for use on the test set.
        """
        # load the most recent checkpoint
        if self.resume:
            self.load_checkpoint(best=False)

        print("\n[*] Train on {} samples, validate on {} samples".format(
            self.num_train, self.num_valid))

        for epoch in range(self.start_epoch, self.epochs):

            print('\nEpoch: {}/{} - LR: {:.6f}'.format(epoch + 1, self.epochs,
                                                       self.lr))

            # train for 1 epoch
            train_loss, train_acc = self.train_one_epoch(epoch)

            # evaluate on validation set
            valid_loss, valid_acc = self.validate(epoch)

            # # reduce lr if validation loss plateaus
            # self.scheduler.step(valid_loss)

            is_best_valid = valid_acc > self.best_valid_acc
            is_best_train = train_acc > self.best_train_acc
            msg1 = "train loss: {:.3f} - train acc: {:.3f} "
            msg2 = "- val loss: {:.3f} - val acc: {:.3f}"

            if is_best_train:
                msg1 += " [*]"

            if is_best_valid:
                self.counter = 0
                msg2 += " [*]"
            msg = msg1 + msg2
            print(msg.format(train_loss, train_acc, valid_loss, valid_acc))

            # check for improvement
            if not is_best_valid:
                self.counter += 1
            if self.counter > self.train_patience:
                print("[!] No improvement in a while, stopping training.")
                return
            self.best_valid_acc = max(valid_acc, self.best_valid_acc)
            self.best_train_acc = max(train_acc, self.best_train_acc)
            self.save_checkpoint(
                {
                    'epoch': epoch + 1,
                    'model_state': self.model.state_dict(),
                    'optim_state': self.optimizer.state_dict(),
                    'best_valid_acc': self.best_valid_acc,
                    'best_train_acc': self.best_train_acc,
                }, is_best_valid)

    def train_one_epoch(self, epoch):
        """
        Train the model for 1 epoch of the training set.

        An epoch corresponds to one full pass through the entire
        training set in successive mini-batches.

        This is used by train() and should not be called manually.
        """
        batch_time = AverageMeter()
        losses = AverageMeter()
        accs = AverageMeter()
        tic = time.time()
        with tqdm(total=self.num_train) as pbar:
            for i, (x_raw, y) in enumerate(self.train_loader):
                #
                if self.use_gpu:
                    x_raw, y = x_raw.cuda(), y.cuda()

                # detach images and their saliency maps
                x = x_raw[:, 0, ...].unsqueeze(1)
                SM = x_raw[:, 1, ...].unsqueeze(1)

                plot = False
                if (epoch % self.plot_freq == 0) and (i == 0):
                    plot = True

                # initialize location vector and hidden state
                self.batch_size = x.shape[0]
                h_t1, h_t2, l_t, cell_state1, cell_state2, SM_local_smooth = self.reset(
                    x, SM)
                # save images
                imgs = []
                imgs.append(x[0:9])

                # extract the glimpses
                locs = []
                log_pi = []
                baselines = []

                for t in range(self.num_glimpses - 1):
                    # forward pass through model
                    h_t1, h_t2, l_t, b_t, p, cell_state1, cell_state2, SM_local_smooth = self.model(
                        x, l_t, h_t1, h_t2, cell_state1, cell_state2, SM,
                        SM_local_smooth)

                    # store
                    locs.append(l_t[0:9])
                    baselines.append(b_t)
                    log_pi.append(p)

                # last iteration
                h_t1, h_t2, l_t, b_t, log_probas, p, cell_state1, cell_state2, SM_local_smooth = self.model(
                    x,
                    l_t,
                    h_t1,
                    h_t2,
                    cell_state1,
                    cell_state2,
                    SM,
                    SM_local_smooth,
                    last=True)

                log_pi.append(p)
                baselines.append(b_t)
                locs.append(l_t[0:9])

                # convert list to tensors and reshape
                baselines = torch.stack(baselines).transpose(1, 0)
                log_pi = torch.stack(log_pi).transpose(1, 0)

                # calculate reward
                predicted = torch.max(log_probas, 1)[1]
                if self.loss_fun_baseline == 'cross_entropy':
                    # cross_entroy_loss need a long, batch x 1 tensor as target but R
                    # also need to be subtracted by the baseline whose size is N x num_glimpse
                    R = (predicted.detach() == y).long()
                    # compute losses for differentiable modules
                    loss_action, loss_baseline = self.choose_loss_fun(
                        log_probas, y, baselines, R)
                    R = R.float().unsqueeze(1).repeat(1, self.num_glimpses)
                else:
                    R = (predicted.detach() == y).float()
                    R = R.unsqueeze(1).repeat(1, self.num_glimpses)
                    # compute losses for differentiable modules
                    loss_action, loss_baseline = self.choose_loss_fun(
                        log_probas, y, baselines, R)

                # loss_action = F.nll_loss(log_probas, y)
                # loss_baseline = F.mse_loss(baselines, R)

                # compute reinforce loss
                # summed over timesteps and averaged across batch
                adjusted_reward = R - baselines.detach()
                loss_reinforce = torch.sum(-log_pi * adjusted_reward, dim=1)
                loss_reinforce = torch.mean(loss_reinforce, dim=0)

                # sum up into a hybrid loss
                loss = loss_action + loss_baseline + loss_reinforce

                # compute accuracy
                correct = (predicted == y).float()
                acc = 100 * (correct.sum() / len(y))

                # store
                #losses.update(loss.data[0], x.size()[0])
                #accs.update(acc.data[0], x.size()[0])
                losses.update(loss.data.item(), x.size()[0])
                accs.update(acc.data.item(), x.size()[0])

                # compute gradients and update SGD
                self.optimizer.zero_grad()
                loss.backward()
                self.optimizer.step()

                # measure elapsed time
                toc = time.time()
                batch_time.update(toc - tic)

                pbar.set_description(
                    ("{:.1f}s - loss: {:.3f} - acc: {:.3f}".format(
                        (toc - tic), loss.data.item(), acc.data.item())))
                pbar.update(self.batch_size)

                # dump the glimpses and locs
                if plot:
                    if self.use_gpu:
                        imgs = [g.cpu().data.numpy().squeeze() for g in imgs]
                        locs = [l.cpu().data.numpy() for l in locs]
                    else:
                        imgs = [g.data.numpy().squeeze() for g in imgs]
                        locs = [l.data.numpy() for l in locs]
                    pickle.dump(
                        imgs,
                        open(self.plot_dir + "g_{}.p".format(epoch + 1), "wb"))
                    pickle.dump(
                        locs,
                        open(self.plot_dir + "l_{}.p".format(epoch + 1), "wb"))
                    sio.savemat(self.plot_dir +
                                "data_train_{}.mat".format(epoch + 1),
                                mdict={
                                    'location': locs,
                                    'patch': imgs
                                })

                # log to tensorboard
                if self.use_tensorboard:
                    iteration = epoch * len(self.train_loader) + i
                    writer.add_scalar('Loss/train', losses, iteration)
                    writer.add_scalar('Accuracy/train', accs, iteration)

            return losses.avg, accs.avg

    def validate(self, epoch):
        """
        Evaluate the model on the validation set.
        """
        losses = AverageMeter()
        accs = AverageMeter()

        for i, (x_raw, y) in enumerate(self.valid_loader):
            if self.use_gpu:
                x_raw, y = x_raw.cuda(), y.cuda()

            # detach images and their saliency maps
            x = x_raw[:, 0, ...].unsqueeze(1)
            SM = x_raw[:, 1, ...].unsqueeze(1)

            # duplicate M times
            x = x.repeat(self.M, 1, 1, 1)
            SM = SM.repeat(self.M, 1, 1, 1)
            # initialize location vector and hidden state
            self.batch_size = x.shape[0]
            h_t1, h_t2, l_t, cell_state1, cell_state2, SM_local_smooth = self.reset(
                x, SM)

            # extract the glimpses
            log_pi = []
            baselines = []

            for t in range(self.num_glimpses - 1):
                # forward pass through model
                h_t1, h_t2, l_t, b_t, p, cell_state1, cell_state2, SM_local_smooth = self.model(
                    x, l_t, h_t1, h_t2, cell_state1, cell_state2, SM,
                    SM_local_smooth)

                # store
                baselines.append(b_t)
                log_pi.append(p)

            # last iteration
            h_t1, h_t2, l_t, b_t, log_probas, p, cell_state1, cell_state2, SM_local_smooth = self.model(
                x,
                l_t,
                h_t1,
                h_t2,
                cell_state1,
                cell_state2,
                SM,
                SM_local_smooth,
                last=True)

            # store
            log_pi.append(p)
            baselines.append(b_t)

            # convert list to tensors and reshape
            baselines = torch.stack(baselines).transpose(1, 0)
            log_pi = torch.stack(log_pi).transpose(1, 0)

            # average
            log_probas = log_probas.view(self.M, -1, log_probas.shape[-1])
            log_probas = torch.mean(log_probas, dim=0)

            baselines = baselines.contiguous().view(self.M, -1,
                                                    baselines.shape[-1])
            baselines = torch.mean(baselines, dim=0)

            log_pi = log_pi.contiguous().view(self.M, -1, log_pi.shape[-1])
            log_pi = torch.mean(log_pi, dim=0)

            # calculate reward
            predicted = torch.max(log_probas, 1)[1]
            if self.loss_fun_baseline == 'cross_entropy':
                # cross_entroy_loss need a long, batch x 1 tensor as target but R
                # also need to be subtracted by the baseline whose size is N x num_glimpse
                R = (predicted.detach() == y).long()
                # compute losses for differentiable modules
                loss_action, loss_baseline = self.choose_loss_fun(
                    log_probas, y, baselines, R)
                R = R.float().unsqueeze(1).repeat(1, self.num_glimpses)
            else:
                R = (predicted.detach() == y).float()
                R = R.unsqueeze(1).repeat(1, self.num_glimpses)
                # compute losses for differentiable modules
                loss_action, loss_baseline = self.choose_loss_fun(
                    log_probas, y, baselines, R)

            # compute losses for differentiable modules
            # loss_action = F.nll_loss(log_probas, y)
            # loss_baseline = F.mse_loss(baselines, R)

            # compute reinforce loss
            adjusted_reward = R - baselines.detach()
            loss_reinforce = torch.sum(-log_pi * adjusted_reward, dim=1)
            loss_reinforce = torch.mean(loss_reinforce, dim=0)

            # sum up into a hybrid loss
            loss = loss_action + loss_baseline + loss_reinforce

            # compute accuracy
            correct = (predicted == y).float()
            acc = 100 * (correct.sum() / len(y))

            # store
            losses.update(loss.data.item(), x.size()[0])
            accs.update(acc.data.item(), x.size()[0])

            # log to tensorboard
            if self.use_tensorboard:
                iteration = epoch * len(self.valid_loader) + i
                writer.add_scalar('Accuracy/valid', accs, iteration)
                writer.add_scalar('Loss/valid', losses, iteration)

        return losses.avg, accs.avg

    def choose_loss_fun(self, log_probas, y, baselines, R):
        """
        use disctionary to save function handle
        replacement of swith-case

        be careful of the argument data type and shape!!!
        """
        loss_fun_pool = {
            'mse': F.mse_loss,
            'l1': F.l1_loss,
            'nll': F.nll_loss,
            'smooth_l1': F.smooth_l1_loss,
            'kl_div': F.kl_div,
            'cross_entropy': F.cross_entropy
        }

        return loss_fun_pool[self.loss_fun_action](
            log_probas, y), loss_fun_pool[self.loss_fun_baseline](baselines, R)

    def test(self):
        """
        Test the model on the held-out test data.
        This function should only be called at the very
        end once the model has finished training.
        """
        correct = 0

        # load the best checkpoint
        self.load_checkpoint(best=self.best)

        for i, (x, y) in enumerate(self.test_loader):
            if self.use_gpu:
                x, y = x.cuda(), y.cuda()
            x, y = Variable(x, volatile=True), Variable(y)

            # duplicate 10 times
            x = x.repeat(self.M, 1, 1, 1)

            # initialize location vector and hidden state
            self.batch_size = x.shape[0]
            h_t1, h_t2, l_t, cell_state1, cell_state2, SM_local_smooth = self.reset(
                x, SM)

            # save images and glimpse location
            locs = []
            imgs = []
            imgs.append(x[0:9])

            for t in range(self.num_glimpses - 1):
                # forward pass through model
                h_t1, h_t2, l_t, b_t, p, cell_state1, cell_state2, SM_local_smooth = self.model(
                    x, l_t, h_t1, h_t2, cell_state1, cell_state2, SM,
                    SM_local_smooth)

                # store
                locs.append(l_t[0:9])
                baselines.append(b_t)
                log_pi.append(p)

            # last iteration
            h_t1, h_t2, l_t, b_t, log_probas, p, cell_state1, cell_state2, SM_local_smooth = self.model(
                x,
                l_t,
                h_t1,
                h_t2,
                cell_state1,
                cell_state2,
                SM,
                SM_local_smooth,
                last=True)

            log_probas = log_probas.view(self.M, -1, log_probas.shape[-1])
            log_probas = torch.mean(log_probas, dim=0)

            pred = log_probas.data.max(1, keepdim=True)[1]
            correct += pred.eq(y.data.view_as(pred)).cpu().sum()

            # dump test data
            if self.use_gpu:
                imgs = [g.cpu().data.numpy().squeeze() for g in imgs]
                locs = [l.cpu().data.numpy() for l in locs]
            else:
                imgs = [g.data.numpy().squeeze() for g in imgs]
                locs = [l.data.numpy() for l in locs]

            pickle.dump(imgs, open(self.plot_dir + "g_test.p", "wb"))

            pickle.dump(locs, open(self.plot_dir + "l_test.p", "wb"))
            sio.savemat(self.plot_dir + "test_transient.mat",
                        mdict={'location': locs})

        perc = (100. * correct) / (self.num_test)
        error = 100 - perc
        print('[*] Test Acc: {}/{} ({:.2f}% - {:.2f}%)'.format(
            correct, self.num_test, perc, error))

    def save_checkpoint(self, state, is_best):
        """
        Save a copy of the model so that it can be loaded at a future
        date. This function is used when the model is being evaluated
        on the test data.

        If this model has reached the best validation accuracy thus
        far, a seperate file with the suffix `best` is created.
        """
        # print("[*] Saving model to {}".format(self.ckpt_dir))

        filename = self.model_name + '_ckpt.pth.tar'
        ckpt_path = os.path.join(self.ckpt_dir, filename)
        torch.save(state, ckpt_path)

        if is_best:
            filename = self.model_name + '_model_best.pth.tar'
            shutil.copyfile(ckpt_path, os.path.join(self.ckpt_dir, filename))

    def load_checkpoint(self, best=False):
        """
        Load the best copy of a model. This is useful for 2 cases:

        - Resuming training with the most recent model checkpoint.
        - Loading the best validation model to evaluate on the test data.

        Params
        ------
        - best: if set to True, loads the best model. Use this if you want
          to evaluate your model on the test data. Else, set to False in
          which case the most recent version of the checkpoint is used.
        """
        print("[*] Loading model from {}".format(self.ckpt_dir))

        filename = self.model_name + '_ckpt.pth.tar'
        if best:
            filename = self.model_name + '_model_best.pth.tar'
        ckpt_path = os.path.join(self.ckpt_dir, filename)
        ckpt = torch.load(ckpt_path)

        # load variables from checkpoint
        self.start_epoch = ckpt['epoch']
        self.best_valid_acc = ckpt['best_valid_acc']
        self.model.load_state_dict(ckpt['model_state'])
        self.optimizer.load_state_dict(ckpt['optim_state'])

        if best:
            print("[*] Loaded {} checkpoint @ epoch {} "
                  "with best valid acc of {:.3f}".format(
                      filename, ckpt['epoch'], ckpt['best_valid_acc']))
        else:
            print("[*] Loaded {} checkpoint @ epoch {}".format(
                filename, ckpt['epoch']))
Exemple #7
0
class Trainer(object):
    """
    Trainer encapsulates all the logic necessary for
    training the Recurrent Attention Model.

    All hyperparameters are provided by the user in the
    config file.
    """
    def __init__(self, config, data_loader):
        """
        Construct a new Trainer instance.

        Args
        ----
        - config: object containing command line arguments.
        - data_loader: data iterator
        """
        self.config = config

        # glimpse network params
        self.patch_size = config.patch_size
        self.glimpse_scale = config.glimpse_scale
        self.num_patches = config.num_patches
        self.loc_hidden = config.loc_hidden
        self.glimpse_hidden = config.glimpse_hidden

        # core network params
        self.num_glimpses = config.num_glimpses
        self.hidden_size = config.hidden_size

        # reinforce params
        self.std = config.std
        self.M = config.M

        # data params
        if config.is_train:
            self.train_loader = data_loader[0]
            self.valid_loader = data_loader[1]
            #self.num_train = len(self.train_loader.sampler.indices)
            #self.num_valid = len(self.valid_loader.sampler.indices)
        else:
            self.test_loader = data_loader
            self.num_test = len(self.test_loader.dataset)
        self.num_classes = 83
        self.num_channels = 1

        # training params
        self.epochs = config.epochs
        self.start_epoch = 0
        self.momentum = config.momentum
        self.lr = config.init_lr

        # misc params
        self.use_gpu = config.use_gpu
        self.best = config.best
        self.ckpt_dir = config.ckpt_dir
        self.logs_dir = config.logs_dir
        self.best_valid_acc = 0.
        self.counter = 0
        self.lr_patience = config.lr_patience
        self.train_patience = config.train_patience
        self.use_tensorboard = config.use_tensorboard

        self.trainSamplesSize = len(self.train_loader.trainSamples)

        self.resume = config.resume
        self.print_freq = config.print_freq
        self.plot_freq = config.plot_freq
        self.model_name = 'ram_{}_{}x{}_{}'.format(config.num_glimpses,
                                                   config.patch_size,
                                                   config.patch_size,
                                                   config.glimpse_scale)

        self.plot_dir = './plots/' + self.model_name + '/'
        if not os.path.exists(self.plot_dir):
            os.makedirs(self.plot_dir)

        # configure tensorboard logging
        if self.use_tensorboard:
            tensorboard_dir = self.logs_dir + self.model_name
            print('[*] Saving tensorboard logs to {}'.format(tensorboard_dir))
            if not os.path.exists(tensorboard_dir):
                os.makedirs(tensorboard_dir)
            configure(tensorboard_dir)

        # build RAM model
        self.model = RecurrentAttention(
            self.patch_size,
            self.num_patches,
            self.glimpse_scale,
            self.num_channels,
            self.loc_hidden,
            self.glimpse_hidden,
            self.std,
            self.hidden_size,
            self.num_classes,
        )
        if self.use_gpu:
            self.model.cuda()

        print('[*] Number of model parameters: {:,}'.format(
            sum([p.data.nelement() for p in self.model.parameters()])))

        # # initialize optimizer and scheduler
        # self.optimizer = optim.SGD(
        #     self.model.parameters(), lr=self.lr, momentum=self.momentum,
        # )
        # self.scheduler = ReduceLROnPlateau(
        #     self.optimizer, 'min', patience=self.lr_patience
        # )
        self.optimizer = optim.Adam(
            self.model.parameters(),
            lr=3e-4,
        )

    def reset(self):
        """
        Initialize the hidden state of the core network
        and the location vector.

        This is called once every time a new minibatch
        `x` is introduced.
        """
        dtype = (torch.cuda.FloatTensor if self.use_gpu else torch.FloatTensor)

        h_t = torch.zeros(self.batch_size, self.hidden_size)
        h_t[:, 1:2:self.hidden_size] = -1
        h_t = Variable(h_t).type(dtype)
        l_t = torch.ones(self.batch_size, 2)
        l_t[:, 0] *= -1
        l_t[:, 1] *= 0
        #l_t = torch.stack([, torch.zeros(self.batch_size,1)], dim=1)
        #print(l_t, l_t.shape)
        #l_t = torch.Tensor(self.batch_size, 2)#.uniform_(-1, 1)
        l_t = Variable(l_t).type(dtype)

        return h_t, l_t

    def train(self):
        """
        Train the model on the training set.

        A checkpoint of the model is saved after each epoch
        and if the validation accuracy is improved upon,
        a separate ckpt is created for use on the test set.
        """
        # load the most recent checkpoint
        if self.resume:
            self.load_checkpoint(best=False)

#        print("\n[*] Train on {} samples, validate on {} samples".format(
#            self.num_train, self.num_valid)
#        )

#self.trainDataset(1)

        for epoch in range(self.start_epoch, self.epochs):

            print('\nEpoch: {}/{} - LR: {:.6f}'.format(epoch + 1, self.epochs,
                                                       self.lr))

            # train for 1 epoch
            train_loss, train_acc = self.trainDataset(
                epoch)  #self.train_one_epoch(epoch)
            valid_loss, valid_acc = self.validateDataset(
                epoch)  #self.train_one_epoch(epoch)

            # evaluate on validation set
            #valid_loss, valid_acc = 0,0
            #valid_loss, valid_acc = self.validate(epoch)

            # # reduce lr if validation loss plateaus
            # self.scheduler.step(valid_loss)

            is_best = valid_acc > self.best_valid_acc
            msg1 = "train loss: {:.3f} - train acc: {:.3f} "
            msg2 = "val loss: {:.3f} - val acc: {:.3f}"
            if is_best:
                self.counter = 0
                msg2 += " [*]"
            msg = msg1 + msg2
            print(msg.format(train_loss, train_acc, valid_loss, valid_acc))

            # check for improvement
            if not is_best:
                self.counter += 1
            if self.counter > self.train_patience:
                print("[!] No improvement in a while, stopping training.")
                return
            self.best_valid_acc = max(valid_acc, self.best_valid_acc)
            self.save_checkpoint(
                {
                    'epoch': epoch + 1,
                    'model_state': self.model.state_dict(),
                    'optim_state': self.optimizer.state_dict(),
                    'best_valid_acc': self.best_valid_acc,
                }, is_best)

    def trainDataset(self, epoch):
        batch_time = AverageMeter()
        losses = AverageMeter()
        accs = AverageMeter()

        tic = time.time()
        with tqdm(total=self.trainSamplesSize) as pbar:
            self.train_loader.trainSet()
            #rew = torch.linspace(1,0.1,self.num_glimpses,dtype=float)
            i = 0
            while self.train_loader.hasNext():
                #if(i>2): break
                i += 1
                iterInfo = self.train_loader.getIteratorInfo()
                batch = self.train_loader.getNext()
                x = batch.imgs
                y = batch.gtTexts
                x, y = torch.tensor(x), torch.tensor(y)
                x = x[:, None, :, :]
                x = x.type(torch.FloatTensor)
                self.batch_size = x.shape[0]
                bmax = 0
                #print("y0",y)
                for ib in range(self.batch_size):
                    #print((y[ib] != 82))
                    #print((y[ib] != 82).nonzero())
                    bmax = max(bmax, len((y[ib] != 82).nonzero()))
                #print("bmax",bmax)
                y = y[:, :bmax]
                #print("y1",y)
                x, y = Variable(x), Variable(y)
                if self.use_gpu:
                    x, y = torch.tensor(
                        x.clone().detach()).cuda(), torch.tensor(
                            y.clone().detach()).cuda()
                #y=y-1 #adjusting to 0-25
                #x=x.T
                #X = x.numpy()
                #X = np.transpose(X, [0, 2, 3, 1])

                #plot_images(x, y)
                #print(x.shape,y)
                #print("\n",i,"*************************************")
                #
                #plot = False
                #if (epoch % self.plot_freq == 0) and (i == 0):
                #    plot = True

                # initialize location vector and hidden state
                h_t, l_t = self.reset()  #returns uniform(-1,1) x,y

                # save images
                imgs = []
                imgs.append(x[0:4])

                # extract the glimpses
                locs = []
                locs.append(l_t[0:4])
                log_pi = []
                baselines = []
                log_probas_list = []
                predicted_list = []
                R_list = []
                baselines_list = []
                #print("y0", y0)
                y0new = []
                y0 = []
                onecharglimpse = 4
                Rdist = []
                #print("no_glimpse", self.num_glimpses)
                #print(bmax*onecharglimpse)
                for t in range(bmax *
                               onecharglimpse):  #self.num_glimpses): #- 1):
                    # forward pass through model
                    # h_t, l_t, b_t, p = self.model(x, l_t, h_t)
                    if t % (onecharglimpse) == 0:
                        y0 = y[:, t // (onecharglimpse)]
                        #for b in range(self.batch_size):
                        #y0.append(y[b][t//(self.num_glimpses)])#first element for 8 glimpses in the batch #[:,t//sel...]Loop can be removed
                        #y0 = torch.tensor(y0)
                        y0new.append(y0)  #will be 32X22
                    #y0 = torch.tensor(y0).cuda()
                    l_t_Prev = l_t
                    h_t, l_t, b_t, log_probas1, p = self.model(x,
                                                               l_t,
                                                               h_t,
                                                               last=True)
                    if (t + 1) % (onecharglimpse) == 0:
                        log_probas_list.append(log_probas1)
                        predicted_list.append(torch.max(log_probas1,
                                                        1)[1])  #22X32X83
                        predicted1 = torch.max(log_probas1, 1)[1]
                        R1 = (predicted1.detach() == y0).float()
                        R1 = R1.unsqueeze(1).repeat(1, onecharglimpse)
                        R_list.append(R1)  #22X32X8
                    locs.append(l_t[0:4])
                    baselines.append(b_t)  #
                    #Rdist.append(-1*(torch.dist(l_t_Prev,l_t,2)))
                    Rdist.append(-1 * (torch.norm(l_t_Prev - l_t, p=2, dim=1)))
                    log_pi.append(p)
                #print(len(Rdist))
                #print(Rdist[0].shape)
                R1 = R_list[0]
                for R2 in R_list[1:]:
                    R1 = torch.cat((R1, R2), 1)
                #print(R1.shape)
                Rdist = torch.stack(Rdist).transpose(1, 0)
                baselines = torch.stack(baselines).transpose(1, 0)  #32X176
                #print(baselines.shape)
                log_pi = torch.stack(log_pi).transpose(1, 0)

                loss_action = F.nll_loss(log_probas_list[0], y0new[0])
                for l in range(1, len(y0new)):
                    loss_action += F.nll_loss(log_probas_list[l], y0new[l])
                loss_baseline = F.mse_loss(baselines, R1 + Rdist)
                #loss_baseline += F.mse_loss(baselines, Rdist)
                #print("predicted_list", predicted_list)
                # compute accuracy
                # compute reinforce loss
                # summed over timesteps and averaged across batch
                adjusted_reward = R1 + Rdist - baselines.detach()
                loss_reinforce = torch.sum(-log_pi * adjusted_reward, dim=1)
                loss_reinforce = torch.mean(loss_reinforce, dim=0)

                # sum up into a hybrid loss
                loss = loss_action + loss_baseline + loss_reinforce
                predicted_list, y0new = torch.cat(predicted_list), torch.cat(
                    y0new)
                correct = (predicted_list == y0new).float()
                acc = 100 * (correct.sum() / len(y0new))
                #acc = 100 * ((correct[(y0new != 82).nonzero()]).sum() / len((y0new != 82).nonzero()))
                # store
                #losses.update(loss.data[0], x.size()[0])
                #accs.update(acc.data[0], x.size()[0])
                #print("loss", loss)
                losses.update(loss.data.item(), x.size()[0])
                accs.update(acc.data.item(), x.size()[0])

                # compute gradients and update SGD
                self.optimizer.zero_grad()
                loss.backward()
                self.optimizer.step()

                # measure elapsed time
                toc = time.time()
                batch_time.update(toc - tic)

                pbar.set_description(
                    ("{:.1f}s - loss: {:.3f} - acc: {:.3f}".format(
                        (toc - tic), loss.data.item(), acc.data.item())))
                pbar.update(self.batch_size)

                # dump the glimpses and locs
                if (1):  #plot:
                    if self.use_gpu:
                        imgs = [g.cpu().data.numpy().squeeze() for g in imgs]
                        locs = [l.cpu().data.numpy() for l in locs]
                    else:
                        imgs = [g.data.numpy().squeeze() for g in imgs]
                        locs = [l.data.numpy() for l in locs]
                    pickle.dump(
                        imgs,
                        open(self.plot_dir + "g_{}.p".format(epoch + 1), "wb"))
                    pickle.dump(
                        locs,
                        open(self.plot_dir + "l_{}.p".format(epoch + 1), "wb"))
                if self.use_tensorboard:
                    iteration = epoch * len(self.train_loader) + i
                    log_value('train_loss', losses.avg, iteration)
                    log_value('train_acc', accs.avg, iteration)
        return losses.avg, accs.avg

    def validateDataset(self, epoch):
        """
        Evaluate the model on the validation set.
        """
        losses = AverageMeter()
        accs = AverageMeter()
        self.train_loader.validationSet()
        i = 0
        while self.train_loader.hasNext():
            #if(i>2): break
            i += 1
            iterInfo = self.train_loader.getIteratorInfo()
            batch = self.train_loader.getNext()

            x = batch.imgs
            y = batch.gtTexts
            x, y = torch.tensor(x), torch.tensor(y)
            self.batch_size = x.shape[0]
            bmax = 0
            #print("y0",y)
            for ib in range(self.batch_size):
                #print((y[ib] != 82))
                #print((y[ib] != 82).nonzero())
                bmax = max(bmax, len((y[ib] != 82).nonzero()))
            #print("bmax",bmax)
            y = y[:, :bmax]
            #x = x.type(torch.cuda.FloatTensor)
            x = x[:, None, :, :]
            x = x.type(torch.FloatTensor)
            x, y = Variable(x), Variable(y)
            if self.use_gpu:
                x, y = torch.tensor(x).cuda(), torch.tensor(y).cuda()

            # duplicate 10 times
            x = x.repeat(self.M, 1, 1, 1)
            self.batch_size = x.shape[0]

            #print(x.shape)
            # initialize location vector and hidden state

            h_t, l_t = self.reset()

            # extract the glimpses
            log_pi = []
            locs = []
            baselines = []
            log_probas_list = []
            predicted_list = []
            R_list = []

            y0new = []
            y0 = []
            onecharglimpse = 4
            Rdist = []
            #print("no_glimpse", self.num_glimpses)
            for t in range(bmax * onecharglimpse):
                # forward pass through model
                # h_t, l_t, b_t, p = self.model(x, l_t, h_t)
                if t % (onecharglimpse) == 0:
                    y0 = y[:, t // (onecharglimpse)]
                    '''for b in range(self.batch_size):
                        y0.append(y[b][t//(self.num_glimpses)])'''#first element for 8 glimpses in the batch #[:,t//sel...]Loop can be removed
                    y0new.append(y0)  #will be 32X22
                #y0 = torch.tensor(y0).cuda()
                l_t_Prev = l_t
                h_t, l_t, b_t, log_probas1, p = self.model(x,
                                                           l_t,
                                                           h_t,
                                                           last=True)
                log_probas1 = log_probas1.view(self.M, -1,
                                               log_probas1.shape[-1])
                log_probas1 = torch.mean(log_probas1, dim=0)
                if (t + 1) % (onecharglimpse) == 0:
                    log_probas_list.append(log_probas1)
                    #predicted_list.append(torch.max(log_probas1, 1)[1])
                    predicted1 = torch.max(log_probas1, 1)[1]
                    R1 = (predicted1.detach() == y0).float()
                    R1 = R1.unsqueeze(1).repeat(1, onecharglimpse)
                    R_list.append(R1)  #22X32X8
                    # store
                locs.append(l_t[0:9])
                baselines.append(b_t)
                log_pi.append(p)
                # average
                l_t_Prev = l_t_Prev.view(self.M, -1, l_t_Prev.shape[-1])
                l_t_Prev = torch.mean(l_t_Prev, dim=0)
                l_t1 = l_t.view(self.M, -1, l_t.shape[-1])
                l_t1 = torch.mean(l_t1, dim=0)
                Rdist.append(-1 * (torch.norm(l_t_Prev - l_t1, p=2, dim=1)))
            R1 = R_list[0]
            for R2 in R_list[1:]:
                R1 = torch.cat((R1, R2), 1)
            #print(R1.shape)
            Rdist = torch.stack(Rdist).transpose(1, 0)
            baselines = torch.stack(baselines).transpose(1, 0)
            log_pi = torch.stack(log_pi).transpose(1, 0)

            # compute losses for differentiable modules#adding 21 for 21 chars or less
            #log_probas_list[0] = log_probas_list[0].view(self.M, -1, log_probas_list[0].shape[-1])
            #log_probas_list[0] = torch.mean(log_probas_list[0], dim=0)
            loss_action = F.nll_loss(log_probas_list[0], y0new[0])
            predicted_list.append(torch.max(log_probas_list[0], 1)[1])
            for l in range(1, len(y0new)):
                #log_probas_list[l] = log_probas_list[l].view(self.M, -1, log_probas_list[0].shape[-1])
                #log_probas_list[l] = torch.mean(log_probas_list[l], dim=0)
                loss_action += F.nll_loss(log_probas_list[l], y0new[l])
                predicted_list.append(torch.max(log_probas_list[l], 1)[1])

            baselines = baselines.contiguous().view(self.M, -1,
                                                    baselines.shape[-1])

            baselines = torch.mean(baselines, dim=0)

            log_pi = log_pi.contiguous().view(self.M, -1, log_pi.shape[-1])
            log_pi = torch.mean(log_pi, dim=0)
            loss_baseline = F.mse_loss(baselines, R1 + Rdist)
            predicted_list, y0new = torch.cat(predicted_list), torch.cat(y0new)
            # compute reinforce loss
            adjusted_reward = R1 + Rdist - baselines.detach()
            loss_reinforce = torch.sum(-log_pi * adjusted_reward, dim=1)
            loss_reinforce = torch.mean(loss_reinforce, dim=0)

            # sum up into a hybrid loss
            loss = loss_action + loss_baseline + loss_reinforce

            # compute accuracy
            correct = (predicted_list == y0new).float()
            #acc = 100 * ((correct[(y0new != 82).nonzero()]).sum() / len((y0new != 82).nonzero()))
            acc = 100 * (correct.sum() / len(y0new))

            #gb changes*********************************************************************************************
            # store
            #losses.update(loss.data[0], x.size()[0])
            #accs.update(acc.data[0], x.size()[0])

            losses.update(loss.item(), x.size()[0])
            accs.update(acc.item(), x.size()[0])

            # log to tensorboard
            if self.use_tensorboard:
                iteration = epoch * len(self.valid_loader) + i
                log_value('valid_loss', losses.avg, iteration)
                log_value('valid_acc', accs.avg, iteration)

        return losses.avg, accs.avg

    def testData(self):
        """
        Test the model on the held-out test data.
        This function should only be called at the very
        end once the model has finished training.
        """
        correct = 0

        # load the best checkpoint
        self.load_checkpoint(best=self.best)

        i = 0
        while self.train_loader.hasNext():
            #if(i>2): break
            i += 1
            iterInfo = self.train_loader.getIteratorInfo()
            batch = self.train_loader.getNext()

            x = batch.imgs
            x = torch.tensor(x)
            #x = x.type(torch.cuda.FloatTensor)
            x = x[:, None, :, :]
            x = x.type(torch.FloatTensor)

            y = batch.gtTexts

            self.batch_size = x.shape[0]

            h_t, l_t = self.reset()

            log_pi = []
            locs = []
            baselines = []
            log_probas_list = []
            predicted_list = []
            R_list = []

            y0new = []
            y0 = []
            # extract the glimpses
            # extract the glimpses
            for t in range(self.num_glimpses):
                # forward pass through model
                if (t % 8 == 0):
                    y0 = []
                    for b in range(self.batch_size):
                        print(b, t // 8, t)
                        y0.append(y[b][t // 8])
                    y0new += y0

                y0 = torch.tensor(y0)

                h_t, l_t, b_t, log_probas1, p = self.model(x,
                                                           l_t,
                                                           h_t,
                                                           last=True)
                if (t + 1) % 8 == 1:
                    log_probas_list.append(log_probas1)
                    predicted_list.append(torch.max(log_probas1, 1)[1])

                locs.append(l_t[0:9])
                baselines.append(b_t)
                log_pi.append(p)

                predicted1 = torch.max(log_probas1, 1)[1]
                R1 = (predicted1.detach() == y0).float()
                R_list.append(R1)

                # store
                #baselines.append(b_t)
                #log_pi.append(p)

#            # last iteration
#            h_t, l_t, b_t, log_probas, p = self.model(
#                x, l_t, h_t, last=True
#            )
#            log_pi.append(p)
#            baselines.append(b_t)

            R = R_list
            R = torch.stack(R).transpose(1, 0)

            pred = log_probas_list.data.max(1, keepdim=True)[1]
            correct += pred.eq(y.data.view_as(pred)).cpu().sum()

        perc = (100. * correct) / (self.num_test)
        error = 100 - perc
        print('[*] Test Acc: {}/{} ({:.2f}% - {:.2f}%)'.format(
            correct, self.num_test, perc, error))

    def test(self):
        """
        Test the model on the held-out test data.
        This function should only be called at the very
        end once the model has finished training.
        """
        correct = 0

        # load the best checkpoint
        self.load_checkpoint(best=self.best)

        for i, (x, y) in enumerate(self.test_loader):
            if self.use_gpu:
                x, y = x.cuda(), y.cuda()
            x, y = Variable(x, volatile=True), Variable(y)
            y = y - 1
            # duplicate 10 times
            x = x.repeat(self.M, 1, 1, 1)

            # initialize location vector and hidden state
            self.batch_size = x.shape[0]
            h_t, l_t = self.reset()

            for t in range(self.num_glimpses - 1):
                # forward pass through model
                h_t, l_t, b_t, p = self.model(x, l_t, h_t)

            # last iteration
            h_t, l_t, b_t, log_probas, p = self.model(x, l_t, h_t, last=True)

            log_probas = log_probas.view(self.M, -1, log_probas.shape[-1])
            log_probas = torch.mean(log_probas, dim=0)

            pred = log_probas.data.max(1, keepdim=True)[1]
            correct += pred.eq(y.data.view_as(pred)).cpu().sum()

        perc = (100. * correct) / (self.num_test)
        error = 100 - perc
        print('[*] Test Acc: {}/{} ({:.2f}% - {:.2f}%)'.format(
            correct, self.num_test, perc, error))

    def save_checkpoint(self, state, is_best):
        """
        Save a copy of the model so that it can be loaded at a future
        date. This function is used when the model is being evaluated
        on the test data.

        If this model has reached the best validation accuracy thus
        far, a seperate file with the suffix `best` is created.
        """
        # print("[*] Saving model to {}".format(self.ckpt_dir))

        filename = self.model_name + '_ckpt.pth.tar'
        ckpt_path = os.path.join(self.ckpt_dir, filename)
        torch.save(state, ckpt_path)

        if is_best:
            filename = self.model_name + '_model_best.pth.tar'
            shutil.copyfile(ckpt_path, os.path.join(self.ckpt_dir, filename))

    def load_checkpoint(self, best=False):
        """
        Load the best copy of a model. This is useful for 2 cases:

        - Resuming training with the most recent model checkpoint.
        - Loading the best validation model to evaluate on the test data.

        Params
        ------
        - best: if set to True, loads the best model. Use this if you want
          to evaluate your model on the test data. Else, set to False in
          which case the most recent version of the checkpoint is used.
        """
        print("[*] Loading model from {}".format(self.ckpt_dir))

        filename = self.model_name + '_ckpt.pth.tar'
        if best:
            filename = self.model_name + '_model_best.pth.tar'
        ckpt_path = os.path.join(self.ckpt_dir, filename)
        ckpt = torch.load(ckpt_path)

        # load variables from checkpoint
        self.start_epoch = ckpt['epoch']
        self.best_valid_acc = ckpt['best_valid_acc']
        self.model.load_state_dict(ckpt['model_state'])
        self.optimizer.load_state_dict(ckpt['optim_state'])

        if best:
            print("[*] Loaded {} checkpoint @ epoch {} "
                  "with best valid acc of {:.3f}".format(
                      filename, ckpt['epoch'], ckpt['best_valid_acc']))
        else:
            print("[*] Loaded {} checkpoint @ epoch {}".format(
                filename, ckpt['epoch']))
class Trainer(object):
    """
    Trainer encapsulates all the logic necessary for
    training the Recurrent Attention Model.

    All hyperparameters are provided by the user in the
    config file.
    """
    def __init__(self, config, data_loader):
        """
        Construct a new Trainer instance.

        Args
        ----
        - config: object containing command line arguments.
        - data_loader: data iterator
        """
        self.config = config

        # glimpse network params
        self.patch_size = config.patch_size
        self.glimpse_scale = config.glimpse_scale
        self.num_patches = config.num_patches
        self.loc_hidden = config.loc_hidden
        self.glimpse_hidden = config.glimpse_hidden

        # core network params
        self.num_glimpses = config.num_glimpses
        self.hidden_size = config.hidden_size

        # reinforce params
        self.std = config.std
        self.M = config.M

        # data params
        if config.is_train:
            self.train_loader = data_loader[0]
            self.valid_loader = data_loader[1]
            self.num_train = len(self.train_loader.sampler.indices)
            self.num_valid = len(self.valid_loader.sampler.indices)
        else:
            self.test_loader = data_loader
            self.num_test = len(self.test_loader.dataset)
        self.num_classes = 10  #1000 #365
        self.num_channels = 3

        # training params
        self.epochs = config.epochs
        self.start_epoch = 0
        self.momentum = config.momentum
        self.lr = config.init_lr

        # misc params
        self.use_gpu = config.use_gpu
        self.best = config.best
        self.ckpt_file = config.ckpt_file
        self.ckpt_dir = config.ckpt_dir
        self.logs_dir = config.logs_dir
        self.best_valid_acc = 0.
        self.counter = 0
        self.lr_patience = config.lr_patience
        self.train_patience = config.train_patience
        self.use_tensorboard = config.use_tensorboard
        self.resume = config.resume
        self.print_freq = config.print_freq
        self.plot_freq = config.plot_freq
        self.model_name = 'ram_{}_{}x{}_{}_{}'.format(
            config.num_glimpses, config.patch_size, config.patch_size,
            config.glimpse_scale,
            datetime.date.today().strftime("%y-%m-%d"))

        self.plot_dir = './plots/' + self.model_name + '/'
        if not os.path.exists(self.plot_dir):
            os.makedirs(self.plot_dir)

        self.alternating_learning = False
        self.train_loc_flag = False

        # configure tensorboard logging
        if self.use_tensorboard:
            tensorboard_dir = self.logs_dir + self.model_name
            print('[*] Saving tensorboard logs to {}'.format(tensorboard_dir))
            if not os.path.exists(tensorboard_dir):
                os.makedirs(tensorboard_dir)
            configure(tensorboard_dir)

        # build RAM model
        self.model = RecurrentAttention(
            self.patch_size,
            self.num_patches,
            self.glimpse_scale,
            self.num_channels,
            self.loc_hidden,
            self.glimpse_hidden,
            self.std,
            self.hidden_size,
            self.num_classes,
        )
        if self.use_gpu:
            self.model.cuda()

        print('[*] Number of model parameters: {:,}'.format(
            sum([p.data.nelement() for p in self.model.parameters()])))

        # # initialize optimizer and scheduler
        # self.optimizer = optim.SGD(
        #     self.model.parameters(), lr=self.lr, momentum=self.momentum,
        # )
        # self.scheduler = ReduceLROnPlateau(
        #     self.optimizer, 'min', patience=self.lr_patience
        # )
        self.optimizer = optim.Adam(
            self.model.parameters(),
            lr=3e-4,
        )

    def reset(self):
        """
        Initialize the hidden state of the core network
        and the location vector.

        This is called once every time a new minibatch
        `x` is introduced.
        """
        dtype = (torch.cuda.FloatTensor if self.use_gpu else torch.FloatTensor)

        h_t = torch.zeros(self.batch_size, self.hidden_size)
        h_t = Variable(h_t).type(dtype)

        l_t = torch.Tensor(self.batch_size, 2).uniform_(-1, 1)
        l_t = Variable(l_t).type(dtype)

        return h_t, l_t

    def train(self):
        """
        Train the model on the training set.

        A checkpoint of the model is saved after each epoch
        and if the validation accuracy is improved upon,
        a separate ckpt is created for use on the test set.
        """
        # load the most recent checkpoint
        if self.resume:
            self.load_checkpoint(best=False)

        print("\n[*] Train on {} samples, validate on {} samples".format(
            self.num_train, self.num_valid))

        for epoch in range(self.start_epoch, self.epochs):

            print('\nEpoch: {}/{} - LR: {:.6f}'.format(epoch + 1, self.epochs,
                                                       self.lr))

            # train for 1 epoch
            train_loss, train_acc = self.train_one_epoch(epoch)

            # evaluate on validation set
            valid_loss, valid_acc = self.validate(epoch)

            # # reduce lr if validation loss plateaus
            # self.scheduler.step(valid_loss)

            is_best = valid_acc > self.best_valid_acc
            msg1 = "train loss: {:.3f} - train acc: {:.3f} "
            msg2 = "- val loss: {:.3f} - val acc: {:.3f}"
            if is_best:
                self.counter = 0
                msg2 += " [*]"
            msg = msg1 + msg2
            print(msg.format(train_loss, train_acc, valid_loss, valid_acc))

            # check for improvement
            if not is_best:
                self.counter += 1
            if self.alternating_learning:
                if self.counter >= 5:
                    self.train_loc_flag = not self.train_loc_flag
                    print(
                        "[!] No improvement in a while. Switch loss. Now training:",
                        ["ActionNet", "LocationNet"][self.train_loc_flag])
                    self.counter = 0
                    # if not self.train_loc_flag:
                    #     self.lr /= 5
                    #     print("[!] No improvement in a while. Decrease learning rate:", self.lr)

            if self.counter > self.train_patience:
                print("[!] No improvement in a while, stopping training.")
                return

            self.best_valid_acc = max(valid_acc, self.best_valid_acc)
            self.save_checkpoint(
                {
                    'epoch': epoch + 1,
                    'model_state': self.model.state_dict(),
                    'optim_state': self.optimizer.state_dict(),
                    'best_valid_acc': self.best_valid_acc,
                }, is_best)

    def train_one_epoch(self, epoch):
        """
        Train the model for 1 epoch of the training set.

        An epoch corresponds to one full pass through the entire
        training set in successive mini-batches.

        This is used by train() and should not be called manually.
        """
        batch_time = AverageMeter()
        losses = AverageMeter()
        accs = AverageMeter()

        tic = time.time()
        with tqdm(total=self.num_train) as pbar:
            for i, (x, y) in enumerate(self.train_loader):
                # for i, (x, y), f in enumerate(self.train_loader): # uncomment when using dataset with fixation proposals
                if self.use_gpu:
                    x, y = x.cuda(), y.cuda()
                x, y = Variable(x), Variable(y)

                plot = False
                if (epoch % self.plot_freq == 0) and (i == 0):
                    plot = True

                # initialize location vector and hidden state
                self.batch_size = x.shape[0]
                h_t, l_t = self.reset()

                # save images
                imgs = []
                imgs.append(x[0:9])

                # extract the glimpses
                locs = []
                log_pi = []
                log_probs = []
                baselines = []

                for t in range(self.num_glimpses):

                    locs.append(l_t)

                    h_t, l_t, b_t, log_probas, p = self.model(x, l_t, h_t)

                    #l_t = f[:,t].float() # uncomment when using dataset with fixation proposals

                    # store
                    baselines.append(b_t)
                    log_pi.append(p)
                    log_probs.append(log_probas[0:9])

                # convert list to tensors and reshape
                baselines = torch.stack(baselines)
                baselines = baselines.transpose(1, 0)
                #log_pi = torch.stack(log_pi).transpose(1, 0)  # only when using RL
                log_probs = torch.stack(log_probs).transpose(1, 0)

                # calculate reward
                predicted = torch.max(log_probas, 1)[1]
                R = (predicted.detach() == y).float()
                R = R.unsqueeze(1).repeat(1, self.num_glimpses)
                #R = R - self.get_loc_reward(locs)

                # compute losses for differentiable modules
                loss_action = F.nll_loss(log_probas, y)
                loss_baseline = F.mse_loss(baselines, R)

                # compute reinforce loss
                # summed over timesteps and averaged across batch
                adjusted_reward = R - baselines.detach()
                #loss_reinforce = torch.sum(-log_pi*adjusted_reward, dim=1)  # only when using RL
                #loss_reinforce = torch.mean(loss_reinforce, dim=0)  # only when using RL

                loss = loss_action
                #loss = loss_action + loss_baseline + loss_reinforce * 0.01  # only when using RL

                # sum up into a hybrid loss
                # if self.alternating_learning:
                #     if self.train_loc_flag:
                #         loss = loss_reinforce
                #     else:
                #         loss = loss_action
                # else:
                #     loss = loss_action + loss_baseline + loss_reinforce
                #loss = loss_action

                # compute accuracy
                correct = (predicted == y).float()
                acc = 100 * (correct.sum() / len(y))

                # store
                losses.update(loss.data.item(), x.size()[0])
                accs.update(acc.data.item(), x.size()[0])

                # compute gradients and update SGD
                self.optimizer.zero_grad()
                loss.backward()
                self.optimizer.step()

                # measure elapsed time
                toc = time.time()
                batch_time.update(toc - tic)

                pbar.set_description(
                    ("{:.1f}s - loss: {:.3f} - acc: {:.3f}".format(
                        (toc - tic), loss.data.item(), acc.data.item())))
                pbar.update(self.batch_size)

                # dump the glimpses and locs
                if plot:
                    if self.use_gpu:
                        imgs = [g.cpu().data.numpy().squeeze() for g in imgs]
                        locs = [l.cpu().data.numpy()[:9] for l in locs]
                        log_probs = [p.cpu().data.numpy() for p in log_probs]
                        ys = [g.cpu().data.numpy() for g in y[:9]]
                    else:
                        imgs = [g.data.numpy().squeeze() for g in imgs]
                        locs = [l.data.numpy()[:9] for l in locs]
                        log_probs = [p.data.numpy() for p in log_probs]
                        ys = [g.data.numpy() for g in y[:9]]
                    pickle.dump(
                        imgs,
                        open(self.plot_dir + "g_{}.p".format(epoch + 1), "wb"))
                    pickle.dump(
                        locs,
                        open(self.plot_dir + "l_{}.p".format(epoch + 1), "wb"))
                    pickle.dump(
                        log_probs,
                        open(self.plot_dir + "p_{}.p".format(epoch + 1), "wb"))
                    pickle.dump(
                        ys,
                        open(self.plot_dir + "y_{}.p".format(epoch + 1), "wb"))

                # log to tensorboard
                if self.use_tensorboard:
                    iteration = epoch * len(self.train_loader) + i
                    log_value('train_loss', losses.avg, iteration)
                    log_value('train_acc', accs.avg, iteration)

            return losses.avg, accs.avg

    def validate(self, epoch):
        """
        Evaluate the model on the validation set.
        """
        losses = AverageMeter()
        accs = AverageMeter()

        for i, (x, y) in enumerate(self.valid_loader):
            # for i, (x, y), f in enumerate(self.valid_loader): # uncomment when using dataset with fixation proposals
            if self.use_gpu:
                x, y = x.cuda(), y.cuda()
            x, y = Variable(x), Variable(y)

            # duplicate M times
            x = x.repeat(self.M, 1, 1, 1)

            # initialize location vector and hidden state
            self.batch_size = x.shape[0]
            h_t, l_t = self.reset()

            # extract the glimpses
            log_pi = []
            baselines = []
            for t in range(self.num_glimpses):

                # forward pass through model
                h_t, l_t, b_t, log_probas, p = self.model(x, l_t, h_t)
                #l_t = f[:,t].float()  # uncomment when using dataset with fixation proposals

                # store
                baselines.append(b_t)
                log_pi.append(p)

            # convert list to tensors and reshape
            baselines = torch.stack(baselines).transpose(1, 0)
            #log_pi = torch.stack(log_pi).transpose(1, 0)

            # average
            log_probas = log_probas.view(self.M, -1, log_probas.shape[-1])
            log_probas = torch.mean(log_probas, dim=0)

            baselines = baselines.contiguous().view(self.M, -1,
                                                    baselines.shape[-1])
            baselines = torch.mean(baselines, dim=0)

            #log_pi = log_pi.contiguous().view(
            #   self.M, -1, log_pi.shape[-1]
            #)
            #log_pi = torch.mean(log_pi, dim=0)

            # calculate reward
            predicted = torch.max(log_probas, 1)[1]
            R = (predicted.detach() == y).float()
            R = R.unsqueeze(1).repeat(1, self.num_glimpses)

            # compute losses for differentiable modules
            loss_action = F.nll_loss(log_probas, y)
            loss_baseline = F.mse_loss(baselines, R)

            # compute reinforce loss
            adjusted_reward = R - baselines.detach()
            #loss_reinforce = torch.sum(-log_pi*adjusted_reward, dim=1)  # only when using RL
            #loss_reinforce = torch.mean(loss_reinforce, dim=0)  # only when using RL

            # sum up into a hybrid loss
            #loss = loss_action + loss_baseline + loss_reinforce * 0.01  # only when using RL
            loss = loss_action

            # compute accuracy
            correct = (predicted == y).float()
            acc = 100 * (correct.sum() / len(y))

            # store
            losses.update(loss.data.item(), x.size()[0])
            accs.update(acc.data.item(), x.size()[0])

            # log to tensorboard
            if self.use_tensorboard:
                iteration = epoch * len(self.valid_loader) + i
                log_value('valid_loss', losses.avg, iteration)
                log_value('valid_acc', accs.avg, iteration)

        return losses.avg, accs.avg

    def test(self):
        """
        Test the model on the held-out test data.
        This function should only be called at the very
        end once the model has finished training.
        """
        torch.manual_seed(15)
        import pandas as pd

        correct = 0
        #corrects = np.zeros((self.num_glimpses))

        offset_x = []
        offset_y = []
        l_ts1 = []
        l_ts2 = []
        l_ts3 = []
        probas1 = []
        probas2 = []
        probas3 = []
        ys = []
        corrs = []

        # load the best checkpoint
        self.load_checkpoint(best=self.best, ckpt_file=self.ckpt_file)

        # for i, (x, y, offset) in enumerate(self.test_loader):
        for i, (x, y) in enumerate(self.test_loader):
            if self.use_gpu:
                x, y = x.cuda(), y.cuda()
            #x, y = Variable(x, volatile=True), Variable(y)
            x, y = Variable(x, requires_grad=False), Variable(y)
            # duplicate 10 times
            x = x.repeat(self.M, 1, 1, 1)

            # initialize location vector and hidden state
            self.batch_size = x.shape[0]
            h_t, l_t = self.reset()

            # extract the glimpses
            l_ts_temp = []
            probas_temp = []
            for t in range(self.num_glimpses):

                l_ts_temp.append((l_t + 1) / 2.0)

                # forward pass through model
                h_t, l_t, b_t, log_probas, p = self.model(x, l_t, h_t)

                # get acc after each glimpse
                log_probas = log_probas.view(self.M, -1, log_probas.shape[-1])
                log_probas = torch.mean(log_probas, dim=0)

                pred = log_probas.data.max(1, keepdim=True)[1]
                #corrects[t] += pred.eq(y.data.view_as(pred)).cpu().sum()

                probas_temp.append(log_probas)

            # offset_x.append(offset[0].cpu().detach())
            # offset_y.append(offset[1].cpu().detach())
            l_ts1.append(l_ts_temp[0].cpu().detach())
            l_ts2.append(l_ts_temp[1].cpu().detach())
            l_ts3.append(l_ts_temp[2].cpu().detach())
            probas1.append(probas_temp[0].cpu().detach())
            probas2.append(probas_temp[1].cpu().detach())
            probas3.append(probas_temp[2].cpu().detach())
            ys.append(y.cpu().detach())
            corrs.append(pred.eq(y.data.view_as(pred)).cpu().detach())

            # # last iteration
            # h_t, l_t, b_t, log_probas, p = self.model(
            #     x, l_t, h_t, last=True
            # )
            # # get acc after each glimpse
            # log_probas = log_probas.view(
            #     self.M, -1, log_probas.shape[-1]
            # )
            # log_probas = torch.mean(log_probas, dim=0)

            # pred = log_probas.data.max(1, keepdim=True)[1]
            # #corrects[t+1] += pred.eq(y.data.view_as(pred)).cpu().sum()
            correct += pred.eq(y.data.view_as(pred)).cpu().sum()

            #print(i+1,":",corrects/(i+1))

            # log_probas = log_probas.view(
            #     self.M, -1, log_probas.shape[-1]
            # )
            # log_probas = torch.mean(log_probas, dim=0)

            # pred = log_probas.data.max(1, keepdim=True)[1]
            # correct += pred.eq(y.data.view_as(pred)).cpu().sum()

            # if i >= 3:
            #     break

        # offset_x = torch.unsqueeze(torch.cat(offset_x), dim=1).float()
        # offset_y = torch.unsqueeze(torch.cat(offset_y), dim=1).float()
        l_ts1 = torch.cat(l_ts1)
        l_ts2 = torch.cat(l_ts2)
        l_ts3 = torch.cat(l_ts3)
        probas1 = torch.cat(probas1)
        probas2 = torch.cat(probas2)
        probas3 = torch.cat(probas3)
        ys = torch.unsqueeze(torch.cat(ys), dim=1).float()
        corrs = torch.cat(corrs).float()

        offset_x = torch.zeros(ys.shape)
        offset_y = torch.zeros(ys.shape)

        data = torch.cat([
            offset_x, offset_y, l_ts1, l_ts2, l_ts3, probas1, probas2, probas3,
            ys, corrs
        ],
                         dim=1)

        data = pd.DataFrame(data.numpy())
        data.to_csv('temp.csv')

        # perc = (100. * corrects[t+1]) / (self.num_test)
        perc = (100.0 * correct) / (self.num_test)
        error = 100 - perc
        print('[*] Test Acc: {}/{} ({:.2f}% - {:.2f}%)'.format(
            correct, self.num_test, perc, error))
        #print((100. * correct) / (self.num_test))

    def save_checkpoint(self, state, is_best):
        """
        Save a copy of the model so that it can be loaded at a future
        date. This function is used when the model is being evaluated
        on the test data.

        If this model has reached the best validation accuracy thus
        far, a seperate file with the suffix `best` is created.
        """
        print("[*] Saving model to {}".format(self.ckpt_dir))

        filename = self.model_name + '_ckpt.pth.tar'
        ckpt_path = os.path.join(self.ckpt_dir, filename)
        torch.save(state, ckpt_path)

        if is_best:
            filename = self.model_name + '_model_best.pth.tar'
            shutil.copyfile(ckpt_path, os.path.join(self.ckpt_dir, filename))

    def load_checkpoint(self, best=False, ckpt_file=None):
        """
        Load the best copy of a model. This is useful for 2 cases:

        - Resuming training with the most recent model checkpoint.
        - Loading the best validation model to evaluate on the test data.

        Params
        ------
        - best: if set to True, loads the best model. Use this if you want
          to evaluate your model on the test data. Else, set to False in
          which case the most recent version of the checkpoint is used.
        """

        if ckpt_file == None:
            if best:
                filename = self.model_name + '_model_best.pth.tar'
            else:
                filename = self.model_name + '_ckpt.pth.tar'
        else:
            filename = ckpt_file

        ckpt_path = os.path.join(self.ckpt_dir, filename)
        ckpt = torch.load(ckpt_path)

        # load variables from checkpoint
        self.start_epoch = ckpt['epoch']
        self.best_valid_acc = ckpt['best_valid_acc']
        self.model.load_state_dict(ckpt['model_state'])
        self.optimizer.load_state_dict(ckpt['optim_state'])

        if best:
            print("[*] Loaded {} checkpoint @ epoch {} "
                  "with best valid acc of {:.3f}".format(
                      filename, ckpt['epoch'], ckpt['best_valid_acc']))
        else:
            print("[*] Loaded {} checkpoint @ epoch {}".format(
                filename, ckpt['epoch']))

    def get_loc_reward(self, locs):
        """
        Calculates a negative reward if subsequent glimpses are very close to a previuos glimpse

        Args:
        ----
        locs: List of locations

        Returns:
        ----
        reward: A negativ reward signal

        """
        pdist = torch.nn.PairwiseDistance(p=2, )

        min_dists = []

        # calc distance from glimpse to all glimpses before
        for i, l_t in enumerate(locs):
            dists = []

            # use max possible distance for first glimpse as no glimpses before
            if i == 0:
                if self.use_gpu:
                    dists.append(
                        torch.ones(l_t.shape[0]).unsqueeze(1).cuda() *
                        float("inf"))
                else:
                    dists.append(
                        torch.ones(l_t.shape[0]).unsqueeze(1) * float("inf"))

            # get distance to all previous glimpses
            for l in locs[:i]:
                dists.append(pdist(l, l_t).unsqueeze(1))
            dists = torch.cat(dists, dim=1)
            dists = torch.min(dists, dim=1).values.unsqueeze(1)
            min_dists.append(dists)

        min_dists = torch.cat(min_dists, dim=1)
        reward = torch.exp(-10 * min_dists)

        return reward
Exemple #9
0
    def __init__(self, data_loader):
        """
        Construct a new Trainer instance.

        Args
        ----
        - config: object containing command line arguments.
        - data_loader: data iterator
        """
        # self.config = config

        # glimpse network params
        self.patch_size = 8
        self.glimpse_scale = 2
        self.num_patches = 3
        self.loc_hidden = 128
        self.glimpse_hidden = 128

        # core network params
        self.num_glimpses = 6
        self.hidden_size = 256

        # reinforce params
        self.std = 0.17
        self.M = 10

        # data params

        self.train_loader = data_loader[0]
        self.valid_loader = data_loader[1]
        self.num_train = len(self.train_loader.sampler.indices)
        self.num_valid = len(self.valid_loader.sampler.indices)


        self.num_classes = 27
        self.num_channels = 3

        # training params
        self.epochs = 200
        self.start_epoch = 0
        self.saturate_epoch = 150
        self.init_lr = 0.001
        self.min_lr = 1e-06
        self.decay_rate = (self.min_lr - self.init_lr) / (self.saturate_epoch)
        self.momentum = 0.5
        self.lr = self.init_lr

        # misc params
        self.use_gpu = False
        self.best = True
        # self.ckpt_dir = config.ckpt_dir
        # self.logs_dir = config.logs_dir
        self.best_valid_acc = 0.
        self.counter = 0
        # self.patience = config.patience
        # self.use_tensorboard = config.use_tensorboard
        # self.resume = config.resume
        # self.print_freq = config.print_freq
        # self.plot_freq = config.plot_freq


        # self.plot_dir = './plots/' + self.model_name + '/'
        # if not os.path.exists(self.plot_dir):
        #     os.makedirs(self.plot_dir)

        # configure tensorboard logging


        # build RAM model
        self.model = RecurrentAttention(
            self.patch_size, self.num_patches, self.glimpse_scale,
            self.num_channels, self.loc_hidden, self.glimpse_hidden,
            self.std, self.hidden_size, self.num_classes,
        )
        if self.use_gpu:
            self.model.cuda()

        print('[*] Number of model parameters: {:,}'.format(
            sum([p.data.nelement() for p in self.model.parameters()])))

        # initialize optimizer and scheduler
        self.optimizer = SGD(
            self.model.parameters(), lr=self.lr, momentum=self.momentum,
        )
        self.scheduler = ReduceLROnPlateau(self.optimizer, 'min')
Exemple #10
0
    def __init__(self, config, data_loader):
        """
		Construct a new Trainer instance.

		Args
		----
		- config: object containing command line arguments.
		- data_loader: data iterator
		"""
        self.config = config

        # glimpse network params
        self.patch_size = config.patch_size  ## default 8 = size of extracted patch at highest res
        self.glimpse_scale = config.glimpse_scale  ## default 2 = scale of successive patches
        self.num_patches = config.num_patches  ## 1 = no of downscaled patches per glimpse
        self.loc_hidden = config.loc_hidden  ## default 128 = hidden size of loc fc
        self.glimpse_hidden = config.glimpse_hidden  ## default =128, hidden size of glimpse fc

        # core network params
        self.num_glimpses = config.num_glimpses  ## default = 6 # of glimpses, i.e. BPTT iterations (BackPropagation Through Time)
        self.hidden_size = config.hidden_size  # default = 256, hidden size of rnn

        # reinforce params
        self.std = config.std  # default = 0.17 gaussian policy standard deviation
        self.M = config.M  ## default =10 Monte Carlo sampling for valid and test sets

        # data params
        if config.is_train:
            self.train_loader = data_loader[0]
            self.valid_loader = data_loader[1]
            self.num_train = len(self.train_loader.sampler.indices)
            self.num_valid = len(self.valid_loader.sampler.indices)
        else:
            self.test_loader = data_loader
            self.num_test = len(self.test_loader.dataset)
        self.num_classes = 10
        self.num_channels = 1

        # training params
        self.epochs = config.epochs
        self.start_epoch = 0
        self.momentum = config.momentum
        self.lr = config.init_lr

        # misc params
        self.use_gpu = config.use_gpu
        self.best = config.best
        self.ckpt_dir = config.ckpt_dir
        self.logs_dir = config.logs_dir
        self.best_valid_acc = 0.
        self.counter = 0
        self.lr_patience = config.lr_patience
        self.train_patience = config.train_patience
        self.use_tensorboard = config.use_tensorboard
        self.resume = config.resume
        self.print_freq = config.print_freq
        self.plot_freq = config.plot_freq
        self.model_name = 'ram_{}_{}x{}_{}'.format(config.num_glimpses,
                                                   config.patch_size,
                                                   config.patch_size,
                                                   config.glimpse_scale)

        self.plot_dir = './plots/' + self.model_name + '/'
        if not os.path.exists(self.plot_dir):
            os.makedirs(self.plot_dir)

        # configure tensorboard logging
        if self.use_tensorboard:
            tensorboard_dir = self.logs_dir + self.model_name
            print('[*] Saving tensorboard logs to {}'.format(tensorboard_dir))
            if not os.path.exists(tensorboard_dir):
                os.makedirs(tensorboard_dir)
            configure(tensorboard_dir)

        # build RAM model
        self.model = RecurrentAttention(
            self.patch_size,
            self.num_patches,
            self.glimpse_scale,
            self.num_channels,
            self.loc_hidden,
            self.glimpse_hidden,
            self.std,
            self.hidden_size,
            self.num_classes,
        )
        if self.use_gpu:
            self.model.cuda()

        print('[*] Number of model parameters: {:,}'.format(
            sum([p.data.nelement() for p in self.model.parameters()])))

        # # initialize optimizer and scheduler
        # self.optimizer = optim.SGD(
        #     self.model.parameters(), lr=self.lr, momentum=self.momentum,
        # )
        # self.scheduler = ReduceLROnPlateau(
        #     self.optimizer, 'min', patience=self.lr_patience
        # )
        self.optimizer = optim.Adam(
            self.model.parameters(),
            lr=3e-4,
        )
Exemple #11
0
class Trainer(object):
    """
    Trainer encapsulates all the logic necessary for
    training the Recurrent Attention Model.

    All hyperparameters are provided by the user in the
    config file.
    """
    def __init__(self, config, data_loader):
        """
        Construct a new Trainer instance.

        Args
        ----
        - config: object containing command line arguments.
        - data_loader: data iterator
        """
        self.config = config

        # glimpse network params
        self.patch_size = config.patch_size
        self.glimpse_scale = config.glimpse_scale
        self.num_patches = config.num_patches
        self.loc_hidden = config.loc_hidden
        self.glimpse_hidden = config.glimpse_hidden

        # core network params
        self.num_glimpses = config.num_glimpses
        self.hidden_size = config.hidden_size

        # reinforce params
        self.std = config.std
        self.M = config.M

        # data params
        if config.is_train:
            self.train_loader = data_loader[0]
            self.valid_loader = data_loader[1]
            self.num_train = len(self.train_loader.sampler.indices)
            self.num_valid = len(self.valid_loader.sampler.indices)
        else:
            self.test_loader = data_loader
            self.num_test = len(self.test_loader.sampler)
        self.num_classes = config.num_classes
        self.num_channels = 1
        #  self.num_channels = 1 if config.dataset == 'mnist' else 3

        # training params
        self.epochs = config.epochs
        self.start_epoch = 0
        self.momentum = config.momentum
        self.lr = config.init_lr

        # misc params
        self.use_gpu = config.use_gpu
        self.best = config.best
        self.ckpt_dir = config.ckpt_dir
        self.logs_dir = config.logs_dir
        self.best_valid_acc = 0.
        self.counter = 0
        self.lr_patience = config.lr_patience
        self.train_patience = config.train_patience
        self.use_tensorboard = config.use_tensorboard
        self.resume = config.resume
        self.print_freq = config.print_freq
        self.plot_freq = config.plot_freq
        self.image_size = config.image_size
        #  import pdb; pdb.set_trace()
        self.model_name = '{}-{}_gnum:{}_gsize:{}x{}_imgsize:{}x{}'.format(
            config.dataset, config.selected_attrs[0], config.num_glimpses,
            config.patch_size, config.patch_size, config.image_size,
            config.image_size)

        self.model_checkpoints = self.ckpt_dir + '/' + self.model_name + '/'
        if not os.path.exists(self.model_checkpoints):
            os.makedirs(self.model_checkpoints)

        self.plot_dir = './plots/' + self.model_name + '/'
        if not os.path.exists(self.plot_dir):
            os.makedirs(self.plot_dir)

        # configure tensorboard logging
        if self.use_tensorboard:
            tensorboard_dir = self.logs_dir + self.model_name
            print('[*] Saving tensorboard logs to {}'.format(tensorboard_dir))
            if not os.path.exists(tensorboard_dir):
                os.makedirs(tensorboard_dir)
            configure(tensorboard_dir)

        # build RAM model
        self.model = RecurrentAttention(
            self.patch_size,
            self.num_patches,
            self.glimpse_scale,
            self.num_channels,
            self.loc_hidden,
            self.glimpse_hidden,
            self.std,
            self.hidden_size,
            self.num_classes,
        )
        if self.use_gpu:
            self.model.cuda()

        print('[*] Number of model parameters: {:,}'.format(
            sum([p.data.nelement() for p in self.model.parameters()])))

        # # initialize optimizer and scheduler
        # self.optimizer = optim.SGD(
        #     self.model.parameters(), lr=self.lr, momentum=self.momentum,
        # )
        # self.scheduler = ReduceLROnPlateau(
        #     self.optimizer, 'min', patience=self.lr_patience
        # )
        self.optimizer = optim.Adam(
            self.model.parameters(),
            lr=3e-4,
        )

    def reset(self):
        """
        Initialize the hidden state of the core network
        and the location vector.

        This is called once every time a new minibatch
        `x` is introduced.
        """
        dtype = (torch.cuda.FloatTensor if self.use_gpu else torch.FloatTensor)

        h_t = torch.zeros(self.batch_size, self.hidden_size)
        h_t = Variable(h_t).type(dtype)

        l_t = torch.Tensor(self.batch_size, 2).uniform_(-1, 1)
        l_t = Variable(l_t).type(dtype)

        return h_t, l_t

    def train(self):
        """
        Train the model on the training set.

        A checkpoint of the model is saved after each epoch
        and if the validation accuracy is improved upon,
        a separate ckpt is created for use on the test set.
        """
        # load the most recent checkpoint
        if self.resume:
            # TODO !!!!!!!
            self.load_checkpoint(best=False)

        print("\n[*] Train on {} samples, validate on {} samples".format(
            self.num_train, self.num_valid))

        for epoch in range(self.start_epoch, self.epochs):

            print('\nEpoch: {}/{} - LR: {:.6f}'.format(epoch + 1, self.epochs,
                                                       self.lr))

            # train for 1 epoch
            train_loss, train_acc = self.train_one_epoch(epoch)

            # evaluate on validation set
            valid_loss, valid_acc = self.validate(epoch)

            # # reduce lr if validation loss plateaus
            # self.scheduler.step(valid_loss)

            is_best = valid_acc > self.best_valid_acc
            msg1 = "train loss: {:.3f} - train acc: {:.3f} "
            msg2 = "- val loss: {:.3f} - val acc: {:.3f}"
            if is_best:
                self.counter = 0
                msg2 += " [*]"
            msg = msg1 + msg2
            print(msg.format(train_loss, train_acc, valid_loss, valid_acc))

            # check for improvement
            if not is_best:
                self.counter += 1
            #  if self.counter > self.train_patience:
            #      print("[!] No improvement in a while, stopping training.")
            #      return
            self.best_valid_acc = max(valid_acc, self.best_valid_acc)
            self.save_checkpoint(
                {
                    'epoch': epoch + 1,
                    'model_state': self.model.state_dict(),
                    'optim_state': self.optimizer.state_dict(),
                    'best_valid_acc': self.best_valid_acc,
                }, is_best)

    def train_one_epoch(self, epoch):
        """
        Train the model for 1 epoch of the training set.

        An epoch corresponds to one full pass through the entire
        training set in successive mini-batches.

        This is used by train() and should not be called manually.
        """
        batch_time = AverageMeter()
        losses = AverageMeter()
        accs = AverageMeter()

        tic = time.time()
        with tqdm(total=self.num_train) as pbar:
            for i, (x, y) in enumerate(self.train_loader):
                if self.use_gpu:
                    x, y = x.cuda(), y.cuda()
                try:
                    x, y = Variable(x), Variable(y.squeeze(1))
                except:
                    x, y = Variable(x), Variable(y)

                plot = False
                if (epoch % self.plot_freq == 0) and (i == 0):
                    plot = True

                # initialize location vector and hidden state
                self.batch_size = x.shape[0]
                h_t, l_t = self.reset()

                # save images
                imgs = []
                imgs.append(x[0:9])

                # extract the glimpses
                locs = []
                log_pi = []
                log_p_targets = []
                kl_divs = []
                baselines = []
                for t in range(self.num_glimpses - 1):
                    # forward pass through model
                    h_t, l_t, b_t, p = self.model(x, l_t, h_t)

                    # store
                    locs.append(l_t[0:9])
                    baselines.append(b_t)
                    log_pi.append(p)

                # last iteration
                h_t, l_t, b_t, log_probas, p = self.model(x,
                                                          l_t,
                                                          h_t,
                                                          last=True)
                log_pi.append(p)
                baselines.append(b_t)
                locs.append(l_t[0:9])

                # convert list to tensors and reshape
                baselines = torch.stack(baselines).transpose(1, 0)
                log_pi = torch.stack(log_pi).transpose(1, 0)

                # calculate reward
                predicted = torch.max(log_probas, 1)[1]
                R = (predicted.detach() == y).float()
                R = R.unsqueeze(1).repeat(1, self.num_glimpses)

                # compute losses for differentiable modules
                loss_action = F.nll_loss(log_probas, y)
                loss_baseline = F.mse_loss(baselines, R)

                # compute reinforce loss
                # summed over timesteps and averaged across batch
                adjusted_reward = R - baselines.detach()
                loss_reinforce = torch.sum(-log_pi * adjusted_reward, dim=1)
                loss_reinforce = torch.mean(loss_reinforce, dim=0)

                # sum up into a hybrid loss
                loss = loss_action + loss_baseline + loss_reinforce

                # compute accuracy
                correct = (predicted == y).float()
                acc = 100 * (correct.sum() / len(y))

                # store
                losses.update(loss.item(), x.size()[0])
                accs.update(acc.item(), x.size()[0])

                # compute gradients and update SGD
                self.optimizer.zero_grad()
                loss.backward()
                self.optimizer.step()

                # measure elapsed time
                toc = time.time()
                batch_time.update(toc - tic)

                pbar.set_description(
                    ("{:.1f}s - loss: {:.3f} - acc: {:.3f}".format(
                        (toc - tic), loss.item(), acc.item())))
                pbar.update(self.batch_size)

                # dump the glimpses and locs
                if plot:
                    if self.use_gpu:
                        imgs = [g.cpu().data.numpy().squeeze() for g in imgs]
                        locs = [l.cpu().data.numpy() for l in locs]
                    else:
                        imgs = [g.data.numpy().squeeze() for g in imgs]
                        locs = [l.data.numpy() for l in locs]
                    pickle.dump(
                        imgs,
                        open(self.plot_dir + "g_{}.p".format(epoch + 1), "wb"))
                    pickle.dump(
                        locs,
                        open(self.plot_dir + "l_{}.p".format(epoch + 1), "wb"))

                # log to tensorboard
                if self.use_tensorboard:
                    iteration = epoch * len(self.train_loader) + i
                    log_value('train_loss', losses.avg, iteration)
                    log_value('train_acc', accs.avg, iteration)

            return losses.avg, accs.avg

    def validate(self, epoch):
        """
        Evaluate the model on the validation set.
        """
        losses = AverageMeter()
        accs = AverageMeter()

        for i, (x, y) in enumerate(self.valid_loader):
            if self.use_gpu:
                x, y = x.cuda(), y.cuda()
            try:
                x, y = Variable(x), Variable(y.squeeze(1))
            except:
                x, y = Variable(x), Variable(y)

            # duplicate 10 times
            x = x.repeat(self.M, 1, 1, 1)

            # initialize location vector and hidden state
            self.batch_size = x.shape[0]
            h_t, l_t = self.reset()

            # extract the glimpses
            log_pi = []
            baselines = []
            for t in range(self.num_glimpses - 1):
                # forward pass through model
                h_t, l_t, b_t, p = self.model(x, l_t, h_t)

                # store
                baselines.append(b_t)
                log_pi.append(p)

            # last iteration
            h_t, l_t, b_t, log_probas, p = self.model(x, l_t, h_t, last=True)
            log_pi.append(p)
            baselines.append(b_t)

            # convert list to tensors and reshape
            baselines = torch.stack(baselines).transpose(1, 0)
            log_pi = torch.stack(log_pi).transpose(1, 0)

            # average
            log_probas = log_probas.view(self.M, -1, log_probas.shape[-1])
            log_probas = torch.mean(log_probas, dim=0)

            baselines = baselines.contiguous().view(self.M, -1,
                                                    baselines.shape[-1])
            baselines = torch.mean(baselines, dim=0)

            log_pi = log_pi.contiguous().view(self.M, -1, log_pi.shape[-1])
            log_pi = torch.mean(log_pi, dim=0)

            # calculate reward
            predicted = torch.max(log_probas, 1)[1]
            R = (predicted.detach() == y).float()
            R = R.unsqueeze(1).repeat(1, self.num_glimpses)

            # compute losses for differentiable modules
            loss_action = F.nll_loss(log_probas, y)
            loss_baseline = F.mse_loss(baselines, R)

            # compute reinforce loss
            adjusted_reward = R - baselines.detach()
            loss_reinforce = torch.sum(-log_pi * adjusted_reward, dim=1)
            loss_reinforce = torch.mean(loss_reinforce, dim=0)

            # sum up into a hybrid loss
            loss = loss_action + loss_baseline + loss_reinforce

            # compute accuracy
            correct = (predicted == y).float()
            acc = 100 * (correct.sum() / len(y))

            # store
            losses.update(loss.item(), x.size()[0])
            accs.update(acc.item(), x.size()[0])

            # log to tensorboard
            if self.use_tensorboard:
                iteration = epoch * len(self.valid_loader) + i
                log_value('valid_loss', losses.avg, iteration)
                log_value('valid_acc', accs.avg, iteration)

        return losses.avg, accs.avg

    def test(self):
        """
        Test the model on the held-out test data.
        This function should only be called at the very
        end once the model has finished training.
        """
        # load the best checkpoint

        epoch = 1
        f1s = []
        accs = []

        print("Testing trained model with ", len(self.test_loader),
              " examples")
        while (True):
            try:
                self.load_checkpoint(epoch=epoch)
            except:
                break

            correct = 0
            f1_correct = 0
            f1_reported = 0
            f1_relevant = 0
            for i, (x, y) in enumerate(self.test_loader):
                with torch.no_grad():
                    if self.use_gpu:
                        x, y = x.cuda(), y.cuda()
                    try:
                        x, y = Variable(x), Variable(y.squeeze(1))
                    except:
                        x, y = Variable(x), Variable(y)

                    # duplicate 10 times
                    x = x.repeat(self.M, 1, 1, 1)

                    # initialize location vector and hidden state
                    self.batch_size = x.shape[0]
                    h_t, l_t = self.reset()

                    # extract the glimpses
                    for t in range(self.num_glimpses - 1):
                        # forward pass through model
                        h_t, l_t, b_t, p = self.model(x, l_t, h_t)

                    # last iteration
                    h_t, l_t, b_t, log_probas, p = self.model(x,
                                                              l_t,
                                                              h_t,
                                                              last=True)

                    log_probas = log_probas.view(self.M, -1,
                                                 log_probas.shape[-1])
                    log_probas = torch.mean(log_probas, dim=0)

                    pred = log_probas.data.max(1, keepdim=True)[1]
                    correct += pred.eq(y.data.view_as(pred)).cpu().sum()

                    preds = pred.flatten()
                    total_reported = pred.sum()
                    total_relevant = y.sum()

                    preds[preds == 0] = 2
                    total_correct = preds.eq(y.cpu()).sum()

                    f1_correct += total_correct
                    f1_reported += total_reported
                    f1_relevant += total_relevant

            perc = (100. * correct) / (self.num_test)
            error = 100 - perc
            precision = float(f1_correct) / float(f1_reported)
            recall = float(f1_correct) / float(f1_relevant)

            f1_score = 2 * (precision * recall / (precision + recall))
            accuracy = float(correct) / float(self.num_test)

            print('[*] Test Acc: {}/{} ({:.2f}% - {:.2f}%) : F1 Score - {} \n'.
                  format(correct, self.num_test, perc, error, f1_score))
            epoch += 1
            f1s.append(f1_score)
            accs.append(accuracy)

        fig, ax = plt.subplots()
        ax.plot(np.arange(len(f1s)), f1s)
        ax.plot(np.arange(len(accs)), accs)
        plt.show()

    def kde(self):
        epoch = 5
        print("plotting kde of trained model with ", len(self.test_loader),
              " examples")
        self.load_checkpoint(epoch=epoch)
        fig, ax = plt.subplots()

        #  for key, value in model_preds[model].items():
        #      fly_kde = value[fly_idx, :, :2]
        #      t_5_x.append(fly_kde[timestep, 0])

        #      t_5_y.append(fly_kde[timestep, 1])
        img_min = 0
        img_max = self.image_size

        #  m1 = np.array(t_5_x)
        #  m2 = np.array(t_5_y)
        X, Y = np.mgrid[img_min:img_max:100j, img_min:img_max:100j]
        positions = np.vstack([X.ravel(), Y.ravel()])

        all_locations = torch.Tensor([])
        for i, (x, y) in enumerate(self.test_loader):
            with torch.no_grad():
                if self.use_gpu:
                    x, y = x.cuda(), y.cuda()
                try:
                    x, y = Variable(x), Variable(y.squeeze(1))
                except:
                    x, y = Variable(x), Variable(y)

                # duplicate 10 times
                #  x = x.repeat(self.M, 1, 1, 1)

                # initialize location vector and hidden state
                self.batch_size = x.shape[0]
                h_t, l_t = self.reset()

                # extract the glimpses
                for t in range(self.num_glimpses - 1):
                    # forward pass through model
                    h_t, l_t, b_t, p = self.model(x, l_t, h_t)

                # last iteration
                h_t, l_t, b_t, log_probas, p = self.model(x,
                                                          l_t,
                                                          h_t,
                                                          last=True)

                all_locations = torch.cat((all_locations, l_t))

        coords = denormalize(self.image_size, all_locations)
        coords = coords + (self.patch_size / 2)
        values = torch.stack((coords[:, 0], (self.image_size - coords[:, 1])))
        kernel = stats.gaussian_kde(values)
        Z = np.reshape(kernel(positions).T, X.shape)
        im = ax.imshow(np.rot90(Z),
                       cmap=plt.cm.gist_earth_r,
                       extent=[0, 256, 0, 256])
        plt.show()

    def save_checkpoint(self, state, is_best):
        """
        Save a copy of the model so that it can be loaded at a future
        date. This function is used when the model is being evaluated
        on the test data.

        If this model has reached the best validation accuracy thus
        far, a seperate file with the suffix `best` is created.
        """
        # print("[*] Saving model to {}".format(self.ckpt_dir))

        filename = self.model_name + '_' + str(
            state['epoch']) + '_ckpt.pth.tar'
        ckpt_path = os.path.join(self.model_checkpoints, filename)
        torch.save(state, ckpt_path)

        if is_best:
            filename = self.model_name + '_model_best.pth.tar'
            shutil.copyfile(ckpt_path,
                            os.path.join(self.model_checkpoints, filename))

    def load_checkpoint(self, epoch=1):
        """
        Load the best copy of a model. This is useful for 2 cases:

        - Resuming training with the most recent model checkpoint.
        - Loading the best validation model to evaluate on the test data.

        Params
        ------
        - best: if set to True, loads the best model. Use this if you want
          to evaluate your model on the test data. Else, set to False in
          which case the most recent version of the checkpoint is used.
        """
        print("[*] Loading model from {}".format(self.model_checkpoints))

        filename = self.model_name + '_' + str(epoch) + '_ckpt.pth.tar'
        #  if best:
        #      filename = self.model_name + '_model_best.pth.tar'
        ckpt_path = os.path.join(self.model_checkpoints, filename)
        ckpt = torch.load(ckpt_path)

        # load variables from checkpoint
        self.start_epoch = ckpt['epoch']
        self.best_valid_acc = ckpt['best_valid_acc']
        self.model.load_state_dict(ckpt['model_state'])
        self.optimizer.load_state_dict(ckpt['optim_state'])

        #  if best:
        #      print(
        #          "[*] Loaded {} checkpoint @ epoch {} "
        #          "with best valid acc of {:.3f}".format(
        #              filename, ckpt['epoch'], ckpt['best_valid_acc'])
        #      )
        #  else:
        print("[*] Loaded {} checkpoint @ epoch {}".format(
            filename, ckpt['epoch']))
Exemple #12
0
    def __init__(self, config, data_loader):
        """
        Construct a new Trainer instance.

        Args
        ----
        - config: object containing command line arguments.
        - data_loader: data iterator
        """
        self.config = config

        # glimpse network params
        self.patch_size = config.patch_size
        self.glimpse_scale = config.glimpse_scale
        self.num_patches = config.num_patches
        self.loc_hidden = config.loc_hidden
        self.glimpse_hidden = config.glimpse_hidden

        # core network params
        self.num_glimpses = config.num_glimpses
        self.hidden_size = config.hidden_size

        # reinforce params
        self.std = config.std
        self.M = config.M

        # data params
        if config.is_train:
            
            self.train_loader = data_loader.train_loader
            self.valid_loader = data_loader.test_loader

            self.num_train = len(self.train_loader.dataset)
            self.num_valid = len(self.valid_loader.dataset)
        else:
            self.test_loader = data_loader.test_loader
            self.num_test = len(self.test_loader.dataset)
        
        self.num_classes = data_loader.output_size
        self.num_channels = 1

        # training params
        self.epochs = config.epochs
        self.start_epoch = 0
        self.momentum = config.momentum
        self.lr = config.init_lr

        # misc params
        self.use_gpu = config.cuda
        self.best = config.best
        self.ckpt_dir = config.ckpt_dir
        self.logs_dir = config.logs_dir
        self.best_valid_acc = 0.
        self.counter = 0
        self.lr_patience = config.lr_patience
        self.train_patience = config.train_patience
        self.use_tensorboard = config.use_tensorboard
        self.resume = config.resume
        self.print_freq = config.print_freq
        self.plot_freq = config.plot_freq
        self.model_name = 'ram_{}_{}x{}_{}'.format(
            config.num_glimpses, config.patch_size,
            config.patch_size, config.glimpse_scale
        )

        # build RAM model
        self.model = RecurrentAttention(
            self.patch_size, self.num_patches, self.glimpse_scale,
            self.num_channels, self.loc_hidden, self.glimpse_hidden,
            self.std, self.hidden_size, self.num_classes,
        )

        if self.use_gpu:
            self.model.cuda()
            self.model = torch.nn.DataParallel(self.model, device_ids=[0, 1])

        self.plot_dir = './plots/' + self.model_name + '/'
        if not os.path.exists(self.plot_dir):
            os.makedirs(self.plot_dir)

        # configure tensorboard logging
        if self.use_tensorboard:
            tensorboard_dir = self.logs_dir + self.model_name
            print('[*] Saving tensorboard logs to {}'.format(tensorboard_dir))
            if not os.path.exists(tensorboard_dir):
                os.makedirs(tensorboard_dir)
            configure(tensorboard_dir)


        # visdom
        self.use_visdom = config.visdom
        self.visdom_images = config.visdom_images
        self.visdom_env = config.visdom_env

        if self.use_visdom:
            self.grapher = Grapher('visdom',
                            env=self.visdom_env,
                            server=config.visdom_url,
                            port=config.visdom_port)

        print('[*] Number of model parameters: {:,}'.format(
            sum([p.data.nelement() for p in self.model.parameters()])))

        self.optimizer = optim.Adam(
            self.model.parameters(), lr=3e-4,
        )
Exemple #13
0
    def __init__(self, config, data_loader):
        self.config = config

        self.train_loader = data_loader[0]
        self.valid_loader = data_loader[1]
        self.num_train = len(self.train_loader.sampler.indices)
        self.num_valid = len(self.valid_loader.sampler.indices)

        if self.config.binary:
            self.num_classes = 2
            self.loss = F.binary_cross_entropy_with_logits
            namestr2 = 'binary'
            namestr3 = str(config.cat)
        else:
            self.num_classes = 8
            self.loss = F.cross_entropy
            namestr2 = 'all'
            namestr3 = 'nocat'

        # model params
        if self.config.semi:
            namestr1 = 'semi'
            self.input_dim = self.num_classes + 3
        else:
            namestr1 = 'fully'
            self.input_dim = 3
        self.output_dim = self.num_classes
        self.mask_rate = config.mask_rate
        self.pc_size = config.pc_size

        # training params
        self.epochs = config.epochs
        self.start_epoch = 0
        self.lr = config.init_lr

        # misc params
        self.use_gpu = config.use_gpu
        self.ckpt_dir = config.ckpt_dir
        self.best = config.best
        self.best_mIoU = -10
        self.best_acc = 0
        self.counter = 0
        self.lr_patience = config.lr_patience
        self.train_patience = config.train_patience
        self.resume = config.resume
        self.model_name = 'dsseg_{}_{}_{}_{}_{}'.format(
            config.init_lr, namestr1, namestr2, namestr3, config.pc_size)

        # attention parameters
        # glimpse params
        # glimpse network params
        self.num_points_per_pc = config.pc_size
        self.num_points_per_sample = config.num_points_per_sample
        self.box_size = config.box_size
        self.glimpse_scale = config.glimpse_scale
        self.num_samples = config.num_samples
        self.loc_hidden = config.loc_hidden
        self.glimpse_hidden = config.glimpse_hidden

        # core network params
        self.num_glimpses = config.num_glimpses
        self.hidden_size = config.hidden_size

        # reinforce params
        self.std = config.std
        self.M = config.M

        # # build DS model
        # self.model = DTanh(
        #     self.input_dim, self.output_dim
        # )
        # build RAM model
        self.model = RecurrentAttention(
            self.num_points_per_pc, self.num_points_per_sample,
            self.num_samples, self.box_size, self.glimpse_scale,
            self.num_channels, self.loc_hidden, self.glimpse_hidden, self.std,
            self.hidden_size, self.num_classes, self.use_gpu)
        if self.use_gpu:
            self.model.cuda()

        print('[*] Number of model parameters: {:,}'.format(
            sum([p.data.nelement() for p in self.model.parameters()])))

        self.optimizer = optim.Adam(
            self.model.parameters(),
            lr=self.lr,
        )
Exemple #14
0
class Trainer(object):
    def __init__(self, config, data_loader):
        self.config = config

        self.train_loader = data_loader[0]
        self.valid_loader = data_loader[1]
        self.num_train = len(self.train_loader.sampler.indices)
        self.num_valid = len(self.valid_loader.sampler.indices)

        if self.config.binary:
            self.num_classes = 2
            self.loss = F.binary_cross_entropy_with_logits
            namestr2 = 'binary'
            namestr3 = str(config.cat)
        else:
            self.num_classes = 8
            self.loss = F.cross_entropy
            namestr2 = 'all'
            namestr3 = 'nocat'

        # model params
        if self.config.semi:
            namestr1 = 'semi'
            self.input_dim = self.num_classes + 3
        else:
            namestr1 = 'fully'
            self.input_dim = 3
        self.output_dim = self.num_classes
        self.mask_rate = config.mask_rate
        self.pc_size = config.pc_size

        # training params
        self.epochs = config.epochs
        self.start_epoch = 0
        self.lr = config.init_lr

        # misc params
        self.use_gpu = config.use_gpu
        self.ckpt_dir = config.ckpt_dir
        self.best = config.best
        self.best_mIoU = -10
        self.best_acc = 0
        self.counter = 0
        self.lr_patience = config.lr_patience
        self.train_patience = config.train_patience
        self.resume = config.resume
        self.model_name = 'dsseg_{}_{}_{}_{}_{}'.format(
            config.init_lr, namestr1, namestr2, namestr3, config.pc_size)

        # attention parameters
        # glimpse params
        # glimpse network params
        self.num_points_per_pc = config.pc_size
        self.num_points_per_sample = config.num_points_per_sample
        self.box_size = config.box_size
        self.glimpse_scale = config.glimpse_scale
        self.num_samples = config.num_samples
        self.loc_hidden = config.loc_hidden
        self.glimpse_hidden = config.glimpse_hidden

        # core network params
        self.num_glimpses = config.num_glimpses
        self.hidden_size = config.hidden_size

        # reinforce params
        self.std = config.std
        self.M = config.M

        # # build DS model
        # self.model = DTanh(
        #     self.input_dim, self.output_dim
        # )
        # build RAM model
        self.model = RecurrentAttention(
            self.num_points_per_pc, self.num_points_per_sample,
            self.num_samples, self.box_size, self.glimpse_scale,
            self.num_channels, self.loc_hidden, self.glimpse_hidden, self.std,
            self.hidden_size, self.num_classes, self.use_gpu)
        if self.use_gpu:
            self.model.cuda()

        print('[*] Number of model parameters: {:,}'.format(
            sum([p.data.nelement() for p in self.model.parameters()])))

        self.optimizer = optim.Adam(
            self.model.parameters(),
            lr=self.lr,
        )

    def mask_tensor(self, x, rate):
        """
        Masks a percentage of the entries in tensor x randomly
        """

        tensor_len = x.shape[1]
        if (rate == 0.):
            return x, np.arange(tensor_len)

        num_index = int(rate * tensor_len)
        permute_indices = np.random.RandomState(
            seed=42).permutation(tensor_len)[:num_index]
        zero_mask = torch.zeros(x.shape[-1] - 3, dtype=torch.float32)
        if self.use_gpu:
            zero_mask = zero_mask.cuda()
        x[:, permute_indices, 3:] = zero_mask

        return x, permute_indices

    def train(self):
        """
        Train the model on the training set.

        A checkpoint of the model is saved after each epoch
        and if the validation accuracy is improved upon,
        a separate ckpt is created for use on the test set.
        """
        # load the most recent checkpoint
        if self.resume:
            self.load_checkpoint(best=False)

        print("\n[*] Train on {} samples, validate on {} samples".format(
            self.num_train, self.num_valid))

        if self.config.binary:
            for epoch in range(self.start_epoch, self.epochs):

                print('\nEpoch: {}/{} - LR: {:.6f}'.format(
                    epoch + 1, self.epochs, self.lr))

                # train for 1 epoch
                # train_loss, train_acc = self.train_one_epoch(epoch)
                train_loss, train_acc, zeros_acc, ones_acc = self.train_one_epoch(
                    epoch)

                # evaluate on validation set
                # valid_loss, valid_acc = self.validate(epoch)
                valid_loss, valid_acc, val_zeros_acc, val_ones_acc = self.validate(
                    epoch)

                # mIoU = (np.mean(valid_IoUs))

                # is_best = mIoU > self.best_mIoU
                is_best = valid_acc > self.best_acc
                msg1 = "train loss: {:.3f} - train acc: {:.3f} - zeros acc: {:.3f} - ones acc: {:.3f}\n"
                # msg2 = "- val loss: {:.3f} - val acc: {:.3f} - val_mIoU: {:.3f}"
                msg2 = "- val loss: {:.3f} - val acc: {:.3f} - val_zeros acc: {:.3f} - val_ones acc: {:.3f}"
                if is_best:
                    self.counter = 0
                    msg2 += " [*]"
                msg = msg1 + msg2
                # print(msg.format(train_loss, train_acc, valid_loss, valid_acc, mIoU))
                print(
                    msg.format(train_loss, train_acc, zeros_acc, ones_acc,
                               valid_loss, valid_acc, val_zeros_acc,
                               val_ones_acc))

                # check for improvement
                if not is_best:
                    self.counter += 1
                if self.counter > self.train_patience:
                    print("[!] No improvement in a while, stopping training.")
                    return
                # self.best_mIoU = max(mIoU, self.best_mIoU)
                self.best_acc = max(valid_acc, self.best_acc)
                self.save_checkpoint(
                    {
                        'epoch': epoch + 1,
                        'model_state': self.model.state_dict(),
                        'optim_state': self.optimizer.state_dict(),
                        # 'best_valid_mIoU': self.best_mIoU,
                        'best_acc': self.best_acc
                    },
                    is_best)
        else:
            for epoch in range(self.start_epoch, self.epochs):

                print('\nEpoch: {}/{} - LR: {:.6f}'.format(
                    epoch + 1, self.epochs, self.lr))

                # train for 1 epoch
                # train_loss, train_acc = self.train_one_epoch(epoch)
                train_loss, train_acc, rand_acc, maj_acc = self.train_one_epoch(
                    epoch)

                # evaluate on validation set
                # valid_loss, valid_acc = self.validate(epoch)
                valid_loss, valid_acc, val_rand_acc, val_maj_acc = self.validate(
                    epoch)

                # mIoU = (np.mean(valid_IoUs))

                # is_best = mIoU > self.best_mIoU
                is_best = valid_acc > self.best_acc
                msg1 = "train loss: {:.3f} - train acc: {:.3f} - rand acc: {:.3f} - maj acc: {:.3f}\n"
                # msg2 = "- val loss: {:.3f} - val acc: {:.3f} - val_mIoU: {:.3f}"
                msg2 = "- val loss: {:.3f} - val acc: {:.3f} - val rand acc: {:.3f} - val maj acc: {:.3f}"
                if is_best:
                    self.counter = 0
                    msg2 += " [*]"
                msg = msg1 + msg2
                # print(msg.format(train_loss, train_acc, valid_loss, valid_acc, mIoU))
                print(
                    msg.format(train_loss, train_acc, rand_acc, maj_acc,
                               valid_loss, valid_acc, val_rand_acc,
                               val_maj_acc))

                # check for improvement
                if not is_best:
                    self.counter += 1
                if self.counter > self.train_patience:
                    print("[!] No improvement in a while, stopping training.")
                    return
                # self.best_mIoU = max(mIoU, self.best_mIoU)
                self.best_acc = max(valid_acc, self.best_acc)
                self.save_checkpoint(
                    {
                        'epoch': epoch + 1,
                        'model_state': self.model.state_dict(),
                        'optim_state': self.optimizer.state_dict(),
                        # 'best_valid_mIoU': self.best_mIoU,
                        'best_acc': self.best_acc
                    },
                    is_best)

    def train_one_epoch(self, epoch):
        """
        Train the model for 1 epoch of the training set.

        An epoch corresponds to one full pass through the entire
        training set in successive mini-batches.

        This is used by train() and should not be called manually.
        """
        batch_time = AverageMeter()
        losses = AverageMeter()
        accs = AverageMeter()
        all_zeros = AverageMeter()
        all_ones = AverageMeter()
        all_rand = AverageMeter()
        all_majority = AverageMeter()

        tic = time.time()
        with tqdm(total=self.num_train) as pbar:
            for i, (x, y) in enumerate(self.train_loader):
                if self.config.binary:
                    x, y = Variable(x).float(), Variable(y).float()
                else:
                    x, y = Variable(x).float(), Variable(y).long()
                if self.use_gpu:
                    x, y = x.cuda(), y.cuda()

                self.batch_size = x.shape[0]
                x = x.view(self.batch_size, self.pc_size, self.input_dim)

                # Do the masking of indices to create the semi-supervised learning problem
                if self.config.semi:
                    x, mask_indices = self.mask_tensor(x, self.mask_rate)

                out = self.model(x)

                # TODO: Instead of squeeze, change view to handle tensors of batch_size != 1
                # To calculate loss, we retrieve everything from 4th column to end
                if self.config.semi:
                    pred = out.squeeze()[mask_indices]
                    labels = y.squeeze()[mask_indices]
                else:
                    pred = out.squeeze()
                    labels = y.squeeze()

                if self.config.binary:
                    loss = self.loss(pred, labels)

                    # compute accuracy
                    predicted = torch.max(pred, 1)[1]
                    true = torch.max(labels, 1)[1]
                    correct = (predicted == true).float()
                    acc = 100 * (correct.sum() / labels.shape[0])

                    predicted = torch.zeros(labels.shape[0], dtype=torch.long)
                    if self.use_gpu:
                        predicted = predicted.cuda()
                    correct = (predicted == true).float()
                    acc_zeros = 100 * (correct.sum() / labels.shape[0])

                    predicted = torch.ones(labels.shape[0], dtype=torch.long)
                    if self.use_gpu:
                        predicted = predicted.cuda()
                    correct = (predicted == true).float()
                    acc_ones = 100 * (correct.sum() / labels.shape[0])

                    all_zeros.update(acc_zeros.item(), labels.size()[0])
                    all_ones.update(acc_ones.item(), labels.size()[0])
                else:
                    labels = torch.max(labels, 1)[1]
                    loss = self.loss(pred, labels)

                    # compute accuracy
                    predicted = torch.max(pred, 1)[1]
                    true = labels
                    correct = (predicted == true).float()
                    acc = 100 * (correct.sum() / labels.shape[0])

                    # For the 1-of-8 problem, we use a random tensor as a baseline
                    # as well as a majority class tensor
                    predicted = torch.zeros(labels.shape[0],
                                            dtype=torch.long).random_(0, 8)
                    if self.use_gpu:
                        predicted = predicted.cuda()
                    correct = (predicted == true).float()
                    acc_rand = 100 * (correct.sum() / labels.shape[0])

                    predicted = torch.zeros(labels.shape[0], dtype=torch.long)
                    if self.use_gpu:
                        predicted = predicted.cuda()
                    correct = (predicted == true).float()
                    acc_maj = 100 * (correct.sum() / labels.shape[0])

                    all_rand.update(acc_rand.item(), labels.size()[0])
                    all_majority.update(acc_maj.item(), labels.size()[0])

                # store
                losses.update(loss.item(), labels.size()[0])
                accs.update(acc.item(), labels.size()[0])

                # compute gradients and update SGD
                self.optimizer.zero_grad()
                loss.backward()
                self.optimizer.step()

                # measure elapsed time
                toc = time.time()
                batch_time.update(toc - tic)

                pbar.set_description(
                    ("{:.1f}s - loss: {:.3f} - acc: {:.3f}".format(
                        (toc - tic), loss.item(), acc.item())))
                pbar.update(self.batch_size)

            if self.config.binary:
                return losses.avg, accs.avg, all_zeros.avg, all_ones.avg
            else:
                return losses.avg, accs.avg, all_rand.avg, all_majority.avg

    def validate(self, epoch):
        """
        Evaluate the model on the validation set.
        """
        losses = AverageMeter()
        accs = AverageMeter()
        all_zeros = AverageMeter()
        all_ones = AverageMeter()
        all_rand = AverageMeter()
        all_majority = AverageMeter()

        for i, (x, y) in enumerate(self.valid_loader):
            if self.config.binary:
                x, y = Variable(x).float(), Variable(y).float()
            else:
                x, y = Variable(x).float(), Variable(y).long()
            if self.use_gpu:
                x, y = x.cuda(), y.cuda()

            self.batch_size = x.shape[0]
            x = x.view(self.batch_size, self.pc_size, self.input_dim)

            if self.config.semi:
                # Do the masking of indices to create the semi-supervised learning problem
                x, mask_indices = self.mask_tensor(x, self.mask_rate)

            out = self.model(x)

            # TODO: Instead of squeeze, change view to handle tensors of batch_size != 1
            # To calculate loss, we retrieve everything from 4th column to end
            if self.config.semi:
                pred = out.squeeze()[mask_indices]
                labels = y.squeeze()[mask_indices]
            else:
                pred = out.squeeze()
                labels = y.squeeze()

            if self.config.binary:
                loss = self.loss(pred, labels)

                # compute accuracy
                predicted = torch.max(pred, 1)[1]
                true = torch.max(labels, 1)[1]
                correct = (predicted == true).float()
                acc = 100 * (correct.sum() / labels.shape[0])

                predicted = torch.zeros(labels.shape[0], dtype=torch.long)
                if self.use_gpu:
                    predicted = predicted.cuda()
                correct = (predicted == true).float()
                acc_zeros = 100 * (correct.sum() / labels.shape[0])

                predicted = torch.ones(labels.shape[0], dtype=torch.long)
                if self.use_gpu:
                    predicted = predicted.cuda()
                correct = (predicted == true).float()
                acc_ones = 100 * (correct.sum() / labels.shape[0])

                all_zeros.update(acc_zeros.item(), labels.size()[0])
                all_ones.update(acc_ones.item(), labels.size()[0])
            else:
                labels = torch.max(labels, 1)[1]
                loss = self.loss(pred, labels)

                # compute accuracy
                predicted = torch.max(pred, 1)[1]
                true = labels
                correct = (predicted == true).float()
                acc = 100 * (correct.sum() / labels.shape[0])

                # For the 1-of-8 problem, we use a random tensor as a baseline
                # as well as a majority class tensor
                predicted = torch.zeros(labels.shape[0],
                                        dtype=torch.long).random_(0, 8)
                if self.use_gpu:
                    predicted = predicted.cuda()
                correct = (predicted == true).float()
                acc_rand = 100 * (correct.sum() / labels.shape[0])

                predicted = torch.zeros(labels.shape[0], dtype=torch.long)
                if self.use_gpu:
                    predicted = predicted.cuda()
                correct = (predicted == true).float()
                acc_maj = 100 * (correct.sum() / labels.shape[0])

                all_rand.update(acc_rand.item(), labels.size()[0])
                all_majority.update(acc_maj.item(), labels.size()[0])

            # store
            losses.update(loss.item(), labels.size()[0])
            accs.update(acc.item(), labels.size()[0])

        if self.config.binary:
            return losses.avg, accs.avg, all_zeros.avg, all_ones.avg
        else:
            return losses.avg, accs.avg, all_rand.avg, all_majority.avg

    def test(self):
        total_acc = 0
        total_zeros = 0
        total_ones = 0
        total_rand = 0
        total_majority = 0
        total_num_points = 0

        # load the best checkpoint
        self.load_checkpoint(best=self.best)

        for i, (x, y) in enumerate(self.valid_loader):
            if self.config.binary:
                x, y = Variable(x).float(), Variable(y).float()
            else:
                x, y = Variable(x).float(), Variable(y).long()
            if self.use_gpu:
                x, y = x.cuda(), y.cuda()

            # initialize location vector and hidden state
            self.batch_size = x.shape[0]
            x = x.view(self.batch_size, self.pc_size, self.input_dim)

            # Do the masking of indices to create the semi-supervised learning problem
            if self.config.semi:
                x, mask_indices = self.mask_tensor(x, self.mask_rate)

            out = self.model(x)

            # TODO: Instead of squeeze, change view to handle tensors of batch_size != 1
            # To calculate loss, we retrieve everything from 4th column to end
            if self.config.semi:
                pred = out.squeeze()[mask_indices]
                labels = y.squeeze()[mask_indices]
                total_num_points += labels.shape[0]
            else:
                pred = out.squeeze()
                labels = y.squeeze()
                total_num_points += labels.shape[0]

            # compute accuracy
            predicted = torch.max(pred, 1)[1]
            true = torch.max(labels, 1)[1]
            correct = (predicted == true).float()
            total_acc += correct.sum()

            if self.config.binary:
                predicted = torch.zeros(labels.shape[0], dtype=torch.long)
                if self.use_gpu:
                    predicted = predicted.cuda()
                correct = (predicted == true).float()
                total_zeros += correct.sum()

                predicted = torch.ones(labels.shape[0], dtype=torch.long)
                if self.use_gpu:
                    predicted = predicted.cuda()
                correct = (predicted == true).float()
                total_ones += correct.sum()
            else:
                # For the 1-of-8 problem, we use a random tensor as a baseline
                # as well as a majority class baseline
                predicted = torch.zeros(labels.shape[0],
                                        dtype=torch.long).random_(0, 8)
                if self.use_gpu:
                    predicted = predicted.cuda()
                correct = (predicted == true).float()
                total_rand += correct.sum()

                predicted = torch.zeros(labels.shape[0], dtype=torch.long)
                if self.use_gpu:
                    predicted = predicted.cuda()
                correct = (predicted == true).float()
                total_majority += correct.sum()

            print("Done with %.3f%%" % ((i + 1) / self.num_valid * 100.))

        print()

        if self.config.binary:
            msg = "Final Accuracy: {:.3f} - Background Baseline: {:.3f} - Foreground Baseline: {:.3f}\n"
            print(
                msg.format(total_acc / total_num_points,
                           total_zeros / total_num_points,
                           total_ones / total_num_points))
        else:
            msg = "Final Accuracy: {:.3f} - Random Baseline: {:.3f} - Majority Class Baseline: {:.3f}\n"
            print(
                msg.format(total_acc / total_num_points,
                           total_rand / total_num_points,
                           total_majority / total_num_points))

    def save_checkpoint(self, state, is_best):
        """
        Save a copy of the model so that it can be loaded at a future
        date. This function is used when the model is being evaluated
        on the test data.

        If this model has reached the best validation accuracy thus
        far, a seperate file with the suffix `best` is created.
        """
        # print("[*] Saving model to {}".format(self.ckpt_dir))

        filename = self.model_name + '_ckpt.pth.tar'
        ckpt_path = os.path.join(self.ckpt_dir, filename)
        torch.save(state, ckpt_path)

        if is_best:
            filename = self.model_name + '_model_best.pth.tar'
            shutil.copyfile(ckpt_path, os.path.join(self.ckpt_dir, filename))

    def load_checkpoint(self, best=False):
        """
        Load the best copy of a model. This is useful for 2 cases:

        - Resuming training with the most recent model checkpoint.
        - Loading the best validation model to evaluate on the test data.

        Params
        ------
        - best: if set to True, loads the best model. Use this if you want
          to evaluate your model on the test data. Else, set to False in
          which case the most recent version of the checkpoint is used.
        """
        print("[*] Loading model from {}".format(self.ckpt_dir))

        filename = self.model_name + '_ckpt.pth.tar'
        if best:
            filename = self.model_name + '_model_best.pth.tar'
        ckpt_path = os.path.join(self.ckpt_dir, filename)
        ckpt = torch.load(ckpt_path)

        # load variables from checkpoint
        self.start_epoch = ckpt['epoch']
        # self.best_mIoU = ckpt['best_valid_mIoU']
        self.model.load_state_dict(ckpt['model_state'])
        self.optimizer.load_state_dict(ckpt['optim_state'])

        print("Successfully loaded model...")

        if best:
            print("[*] Loaded {} checkpoint @ epoch {} "
                  "with best valid acc of {:.3f}".format(
                      filename, ckpt['epoch'], ckpt['best_acc']))
        else:
            print("[*] Loaded {} checkpoint @ epoch {}".format(
                filename, ckpt['epoch']))
Exemple #15
0
class Trainer(object):
    """
    Trainer encapsulates all the logic necessary for
    training the Recurrent Attention Model.

    All hyperparameters are provided by the user in the
    config file.
    """
    def __init__(self, config, data_loader):
        """
        Construct a new Trainer instance.

        Args
        ----
        - config: object containing command line arguments.
        - data_loader: data iterator
        """
        self.config = config

        # glimpse network params
        self.patch_size = config.patch_size
        self.glimpse_scale = config.glimpse_scale
        self.num_patches = config.num_patches
        self.loc_hidden = config.loc_hidden
        self.glimpse_hidden = config.glimpse_hidden

        # core network params
        self.num_glimpses = config.num_glimpses
        self.hidden_size = config.hidden_size

        # reinforce params
        self.std = config.std
        self.M = config.M

        # data params
        if config.is_train:
            self.train_loader = data_loader[0]
            self.valid_loader = data_loader[1]
            self.num_train = len(self.train_loader.sampler.indices)
            self.num_valid = len(self.valid_loader.sampler.indices)
        else:
            self.test_loader = data_loader
            self.num_test = len(self.test_loader.dataset)
        self.num_classes = 10
        self.num_channels = 1

        # training params
        self.epochs = config.epochs
        self.start_epoch = 0
        self.momentum = config.momentum
        self.lr = config.init_lr

        # misc params
        self.use_gpu = config.use_gpu
        self.best = config.best
        self.ckpt_dir = config.ckpt_dir
        self.logs_dir = config.logs_dir
        self.best_valid_acc = 0.
        self.counter = 0
        self.patience = config.patience
        self.use_tensorboard = config.use_tensorboard
        self.resume = config.resume
        self.print_freq = config.print_freq
        self.plot_freq = config.plot_freq
        self.model_name = 'ram_{}_{}x{}_{}'.format(config.num_glimpses,
                                                   config.patch_size,
                                                   config.patch_size,
                                                   config.glimpse_scale)

        self.plot_dir = './plots/' + self.model_name + '/'
        if not os.path.exists(self.plot_dir):
            os.makedirs(self.plot_dir)

        # configure tensorboard logging
        if self.use_tensorboard:
            tensorboard_dir = self.logs_dir + self.model_name
            print('[*] Saving tensorboard logs to {}'.format(tensorboard_dir))
            if not os.path.exists(tensorboard_dir):
                os.makedirs(tensorboard_dir)
            configure(tensorboard_dir)

        # build RAM model
        self.model = RecurrentAttention(
            self.patch_size,
            self.num_patches,
            self.glimpse_scale,
            self.num_channels,
            self.loc_hidden,
            self.glimpse_hidden,
            self.std,
            self.hidden_size,
            self.num_classes,
        )
        if self.use_gpu:
            self.model.cuda()

        print('[*] Number of model parameters: {:,}'.format(
            sum([p.data.nelement() for p in self.model.parameters()])))

        # initialize optimizer and scheduler
        self.optimizer = SGD(
            self.model.parameters(),
            lr=self.lr,
            momentum=self.momentum,
        )
        self.scheduler = ReduceLROnPlateau(self.optimizer,
                                           'min',
                                           patience=self.patience)

    def reset(self):
        """
        Initialize the hidden state of the core network
        and the location vector.

        This is called once every time a new minibatch
        `x` is introduced.
        """
        dtype = torch.cuda.FloatTensor if self.use_gpu else torch.FloatTensor

        h_t = torch.zeros(self.batch_size, self.hidden_size)
        h_t = Variable(h_t).type(dtype)

        l_t = torch.Tensor(self.batch_size, 2).uniform_(-1, 1)
        l_t = Variable(l_t).type(dtype)

        return h_t, l_t

    def train(self):
        """
        Train the model on the training set.

        A checkpoint of the model is saved after each epoch
        and if the validation accuracy is improved upon,
        a separate ckpt is created for use on the test set.
        """
        # load the most recent checkpoint
        if self.resume:
            self.load_checkpoint(best=False)

        print("\n[*] Train on {} samples, validate on {} samples".format(
            self.num_train, self.num_valid))

        for epoch in range(self.start_epoch, self.epochs):

            print('\nEpoch: {}/{} - LR: {:.6f}'.format(epoch + 1, self.epochs,
                                                       self.lr))

            # train for 1 epoch
            train_loss, train_acc = self.train_one_epoch(epoch)

            # evaluate on validation set
            valid_loss, valid_acc = self.validate(epoch)

            # reduce lr if validation loss plateaus
            self.scheduler.step(valid_loss)

            is_best = valid_acc > self.best_valid_acc
            msg1 = "train loss: {:.3f} - train acc: {:.3f} "
            msg2 = "- val loss: {:.3f} - val acc: {:.3f}"
            if is_best:
                msg2 += " [*]"
            msg = msg1 + msg2
            print(msg.format(train_loss, train_acc, valid_loss, valid_acc))

            # check for improvement
            if not is_best:
                self.counter += 1
            if self.counter > self.patience:
                print("[!] No improvement in a while, stopping training.")
                return
            self.best_valid_acc = max(valid_acc, self.best_valid_acc)
            self.save_checkpoint(
                {
                    'epoch': epoch + 1,
                    'model_state': self.model.state_dict(),
                    'optim_state': self.optimizer.state_dict(),
                    'best_valid_acc': self.best_valid_acc,
                }, is_best)

    def train_one_epoch(self, epoch):
        """
        Train the model for 1 epoch of the training set.

        An epoch corresponds to one full pass through the entire
        training set in successive mini-batches.

        This is used by train() and should not be called manually.
        """
        batch_time = AverageMeter()
        losses = AverageMeter()
        accs = AverageMeter()

        tic = time.time()
        with tqdm(total=self.num_train) as pbar:
            for i, (x, y) in enumerate(self.train_loader):
                if self.use_gpu:
                    x, y = x.cuda(), y.cuda()
                x, y = Variable(x), Variable(y)

                plot = False
                if (epoch % self.plot_freq == 0) and (i == 0):
                    plot = True

                # initialize location vector and hidden state
                self.batch_size = x.shape[0]
                h_t, l_t = self.reset()

                # save images
                imgs = []
                imgs.append(x[0:9])

                # extract the glimpses
                locs = []
                log_pi = []
                baselines = []
                for t in range(self.num_glimpses - 1):
                    # forward pass through model
                    h_t, l_t, b_t, p = self.model(x, l_t, h_t)

                    # store
                    locs.append(l_t[0:9])
                    baselines.append(b_t)
                    log_pi.append(p)

                # last iteration
                h_t, l_t, b_t, log_probas, p = self.model(x,
                                                          l_t,
                                                          h_t,
                                                          last=True)
                log_pi.append(p)
                baselines.append(b_t)
                locs.append(l_t[0:9])

                # convert list to tensors and reshape
                baselines = torch.stack(baselines).transpose(1, 0)
                log_pi = torch.stack(log_pi).transpose(1, 0)

                # calculate reward
                predicted = torch.max(log_probas, 1)[1]
                R = (predicted.detach() == y).float()
                R = R.unsqueeze(1).repeat(1, self.num_glimpses)

                # compute losses for differentiable modules
                loss_action = F.nll_loss(log_probas, y)
                loss_baseline = F.mse_loss(baselines, R)

                # compute reinforce loss
                adjusted_reward = R - baselines.detach()
                loss_reinforce = torch.mean(-log_pi * adjusted_reward)

                # sum up into a hybrid loss
                loss = loss_action + loss_baseline + loss_reinforce

                # compute accuracy
                correct = (predicted == y).float()
                acc = 100 * (correct.sum() / len(y))

                # store
                losses.update(loss.data[0], x.size()[0])
                accs.update(acc.data[0], x.size()[0])

                # compute gradients and update SGD
                self.optimizer.zero_grad()
                loss.backward()
                self.optimizer.step()

                # measure elapsed time
                toc = time.time()
                batch_time.update(toc - tic)

                pbar.set_description(
                    ("{:.1f}s - loss: {:.3f} - acc: {:.3f}".format(
                        (toc - tic), loss.data[0], acc.data[0])))
                pbar.update(self.batch_size)

                # dump the glimpses and locs
                if plot:
                    if self.use_gpu:
                        imgs = [g.cpu().data.numpy().squeeze() for g in imgs]
                        locs = [l.cpu().data.numpy() for l in locs]
                    else:
                        imgs = [g.data.numpy().squeeze() for g in imgs]
                        locs = [l.data.numpy() for l in locs]
                    pickle.dump(
                        imgs,
                        open(self.plot_dir + "g_{}.p".format(epoch + 1), "wb"))
                    pickle.dump(
                        locs,
                        open(self.plot_dir + "l_{}.p".format(epoch + 1), "wb"))

                # log to tensorboard
                if self.use_tensorboard:
                    iteration = epoch * len(self.train_loader) + i
                    log_value('train_loss', losses.avg, iteration)
                    log_value('train_acc', accs.avg, iteration)

            return losses.avg, accs.avg

    def validate(self, epoch):
        """
        Evaluate the model on the validation set.
        """
        losses = AverageMeter()
        accs = AverageMeter()

        for i, (x, y) in enumerate(self.valid_loader):
            if self.use_gpu:
                x, y = x.cuda(), y.cuda()
            x, y = Variable(x), Variable(y)

            # duplicate 10 times
            x = x.repeat(self.M, 1, 1, 1)

            # initialize location vector and hidden state
            self.batch_size = x.shape[0]
            h_t, l_t = self.reset()

            # extract the glimpses
            log_pi = []
            baselines = []
            for t in range(self.num_glimpses - 1):
                # forward pass through model
                h_t, l_t, b_t, p = self.model(x, l_t, h_t)

                # store
                baselines.append(b_t)
                log_pi.append(p)

            # last iteration
            h_t, l_t, b_t, log_probas, p = self.model(x, l_t, h_t, last=True)
            log_pi.append(p)
            baselines.append(b_t)

            # convert list to tensors and reshape
            baselines = torch.stack(baselines).transpose(1, 0)
            log_pi = torch.stack(log_pi).transpose(1, 0)

            # average
            log_probas = log_probas.view(self.M, -1, log_probas.shape[-1])
            log_probas = torch.mean(log_probas, dim=0)

            baselines = baselines.contiguous().view(self.M, -1,
                                                    baselines.shape[-1])
            baselines = torch.mean(baselines, dim=0)

            log_pi = log_pi.contiguous().view(self.M, -1, log_pi.shape[-1])
            log_pi = torch.mean(log_pi, dim=0)

            # calculate reward
            predicted = torch.max(log_probas, 1)[1]
            R = (predicted.detach() == y).float()
            R = R.unsqueeze(1).repeat(1, self.num_glimpses)

            # compute losses for differentiable modules
            loss_action = F.nll_loss(log_probas, y)
            loss_baseline = F.mse_loss(baselines, R)

            # compute reinforce loss
            adjusted_reward = R - baselines.detach()
            loss_reinforce = torch.mean(-log_pi * adjusted_reward)

            # sum up into a hybrid loss
            loss = loss_action + loss_baseline + loss_reinforce

            # compute accuracy
            correct = (predicted == y).float()
            acc = 100 * (correct.sum() / len(y))

            # store
            losses.update(loss.data[0], x.size()[0])
            accs.update(acc.data[0], x.size()[0])

            # log to tensorboard
            if self.use_tensorboard:
                iteration = epoch * len(self.valid_loader) + i
                log_value('valid_loss', losses.avg, iteration)
                log_value('valid_acc', accs.avg, iteration)

        return losses.avg, accs.avg

    def test(self):
        """
        Test the model on the held-out test data.
        This function should only be called at the very
        end once the model has finished training.
        """
        correct = 0

        # load the best checkpoint
        self.load_checkpoint(best=self.best)

        for i, (x, y) in enumerate(self.test_loader):
            if self.use_gpu:
                x, y = x.cuda(), y.cuda()
            x, y = Variable(x, volatile=True), Variable(y)

            # duplicate 10 times
            x = x.repeat(self.M, 1, 1, 1)

            # initialize location vector and hidden state
            self.batch_size = x.shape[0]
            h_t, l_t = self.reset()

            # extract the glimpses
            for t in range(self.num_glimpses - 1):
                # forward pass through model
                h_t, l_t, b_t, p = self.model(x, l_t, h_t)

            # last iteration
            h_t, l_t, b_t, log_probas, p = self.model(x, l_t, h_t, last=True)

            log_probas = log_probas.view(self.M, -1, log_probas.shape[-1])
            log_probas = torch.mean(log_probas, dim=0)

            pred = log_probas.data.max(1, keepdim=True)[1]
            correct += pred.eq(y.data.view_as(pred)).cpu().sum()

        perc = (100. * correct) / (self.num_test)
        print('[*] Test Acc: {}/{} ({:.2f}%)'.format(correct, self.num_test,
                                                     perc))

    def save_checkpoint(self, state, is_best):
        """
        Save a copy of the model so that it can be loaded at a future
        date. This function is used when the model is being evaluated
        on the test data.

        If this model has reached the best validation accuracy thus
        far, a seperate file with the suffix `best` is created.
        """
        # print("[*] Saving model to {}".format(self.ckpt_dir))

        filename = self.model_name + '_ckpt.pth.tar'
        ckpt_path = os.path.join(self.ckpt_dir, filename)
        torch.save(state, ckpt_path)

        if is_best:
            filename = self.model_name + '_model_best.pth.tar'
            shutil.copyfile(ckpt_path, os.path.join(self.ckpt_dir, filename))

    def load_checkpoint(self, best=False):
        """
        Load the best copy of a model. This is useful for 2 cases:

        - Resuming training with the most recent model checkpoint.
        - Loading the best validation model to evaluate on the test data.

        Params
        ------
        - best: if set to True, loads the best model. Use this if you want
          to evaluate your model on the test data. Else, set to False in
          which case the most recent version of the checkpoint is used.
        """
        print("[*] Loading model from {}".format(self.ckpt_dir))

        filename = self.model_name + '_ckpt.pth.tar'
        if best:
            filename = self.model_name + '_model_best.pth.tar'
        ckpt_path = os.path.join(self.ckpt_dir, filename)
        ckpt = torch.load(ckpt_path)

        # load variables from checkpoint
        self.start_epoch = ckpt['epoch']
        self.best_valid_acc = ckpt['best_valid_acc']
        self.model.load_state_dict(ckpt['model_state'])
        self.optimizer.load_state_dict(ckpt['optim_state'])

        if best:
            print("[*] Loaded {} checkpoint @ epoch {} "
                  "with best valid acc of {:.3f}".format(
                      filename, ckpt['epoch'] + 1, ckpt['best_valid_acc']))
        else:
            print("[*] Loaded {} checkpoint @ epoch {}".format(
                filename, ckpt['epoch'] + 1))
Exemple #16
0
class Trainer(object):
    """
    Trainer encapsulates all the logic necessary for
    training the Recurrent Attention Model.

    All hyperparameters are provided by the user in the
    config file.
    """
    def __init__(self, config, data_loader):
        """
        Construct a new Trainer instance.

        Args
        ----
        - config: object containing command line arguments.
        - data_loader: data iterator
        """
        self.config = config

        # glimpse network params
        self.patch_size = config.patch_size
        self.glimpse_scale = config.glimpse_scale
        self.num_patches = config.num_patches
        self.loc_hidden = config.loc_hidden
        self.glimpse_hidden = config.glimpse_hidden

        # core network params
        self.num_glimpses = config.num_glimpses
        self.hidden_size = config.hidden_size

        # reinforce params
        self.std = config.std
        self.M = config.M

        # data params
        if config.is_train:
            self.train_loader = data_loader[0]
            self.valid_loader = data_loader[1]
            self.num_train = len(self.train_loader.sampler.indices)
            self.num_valid = len(self.valid_loader.sampler.indices)
        else:
            self.test_loader = data_loader
            self.num_test = len(self.test_loader.dataset)
        self.num_classes = 10
        self.num_channels = 1

        # training params
        self.epochs = config.epochs
        self.start_epoch = 0
        self.momentum = config.momentum
        self.lr = config.init_lr

        # misc params
        self.no_tqdm = config.no_tqdm
        self.use_gpu = config.use_gpu
        self.best = config.best
        self.ckpt_dir = config.ckpt_dir
        self.logs_dir = config.logs_dir
        self.best_valid_acc = 0.
        self.counter = 0
        self.lr_patience = config.lr_patience
        self.train_patience = config.train_patience
        self.use_tensorboard = config.use_tensorboard
        self.resume = config.resume
        self.print_freq = config.print_freq
        self.plot_freq = config.plot_freq
        self.model_name = 'ram_{}_{}x{}_{}'.format(config.num_glimpses,
                                                   config.patch_size,
                                                   config.patch_size,
                                                   config.glimpse_scale)

        if config.uncertainty == True:
            self.model_name += '_uncertainty_1'
        else:
            self.model_name += '_uncertainty_0'
        if config.intrinsic == True:
            self.model_name += '_intrinsic_1'
        else:
            self.model_name += '_intrinsic_0'

        self.plot_dir = './plots/' + self.model_name + '/'
        if not os.path.exists(self.plot_dir):
            os.makedirs(self.plot_dir)

        # configure tensorboard logging
        if self.use_tensorboard:
            tensorboard_dir = self.logs_dir + self.model_name
            print('[*] Saving tensorboard logs to {}'.format(tensorboard_dir))
            if not os.path.exists(tensorboard_dir):
                os.makedirs(tensorboard_dir)
            configure(tensorboard_dir)

        # build RAM model
        self.model = RecurrentAttention(self.patch_size, self.num_patches,
                                        self.glimpse_scale, self.num_channels,
                                        self.loc_hidden, self.glimpse_hidden,
                                        self.std, self.hidden_size,
                                        self.num_classes, self.config)
        if self.use_gpu:
            self.model.cuda()

        self.dtypeFloat = (torch.cuda.FloatTensor
                           if self.use_gpu else torch.FloatTensor)
        self.dtypeLong = (torch.cuda.LongTensor
                          if self.use_gpu else torch.LongTensor)

        print('[*] Number of model parameters: {:,}'.format(
            sum([p.data.nelement() for p in self.model.parameters()])))

        # # initialize optimizer and scheduler
        self.optimizer = optim.Adam(
            self.model.parameters(),
            lr=self.config.init_lr,
        )
        lambda_of_lr = lambda epoch: 0.95**epoch
        self.scheduler = LambdaLR(self.optimizer, lr_lambda=lambda_of_lr)
        # self.scheduler = StepLR(self.optimizer,step_size=20,gamma=0.1)
        # self.scheduler = ReduceLROnPlateau(
        #     self.optimizer, 'min', patience=self.lr_patience
        # )

    def reset(self):
        """
        Initialize the hidden state of the core network
        and the location vector.

        This is called once every time a new minibatch
        `x` is introduced.
        """
        dtype = (torch.cuda.FloatTensor if self.use_gpu else torch.FloatTensor)

        h_t = torch.zeros(self.batch_size, self.hidden_size)
        h_t = Variable(h_t).type(dtype)

        l_t = torch.Tensor(self.batch_size, 2).uniform_(-1, 1)
        l_t = Variable(l_t).type(dtype)

        return h_t, l_t

    def train(self):
        """
        Train the model on the training set.

        A checkpoint of the model is saved after each epoch
        and if the validation accuracy is improved upon,
        a separate ckpt is created for use on the test set.
        """
        # load the most recent checkpoint
        if self.resume:
            self.load_checkpoint(best=False)

        print(
            "\n[*] Train on {} samples, validate on {} samples, learn rate {}".
            format(self.num_train, self.num_valid, self.scheduler.get_lr()))

        for epoch in range(self.start_epoch, self.epochs):

            print('\nEpoch: {}/{} . lr: {:.4e} '.format(
                epoch + 1, self.epochs,
                self.scheduler.get_lr()[0]))

            # train for 1 epoch
            train_loss, train_acc = self.train_one_epoch(epoch)

            # evaluate on validation set
            valid_loss, valid_acc = self.validate(epoch)

            self.scheduler.step()

            is_best = valid_acc > self.best_valid_acc
            msg1 = "train loss: {:.3f} - train acc: {:.3f} "
            msg2 = "- val loss: {:.3f} - val acc: {:.3f}"
            if is_best:
                self.counter = 0
                msg2 += " [*]"
            msg = msg1 + msg2
            print(msg.format(train_loss, train_acc, valid_loss, valid_acc))

            # check for improvement
            if not is_best:
                self.counter += 1
            if self.counter > self.train_patience:
                print("[!] No improvement in a while, stopping training.")
                return
            self.best_valid_acc = max(valid_acc, self.best_valid_acc)
            self.save_checkpoint(
                {
                    'epoch': epoch + 1,
                    'model_state': self.model.state_dict(),
                    'optim_state': self.optimizer.state_dict(),
                    'best_valid_acc': self.best_valid_acc,
                }, is_best)

    def train_one_epoch(self, epoch):
        """
        Train the model for 1 epoch of the training set.

        An epoch corresponds to one full pass through the entire
        training set in successive mini-batches.

        This is used by train() and should not be called manually.
        """
        batch_time = AverageMeter()
        losses = AverageMeter()
        accs = AverageMeter()

        tic = time.time()
        with tqdm(total=self.num_train, disable=self.no_tqdm) as pbar:
            for i, (x, y) in enumerate(self.train_loader):
                if self.config.use_translate:
                    x = translate_function(x, original_dataset=x)
                if self.use_gpu:
                    x, y = x.cuda(), y.cuda()
                x, y = Variable(x), Variable(y)

                plot = False
                if (epoch % self.plot_freq == 0) and (i == 0):
                    plot = True

                # initialize location vector and hidden state
                self.batch_size = x.shape[0]
                h_t, l_t = self.reset()

                # save images
                imgs = []
                imgs.append(x[0:9])

                # extract the glimpses
                locs = []
                log_pi = []
                baselines = []
                all_log_probas = []  # the prediction at each glimpse step
                uncertainities = [
                ]  # the self-uncertainty at each glimpse step
                uncertainities_baseline = [
                ]  # the self-uncertainty at each glimpse step, but this baseline is only used for the loss of training self-uncertainty, which only involves the error network.

                # by default it needs to run `self.num_glimpse` times
                num_glimpses_taken = [
                    self.num_glimpses - 1 for _ in range(self.batch_size)
                ]

                for t in range(self.num_glimpses):

                    # forward pass through model
                    h_t, l_t, b_t, log_probas, p, diff_uncertainty, diff_uncertainty_baseline = self.model(
                        x, l_t, h_t, last=True)

                    # store
                    locs.append(l_t[0:9])
                    baselines.append(b_t)
                    log_pi.append(p)
                    all_log_probas.append(log_probas)
                    uncertainities.append(diff_uncertainty)
                    uncertainities_baseline.append(diff_uncertainty_baseline)

                # convert list to tensors and reshape
                baselines = torch.stack(baselines).transpose(1, 0)
                log_pi = torch.stack(log_pi).transpose(1, 0)
                # if self.config.uncertainty == True:
                if self.config.uncertainty == True:
                    uncertainities = torch.stack(uncertainities).transpose(
                        1, 0)
                    uncertainities_baseline = torch.stack(
                        uncertainities_baseline).transpose(1, 0)
                all_log_probas = torch.stack(all_log_probas).transpose(1, 0)

                # calculate reward
                num_glimpses_taken_indices = torch.LongTensor(
                    num_glimpses_taken).type(self.dtypeLong)
                log_probas = torch.cat([
                    torch.index_select(a, 0, i).unsqueeze(0)
                    for a, i in zip(all_log_probas, num_glimpses_taken_indices)
                ]).squeeze()
                predicted = torch.max(log_probas, 1)[1]
                R = (predicted.detach() == y).float()
                R = R.unsqueeze(1).repeat(1, self.num_glimpses)

                # compute losses for differentiable modules
                num_glimpses_taken = Variable(
                    torch.LongTensor(num_glimpses_taken),
                    requires_grad=False).type(self.dtypeLong)

                # the mask is used to take only the result of the last glimpse
                mask = _sequence_mask(sequence_length=num_glimpses_taken,
                                      max_len=self.num_glimpses)
                loss_action = F.nll_loss(log_probas, y, reduction='none')
                loss_action = torch.mean(loss_action)

                loss_baseline = F.mse_loss(baselines, R, reduction='none')
                loss_baseline = torch.mean(loss_baseline * mask)
                # loss_baseline = torch.mean( loss_baseline  )

                # compute reinforce loss
                # summed over timesteps and averaged across batch
                adjusted_reward = R - baselines.detach()
                loss_reinforce = torch.sum(-log_pi * adjusted_reward * mask,
                                           dim=1)
                loss_reinforce = torch.mean(loss_reinforce, dim=0)

                # sum up into a hybrid loss
                loss = loss_action + loss_baseline + loss_reinforce

                if self.config.uncertainty == True:
                    y_real_value = F.one_hot(
                        y, self.num_classes).float().detach()
                    diff_ = Variable(torch.abs(
                        y_real_value.unsqueeze(1).expand(
                            -1, self.num_glimpses, -1).data -
                        torch.exp(all_log_probas).data),
                                     requires_grad=False)
                    # loss_self_uncertaintiy_baseline = F.mse_loss(uncertainities_baseline, diff_)
                    loss_self_uncertaintiy_baseline = F.mse_loss(
                        uncertainities_baseline, diff_,
                        reduction='none').mean()
                    loss_self_uncertaintiy_baseline = torch.mean(
                        loss_self_uncertaintiy_baseline)

                    loss += loss_self_uncertaintiy_baseline

                if self.config.intrinsic == True:
                    # the intrinsic sparsity belief
                    reg = self.config.lambda_intrinsic
                    intrinsic_term = torch.sum(-(1.0 / self.num_classes) *
                                               log_probas)
                    loss_intrinsic = reg * intrinsic_term
                    loss += loss_intrinsic
                if self.config.uncertainty == True:
                    # the second reinforce loss: minimizing the uncertainty
                    reg = self.config.lambda_uncertainty
                    loss_self_uncertaintiy_minimizing = reg * torch.sum(
                        uncertainities)
                    loss += loss_self_uncertaintiy_minimizing

                # compute accuracy
                correct = (predicted == y).float()
                acc = 100 * (correct.sum() / len(y))

                # store
                losses.update(loss.data, list(x.size())[0])
                accs.update(acc.data, list(x.size())[0])

                # compute gradients and update SGD
                self.optimizer.zero_grad()
                loss.backward()
                self.optimizer.step()

                # measure elapsed time
                toc = time.time()
                batch_time.update(toc - tic)

                if self.no_tqdm is not True:
                    pbar.set_description(
                        ("{:.1f}s - loss: {:.3f} - acc: {:.3f}".format(
                            (toc - tic), loss.data, acc.data)))
                    pbar.update(self.batch_size)

                # dump the glimpses and locs
                if plot:
                    if self.use_gpu:
                        imgs = [g.cpu().data.numpy().squeeze() for g in imgs]
                        locs = [l.cpu().data.numpy() for l in locs]
                    else:
                        imgs = [g.data.numpy().squeeze() for g in imgs]
                        locs = [l.data.numpy() for l in locs]
                    pickle.dump(
                        imgs,
                        open(self.plot_dir + "g_{}.p".format(epoch + 1), "wb"))
                    pickle.dump(
                        locs,
                        open(self.plot_dir + "l_{}.p".format(epoch + 1), "wb"))

                # log to tensorboard
                if self.use_tensorboard:
                    iteration = epoch * len(self.train_loader) + i
                    log_value('train_loss', losses.avg, iteration)
                    log_value('train_acc', accs.avg, iteration)

            return losses.avg, accs.avg

    def validate(self, epoch, M=1):
        """
        Evaluate the model on the validation set.
        """
        losses = AverageMeter()
        accs = AverageMeter()

        for i, (x, y) in enumerate(self.valid_loader):
            if self.config.use_translate:
                x = translate_function(x, original_dataset=x)
            if self.use_gpu:
                x, y = x.cuda(), y.cuda()
            x, y = Variable(x), Variable(y)

            # duplicate M times
            x = x.repeat(M, 1, 1, 1)

            # initialize location vector and hidden state
            self.batch_size = x.shape[0]
            h_t, l_t = self.reset()

            # extract the glimpses
            locs = []
            log_pi = []
            baselines = []
            all_log_probas = []
            uncertainities = []
            uncertainities_baseline = []

            # by default it needs to run `self.num_glimpse` times
            num_glimpses_taken = [
                self.num_glimpses - 1 for _ in range(self.batch_size)
            ]

            for t in range(self.num_glimpses):

                # forward pass through model
                h_t, l_t, b_t, log_probas, p, diff_uncertainty, diff_uncertainty_baseline = self.model(
                    x, l_t, h_t, last=True)

                # store
                locs.append(l_t[0:9])
                baselines.append(b_t)
                log_pi.append(p)
                all_log_probas.append(log_probas)
                uncertainities.append(diff_uncertainty)
                uncertainities_baseline.append(diff_uncertainty_baseline)

            # convert list to tensors and reshape
            baselines = torch.stack(baselines).transpose(1, 0)
            log_pi = torch.stack(log_pi).transpose(1, 0)
            if self.config.uncertainty == True:
                uncertainities = torch.stack(uncertainities).transpose(1, 0)
                uncertainities_baseline = torch.stack(
                    uncertainities_baseline).transpose(1, 0)
            all_log_probas = torch.stack(all_log_probas).transpose(1, 0)

            # calculate reward
            num_glimpses_taken_indices = torch.LongTensor(
                num_glimpses_taken).type(self.dtypeLong)
            log_probas = torch.cat([
                torch.index_select(a, 0, i).unsqueeze(0)
                for a, i in zip(all_log_probas, num_glimpses_taken_indices)
            ]).squeeze()
            # average the `self.M` times of prediction
            log_probas = log_probas.view(M, -1, log_probas.shape[-1])
            log_probas = torch.mean(log_probas, dim=0)
            predicted = torch.max(log_probas, 1)[1]
            R = (predicted.detach() == y).float()
            R = R.unsqueeze(1).repeat(M, self.num_glimpses)

            # compute losses for differentiable modules
            num_glimpses_taken = Variable(torch.LongTensor(num_glimpses_taken),
                                          requires_grad=False).type(
                                              self.dtypeLong)

            mask = _sequence_mask(sequence_length=num_glimpses_taken,
                                  max_len=self.num_glimpses)
            loss_action = F.nll_loss(log_probas, y, reduction='none')
            loss_action = torch.mean(loss_action)

            loss_baseline = F.mse_loss(baselines, R, reduction='none')
            loss_baseline = torch.mean(loss_baseline * mask)

            adjusted_reward = R - baselines.detach()
            loss_reinforce = torch.sum(-log_pi * adjusted_reward * mask, dim=1)
            loss_reinforce = torch.mean(loss_reinforce, dim=0)

            # sum up into a hybrid loss
            loss = loss_action + loss_baseline + loss_reinforce

            if self.config.uncertainty == True:
                y_real_value = F.one_hot(y, self.num_classes).float().detach()
                diff_ = Variable(torch.abs(
                    y_real_value.unsqueeze(1).expand(-1, self.num_glimpses,
                                                     -1).data -
                    torch.exp(all_log_probas).data),
                                 requires_grad=False)

                loss_self_uncertaintiy_baseline = F.mse_loss(
                    uncertainities_baseline, diff_, reduction='none').mean()
                loss_self_uncertaintiy_baseline = torch.mean(
                    loss_self_uncertaintiy_baseline)
                loss += loss_self_uncertaintiy_baseline

            if self.config.intrinsic == True:
                # the intrinsic sparsity belief
                reg = self.config.lambda_intrinsic
                loss_intrinsic = reg * torch.sum(
                    -(1.0 / self.num_classes) * log_probas)
                loss += loss_intrinsic
            if self.config.uncertainty == True:
                # the second reinforce loss: minimizing the uncertainty
                reg = self.config.lambda_uncertainty
                loss_self_uncertaintiy_minimizing = reg * torch.sum(
                    uncertainities)
                loss += loss_self_uncertaintiy_minimizing

            # compute accuracy
            correct = (predicted == y).float()
            acc = 100 * (correct.sum() / len(y))

            # store
            losses.update(loss.data, list(x.size())[0])
            accs.update(acc.data, list(x.size())[0])

            # log to tensorboard
            if self.use_tensorboard:
                iteration = epoch * len(self.valid_loader) + i
                log_value('valid_loss', losses.avg, iteration)
                log_value('valid_acc', accs.avg, iteration)

        return losses.avg, accs.avg

    def test(self):
        """
        Test the model on the held-out test data.
        This function should only be called at the very
        end once the model has finished training.
        """
        correct = 0

        # load the best checkpoint
        self.load_checkpoint(best=self.best)

        self.num_test = len(self.test_loader.sampler)

        all_num_glimpses_taken = []
        for i, (x, y) in enumerate(self.test_loader):
            torch.manual_seed(self.config.random_seed)
            if self.use_gpu:
                x, y = x.cuda(), y.cuda()
            x, y = Variable(x), Variable(y)

            # duplicate 10 times
            x = x.repeat(self.M, 1, 1, 1)

            # initialize location vector and hidden state
            self.batch_size = x.shape[0]
            h_t, l_t = self.reset()

            # extract the glimpses
            locs = []
            log_pi = []
            baselines = []
            all_log_probas = []
            uncertainities = []

            # by default it needs to run `self.num_glimpse` times
            num_glimpses_taken = [
                self.config.num_glimpses - 1 for _ in range(self.batch_size)
            ]

            for t in range(self.config.num_glimpses):

                # forward pass through model
                h_t, l_t, b_t, log_probas, p, diff_uncertainty, diff_uncertainty_baseline = self.model(
                    x, l_t, h_t, last=True)
                # store
                locs.append(l_t[0:9])
                baselines.append(b_t)
                log_pi.append(p)
                all_log_probas.append(log_probas)
                uncertainities.append(diff_uncertainty)

                if self.config.dynamic == True:
                    # determine if it has achieve a threshold uncertainty
                    probs_data = torch.exp(log_probas).data.tolist()
                    diff_uncertainty_data = diff_uncertainty.data.tolist()
                    for instance_idx, (prediction, uncertainty) in enumerate(
                            zip(probs_data, diff_uncertainty_data)):
                        a_star_idx = max(enumerate(prediction),
                                         key=lambda x: x[1])[0]
                        a_prime_idx = max(
                            [(idx, pred +
                              self.config.exploration_rate * uncertainty[idx])
                             for idx, pred in enumerate(prediction)
                             if idx != a_star_idx],
                            key=lambda x: x[1])[0]
                        a_star_lower_bound = prediction[
                            a_star_idx] - self.config.exploration_rate * uncertainty[
                                a_star_idx]
                        a_prime_upper_bound = prediction[
                            a_prime_idx] - self.config.exploration_rate * uncertainty[
                                a_prime_idx]
                        if a_star_lower_bound >= a_prime_upper_bound:
                            num_glimpses_taken[instance_idx] = t

                    if all([
                            num < self.config.num_glimpses - 1
                            for num in num_glimpses_taken
                    ]):
                        # print(num_glimpses_taken)
                        break
                        # print('strange! end now!:',t)

            # convert list to tensors and reshape
            baselines = torch.stack(baselines).transpose(1, 0)
            log_pi = torch.stack(log_pi).transpose(1, 0)
            if self.config.uncertainty == True or self.config.dynamic == True:
                uncertainities = torch.stack(uncertainities).transpose(1, 0)
            all_log_probas = torch.stack(all_log_probas).transpose(1, 0)

            all_num_glimpses_taken.extend(num_glimpses_taken)

            # calculate reward
            num_glimpses_taken_indices = torch.LongTensor(
                num_glimpses_taken).type(self.dtypeLong)
            log_probas = torch.cat([
                torch.index_select(a, 0, i).unsqueeze(0)
                for a, i in zip(all_log_probas, num_glimpses_taken_indices)
            ]).squeeze()
            # average the `self.M` times of prediction
            log_probas = log_probas.view(self.M, -1, log_probas.shape[-1])
            log_probas = torch.mean(log_probas, dim=0)

            pred = log_probas.data.max(1, keepdim=True)[1]
            correct += pred.eq(y.data.view_as(pred)).cpu().sum()

        perc = (100. * correct) / (self.num_test)
        error = 100 - perc
        print('[*] Test Acc: {}/{} ({:.2f}% - {:.2f}%)'.format(
            correct, self.num_test, perc, error))
        if self.config.dynamic == True:
            print('use dynamic')
            avg_num_glimpses_taken = sum(all_num_glimpses_taken) / len(
                all_num_glimpses_taken) + 1
            return (avg_num_glimpses_taken,
                    1.0 * correct.tolist() / self.num_test)
        return 1.0 * correct.tolist() / self.num_test
        # return perc.tolist()

    def test_for_all(
        self,
        range_all=100,
    ):
        """
        Test the model on the held-out test data.
        This is used to run the model under different number of glimpses
        """
        correct = []
        for _ in range(range_all):
            correct.append(0)

        # load the best checkpoint
        self.load_checkpoint(best=self.best)

        self.num_test = len(self.test_loader.sampler)

        all_num_glimpses_taken = []
        for i, (x, y) in enumerate(tqdm(self.test_loader)):
            torch.manual_seed(self.config.random_seed)
            if self.use_gpu:
                x, y = x.cuda(), y.cuda()
            x, y = Variable(x), Variable(y)

            # duplicate 10 times
            x = x.repeat(self.M, 1, 1, 1)

            # initialize location vector and hidden state
            self.batch_size = x.shape[0]
            h_t, l_t = self.reset()

            # extract the glimpses
            locs = []
            log_pi = []
            baselines = []
            all_log_probas = []
            uncertainities = []

            # by default it needs to run `self.num_glimpse` times
            num_glimpses_taken = [
                range_all - 1 for _ in range(self.batch_size)
            ]

            for t in range(self.config.num_glimpses):

                # forward pass through model
                h_t, l_t, b_t, log_probas, p, diff_uncertainty, diff_uncertainty_baseline = self.model(
                    x, l_t, h_t, last=True)
                # store
                locs.append(l_t[0:9])
                baselines.append(b_t)
                log_pi.append(p)
                all_log_probas.append(log_probas)
                uncertainities.append(diff_uncertainty)

                if self.config.dynamic == True:
                    # determine if it has achieve a threshold uncertainty
                    probs_data = torch.exp(log_probas).data.tolist()
                    diff_uncertainty_data = diff_uncertainty.data.tolist()
                    for instance_idx, (prediction, uncertainty) in enumerate(
                            zip(probs_data, diff_uncertainty_data)):
                        a_star_idx = max(enumerate(prediction),
                                         key=lambda x: x[1])[0]
                        a_prime_idx = max(
                            [(idx, pred +
                              self.config.exploration_rate * uncertainty[idx])
                             for idx, pred in enumerate(prediction)
                             if idx != a_star_idx],
                            key=lambda x: x[1])[0]
                        a_star_lower_bound = prediction[
                            a_star_idx] - self.config.exploration_rate * uncertainty[
                                a_star_idx]
                        a_prime_upper_bound = prediction[
                            a_prime_idx] - self.config.exploration_rate * uncertainty[
                                a_prime_idx]
                        if a_star_lower_bound >= a_prime_upper_bound:
                            num_glimpses_taken[instance_idx] = t

                    if all([
                            num < self.config.num_glimpses - 1
                            for num in num_glimpses_taken
                    ]):
                        # print(num_glimpses_taken)
                        break

            # convert list to tensors and reshape
            baselines = torch.stack(baselines).transpose(1, 0)
            log_pi = torch.stack(log_pi).transpose(1, 0)
            if self.config.uncertainty == True or self.config.dynamic == True:
                uncertainities = torch.stack(uncertainities).transpose(1, 0)
            all_log_probas = torch.stack(all_log_probas).transpose(1, 0)

            all_num_glimpses_taken.extend(num_glimpses_taken)

            # calculate reward
            for num in range(range_all):
                num_glimpses_taken = [num for _ in range(self.batch_size)]
                num_glimpses_taken_indices = torch.LongTensor(
                    num_glimpses_taken).type(self.dtypeLong)
                # log_probas = torch.cat([ torch.index_select(a, 0, i).unsqueeze(0) for a, i in zip(all_log_probas, num_glimpses_taken_indices) ]).squeeze()

                log_probas = all_log_probas[:, num]
                # print(all_log_probas.size(),log_probas.size())
                # average the `self.M` times of prediction
                log_probas = log_probas.view(self.M, -1, log_probas.shape[-1])
                log_probas = torch.mean(log_probas, dim=0)

                pred = log_probas.data.max(1, keepdim=True)[1]
                correct[num] += pred.eq(y.data.view_as(pred)).cpu().sum()

        return [1.0 * cor.tolist() / self.num_test for cor in correct]

        # return 1.0 * correct.tolist() / self.num_test

    def save_checkpoint(self, state, is_best):
        """
        Save a copy of the model so that it can be loaded at a future
        date. This function is used when the model is being evaluated
        on the test data.

        If this model has reached the best validation accuracy thus
        far, a seperate file with the suffix `best` is created.
        """
        # print("[*] Saving model to {}".format(self.ckpt_dir))

        filename = self.model_name + '_ckpt.pth.tar'
        ckpt_path = os.path.join(self.ckpt_dir, filename)
        torch.save(state, ckpt_path)

        if is_best:
            filename = self.model_name + '_model_best.pth.tar'
            shutil.copyfile(ckpt_path, os.path.join(self.ckpt_dir, filename))

    def load_checkpoint(self, best=False):
        """
        Load the best copy of a model. This is useful for 2 cases:

        - Resuming training with the most recent model checkpoint.
        - Loading the best validation model to evaluate on the test data.

        Params
        ------
        - best: if set to True, loads the best model. Use this if you want
          to evaluate your model on the test data. Else, set to False in
          which case the most recent version of the checkpoint is used.
        """
        print("[*] Loading model from {}".format(self.ckpt_dir))

        filename = self.model_name + '_ckpt.pth.tar'
        if best:
            filename = self.model_name + '_model_best.pth.tar'
        ckpt_path = os.path.join(self.ckpt_dir, filename)
        ckpt = torch.load(ckpt_path)

        # load variables from checkpoint
        self.start_epoch = ckpt['epoch']
        self.best_valid_acc = ckpt['best_valid_acc']
        self.model.load_state_dict(ckpt['model_state'])
        self.optimizer.load_state_dict(ckpt['optim_state'])

        if best:
            print("[*] Loaded {} checkpoint @ epoch {} "
                  "with best valid acc of {:.3f}".format(
                      filename, ckpt['epoch'], ckpt['best_valid_acc']))
        else:
            print("[*] Loaded {} checkpoint @ epoch {}".format(
                filename, ckpt['epoch']))
            val_dataset,
            val_split=args.val_split,
            random_split=args.random_split,
            batch_size=args.batch_size,
            **kwargs)
        args.num_class = train_loader.dataset.num_class
        args.num_channels = train_loader.dataset.num_channels

    else:
        test_dataset = get_MNIST_test_dataset(args.data_dir)
        test_loader = get_test_loader(test_dataset, args.batch_size, **kwargs)
        args.num_class = test_loader.dataset.num_class
        args.num_channels = test_loader.dataset.num_channels

    # build RAM model
    model = RecurrentAttention(args)
    if args.use_gpu:
        model.cuda()
    optimizer = torch.optim.SGD(model.parameters(),
                                lr=args.init_lr,
                                momentum=args.momentum)

    logger.info('Number of model parameters: {:,}'.format(
        sum([p.data.nelement() for p in model.parameters()])))
    trainer = Trainer(model, optimizer, watch=['acc'], val_watch=['acc'])

    if args.is_train:
        logger.info("Train on {} samples, validate on {} samples".format(
            len(train_loader.dataset), len(val_loader.dataset)))
        start_epoch = 0
        if args.resume:
Exemple #18
0
    def __init__(self, config, data_loader):
        """
        Construct a new Trainer instance.

        Args
        ----
        - config: object containing command line arguments.
        - data_loader: data iterator
        """
        self.config = config

        # glimpse network params
        self.patch_size = config.patch_size
        self.glimpse_scale = config.glimpse_scale
        self.num_patches = config.num_patches
        self.loc_hidden = config.loc_hidden
        self.glimpse_hidden = config.glimpse_hidden

        # core network params
        self.num_glimpses = config.num_glimpses
        self.hidden_size = config.hidden_size

        # reinforce params
        self.std = config.std
        self.M = config.M

        # data params
        if config.is_train:
            self.train_loader = data_loader[0]
            self.valid_loader = data_loader[1]
            self.num_train = len(self.train_loader.sampler.indices)
            self.num_valid = len(self.valid_loader.sampler.indices)
        else:
            self.test_loader = data_loader
            self.num_test = len(self.test_loader.dataset)
        self.num_classes = 10
        self.num_channels = 1

        # training params
        self.epochs = config.epochs
        self.start_epoch = 0
        self.momentum = config.momentum
        self.lr = config.init_lr

        # misc params
        self.no_tqdm = config.no_tqdm
        self.use_gpu = config.use_gpu
        self.best = config.best
        self.ckpt_dir = config.ckpt_dir
        self.logs_dir = config.logs_dir
        self.best_valid_acc = 0.
        self.counter = 0
        self.lr_patience = config.lr_patience
        self.train_patience = config.train_patience
        self.use_tensorboard = config.use_tensorboard
        self.resume = config.resume
        self.print_freq = config.print_freq
        self.plot_freq = config.plot_freq
        self.model_name = 'ram_{}_{}x{}_{}'.format(config.num_glimpses,
                                                   config.patch_size,
                                                   config.patch_size,
                                                   config.glimpse_scale)

        if config.uncertainty == True:
            self.model_name += '_uncertainty_1'
        else:
            self.model_name += '_uncertainty_0'
        if config.intrinsic == True:
            self.model_name += '_intrinsic_1'
        else:
            self.model_name += '_intrinsic_0'

        self.plot_dir = './plots/' + self.model_name + '/'
        if not os.path.exists(self.plot_dir):
            os.makedirs(self.plot_dir)

        # configure tensorboard logging
        if self.use_tensorboard:
            tensorboard_dir = self.logs_dir + self.model_name
            print('[*] Saving tensorboard logs to {}'.format(tensorboard_dir))
            if not os.path.exists(tensorboard_dir):
                os.makedirs(tensorboard_dir)
            configure(tensorboard_dir)

        # build RAM model
        self.model = RecurrentAttention(self.patch_size, self.num_patches,
                                        self.glimpse_scale, self.num_channels,
                                        self.loc_hidden, self.glimpse_hidden,
                                        self.std, self.hidden_size,
                                        self.num_classes, self.config)
        if self.use_gpu:
            self.model.cuda()

        self.dtypeFloat = (torch.cuda.FloatTensor
                           if self.use_gpu else torch.FloatTensor)
        self.dtypeLong = (torch.cuda.LongTensor
                          if self.use_gpu else torch.LongTensor)

        print('[*] Number of model parameters: {:,}'.format(
            sum([p.data.nelement() for p in self.model.parameters()])))

        # # initialize optimizer and scheduler
        self.optimizer = optim.Adam(
            self.model.parameters(),
            lr=self.config.init_lr,
        )
        lambda_of_lr = lambda epoch: 0.95**epoch
        self.scheduler = LambdaLR(self.optimizer, lr_lambda=lambda_of_lr)
Exemple #19
0
    def __init__(self, config, data_loader):
        """
        Construct a new Trainer instance.

        Args
        ----
        - config: object containing command line arguments.
        - data_loader: data iterator
        """
        self.config = config

        # glimpse network params
        self.patch_size = config.patch_size
        self.glimpse_scale = config.glimpse_scale
        self.num_patches = config.num_patches
        self.loc_hidden = config.loc_hidden
        self.glimpse_hidden = config.glimpse_hidden

        # core network params
        self.num_glimpses = config.num_glimpses
        self.hidden_size = config.hidden_size

        # reinforce params
        self.std = config.std
        self.M = config.M

        # data params
        if config.is_train:
            self.train_loader = data_loader[0]
            self.valid_loader = data_loader[1]
            #self.num_train = len(self.train_loader.sampler.indices)
            #self.num_valid = len(self.valid_loader.sampler.indices)
        else:
            self.test_loader = data_loader
            self.num_test = len(self.test_loader.dataset)
        self.num_classes = 83
        self.num_channels = 1

        # training params
        self.epochs = config.epochs
        self.start_epoch = 0
        self.momentum = config.momentum
        self.lr = config.init_lr

        # misc params
        self.use_gpu = config.use_gpu
        self.best = config.best
        self.ckpt_dir = config.ckpt_dir
        self.logs_dir = config.logs_dir
        self.best_valid_acc = 0.
        self.counter = 0
        self.lr_patience = config.lr_patience
        self.train_patience = config.train_patience
        self.use_tensorboard = config.use_tensorboard

        self.trainSamplesSize = len(self.train_loader.trainSamples)

        self.resume = config.resume
        self.print_freq = config.print_freq
        self.plot_freq = config.plot_freq
        self.model_name = 'ram_{}_{}x{}_{}'.format(config.num_glimpses,
                                                   config.patch_size,
                                                   config.patch_size,
                                                   config.glimpse_scale)

        self.plot_dir = './plots/' + self.model_name + '/'
        if not os.path.exists(self.plot_dir):
            os.makedirs(self.plot_dir)

        # configure tensorboard logging
        if self.use_tensorboard:
            tensorboard_dir = self.logs_dir + self.model_name
            print('[*] Saving tensorboard logs to {}'.format(tensorboard_dir))
            if not os.path.exists(tensorboard_dir):
                os.makedirs(tensorboard_dir)
            configure(tensorboard_dir)

        # build RAM model
        self.model = RecurrentAttention(
            self.patch_size,
            self.num_patches,
            self.glimpse_scale,
            self.num_channels,
            self.loc_hidden,
            self.glimpse_hidden,
            self.std,
            self.hidden_size,
            self.num_classes,
        )
        if self.use_gpu:
            self.model.cuda()

        print('[*] Number of model parameters: {:,}'.format(
            sum([p.data.nelement() for p in self.model.parameters()])))

        # # initialize optimizer and scheduler
        # self.optimizer = optim.SGD(
        #     self.model.parameters(), lr=self.lr, momentum=self.momentum,
        # )
        # self.scheduler = ReduceLROnPlateau(
        #     self.optimizer, 'min', patience=self.lr_patience
        # )
        self.optimizer = optim.Adam(
            self.model.parameters(),
            lr=3e-4,
        )
Exemple #20
0
class Trainer:
    """A Recurrent Attention Model trainer.

    All hyperparameters are provided by the user in the
    config file.
    """

    def __init__(self, config, data_loader):
        """
        Construct a new Trainer instance.

        Args:
            config: object containing command line arguments.
            data_loader: A data iterator.
        """
        self.config = config

        if config.use_gpu and torch.cuda.is_available():
            self.device = torch.device("cuda")
        else:
            self.device = torch.device("cpu")

        # glimpse network params
        self.patch_size = config.patch_size
        self.glimpse_scale = config.glimpse_scale
        self.num_patches = config.num_patches
        self.loc_hidden = config.loc_hidden
        self.glimpse_hidden = config.glimpse_hidden

        # core network params
        self.num_glimpses = config.num_glimpses
        self.hidden_size = config.hidden_size

        # reinforce params
        self.std = config.std
        self.M = config.M

        # data params
        if config.is_train:
            self.train_loader = data_loader[0]
            self.valid_loader = data_loader[1]
            self.num_train = len(self.train_loader.sampler.indices)
            self.num_valid = len(self.valid_loader.sampler.indices)
        else:
            self.test_loader = data_loader
            self.num_test = len(self.test_loader.dataset)
        self.num_classes = 10
        self.num_channels = 1

        # training params
        self.epochs = config.epochs
        self.start_epoch = 0
        self.momentum = config.momentum
        self.lr = config.init_lr

        # misc params
        self.best = config.best
        self.ckpt_dir = config.ckpt_dir
        self.logs_dir = config.logs_dir
        self.best_valid_acc = 0.0
        self.counter = 0
        self.lr_patience = config.lr_patience
        self.train_patience = config.train_patience
        self.use_tensorboard = config.use_tensorboard
        self.resume = config.resume
        self.print_freq = config.print_freq
        self.plot_freq = config.plot_freq
        self.model_name = "ram_{}_{}x{}_{}".format(
            config.num_glimpses,
            config.patch_size,
            config.patch_size,
            config.glimpse_scale,
        )

        self.plot_dir = "./plots/" + self.model_name + "/"
        if not os.path.exists(self.plot_dir):
            os.makedirs(self.plot_dir)

        # configure tensorboard logging
        if self.use_tensorboard:
            tensorboard_dir = self.logs_dir + self.model_name
            print("[*] Saving tensorboard logs to {}".format(tensorboard_dir))
            if not os.path.exists(tensorboard_dir):
                os.makedirs(tensorboard_dir)
            configure(tensorboard_dir)

        # build RAM model
        self.model = RecurrentAttention(
            self.patch_size,
            self.num_patches,
            self.glimpse_scale,
            self.num_channels,
            self.loc_hidden,
            self.glimpse_hidden,
            self.std,
            self.hidden_size,
            self.num_classes,
        )
        self.model.to(self.device)

        # initialize optimizer and scheduler
        self.optimizer = torch.optim.Adam(
            self.model.parameters(), lr=self.config.init_lr
        )
        self.scheduler = ReduceLROnPlateau(
            self.optimizer, "min", patience=self.lr_patience
        )

    def reset(self):
        h_t = torch.zeros(
            self.batch_size,
            self.hidden_size,
            dtype=torch.float,
            device=self.device,
            requires_grad=True,
        )
        l_t = torch.FloatTensor(self.batch_size, 2).uniform_(-1, 1).to(self.device)
        l_t.requires_grad = True

        return h_t, l_t

    def train(self):
        """Train the model on the training set.

        A checkpoint of the model is saved after each epoch
        and if the validation accuracy is improved upon,
        a separate ckpt is created for use on the test set.
        """
        # load the most recent checkpoint
        if self.resume:
            self.load_checkpoint(best=False)

        print(
            "\n[*] Train on {} samples, validate on {} samples".format(
                self.num_train, self.num_valid
            )
        )

        for epoch in range(self.start_epoch, self.epochs):

            print(
                "\nEpoch: {}/{} - LR: {:.6f}".format(
                    epoch + 1, self.epochs, self.optimizer.param_groups[0]["lr"]
                )
            )

            # train for 1 epoch
            train_loss, train_acc = self.train_one_epoch(epoch)

            # evaluate on validation set
            valid_loss, valid_acc = self.validate(epoch)

            # # reduce lr if validation loss plateaus
            self.scheduler.step(-valid_acc)

            is_best = valid_acc > self.best_valid_acc
            msg1 = "train loss: {:.3f} - train acc: {:.3f} "
            msg2 = "- val loss: {:.3f} - val acc: {:.3f} - val err: {:.3f}"
            if is_best:
                self.counter = 0
                msg2 += " [*]"
            msg = msg1 + msg2
            print(
                msg.format(
                    train_loss, train_acc, valid_loss, valid_acc, 100 - valid_acc
                )
            )

            # check for improvement
            if not is_best:
                self.counter += 1
            if self.counter > self.train_patience:
                print("[!] No improvement in a while, stopping training.")
                return
            self.best_valid_acc = max(valid_acc, self.best_valid_acc)
            self.save_checkpoint(
                {
                    "epoch": epoch + 1,
                    "model_state": self.model.state_dict(),
                    "optim_state": self.optimizer.state_dict(),
                    "best_valid_acc": self.best_valid_acc,
                },
                is_best,
            )

    def train_one_epoch(self, epoch):
        """
        Train the model for 1 epoch of the training set.

        An epoch corresponds to one full pass through the entire
        training set in successive mini-batches.

        This is used by train() and should not be called manually.
        """
        self.model.train()
        batch_time = AverageMeter()
        losses = AverageMeter()
        accs = AverageMeter()

        tic = time.time()
        with tqdm(total=self.num_train) as pbar:
            for i, (x, y) in enumerate(self.train_loader):
                self.optimizer.zero_grad()

                x, y = x.to(self.device), y.to(self.device)

                plot = False
                if (epoch % self.plot_freq == 0) and (i == 0):
                    plot = True

                # initialize location vector and hidden state
                self.batch_size = x.shape[0]
                h_t, l_t = self.reset()

                # save images
                imgs = []
                imgs.append(x[0:9])

                # extract the glimpses
                locs = []
                log_pi = []
                baselines = []
                for t in range(self.num_glimpses - 1):
                    # forward pass through model
                    h_t, l_t, b_t, p = self.model(x, l_t, h_t)

                    # store
                    locs.append(l_t[0:9])
                    baselines.append(b_t)
                    log_pi.append(p)

                # last iteration
                h_t, l_t, b_t, log_probas, p = self.model(x, l_t, h_t, last=True)
                log_pi.append(p)
                baselines.append(b_t)
                locs.append(l_t[0:9])

                # convert list to tensors and reshape
                baselines = torch.stack(baselines).transpose(1, 0)
                log_pi = torch.stack(log_pi).transpose(1, 0)

                # calculate reward
                predicted = torch.max(log_probas, 1)[1]
                R = (predicted.detach() == y).float()
                R = R.unsqueeze(1).repeat(1, self.num_glimpses)

                # compute losses for differentiable modules
                loss_action = F.nll_loss(log_probas, y)
                loss_baseline = F.mse_loss(baselines, R)

                # compute reinforce loss
                # summed over timesteps and averaged across batch
                adjusted_reward = R - baselines.detach()
                loss_reinforce = torch.sum(-log_pi * adjusted_reward, dim=1)
                loss_reinforce = torch.mean(loss_reinforce, dim=0)

                # sum up into a hybrid loss
                loss = loss_action + loss_baseline + loss_reinforce * 0.01

                # compute accuracy
                correct = (predicted == y).float()
                acc = 100 * (correct.sum() / len(y))

                # store
                losses.update(loss.item(), x.size()[0])
                accs.update(acc.item(), x.size()[0])

                # compute gradients and update SGD
                loss.backward()
                self.optimizer.step()

                # measure elapsed time
                toc = time.time()
                batch_time.update(toc - tic)

                pbar.set_description(
                    (
                        "{:.1f}s - loss: {:.3f} - acc: {:.3f}".format(
                            (toc - tic), loss.item(), acc.item()
                        )
                    )
                )
                pbar.update(self.batch_size)

                # dump the glimpses and locs
                if plot:
                    imgs = [g.cpu().data.numpy().squeeze() for g in imgs]
                    locs = [l.cpu().data.numpy() for l in locs]
                    pickle.dump(
                        imgs, open(self.plot_dir + "g_{}.p".format(epoch + 1), "wb")
                    )
                    pickle.dump(
                        locs, open(self.plot_dir + "l_{}.p".format(epoch + 1), "wb")
                    )

                # log to tensorboard
                if self.use_tensorboard:
                    iteration = epoch * len(self.train_loader) + i
                    log_value("train_loss", losses.avg, iteration)
                    log_value("train_acc", accs.avg, iteration)

            return losses.avg, accs.avg

    @torch.no_grad()
    def validate(self, epoch):
        """Evaluate the RAM model on the validation set.
        """
        losses = AverageMeter()
        accs = AverageMeter()

        for i, (x, y) in enumerate(self.valid_loader):
            x, y = x.to(self.device), y.to(self.device)

            # duplicate M times
            x = x.repeat(self.M, 1, 1, 1)

            # initialize location vector and hidden state
            self.batch_size = x.shape[0]
            h_t, l_t = self.reset()

            # extract the glimpses
            log_pi = []
            baselines = []
            for t in range(self.num_glimpses - 1):
                # forward pass through model
                h_t, l_t, b_t, p = self.model(x, l_t, h_t)

                # store
                baselines.append(b_t)
                log_pi.append(p)

            # last iteration
            h_t, l_t, b_t, log_probas, p = self.model(x, l_t, h_t, last=True)
            log_pi.append(p)
            baselines.append(b_t)

            # convert list to tensors and reshape
            baselines = torch.stack(baselines).transpose(1, 0)
            log_pi = torch.stack(log_pi).transpose(1, 0)

            # average
            log_probas = log_probas.view(self.M, -1, log_probas.shape[-1])
            log_probas = torch.mean(log_probas, dim=0)

            baselines = baselines.contiguous().view(self.M, -1, baselines.shape[-1])
            baselines = torch.mean(baselines, dim=0)

            log_pi = log_pi.contiguous().view(self.M, -1, log_pi.shape[-1])
            log_pi = torch.mean(log_pi, dim=0)

            # calculate reward
            predicted = torch.max(log_probas, 1)[1]
            R = (predicted.detach() == y).float()
            R = R.unsqueeze(1).repeat(1, self.num_glimpses)

            # compute losses for differentiable modules
            loss_action = F.nll_loss(log_probas, y)
            loss_baseline = F.mse_loss(baselines, R)

            # compute reinforce loss
            adjusted_reward = R - baselines.detach()
            loss_reinforce = torch.sum(-log_pi * adjusted_reward, dim=1)
            loss_reinforce = torch.mean(loss_reinforce, dim=0)

            # sum up into a hybrid loss
            loss = loss_action + loss_baseline + loss_reinforce * 0.01

            # compute accuracy
            correct = (predicted == y).float()
            acc = 100 * (correct.sum() / len(y))

            # store
            losses.update(loss.item(), x.size()[0])
            accs.update(acc.item(), x.size()[0])

            # log to tensorboard
            if self.use_tensorboard:
                iteration = epoch * len(self.valid_loader) + i
                log_value("valid_loss", losses.avg, iteration)
                log_value("valid_acc", accs.avg, iteration)

        return losses.avg, accs.avg

    @torch.no_grad()
    def test(self):
        """Test the RAM model.

        This function should only be called at the very
        end once the model has finished training.
        """
        correct = 0

        # load the best checkpoint
        self.load_checkpoint(best=self.best)

        for i, (x, y) in enumerate(self.test_loader):
            x, y = x.to(self.device), y.to(self.device)

            # duplicate M times
            x = x.repeat(self.M, 1, 1, 1)

            # initialize location vector and hidden state
            self.batch_size = x.shape[0]
            h_t, l_t = self.reset()

            # extract the glimpses
            for t in range(self.num_glimpses - 1):
                # forward pass through model
                h_t, l_t, b_t, p = self.model(x, l_t, h_t)

            # last iteration
            h_t, l_t, b_t, log_probas, p = self.model(x, l_t, h_t, last=True)

            log_probas = log_probas.view(self.M, -1, log_probas.shape[-1])
            log_probas = torch.mean(log_probas, dim=0)

            pred = log_probas.data.max(1, keepdim=True)[1]
            correct += pred.eq(y.data.view_as(pred)).cpu().sum()

        perc = (100.0 * correct) / (self.num_test)
        error = 100 - perc
        print(
            "[*] Test Acc: {}/{} ({:.2f}% - {:.2f}%)".format(
                correct, self.num_test, perc, error
            )
        )

    def save_checkpoint(self, state, is_best):
        """Saves a checkpoint of the model.

        If this model has reached the best validation accuracy thus
        far, a seperate file with the suffix `best` is created.
        """
        filename = self.model_name + "_ckpt.pth.tar"
        ckpt_path = os.path.join(self.ckpt_dir, filename)
        torch.save(state, ckpt_path)
        if is_best:
            filename = self.model_name + "_model_best.pth.tar"
            shutil.copyfile(ckpt_path, os.path.join(self.ckpt_dir, filename))

    def load_checkpoint(self, best=False):
        """Load the best copy of a model.

        This is useful for 2 cases:
        - Resuming training with the most recent model checkpoint.
        - Loading the best validation model to evaluate on the test data.

        Args:
            best: if set to True, loads the best model.
                Use this if you want to evaluate your model
                on the test data. Else, set to False in which
                case the most recent version of the checkpoint
                is used.
        """
        print("[*] Loading model from {}".format(self.ckpt_dir))

        filename = self.model_name + "_ckpt.pth.tar"
        if best:
            filename = self.model_name + "_model_best.pth.tar"
        ckpt_path = os.path.join(self.ckpt_dir, filename)
        ckpt = torch.load(ckpt_path)

        # load variables from checkpoint
        self.start_epoch = ckpt["epoch"]
        self.best_valid_acc = ckpt["best_valid_acc"]
        self.model.load_state_dict(ckpt["model_state"])
        self.optimizer.load_state_dict(ckpt["optim_state"])

        if best:
            print(
                "[*] Loaded {} checkpoint @ epoch {} "
                "with best valid acc of {:.3f}".format(
                    filename, ckpt["epoch"], ckpt["best_valid_acc"]
                )
            )
        else:
            print("[*] Loaded {} checkpoint @ epoch {}".format(filename, ckpt["epoch"]))