def test_estimator_keras_xshards_checkpoint(self):
        import zoo.orca.data.pandas

        tf.reset_default_graph()

        model = self.create_model()
        file_path = os.path.join(self.resource_path, "orca/learn/ncf.csv")
        data_shard = zoo.orca.data.pandas.read_csv(file_path)

        def transform(df):
            result = {
                "x": (df['user'].to_numpy().reshape([-1, 1]),
                      df['item'].to_numpy().reshape([-1, 1])),
                "y":
                df['label'].to_numpy()
            }
            return result

        data_shard = data_shard.transform_shard(transform)

        temp = tempfile.mkdtemp()
        model_dir = os.path.join(temp, "test_model")

        est = Estimator.from_keras(keras_model=model, model_dir=model_dir)
        est.fit(data=data_shard,
                batch_size=8,
                epochs=6,
                validation_data=data_shard,
                checkpoint_trigger=SeveralIteration(4))

        eval_result = est.evaluate(data_shard)
        print(eval_result)

        tf.reset_default_graph()

        model = self.create_model()

        est = Estimator.from_keras(keras_model=model, model_dir=model_dir)
        est.load_latest_checkpoint(model_dir)
        est.fit(data=data_shard,
                batch_size=8,
                epochs=10,
                validation_data=data_shard,
                checkpoint_trigger=SeveralIteration(4))

        eval_result = est.evaluate(data_shard)
        print(eval_result)
        shutil.rmtree(temp)
    def test_estimator_graph_checkpoint(self):
        import zoo.orca.data.pandas
        tf.reset_default_graph()

        model = SimpleModel()
        file_path = os.path.join(resource_path, "orca/learn/ncf.csv")
        data_shard = zoo.orca.data.pandas.read_csv(file_path)

        def transform(df):
            result = {
                "x": (df['user'].to_numpy(), df['item'].to_numpy()),
                "y": df['label'].to_numpy()
            }
            return result

        data_shard = data_shard.transform_shard(transform)

        temp = tempfile.mkdtemp()
        model_dir = os.path.join(temp, "test_model")

        est = Estimator.from_graph(
            inputs=[model.user, model.item],
            labels=[model.label],
            loss=model.loss,
            optimizer=tf.train.AdamOptimizer(),
            metrics={"loss": model.loss},
            model_dir=model_dir
        )
        est.fit(data=data_shard,
                batch_size=8,
                epochs=6,
                validation_data=data_shard,
                checkpoint_trigger=SeveralIteration(4))

        est.sess.close()

        tf.reset_default_graph()

        model = SimpleModel()

        est = Estimator.from_graph(
            inputs=[model.user, model.item],
            labels=[model.label],
            loss=model.loss,
            optimizer=tf.train.AdamOptimizer(),
            metrics={"loss": model.loss},
            model_dir=model_dir
        )

        est.load_latest_checkpoint(model_dir)

        est.fit(data=data_shard,
                batch_size=8,
                epochs=10,
                validation_data=data_shard)

        result = est.evaluate(data_shard)
        assert "loss" in result
        print(result)
        shutil.rmtree(temp)
Пример #3
0
 def __init__(self, interval):
     from bigdl.optim.optimizer import SeveralIteration
     self.trigger = SeveralIteration(interval)