Example #1
0
    def fit(self, X, y):
        with tf.Graph().as_default() as graph:
            tf.set_random_seed(self.tf_random_seed)
            self._global_step = tf.Variable(0, name="global_step", trainable=False)

            # Setting up input and output placeholders.
            input_shape = [None] + list(X.shape[1:])
            self._inp = tf.placeholder(tf.float32, input_shape,
                                       name="input")
            self._out = tf.placeholder(tf.float32, 
                [None] if self.n_classes < 2 else [None, self.n_classes],
                name="output")

            # Create model's graph.
            self._model_predictions, self._model_loss = self.model_fn(self._inp, self._out)

            # Create data feeder, to sample inputs from dataset.
            self._data_feeder = DataFeeder(X, y, self.n_classes, self.batch_size)

            # Create trainer and augment graph with gradients and optimizer.
            self._trainer = TensorFlowTrainer(self._model_loss,
                self._global_step, self.optimizer, self.learning_rate)
            self._session = tf.Session(self.tf_master)

            # Initialize and train model.
            self._trainer.initialize(self._session)
            self._trainer.train(self._session,
                                self._data_feeder.get_feed_dict_fn(self._inp,
                                                             self._out), self.steps)
Example #2
0
class TensorFlowEstimator(BaseEstimator):
    """Base class for all TensorFlow estimators.
  
    Parameters:
        model_fn: Model function, that takes input X, y tensors and outputs
                  prediction and loss tensors.
        n_classes: Number of classes in the target.
        tf_master: TensorFlow master. Empty string is default for local.
        batch_size: Mini batch size.
        steps: Number of steps to run over data.
        optimizer: Optimizer name (or class), for example "SGD", "Adam",
                   "Adagrad".
        learning_rate: Learning rate for optimizer.
        tf_random_seed: Random seed for TensorFlow initializers.
            Setting this value, allows consistency between reruns.
    """

    def __init__(self, model_fn, n_classes, tf_master="", batch_size=32, steps=50, optimizer="SGD",
                 learning_rate=0.1, tf_random_seed=42):
        self.n_classes = n_classes
        self.tf_master = tf_master
        self.batch_size = batch_size
        self.steps = steps
        self.optimizer = optimizer
        self.learning_rate = learning_rate
        self.tf_random_seed = tf_random_seed
        self.model_fn = model_fn

    def fit(self, X, y):
        with tf.Graph().as_default() as graph:
            tf.set_random_seed(self.tf_random_seed)
            self._global_step = tf.Variable(0, name="global_step", trainable=False)

            # Setting up input and output placeholders.
            input_shape = [None] + list(X.shape[1:])
            self._inp = tf.placeholder(tf.float32, input_shape,
                                       name="input")
            self._out = tf.placeholder(tf.float32, 
                [None] if self.n_classes < 2 else [None, self.n_classes],
                name="output")

            # Create model's graph.
            self._model_predictions, self._model_loss = self.model_fn(self._inp, self._out)

            # Create data feeder, to sample inputs from dataset.
            self._data_feeder = DataFeeder(X, y, self.n_classes, self.batch_size)

            # Create trainer and augment graph with gradients and optimizer.
            self._trainer = TensorFlowTrainer(self._model_loss,
                self._global_step, self.optimizer, self.learning_rate)
            self._session = tf.Session(self.tf_master)

            # Initialize and train model.
            self._trainer.initialize(self._session)
            self._trainer.train(self._session,
                                self._data_feeder.get_feed_dict_fn(self._inp,
                                                             self._out), self.steps)

    def predict(self, X):
        pred = self._session.run(self._model_predictions,
                                 feed_dict={
                                     self._inp.name: X
                                 })
        if self.n_classes < 2:
            return pred
        return pred.argmax(axis=1)