class Trainer: """A Recurrent Attention Model trainer. All hyperparameters are provided by the user in the config file. """ def __init__(self, config, data_loader): """ Construct a new Trainer instance. Args: config: object containing command line arguments. data_loader: A data iterator. """ self.config = config if config.use_gpu and torch.cuda.is_available(): self.device = torch.device("cuda") else: self.device = torch.device("cpu") # glimpse network params self.patch_size = config.patch_size self.glimpse_scale = config.glimpse_scale self.num_patches = config.num_patches self.loc_hidden = config.loc_hidden self.glimpse_hidden = config.glimpse_hidden # core network params self.num_glimpses = config.num_glimpses self.hidden_size = config.hidden_size # reinforce params self.std = config.std self.M = config.M # data params if config.is_train: self.train_loader = data_loader[0] self.valid_loader = data_loader[1] self.num_train = len(self.train_loader.sampler.indices) self.num_valid = len(self.valid_loader.sampler.indices) else: self.test_loader = data_loader self.num_test = len(self.test_loader.dataset) self.num_classes = 25 #10 self.num_channels = 1 # training params self.epochs = config.epochs self.start_epoch = 0 self.momentum = config.momentum self.lr = config.init_lr # misc params self.best = config.best self.ckpt_dir = config.ckpt_dir self.logs_dir = config.logs_dir self.best_valid_acc = 0.0 self.counter = 0 self.lr_patience = config.lr_patience self.train_patience = config.train_patience self.use_tensorboard = config.use_tensorboard self.resume = config.resume self.print_freq = config.print_freq self.plot_freq = config.plot_freq self.model_name = "ram_{}_{}x{}_{}".format( config.num_glimpses, config.patch_size, config.patch_size, config.glimpse_scale, ) self.plot_dir = "./plots/" + self.model_name + "/" if not os.path.exists(self.plot_dir): os.makedirs(self.plot_dir) # configure tensorboard logging if self.use_tensorboard: tensorboard_dir = self.logs_dir + self.model_name print("[*] Saving tensorboard logs to {}".format(tensorboard_dir)) if not os.path.exists(tensorboard_dir): os.makedirs(tensorboard_dir) configure(tensorboard_dir) # build RAM model self.model = RecurrentAttention( self.patch_size, self.num_patches, self.glimpse_scale, self.num_channels, self.loc_hidden, self.glimpse_hidden, self.std, self.hidden_size, self.num_classes, ) self.model.to(self.device) # initialize optimizer and scheduler self.optimizer = torch.optim.Adam(self.model.parameters(), lr=self.config.init_lr) self.scheduler = ReduceLROnPlateau(self.optimizer, "min", patience=self.lr_patience) def gmdataset(self): import pandas as pd #data = pd.read_csv("check.csv", index_col ="name") gW = pd.read_csv("goodware.csv", index_col="name") mW = pd.read_csv("malware.csv", index_col="name") out = mW.append(gW) data = out out.drop('(BAD)', axis=1, inplace=True) out.drop('STD', axis=1, inplace=True) out.drop('SHLD', axis=1, inplace=True) out.drop('SETLE', axis=1, inplace=True) out.drop('SETB', axis=1, inplace=True) out.drop('SBB', axis=1, inplace=True) out.drop('RDTSC', axis=1, inplace=True) out.drop('PUSHF', axis=1, inplace=True) out.drop('FSTCW', axis=1, inplace=True) out.drop('FDIVP', axis=1, inplace=True) out.drop('FILD', axis=1, inplace=True) out.drop('RETN', axis=1, inplace=True) out.drop('LEA', axis=1, inplace=True) out.drop('IMUL', axis=1, inplace=True) from sklearn.model_selection import train_test_split #print(data['labels']) M = data.values X = M[:, :-1] Y = M[:, -1] #print(Y) #X_train, X_test, y_train, y_test = train_test_split(X, Y, test_size=0.33, random_state=42) #print("HHHHHHH: ",X_train.shape, y_train.shape) import numpy as np # print(X_train.shape) #x_train = np.reshape(X_train, (X_train.shape[0],5, 5,1)) #padie=np.pad(X_train, ((0,0),(0,759)), 'constant', constant_values=0) padie = np.pad(X, ((0, 0), (0, 1)), 'constant', constant_values=0) print(padie.shape) x = np.reshape(padie, (padie.shape[0], 1, 4, 4)) return x, Y def alldatacsv(self): import numpy as np # linear algebra import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv) from sklearn.feature_extraction.text import CountVectorizer from keras.preprocessing.text import Tokenizer from keras.preprocessing.sequence import pad_sequences from keras.models import Sequential from keras.layers import Dense, Embedding, LSTM, SpatialDropout1D, Dropout from sklearn.model_selection import train_test_split from keras.utils.np_utils import to_categorical import re # Input data files are available in the "../input/" directory. # For example, running this (by clicking run or pressing Shift+Enter) will list the files in the input directory """Only keeping the necessary columns.""" import pandas as pd data = pd.read_csv("AllData.csv") data = data.drop(['Unnamed: 0'], axis=1) data = data.rename(columns={'Text': 'text', 'Label': 'sentiment'}) data #data = pd.read_csv('../input/Sentiment.csv') # Keeping only the neccessary columns data = data[['text', 'sentiment']] pos = data[data['sentiment'] == 1] pos.shape[0] #data = data[data.sentiment != "Neutral"] data['text'] = data['text'].apply(lambda x: x.lower()) data['text'] = data['text'].apply( (lambda x: re.sub('[^a-zA-z0-9\s]', '', x))) #print(data[ data['sentiment'] == 1].size) #print(data[ data['sentiment'] == 0].size) for idx, row in data.iterrows(): row[0] = row[0].replace('rt', ' ') max_fatures = 2000 tokenizer = Tokenizer(num_words=max_fatures, split=' ') tokenizer.fit_on_texts(data['text'].values) X = tokenizer.texts_to_sequences(data['text'].values) X = pad_sequences(X) data['sentiment'] pd.DataFrame(data=X[1:, -20000:], index=X[1:, 0]) # 1st row as the column name #X=X[0:,-20000:] X.shape """# Train and Test Dataset Declaration""" Y = pd.get_dummies(data['sentiment']).values #X_train, X_test, Y_train, Y_test = train_test_split(X,Y, test_size = 0.33, random_state = 78) #conu = 0 # for x in Y_test: # if x.argmax()== 1: # conu = conu + 1 # conu import numpy as np import matplotlib.pyplot as plt #print(X_train.shape) #x_train = np.reshape(X_train, (X_train.shape[0],5, 5,1)) padie = np.pad(X, ((0, 0), (0, 251)), 'constant', constant_values=0) padie = padie[:, 459 * 459 - 50176:] x = np.reshape(padie, (padie.shape[0], 1, 224, 224)) #x_test = np.reshape(x_test, (x_test.shape[0],2,2, 1)) #print(Y) sk = pd.DataFrame(data=Y, columns=[0, 1]) inverted = sk.idxmax(1).values ss = np.rint(inverted) Y = ss #for i in range(0,Y.shape[0]): # if Y[i] == 1: # #print("YESSSSSSSS") # string = "imgs/" + str(i) + ".png" # plt.imsave(string,x[i][0,:,:]) # qq = i #x = x[:] #Y = Y[int(qq-qq/2):] #print(type(X)) return x, Y def batadal(self): import pandas as pd import numpy as np fD = pd.read_csv("newDatasets/BATADAL_dataset04.csv", header=None) #fD = pd.read_csv("/content/drive/My Drive/newDatasets/BATADAL_dataset02 (1).csv" , header=None) test = pd.read_csv("newDatasets/BATADAL_test_dataset.csv", header=None) test = test.drop(columns=0) test = test.drop([0], axis=0) Data = fD.drop(columns=0) Data = Data.drop([0], axis=0) nData = Data.values nData.shape testData = test.values xData = nData[:, :43] yData = nData[:, 43] xData.shape testData.shape xData = np.pad(xData, ((0, 0), (0, 6)), 'constant', constant_values=0) testData = np.pad(testData, ((0, 0), (0, 6)), 'constant', constant_values=0) test[:] xData.shape xData = xData.reshape(-1, 1, 7, 7) testData = testData.reshape(-1, 1, 7, 7) from sklearn.model_selection import train_test_split X_train, X_test, y_train, y_test = train_test_split(xData, yData, test_size=0.1, random_state=42) #print(X_train, y_train) #print("JJ: ", type(xData)) return xData, yData def Malimg(self): import tensorflow as tf import keras import numpy as np from sklearn.preprocessing import StandardScaler from sklearn.model_selection import train_test_split dataset = np.load('malimg.npz', allow_pickle=True) BATCH_SIZE = 256 CELL_SIZE = 256 DROPOUT_RATE = 0.85 LEARNING_RATE = 1e-3 NODE_SIZE = [512, 256, 128] NUM_LAYERS = 5 features = dataset['arr'][:, 0] features = np.array([feature for feature in features]) features = np.reshape( features, (features.shape[0], features.shape[1] * features.shape[2])) r, c = features.shape print("Number of Samples", r) print("Number of Features", c) if 1 == 1: features = StandardScaler().fit_transform(features) labels = dataset['arr'][:, 1] labels = np.array([label for label in labels]) one_hot = np.zeros((labels.shape[0], labels.max() + 1)) one_hot[np.arange(labels.shape[0]), labels] = 1 labels = one_hot labels[labels == 0] = 0 num_features = features.shape[1] num_classes = labels.shape[1] Y = labels X = features print("Shape of Labels", Y.shape) print("Shape of Features", X.shape) train_features, test_features, train_labels, test_labels = train_test_split( features, labels, test_size=0.1, stratify=labels) #10% Test size train_size = int(train_features.shape[0]) train_features = train_features[:train_size - (train_size % BATCH_SIZE)] train_labels = train_labels[:train_size - (train_size % BATCH_SIZE)] test_size = int(test_features.shape[0]) test_features = test_features[:test_size - (test_size % BATCH_SIZE)] test_labels = test_labels[:test_size - (test_size % BATCH_SIZE)] fsize = int(features.shape[0]) features = features[:fsize - (fsize % BATCH_SIZE)] labels = labels[:fsize - (fsize % BATCH_SIZE)] r, c = train_features.shape print("Number of Training Samples", r) print("Number of Training Features", c) r, c = test_features.shape print("Number of Test Samples", r) print("Number of Test Features", c) #print(train_labels.shape) #print(tf.reshape(test_features[1], [32,32])) print(train_features.shape, test_features.shape, train_labels.shape, test_labels.shape) #print(train_labels) train_X = train_features.reshape(-1, 1, 32, 32) feat = features.reshape(-1, 1, 32, 32) test_X = test_features.reshape(-1, 32, 32, 1) Unchanined = X.reshape(-1, 32, 32, 1) y_test_non_category = [np.argmax(t) for t in labels] print("LABELS", np.asarray(y_test_non_category)) return feat, np.asarray(y_test_non_category) def SWAT(self): import pandas as pd import numpy as np fD = pd.read_excel("newDatasets/SWaT/SWaT_Dataset_Attack_v0.xlsx", header=None) Data = fD.drop(columns=0) Data = Data.drop([0, 1], axis=0) xData = Data.values[:, :51] yData = Data.values[:, 51] count = 0 for i in yData: if i == 'Normal': yData[count] = 0 else: yData[count] = 1 count = count + 1 xData.shape xData = np.pad(xData, ((0, 0), (0, 13)), 'constant', constant_values=0) xData = xData.reshape(xData.shape[0], 8, 8, 1) return xData[:200], yData[:200] def reset(self): h_t = torch.zeros( self.batch_size, self.hidden_size, dtype=torch.float, device=self.device, requires_grad=True, ) l_t = torch.FloatTensor(self.batch_size, 2).uniform_(-1, 1).to(self.device) l_t.requires_grad = True return h_t, l_t def train(self): """Train the model on the training set. A checkpoint of the model is saved after each epoch and if the validation accuracy is improved upon, a separate ckpt is created for use on the test set. """ # load the most recent checkpoint if self.resume: self.load_checkpoint(best=False) print("\n[*] Train on {} samples, validate on {} samples".format( self.num_train, self.num_valid)) for epoch in range(self.start_epoch, self.epochs): print("\nEpoch: {}/{} - LR: {:.6f}".format( epoch + 1, self.epochs, self.optimizer.param_groups[0]["lr"])) # train for 1 epoch train_loss, train_acc = self.train_one_epoch(epoch) # evaluate on validation set valid_loss, valid_acc = self.validate(epoch) # # reduce lr if validation loss plateaus self.scheduler.step(-valid_acc) is_best = valid_acc > self.best_valid_acc msg1 = "train loss: {:.3f} - train acc: {:.3f} " msg2 = "- val loss: {:.3f} - val acc: {:.3f} - val err: {:.3f}" if is_best: self.counter = 0 msg2 += " [*]" msg = msg1 + msg2 print( msg.format(train_loss, train_acc, valid_loss, valid_acc, 100 - valid_acc)) # check for improvement if not is_best: self.counter += 1 if self.counter > self.train_patience: print("[!] No improvement in a while, stopping training.") return self.best_valid_acc = max(valid_acc, self.best_valid_acc) self.save_checkpoint( { "epoch": epoch + 1, "model_state": self.model.state_dict(), "optim_state": self.optimizer.state_dict(), "best_valid_acc": self.best_valid_acc, }, is_best, ) def train_one_epoch(self, epoch): """ Train the model for 1 epoch of the training set. An epoch corresponds to one full pass through the entire training set in successive mini-batches. This is used by train() and should not be called manually. """ import pandas as pd import numpy as np self.model.train() batch_time = AverageMeter() losses = AverageMeter() accs = AverageMeter() tic = time.time() with tqdm(total=self.num_train) as pbar: for i, (x, y) in enumerate(self.train_loader): self.optimizer.zero_grad() x, y = x.to(self.device), y.to(self.device) x1, y1 = self.SWAT( ) #self.gmdataset()#self.batadal()#self.alldatacsv()#self.gmdataset() x1 = x1.astype(np.float32) y1 = y1.astype(np.float32) x, y = torch.from_numpy(x1).float(), torch.from_numpy( y1).long() #print("Here", y) plot = False if (epoch % self.plot_freq == 0) and (i == 0): plot = True # initialize location vector and hidden state self.batch_size = x.shape[0] h_t, l_t = self.reset() # save images imgs = [] imgs.append(x[0:9]) # extract the glimpses locs = [] log_pi = [] baselines = [] for t in range(self.num_glimpses - 1): # forward pass through model h_t, l_t, b_t, p = self.model(x, l_t, h_t) # store locs.append(l_t[0:9]) baselines.append(b_t) log_pi.append(p) # last iteration h_t, l_t, b_t, log_probas, p = self.model(x, l_t, h_t, last=True) log_pi.append(p) baselines.append(b_t) locs.append(l_t[0:9]) # convert list to tensors and reshape baselines = torch.stack(baselines).transpose(1, 0) log_pi = torch.stack(log_pi).transpose(1, 0) # calculate reward predicted = torch.max(log_probas, 1)[1] R = (predicted.detach() == y).float() R = R.unsqueeze(1).repeat(1, self.num_glimpses) # compute losses for differentiable modules loss_action = F.nll_loss(log_probas, y) loss_baseline = F.mse_loss(baselines, R) # compute reinforce loss # summed over timesteps and averaged across batch adjusted_reward = R - baselines.detach() loss_reinforce = torch.sum(-log_pi * adjusted_reward, dim=1) loss_reinforce = torch.mean(loss_reinforce, dim=0) # sum up into a hybrid loss loss = loss_action + loss_baseline + loss_reinforce * 0.01 # compute accuracy correct = (predicted == y).float() acc = 100 * (correct.sum() / len(y)) #print("Predicted: ", predicted, "\nTrue", y) # store losses.update(loss.item(), x.size()[0]) accs.update(acc.item(), x.size()[0]) # compute gradients and update SGD loss.backward() self.optimizer.step() # measure elapsed time toc = time.time() batch_time.update(toc - tic) pbar.set_description( ("{:.1f}s - loss: {:.3f} - acc: {:.3f}".format( (toc - tic), loss.item(), acc.item()))) pbar.update(self.batch_size) # dump the glimpses and locs if plot: imgs = [g.cpu().data.numpy().squeeze() for g in imgs] locs = [l.cpu().data.numpy() for l in locs] pickle.dump( imgs, open(self.plot_dir + "g_{}.p".format(epoch + 1), "wb")) pickle.dump( locs, open(self.plot_dir + "l_{}.p".format(epoch + 1), "wb")) # log to tensorboard if self.use_tensorboard: iteration = epoch * len(self.train_loader) + i log_value("train_loss", losses.avg, iteration) log_value("train_acc", accs.avg, iteration) return losses.avg, accs.avg @torch.no_grad() def validate(self, epoch): """Evaluate the RAM model on the validation set. """ import torch import numpy as np losses = AverageMeter() accs = AverageMeter() for i, (x, y) in enumerate(self.valid_loader): x, y = x.to(self.device), y.to(self.device) x1, y1 = self.gmdataset( ) #self.batadal()#self.alldatacsv()#self.gmdataset() x1 = x1.astype(np.float32) y1 = y1.astype(np.float32) x, y = torch.from_numpy(x1).float(), torch.from_numpy(y1).long() # duplicate M times x = x.repeat(self.M, 1, 1, 1) # initialize location vector and hidden state self.batch_size = x.shape[0] h_t, l_t = self.reset() # extract the glimpses log_pi = [] baselines = [] for t in range(self.num_glimpses - 1): # forward pass through model h_t, l_t, b_t, p = self.model(x, l_t, h_t) # store baselines.append(b_t) log_pi.append(p) # last iteration h_t, l_t, b_t, log_probas, p = self.model(x, l_t, h_t, last=True) log_pi.append(p) baselines.append(b_t) # convert list to tensors and reshape baselines = torch.stack(baselines).transpose(1, 0) log_pi = torch.stack(log_pi).transpose(1, 0) # average log_probas = log_probas.view(self.M, -1, log_probas.shape[-1]) log_probas = torch.mean(log_probas, dim=0) baselines = baselines.contiguous().view(self.M, -1, baselines.shape[-1]) baselines = torch.mean(baselines, dim=0) log_pi = log_pi.contiguous().view(self.M, -1, log_pi.shape[-1]) log_pi = torch.mean(log_pi, dim=0) # calculate reward predicted = torch.max(log_probas, 1)[1] R = (predicted.detach() == y).float() R = R.unsqueeze(1).repeat(1, self.num_glimpses) # compute losses for differentiable modules loss_action = F.nll_loss(log_probas, y) loss_baseline = F.mse_loss(baselines, R) # compute reinforce loss adjusted_reward = R - baselines.detach() loss_reinforce = torch.sum(-log_pi * adjusted_reward, dim=1) loss_reinforce = torch.mean(loss_reinforce, dim=0) # sum up into a hybrid loss loss = loss_action + loss_baseline + loss_reinforce * 0.01 # compute accuracy correct = (predicted == y).float() acc = 100 * (correct.sum() / len(y)) count = 0 countFP = 0 countFN = 0 countTN = 0 countTP = 0 for i in range(len(correct)): if (correct[i] == 0): count = count + 1 if (predicted[i] == 1 and y[i] == 0): #False Positive countFP = countFP + 1 if (predicted[i] == 0 and y[i] == 1): #False Negative countFN = countFN + 1 if (predicted[i] == 0 and y[i] == 0): #True Negative countTN = countTN + 1 if (predicted[i] == 1 and y[i] == 1): #True Positive countTP = countTP + 1 print("Total: ", len(correct), "Wrong: ", count) print("TP: ", countTP, "TN: ", countTN, "FN: ", countFN, "FP: ", countFP) # store losses.update(loss.item(), x.size()[0]) accs.update(acc.item(), x.size()[0]) # log to tensorboard if self.use_tensorboard: iteration = epoch * len(self.valid_loader) + i log_value("valid_loss", losses.avg, iteration) log_value("valid_acc", accs.avg, iteration) return losses.avg, accs.avg @torch.no_grad() def test(self): """Test the RAM model. This function should only be called at the very end once the model has finished training. """ correct = 0 # load the best checkpoint self.load_checkpoint(best=self.best) for i, (x, y) in enumerate(self.test_loader): x, y = x.to(self.device), y.to(self.device) # duplicate M times x = x.repeat(self.M, 1, 1, 1) # initialize location vector and hidden state self.batch_size = x.shape[0] h_t, l_t = self.reset() # extract the glimpses for t in range(self.num_glimpses - 1): # forward pass through model h_t, l_t, b_t, p = self.model(x, l_t, h_t) # last iteration h_t, l_t, b_t, log_probas, p = self.model(x, l_t, h_t, last=True) log_probas = log_probas.view(self.M, -1, log_probas.shape[-1]) log_probas = torch.mean(log_probas, dim=0) pred = log_probas.data.max(1, keepdim=True)[1] correct += pred.eq(y.data.view_as(pred)).cpu().sum() perc = (100.0 * correct) / (self.num_test) error = 100 - perc print("[*] Test Acc: {}/{} ({:.2f}% - {:.2f}%)".format( correct, self.num_test, perc, error)) def save_checkpoint(self, state, is_best): """Saves a checkpoint of the model. If this model has reached the best validation accuracy thus far, a seperate file with the suffix `best` is created. """ filename = self.model_name + "_ckpt.pth.tar" ckpt_path = os.path.join(self.ckpt_dir, filename) torch.save(state, ckpt_path) if is_best: filename = self.model_name + "_model_best.pth.tar" shutil.copyfile(ckpt_path, os.path.join(self.ckpt_dir, filename)) def load_checkpoint(self, best=False): """Load the best copy of a model. This is useful for 2 cases: - Resuming training with the most recent model checkpoint. - Loading the best validation model to evaluate on the test data. Args: best: if set to True, loads the best model. Use this if you want to evaluate your model on the test data. Else, set to False in which case the most recent version of the checkpoint is used. """ print("[*] Loading model from {}".format(self.ckpt_dir)) filename = self.model_name + "_ckpt.pth.tar" if best: filename = self.model_name + "_model_best.pth.tar" ckpt_path = os.path.join(self.ckpt_dir, filename) ckpt = torch.load(ckpt_path) # load variables from checkpoint self.start_epoch = ckpt["epoch"] self.best_valid_acc = ckpt["best_valid_acc"] self.model.load_state_dict(ckpt["model_state"]) self.optimizer.load_state_dict(ckpt["optim_state"]) if best: print("[*] Loaded {} checkpoint @ epoch {} " "with best valid acc of {:.3f}".format( filename, ckpt["epoch"], ckpt["best_valid_acc"])) else: print("[*] Loaded {} checkpoint @ epoch {}".format( filename, ckpt["epoch"]))
class Trainer: """A Recurrent Attention Model trainer. All hyperparameters are provided by the user in the config file. """ def __init__(self, config, data_loader): """ Construct a new Trainer instance. Args: config: object containing command line arguments. data_loader: A data iterator. """ self.config = config if config.use_gpu and torch.cuda.is_available(): self.device = torch.device("cuda") else: self.device = torch.device("cpu") # glimpse network params self.patch_size = config.patch_size self.glimpse_scale = config.glimpse_scale self.num_patches = config.num_patches self.loc_hidden = config.loc_hidden self.glimpse_hidden = config.glimpse_hidden # core network params self.num_glimpses = config.num_glimpses self.hidden_size = config.hidden_size # reinforce params self.std = config.std self.M = config.M # data params if config.is_train: self.train_loader = data_loader[0] self.valid_loader = data_loader[1] self.num_train = len(self.train_loader.sampler.indices) self.num_valid = len(self.valid_loader.sampler.indices) else: self.test_loader = data_loader self.num_test = len(self.test_loader.dataset) self.num_classes = 10 self.num_channels = 1 # training params self.epochs = config.epochs self.start_epoch = 0 self.momentum = config.momentum self.lr = config.init_lr # misc params self.best = config.best self.ckpt_dir = config.ckpt_dir self.logs_dir = config.logs_dir self.best_valid_acc = 0.0 self.counter = 0 self.lr_patience = config.lr_patience self.train_patience = config.train_patience self.use_tensorboard = config.use_tensorboard self.resume = config.resume self.print_freq = config.print_freq self.plot_freq = config.plot_freq self.model_name = "ram_{}_{}x{}_{}".format( config.num_glimpses, config.patch_size, config.patch_size, config.glimpse_scale, ) self.plot_dir = "./plots/" + self.model_name + "/" if not os.path.exists(self.plot_dir): os.makedirs(self.plot_dir) # configure tensorboard logging if self.use_tensorboard: tensorboard_dir = self.logs_dir + self.model_name print("[*] Saving tensorboard logs to {}".format(tensorboard_dir)) if not os.path.exists(tensorboard_dir): os.makedirs(tensorboard_dir) configure(tensorboard_dir) # build RAM model self.model = RecurrentAttention( self.patch_size, self.num_patches, self.glimpse_scale, self.num_channels, self.loc_hidden, self.glimpse_hidden, self.std, self.hidden_size, self.num_classes, ) self.model.to(self.device) # initialize optimizer and scheduler self.optimizer = torch.optim.Adam( self.model.parameters(), lr=self.config.init_lr ) self.scheduler = ReduceLROnPlateau( self.optimizer, "min", patience=self.lr_patience ) def reset(self): h_t = torch.zeros( self.batch_size, self.hidden_size, dtype=torch.float, device=self.device, requires_grad=True, ) l_t = torch.FloatTensor(self.batch_size, 2).uniform_(-1, 1).to(self.device) l_t.requires_grad = True return h_t, l_t def train(self): """Train the model on the training set. A checkpoint of the model is saved after each epoch and if the validation accuracy is improved upon, a separate ckpt is created for use on the test set. """ # load the most recent checkpoint if self.resume: self.load_checkpoint(best=False) print( "\n[*] Train on {} samples, validate on {} samples".format( self.num_train, self.num_valid ) ) for epoch in range(self.start_epoch, self.epochs): print( "\nEpoch: {}/{} - LR: {:.6f}".format( epoch + 1, self.epochs, self.optimizer.param_groups[0]["lr"] ) ) # train for 1 epoch train_loss, train_acc = self.train_one_epoch(epoch) # evaluate on validation set valid_loss, valid_acc = self.validate(epoch) # # reduce lr if validation loss plateaus self.scheduler.step(-valid_acc) is_best = valid_acc > self.best_valid_acc msg1 = "train loss: {:.3f} - train acc: {:.3f} " msg2 = "- val loss: {:.3f} - val acc: {:.3f} - val err: {:.3f}" if is_best: self.counter = 0 msg2 += " [*]" msg = msg1 + msg2 print( msg.format( train_loss, train_acc, valid_loss, valid_acc, 100 - valid_acc ) ) # check for improvement if not is_best: self.counter += 1 if self.counter > self.train_patience: print("[!] No improvement in a while, stopping training.") return self.best_valid_acc = max(valid_acc, self.best_valid_acc) self.save_checkpoint( { "epoch": epoch + 1, "model_state": self.model.state_dict(), "optim_state": self.optimizer.state_dict(), "best_valid_acc": self.best_valid_acc, }, is_best, ) def train_one_epoch(self, epoch): """ Train the model for 1 epoch of the training set. An epoch corresponds to one full pass through the entire training set in successive mini-batches. This is used by train() and should not be called manually. """ self.model.train() batch_time = AverageMeter() losses = AverageMeter() accs = AverageMeter() tic = time.time() with tqdm(total=self.num_train) as pbar: for i, (x, y) in enumerate(self.train_loader): self.optimizer.zero_grad() x, y = x.to(self.device), y.to(self.device) plot = False if (epoch % self.plot_freq == 0) and (i == 0): plot = True # initialize location vector and hidden state self.batch_size = x.shape[0] h_t, l_t = self.reset() # save images imgs = [] imgs.append(x[0:9]) # extract the glimpses locs = [] log_pi = [] baselines = [] for t in range(self.num_glimpses - 1): # forward pass through model h_t, l_t, b_t, p = self.model(x, l_t, h_t) # store locs.append(l_t[0:9]) baselines.append(b_t) log_pi.append(p) # last iteration h_t, l_t, b_t, log_probas, p = self.model(x, l_t, h_t, last=True) log_pi.append(p) baselines.append(b_t) locs.append(l_t[0:9]) # convert list to tensors and reshape baselines = torch.stack(baselines).transpose(1, 0) log_pi = torch.stack(log_pi).transpose(1, 0) # calculate reward predicted = torch.max(log_probas, 1)[1] R = (predicted.detach() == y).float() R = R.unsqueeze(1).repeat(1, self.num_glimpses) # compute losses for differentiable modules loss_action = F.nll_loss(log_probas, y) loss_baseline = F.mse_loss(baselines, R) # compute reinforce loss # summed over timesteps and averaged across batch adjusted_reward = R - baselines.detach() loss_reinforce = torch.sum(-log_pi * adjusted_reward, dim=1) loss_reinforce = torch.mean(loss_reinforce, dim=0) # sum up into a hybrid loss loss = loss_action + loss_baseline + loss_reinforce * 0.01 # compute accuracy correct = (predicted == y).float() acc = 100 * (correct.sum() / len(y)) # store losses.update(loss.item(), x.size()[0]) accs.update(acc.item(), x.size()[0]) # compute gradients and update SGD loss.backward() self.optimizer.step() # measure elapsed time toc = time.time() batch_time.update(toc - tic) pbar.set_description( ( "{:.1f}s - loss: {:.3f} - acc: {:.3f}".format( (toc - tic), loss.item(), acc.item() ) ) ) pbar.update(self.batch_size) # dump the glimpses and locs if plot: imgs = [g.cpu().data.numpy().squeeze() for g in imgs] locs = [l.cpu().data.numpy() for l in locs] pickle.dump( imgs, open(self.plot_dir + "g_{}.p".format(epoch + 1), "wb") ) pickle.dump( locs, open(self.plot_dir + "l_{}.p".format(epoch + 1), "wb") ) # log to tensorboard if self.use_tensorboard: iteration = epoch * len(self.train_loader) + i log_value("train_loss", losses.avg, iteration) log_value("train_acc", accs.avg, iteration) return losses.avg, accs.avg @torch.no_grad() def validate(self, epoch): """Evaluate the RAM model on the validation set. """ losses = AverageMeter() accs = AverageMeter() for i, (x, y) in enumerate(self.valid_loader): x, y = x.to(self.device), y.to(self.device) # duplicate M times x = x.repeat(self.M, 1, 1, 1) # initialize location vector and hidden state self.batch_size = x.shape[0] h_t, l_t = self.reset() # extract the glimpses log_pi = [] baselines = [] for t in range(self.num_glimpses - 1): # forward pass through model h_t, l_t, b_t, p = self.model(x, l_t, h_t) # store baselines.append(b_t) log_pi.append(p) # last iteration h_t, l_t, b_t, log_probas, p = self.model(x, l_t, h_t, last=True) log_pi.append(p) baselines.append(b_t) # convert list to tensors and reshape baselines = torch.stack(baselines).transpose(1, 0) log_pi = torch.stack(log_pi).transpose(1, 0) # average log_probas = log_probas.view(self.M, -1, log_probas.shape[-1]) log_probas = torch.mean(log_probas, dim=0) baselines = baselines.contiguous().view(self.M, -1, baselines.shape[-1]) baselines = torch.mean(baselines, dim=0) log_pi = log_pi.contiguous().view(self.M, -1, log_pi.shape[-1]) log_pi = torch.mean(log_pi, dim=0) # calculate reward predicted = torch.max(log_probas, 1)[1] R = (predicted.detach() == y).float() R = R.unsqueeze(1).repeat(1, self.num_glimpses) # compute losses for differentiable modules loss_action = F.nll_loss(log_probas, y) loss_baseline = F.mse_loss(baselines, R) # compute reinforce loss adjusted_reward = R - baselines.detach() loss_reinforce = torch.sum(-log_pi * adjusted_reward, dim=1) loss_reinforce = torch.mean(loss_reinforce, dim=0) # sum up into a hybrid loss loss = loss_action + loss_baseline + loss_reinforce * 0.01 # compute accuracy correct = (predicted == y).float() acc = 100 * (correct.sum() / len(y)) # store losses.update(loss.item(), x.size()[0]) accs.update(acc.item(), x.size()[0]) # log to tensorboard if self.use_tensorboard: iteration = epoch * len(self.valid_loader) + i log_value("valid_loss", losses.avg, iteration) log_value("valid_acc", accs.avg, iteration) return losses.avg, accs.avg @torch.no_grad() def test(self): """Test the RAM model. This function should only be called at the very end once the model has finished training. """ correct = 0 # load the best checkpoint self.load_checkpoint(best=self.best) for i, (x, y) in enumerate(self.test_loader): x, y = x.to(self.device), y.to(self.device) # duplicate M times x = x.repeat(self.M, 1, 1, 1) # initialize location vector and hidden state self.batch_size = x.shape[0] h_t, l_t = self.reset() # extract the glimpses for t in range(self.num_glimpses - 1): # forward pass through model h_t, l_t, b_t, p = self.model(x, l_t, h_t) # last iteration h_t, l_t, b_t, log_probas, p = self.model(x, l_t, h_t, last=True) log_probas = log_probas.view(self.M, -1, log_probas.shape[-1]) log_probas = torch.mean(log_probas, dim=0) pred = log_probas.data.max(1, keepdim=True)[1] correct += pred.eq(y.data.view_as(pred)).cpu().sum() perc = (100.0 * correct) / (self.num_test) error = 100 - perc print( "[*] Test Acc: {}/{} ({:.2f}% - {:.2f}%)".format( correct, self.num_test, perc, error ) ) def save_checkpoint(self, state, is_best): """Saves a checkpoint of the model. If this model has reached the best validation accuracy thus far, a seperate file with the suffix `best` is created. """ filename = self.model_name + "_ckpt.pth.tar" ckpt_path = os.path.join(self.ckpt_dir, filename) torch.save(state, ckpt_path) if is_best: filename = self.model_name + "_model_best.pth.tar" shutil.copyfile(ckpt_path, os.path.join(self.ckpt_dir, filename)) def load_checkpoint(self, best=False): """Load the best copy of a model. This is useful for 2 cases: - Resuming training with the most recent model checkpoint. - Loading the best validation model to evaluate on the test data. Args: best: if set to True, loads the best model. Use this if you want to evaluate your model on the test data. Else, set to False in which case the most recent version of the checkpoint is used. """ print("[*] Loading model from {}".format(self.ckpt_dir)) filename = self.model_name + "_ckpt.pth.tar" if best: filename = self.model_name + "_model_best.pth.tar" ckpt_path = os.path.join(self.ckpt_dir, filename) ckpt = torch.load(ckpt_path) # load variables from checkpoint self.start_epoch = ckpt["epoch"] self.best_valid_acc = ckpt["best_valid_acc"] self.model.load_state_dict(ckpt["model_state"]) self.optimizer.load_state_dict(ckpt["optim_state"]) if best: print( "[*] Loaded {} checkpoint @ epoch {} " "with best valid acc of {:.3f}".format( filename, ckpt["epoch"], ckpt["best_valid_acc"] ) ) else: print("[*] Loaded {} checkpoint @ epoch {}".format(filename, ckpt["epoch"]))