Esempio n. 1
0
    def runHardThrsd(self, sess):
        '''
        Function to run the IHT routine on Bonsai Obj
        '''
        currW = self.bonsaiObj.W.eval()
        currV = self.bonsaiObj.V.eval()
        currT = self.bonsaiObj.T.eval()

        self.__thrsdW = utils.hardThreshold(currW, self.sW)
        self.__thrsdV = utils.hardThreshold(currV, self.sV)
        self.__thrsdT = utils.hardThreshold(currT, self.sT)

        fd_thrsd = {self.__Wth: self.__thrsdW, self.__Vth: self.__thrsdV,
                    self.__Tth: self.__thrsdT}
        sess.run(self.hardThresholdGroup, feed_dict=fd_thrsd)
Esempio n. 2
0
    def runHardThrsd(self, sess):
        '''
        Function to run the IHT routine on FastObj
        '''
        self.thrsdParams = []
        for i in range(0, self.numMatrices[0]):
            self.thrsdParams.append(
                utils.hardThreshold(self.FastParams[i].eval(), self.sW))
        for i in range(self.numMatrices[0], self.totalMatrices):
            self.thrsdParams.append(
                utils.hardThreshold(self.FastParams[i].eval(), self.sU))

        fd_thrsd = {}
        for i in range(0, self.totalMatrices):
            fd_thrsd[self.paramPlaceholders[i]] = self.thrsdParams[i]
        sess.run(self.hardThresholdGroup, feed_dict=fd_thrsd)
Esempio n. 3
0
    def train(self,
              batchSize,
              totalEpochs,
              sess,
              x_train,
              x_val,
              y_train,
              y_val,
              noInit=False,
              redirFile=None,
              printStep=10,
              valStep=3):
        '''
        Performs dense training of ProtoNN followed by iterative hard
        thresholding to enforce sparsity constraints.

        batchSize: Batch size per update
        totalEpochs: The number of epochs to run training for. One epoch is
            defined as one pass over the entire training data.
        sess: The Tensorflow session to use for running various graph
            operators.
        x_train, x_val, y_train, y_val: The numpy array containing train and
            validation data. x data is assumed to in of shape [-1,
            featureDimension] while y should have shape [-1, numberLabels].
        noInit: By default, all the tensors of the computation graph are
        initialized at the start of the training session. Set noInit=False to
        disable this behaviour.
        printStep: Number of batches between echoing of loss and train accuracy.
        valStep: Number of epochs between evolutions on validation set.
        '''
        d, d_cap, m, L, gamma = self.protoNNObj.getHyperParams()
        assert batchSize >= 1, 'Batch size should be positive integer'
        assert totalEpochs >= 1, 'Total epochs should be positive integer'
        assert x_train.ndim == 2, 'Expected training data to be of rank 2'
        assert x_train.shape[1] == d, 'Expected x_train to be [-1, %d]' % d
        assert x_val.ndim == 2, 'Expected validation data to be of rank 2'
        assert x_val.shape[1] == d, 'Expected x_val to be [-1, %d]' % d
        assert y_train.ndim == 2, 'Expected training labels to be of rank 2'
        assert y_train.shape[1] == L, 'Expected y_train to be [-1, %d]' % L
        assert y_val.ndim == 2, 'Expected validation labels to be of rank 2'
        assert y_val.shape[1] == L, 'Expected y_val to be [-1, %d]' % L

        # Numpy will throw asserts for arrays
        if sess is None:
            raise ValueError('sess must be valid Tensorflow session.')

        trainNumBatches = int(np.ceil(len(x_train) / batchSize))
        valNumBatches = int(np.ceil(len(x_val) / batchSize))
        x_train_batches = np.array_split(x_train, trainNumBatches)
        y_train_batches = np.array_split(y_train, trainNumBatches)
        x_val_batches = np.array_split(x_val, valNumBatches)
        y_val_batches = np.array_split(y_val, valNumBatches)
        if not noInit:
            sess.run(tf.global_variables_initializer())
        X, Y = self.X, self.Y
        W, B, Z, _ = self.protoNNObj.getModelMatrices()
        for epoch in range(totalEpochs):
            for i in range(len(x_train_batches)):
                batch_x = x_train_batches[i]
                batch_y = y_train_batches[i]
                feed_dict = {X: batch_x, Y: batch_y}
                sess.run(self.trainStep, feed_dict=feed_dict)
                if i % printStep == 0:
                    loss, acc = sess.run([self.loss, self.accuracy],
                                         feed_dict=feed_dict)
                    msg = "Epoch: %3d Batch: %3d" % (epoch, i)
                    msg += " Loss: %3.5f Accuracy: %2.5f" % (loss, acc)
                    print(msg, file=redirFile)

            # Perform Hard thresholding
            if self.sparseTraining:
                W_, B_, Z_ = sess.run([W, B, Z])
                fd_thrsd = {
                    self.W_th: utils.hardThreshold(W_, self.__sW),
                    self.B_th: utils.hardThreshold(B_, self.__sB),
                    self.Z_th: utils.hardThreshold(Z_, self.__sZ)
                }
                sess.run(self.__hthOp, feed_dict=fd_thrsd)

            if (epoch + 1) % valStep == 0:
                acc = 0.0
                loss = 0.0
                for j in range(len(x_val_batches)):
                    batch_x = x_val_batches[j]
                    batch_y = y_val_batches[j]
                    feed_dict = {X: batch_x, Y: batch_y}
                    acc_, loss_ = sess.run([self.accuracy, self.loss],
                                           feed_dict=feed_dict)
                    acc += acc_
                    loss += loss_
                acc /= len(y_val_batches)
                loss /= len(y_val_batches)
                print("Test Loss: %2.5f Accuracy: %2.5f" % (loss, acc))