def validate_keras_model(serialized_model, weights, dataset_iterator, data_count): logging.info('Keras validation just started.') assert weights != None, "weights must not be 'None'." model = model_from_serialized(serialized_model) model.set_weights(weights) history = model.evaluate_generator(dataset_iterator, steps=data_count) metrics = dict(zip(model.metrics_names, history)) logging.info('Keras validation complete.') return {'val_metric': metrics}
def train_keras_model(serialized_model, weights, dataset_iterator, data_count, hyperparams, config): logging.info('Keras training just started.') assert weights != None, "Initial weights must not be 'None'." model = model_from_serialized(serialized_model) model.set_weights(weights) hist = model.fit_generator(dataset_iterator, epochs=hyperparams['epochs'], \ steps_per_epoch=data_count//hyperparams['batch_size']) # weights_filepath = os.path.join( # os.path.dirname(os.path.realpath(__file__)), # config["weights_directory"], # uuid.uuid4().hex[:8] + ".h5" # ) # ensure_dir(weights_filepath) # model.save_weights(weights_filepath) weights = model.get_weights() logging.info('Keras training complete.') return weights, {'training_history': hist.history}
def _initialize(self, job): """ Initializes and returns a DMLResult with the model weights as specified in the model. """ assert job.framework_type in ['keras'], \ "Model type '{0}' is not supported.".format(job.framework_type) logging.info("Initializing model...") if job.framework_type == 'keras': model = model_from_serialized(job.serialized_model) #model.summary() initial_weights = model.get_weights() results = DMLResult( status='successful', job=job, results={ 'weights': initial_weights, }, error_message="", ) return results