def runHardThrsd(self, sess): ''' Function to run the IHT routine on Bonsai Obj ''' currW = self.bonsaiObj.W.eval() currV = self.bonsaiObj.V.eval() currT = self.bonsaiObj.T.eval() self.__thrsdW = utils.hardThreshold(currW, self.sW) self.__thrsdV = utils.hardThreshold(currV, self.sV) self.__thrsdT = utils.hardThreshold(currT, self.sT) fd_thrsd = {self.__Wth: self.__thrsdW, self.__Vth: self.__thrsdV, self.__Tth: self.__thrsdT} sess.run(self.hardThresholdGroup, feed_dict=fd_thrsd)
def runHardThrsd(self, sess): ''' Function to run the IHT routine on FastObj ''' self.thrsdParams = [] for i in range(0, self.numMatrices[0]): self.thrsdParams.append( utils.hardThreshold(self.FastParams[i].eval(), self.sW)) for i in range(self.numMatrices[0], self.totalMatrices): self.thrsdParams.append( utils.hardThreshold(self.FastParams[i].eval(), self.sU)) fd_thrsd = {} for i in range(0, self.totalMatrices): fd_thrsd[self.paramPlaceholders[i]] = self.thrsdParams[i] sess.run(self.hardThresholdGroup, feed_dict=fd_thrsd)
def train(self, batchSize, totalEpochs, sess, x_train, x_val, y_train, y_val, noInit=False, redirFile=None, printStep=10, valStep=3): ''' Performs dense training of ProtoNN followed by iterative hard thresholding to enforce sparsity constraints. batchSize: Batch size per update totalEpochs: The number of epochs to run training for. One epoch is defined as one pass over the entire training data. sess: The Tensorflow session to use for running various graph operators. x_train, x_val, y_train, y_val: The numpy array containing train and validation data. x data is assumed to in of shape [-1, featureDimension] while y should have shape [-1, numberLabels]. noInit: By default, all the tensors of the computation graph are initialized at the start of the training session. Set noInit=False to disable this behaviour. printStep: Number of batches between echoing of loss and train accuracy. valStep: Number of epochs between evolutions on validation set. ''' d, d_cap, m, L, gamma = self.protoNNObj.getHyperParams() assert batchSize >= 1, 'Batch size should be positive integer' assert totalEpochs >= 1, 'Total epochs should be positive integer' assert x_train.ndim == 2, 'Expected training data to be of rank 2' assert x_train.shape[1] == d, 'Expected x_train to be [-1, %d]' % d assert x_val.ndim == 2, 'Expected validation data to be of rank 2' assert x_val.shape[1] == d, 'Expected x_val to be [-1, %d]' % d assert y_train.ndim == 2, 'Expected training labels to be of rank 2' assert y_train.shape[1] == L, 'Expected y_train to be [-1, %d]' % L assert y_val.ndim == 2, 'Expected validation labels to be of rank 2' assert y_val.shape[1] == L, 'Expected y_val to be [-1, %d]' % L # Numpy will throw asserts for arrays if sess is None: raise ValueError('sess must be valid Tensorflow session.') trainNumBatches = int(np.ceil(len(x_train) / batchSize)) valNumBatches = int(np.ceil(len(x_val) / batchSize)) x_train_batches = np.array_split(x_train, trainNumBatches) y_train_batches = np.array_split(y_train, trainNumBatches) x_val_batches = np.array_split(x_val, valNumBatches) y_val_batches = np.array_split(y_val, valNumBatches) if not noInit: sess.run(tf.global_variables_initializer()) X, Y = self.X, self.Y W, B, Z, _ = self.protoNNObj.getModelMatrices() for epoch in range(totalEpochs): for i in range(len(x_train_batches)): batch_x = x_train_batches[i] batch_y = y_train_batches[i] feed_dict = {X: batch_x, Y: batch_y} sess.run(self.trainStep, feed_dict=feed_dict) if i % printStep == 0: loss, acc = sess.run([self.loss, self.accuracy], feed_dict=feed_dict) msg = "Epoch: %3d Batch: %3d" % (epoch, i) msg += " Loss: %3.5f Accuracy: %2.5f" % (loss, acc) print(msg, file=redirFile) # Perform Hard thresholding if self.sparseTraining: W_, B_, Z_ = sess.run([W, B, Z]) fd_thrsd = { self.W_th: utils.hardThreshold(W_, self.__sW), self.B_th: utils.hardThreshold(B_, self.__sB), self.Z_th: utils.hardThreshold(Z_, self.__sZ) } sess.run(self.__hthOp, feed_dict=fd_thrsd) if (epoch + 1) % valStep == 0: acc = 0.0 loss = 0.0 for j in range(len(x_val_batches)): batch_x = x_val_batches[j] batch_y = y_val_batches[j] feed_dict = {X: batch_x, Y: batch_y} acc_, loss_ = sess.run([self.accuracy, self.loss], feed_dict=feed_dict) acc += acc_ loss += loss_ acc /= len(y_val_batches) loss /= len(y_val_batches) print("Test Loss: %2.5f Accuracy: %2.5f" % (loss, acc))