コード例 #1
0
 def testAggregateH5FormatSaveLoad(self):
     model_config = configs.AggregateFunctionConfig(
         feature_configs=feature_configs,
         regularizer_configs=[
             configs.RegularizerConfig('calib_hessian', l2=1e-4),
             configs.RegularizerConfig('torsion', l2=1e-3),
         ],
         middle_calibration=True,
         middle_monotonicity='increasing',
         output_min=0.0,
         output_max=1.0,
         output_calibration=True,
         output_calibration_num_keypoints=8,
         output_initialization=[0.0, 1.0])
     model = premade.AggregateFunction(model_config)
     # Compile and fit model.
     model.compile(loss='mse', optimizer=tf.keras.optimizers.Adam(0.1))
     model.fit(fake_data['train_xs'], fake_data['train_ys'])
     # Save model using H5 format.
     with tempfile.NamedTemporaryFile(suffix='.h5') as f:
         # Note: because of naming clashes in the optimizer, we cannot include it
         # when saving in HDF5. The keras team has informed us that we should not
         # push to support this since SavedModel format is the new default and no
         # new HDF5 functionality is desired.
         tf.keras.models.save_model(model, f.name, include_optimizer=False)
         loaded_model = tf.keras.models.load_model(
             f.name, custom_objects=premade.get_custom_objects())
         self.assertAllClose(model.predict(fake_data['eval_xs']),
                             loaded_model.predict(fake_data['eval_xs']))
コード例 #2
0
 def testAggregateFromConfig(self):
   model_config = configs.AggregateFunctionConfig(
       feature_configs=feature_configs,
       regularizer_configs=[
           configs.RegularizerConfig('calib_hessian', l2=1e-4),
           configs.RegularizerConfig('torsion', l2=1e-3),
       ],
       middle_calibration=True,
       middle_monotonicity='increasing',
       output_min=0.0,
       output_max=1.0,
       output_calibration=True,
       output_calibration_num_keypoints=8,
       output_initialization=[0.0, 1.0])
   model = premade.AggregateFunction(model_config)
   loaded_model = premade.AggregateFunction.from_config(model.get_config())
   self.assertEqual(
       json.dumps(model.get_config(), sort_keys=True, cls=self.Encoder),
       json.dumps(loaded_model.get_config(), sort_keys=True, cls=self.Encoder))