def predict(self, X):

        with tf.Session(graph=self.model.graph) as session:
            saver = tf.train.Saver()
            if (self.model.load(session, saver)):
                num_epochs_trained = self.model.model_graph.cur_epoch_tensor.eval(
                    session)
                print('EPOCHS trained: ', num_epochs_trained)
            else:
                return

            #if len(X.shape)<4:
            #X = np.array([X])

            X_ = prepare_dataset(X)
            y_l = list()

            start = 0
            end = self.model.batch_size

            while end < X.shape[0]:
                x = X_[start:end]
                print('from {} to {}'.format(start, end))

                y_pred = self.model.model_graph.predict(session, x)

                y_l.append(np.array(y_pred[0]))

                start = end
                end += self.model.batch_size

            else:

                x = X_[start:]
                xsize = len(x)

                print('from {} to {}'.format(start, len(X_)))

                p = np.zeros([self.model.batch_size - xsize] +
                             list(x.shape[1:]))

                y_pred = self.model.model_graph.predict(
                    session, np.concatenate((x, p), axis=0))

                y_l.append(np.array(y_pred[0][0:xsize]))

        return np.vstack(y_l)
示例#2
0
 def reconst_loss(self, inputs):
     '''  ------------------------------------------------------------------------------
                                      DATA PROCESSING
     ------------------------------------------------------------------------------ '''
     inputs = utils.prepare_dataset(inputs)
     return self.model.reconst_loss(inputs)
示例#3
0
 def interpolate(self, input1, input2):
     input1 = utils.prepare_dataset(input1)
     input2 = utils.prepare_dataset(input2)
     return self.model.interpolate(input1, input2)
示例#4
0
 def encode(self, inputs):
     '''  ------------------------------------------------------------------------------
                                      DATA PROCESSING
     ------------------------------------------------------------------------------ '''
     inputs = utils.prepare_dataset(inputs)
     return self.model.encode(inputs)