Beispiel #1
0
    def train(self, data_iterator):
        """Train a keras model on a worker
        """
        optimizer = get_optimizer(self.master_optimizer)
        self.model = model_from_yaml(self.yaml, self.custom_objects)
        self.model.compile(optimizer=optimizer,
                           loss=self.master_loss,
                           metrics=self.master_metrics)
        self.model.set_weights(self.parameters.value)

        feature_iterator, label_iterator = tee(data_iterator, 2)
        x_train = np.asarray([x for x, y in feature_iterator])
        y_train = np.asarray([y for x, y in label_iterator])

        self.model.compile(optimizer=get_optimizer(self.master_optimizer),
                           loss=self.master_loss,
                           metrics=self.master_metrics)

        weights_before_training = self.model.get_weights()
        if x_train.shape[0] > self.train_config.get('batch_size'):
            self.model.fit(x_train, y_train, **self.train_config)
        weights_after_training = self.model.get_weights()
        deltas = subtract_params(weights_before_training,
                                 weights_after_training)
        yield deltas
Beispiel #2
0
    def _fit(self, df):
        """Private fit method of the Estimator, which trains the model.
        """
        simple_rdd = df_to_simple_rdd(
            df,
            categorical=self.get_categorical_labels(),
            nb_classes=self.get_nb_classes(),
            features_col=self.getFeaturesCol(),
            label_col=self.getLabelCol())
        simple_rdd = simple_rdd.repartition(self.get_num_workers())
        keras_model = model_from_yaml(self.get_keras_model_config())
        metrics = self.get_metrics()
        loss = self.get_loss()
        optimizer = get_optimizer(self.get_optimizer_config())
        keras_model.compile(loss=loss, optimizer=optimizer, metrics=metrics)

        spark_model = SparkModel(model=keras_model,
                                 mode=self.get_mode(),
                                 frequency=self.get_frequency(),
                                 num_workers=self.get_num_workers())
        spark_model.fit(simple_rdd,
                        epochs=self.get_epochs(),
                        batch_size=self.get_batch_size(),
                        verbose=self.get_verbosity(),
                        validation_split=self.get_validation_split())

        model_weights = spark_model.master_network.get_weights()
        weights = simple_rdd.ctx.broadcast(model_weights)
        return ElephasTransformer(
            labelCol=self.getLabelCol(),
            outputCol='prediction',
            keras_model_config=spark_model.master_network.to_yaml(),
            weights=weights,
            loss=loss)
Beispiel #3
0
    def train(self, data_iterator):
        """Train a keras model on a worker and send asynchronous updates
        to parameter server
        """
        feature_iterator, label_iterator = tee(data_iterator, 2)
        x_train = np.asarray([x for x, y in feature_iterator])
        y_train = np.asarray([y for x, y in label_iterator])

        if x_train.size == 0:
            return

        self.model = model_from_yaml(self.yaml, self.custom_objects)
        self.model.compile(optimizer=get_optimizer(self.master_optimizer),
                           loss=self.master_loss,
                           metrics=self.master_metrics)
        self.model.set_weights(self.parameters.value)

        epochs = self.train_config['epochs']
        batch_size = self.train_config.get('batch_size')
        nb_train_sample = x_train.shape[0]
        nb_batch = int(np.ceil(nb_train_sample / float(batch_size)))
        index_array = np.arange(nb_train_sample)
        batches = [(i * batch_size, min(nb_train_sample, (i + 1) * batch_size))
                   for i in range(0, nb_batch)]

        if self.frequency == 'epoch':
            for epoch in range(epochs):
                weights_before_training = self.client.get_parameters()
                self.model.set_weights(weights_before_training)
                self.train_config['epochs'] = 1
                if x_train.shape[0] > batch_size:
                    self.model.fit(x_train, y_train, **self.train_config)
                self.train_config['epochs'] = epochs
                weights_after_training = self.model.get_weights()
                deltas = subtract_params(weights_before_training,
                                         weights_after_training)
                self.client.update_parameters(deltas)
        elif self.frequency == 'batch':
            for epoch in range(epochs):
                if x_train.shape[0] > batch_size:
                    for (batch_start, batch_end) in batches:
                        weights_before_training = self.client.get_parameters()
                        self.model.set_weights(weights_before_training)
                        batch_ids = index_array[batch_start:batch_end]
                        x = slice_arrays(x_train, batch_ids)
                        y = slice_arrays(y_train, batch_ids)
                        self.model.train_on_batch(x, y)
                        weights_after_training = self.model.get_weights()
                        deltas = subtract_params(weights_before_training,
                                                 weights_after_training)
                        self.client.update_parameters(deltas)
        else:
            raise ValueError(
                'frequency parameter can be `epoch` or `batch, got {}'.format(
                    self.frequency))
        yield []
Beispiel #4
0
    def _fit(self, rdd: RDD, **kwargs):
        """Protected train method to make wrapping of modes easier
        """
        self._master_network.compile(optimizer=get_optimizer(
            self.master_optimizer),
                                     loss=self.master_loss,
                                     metrics=self.master_metrics)
        if self.mode in ['asynchronous', 'hogwild']:
            self.start_server()
        train_config = kwargs
        freq = self.frequency
        optimizer = deserialize_optimizer(self.master_optimizer)
        loss = self.master_loss
        metrics = self.master_metrics
        custom = self.custom_objects

        yaml = self._master_network.to_yaml()
        init = self._master_network.get_weights()
        parameters = rdd.context.broadcast(init)

        if self.mode in ['asynchronous', 'hogwild']:
            print('>>> Initialize workers')
            worker = AsynchronousSparkWorker(yaml, parameters, self.client,
                                             train_config, freq, optimizer,
                                             loss, metrics, custom)
            print('>>> Distribute load')
            rdd.mapPartitions(worker.train).collect()
            print('>>> Async training complete.')
            new_parameters = self.client.get_parameters()
        elif self.mode == 'synchronous':
            worker = SparkWorker(yaml, parameters, train_config, optimizer,
                                 loss, metrics, custom)
            training_outcomes = rdd.mapPartitions(worker.train).collect()
            new_parameters = self._master_network.get_weights()
            number_of_sub_models = len(training_outcomes)
            for training_outcome in training_outcomes:
                grad, history = training_outcome
                self.training_histories.append(history)
                weighted_grad = divide_by(grad, number_of_sub_models)
                new_parameters = subtract_params(new_parameters, weighted_grad)
            print('>>> Synchronous training complete.')
        else:
            raise ValueError("Unsupported mode {}".format(self.mode))
        self._master_network.set_weights(new_parameters)
        if self.mode in ['asynchronous', 'hogwild']:
            self.stop_server()
    def _fit(self, rdd, **kwargs):
        """Protected train method to make wrapping of modes easier
        """
        self._master_network.compile(optimizer=get_optimizer(self.master_optimizer),
                                     loss=self.master_loss,
                                     metrics=self.master_metrics)
        if self.mode in ['asynchronous', 'hogwild']:
            self.start_server()
        train_config = kwargs
        freq = self.frequency
        optimizer = self.master_optimizer
        loss = self.master_loss
        metrics = self.master_metrics
        custom = self.custom_objects

        yaml = self._master_network.to_yaml()
        init = self._master_network.get_weights()
        parameters = rdd.context.broadcast(init)

        if self.mode in ['asynchronous', 'hogwild']:
            print('>>> Initialize workers')
            worker = AsynchronousSparkWorker(
                yaml, parameters, self.client, train_config, freq, optimizer, loss, metrics, custom)
            print('>>> Distribute load')
            rdd.mapPartitions(worker.train).collect()
            print('>>> Async training complete.')
            new_parameters = self.client.get_parameters()
        elif self.mode == 'synchronous':
            worker = SparkWorker(yaml, parameters, train_config,
                                 optimizer, loss, metrics, custom)
            gradients = rdd.mapPartitions(worker.train).collect()
            new_parameters = self._master_network.get_weights()
            for grad in gradients:  # simply accumulate gradients one by one
                new_parameters = subtract_params(new_parameters, grad)
            print('>>> Synchronous training complete.')
        else:
            raise ValueError("Unsupported mode {}".format(self.mode))
        self._master_network.set_weights(new_parameters)
        if self.mode in ['asynchronous', 'hogwild']:
            self.stop_server()
Beispiel #6
0
    def _fit(self, df: DataFrame):
        """Private fit method of the Estimator, which trains the model.
        """
        simple_rdd = df_to_simple_rdd(
            df,
            categorical=self.get_categorical_labels(),
            nb_classes=self.get_nb_classes(),
            features_col=self.getFeaturesCol(),
            label_col=self.getLabelCol())
        simple_rdd = simple_rdd.repartition(self.get_num_workers())
        keras_model = model_from_yaml(self.get_keras_model_config(),
                                      self.get_custom_objects())
        metrics = self.get_metrics()
        loss = self.get_loss()
        optimizer = get_optimizer(self.get_optimizer_config())
        keras_model.compile(loss=loss, optimizer=optimizer, metrics=metrics)

        spark_model = SparkModel(model=keras_model,
                                 mode=self.get_mode(),
                                 frequency=self.get_frequency(),
                                 num_workers=self.get_num_workers(),
                                 custom_objects=self.get_custom_objects())
        spark_model.fit(simple_rdd,
                        epochs=self.get_epochs(),
                        batch_size=self.get_batch_size(),
                        verbose=self.get_verbosity(),
                        validation_split=self.get_validation_split())

        model_weights = spark_model.master_network.get_weights()
        return ElephasTransformer(
            labelCol=self.getLabelCol(),
            outputCol=self.getOutputCol(),
            featuresCol=self.getFeaturesCol(),
            keras_model_config=spark_model.master_network.to_yaml(),
            weights=model_weights,
            custom_objects=self.get_custom_objects(),
            model_type=LossModelTypeMapper().get_model_type(loss),
            history=spark_model.training_histories)