def test_tf_model(self): def update(): with tf.GradientTape(persistent=True) as tape: self.state['tape'] = tape pred = fe.backend.feed_forward(self.tf_model, self.tf_input_data) loss = fe.backend.mean_squared_error(y_pred=pred, y_true=self.tf_y) op.forward(data=loss, state=self.state) op = UpdateOp(model=self.tf_model, loss_name='loss') op.build("tf") strategy = tf.distribute.get_strategy() if isinstance(strategy, tf.distribute.MirroredStrategy): strategy.run(update, args=()) else: update() bms = BestModelSaver(model=self.tf_model, save_dir=self.save_dir) bms.on_epoch_end(data=self.data) m2 = fe.build(model_fn=one_layer_model_without_weights, optimizer_fn='adam') fe.backend.load_model(m2, os.path.join(self.save_dir, 'tf_best_loss.h5')) self.assertTrue( is_equal(m2.trainable_variables, self.tf_model.trainable_variables))
def test_torch_model(self): op = UpdateOp(model=self.torch_model, loss_name='loss') pred = fe.backend.feed_forward(self.torch_model, self.torch_input_data) loss = fe.backend.mean_squared_error(y_pred=pred, y_true=self.torch_y) output = op.forward(data=loss, state=self.state) bms = BestModelSaver(model=self.torch_model, save_dir=self.save_dir) bms.on_epoch_end(data=self.data) m2 = fe.build(model_fn=MultiLayerTorchModelWithoutWeights, optimizer_fn='adam') fe.backend.load_model( m2, os.path.join(self.save_dir, 'torch_best_loss.pt')) self.assertTrue( is_equal(list(m2.parameters()), list(self.torch_model.parameters())))
def test_tf_model(self): op = UpdateOp(model=self.tf_model, loss_name='loss') with tf.GradientTape(persistent=True) as tape: self.state['tape'] = tape pred = fe.backend.feed_forward(self.tf_model, self.tf_input_data) loss = fe.backend.mean_squared_error(y_pred=pred, y_true=self.tf_y) output = op.forward(data=loss, state=self.state) bms = BestModelSaver(model=self.tf_model, save_dir=self.save_dir) bms.on_epoch_end(data=self.data) m2 = fe.build(model_fn=one_layer_model_without_weights, optimizer_fn='adam') fe.backend.load_model(m2, os.path.join(self.save_dir, 'tf_best_loss.h5')) self.assertTrue( is_equal(m2.trainable_variables, self.tf_model.trainable_variables))