def prepare_model(self): """Prepares the model for training.""" # Set the Keras directory. set_keras_base_directory() # Deserialize the Keras model. self.model = deserialize_keras_model(self.model) # Compile the model with the specified loss and optimizer. self.model.compile(loss=self.loss, optimizer=self.optimizer, metrics=self.metrics)
def __init__(self, keras_model, loss, worker_optimizer): set_keras_base_directory() self.master_model = serialize_keras_model(keras_model) self.loss = loss self.worker_optimizer = worker_optimizer self.history = [] self.training_time_start = 0 self.training_time_end = 0 self.training_time = 0
def __init__(self, keras_model, loss, worker_optimizer, metrics=["accuracy"]): set_keras_base_directory() self.master_model = serialize_keras_model(keras_model) self.loss = loss self.worker_optimizer = worker_optimizer self.metrics = metrics self.history = [] self.training_time_start = 0 self.training_time_end = 0 self.training_time = 0 self.max_mini_batches_prefetch = 100
def __init__(self, keras_model, loss, worker_optimizer, metrics=["accuracy"], loss_weights=None): set_keras_base_directory() self.master_model = serialize_keras_model(keras_model) self.loss = loss self.loss_weights = loss_weights self.worker_optimizer = worker_optimizer self.metrics = metrics self.history = [] self.training_time_start = 0 self.training_time_end = 0 self.training_time = 0 self.max_mini_batches_prefetch = 100
def prepare_model(self): """Prepares the model for training.""" # Set the Keras directory. set_keras_base_directory() if K.backend() == 'tensorflow': # set GPU option allow_growth to False for GPU-enabled tensorflow config = tf.ConfigProto() config.gpu_options.allow_growth = False sess = tf.Session(config=config) K.set_session(sess) # Deserialize the Keras model. self.model = deserialize_keras_model(self.model) self.optimizer = deserialize(self.optimizer) # Compile the model with the specified loss and optimizer. self.model.compile(loss=self.loss, loss_weights = self.loss_weights, optimizer=self.optimizer, metrics=self.metrics)