Ejemplo n.º 1
0
 def _v1_single_metagraph_saved_model(self, use_resource):
   export_graph = ops.Graph()
   with export_graph.as_default():
     start = array_ops.placeholder(
         shape=[None], dtype=dtypes.float32, name="start")
     if use_resource:
       distractor = variables.RefVariable(-1., name="distractor")
       v = resource_variable_ops.ResourceVariable(3., name="v")
     else:
       # "distractor" gets saved in the checkpoint and so used in the restore
       # function, but not in the pruned function for the signature. This tests
       # node naming: it needs to be consistent (and ideally always the same as
       # the node in the original GraphDef) for the resource manager to find
       # the right variable.
       distractor = variables.RefVariable(-1., name="distractor")
       v = variables.RefVariable(3., name="v")
     local_variable = variables.VariableV1(
         1.,
         collections=[ops.GraphKeys.LOCAL_VARIABLES],
         trainable=False,
         use_resource=True)
     output = array_ops.identity(start * v * local_variable, name="output")
     with session_lib.Session() as session:
       session.run([v.initializer, distractor.initializer,
                    local_variable.initializer])
       path = os.path.join(self.get_temp_dir(), "saved_model", str(ops.uid()))
       simple_save.simple_save(
           session,
           path,
           inputs={"start": start},
           outputs={"output": output},
           legacy_init_op=local_variable.initializer)
   return path
Ejemplo n.º 2
0
  def _v1_nested_while_saved_model(self):
    export_graph = ops.Graph()
    with export_graph.as_default():

      def _inner_while(loop_iterations):
        _, output = control_flow_ops.while_loop(
            lambda index, accum: index <= loop_iterations,
            lambda index, accum: (index + 1, accum + index),
            [constant_op.constant(0), constant_op.constant(0)])
        return output

      loop_iterations = array_ops.placeholder(
          name="loop_iterations", shape=[], dtype=dtypes.int32)
      _, output = control_flow_ops.while_loop(
          lambda index, accum: index <= loop_iterations,
          lambda index, accum: (index + 1, accum + _inner_while(index)),
          [constant_op.constant(0), constant_op.constant(0)])
      with session_lib.Session() as session:
        path = os.path.join(self.get_temp_dir(), "saved_model", str(ops.uid()))
        simple_save.simple_save(
            session,
            path,
            inputs={"loop_iterations": loop_iterations},
            outputs={"output": output})
    return path
Ejemplo n.º 3
0
    def _model_with_defun(self):
        """Generate a graph with a Defun and serialize in V1 format."""
        export_graph = ops.Graph()
        with export_graph.as_default():

            @framework_function.Defun(dtypes.int64)
            def z(x):
                return x + 1

            @framework_function.Defun(dtypes.int64)
            def g(x):
                return z(x) + 1

            @framework_function.Defun(dtypes.int64)
            def f(x):
                return g(x) + 1

            in_placeholder = array_ops.placeholder(dtype=dtypes.int64,
                                                   shape=[1])
            out = f(in_placeholder)
            with session_lib.Session() as session:
                path = os.path.join(self.get_temp_dir(), "saved_model",
                                    str(ops.uid()))
                simple_save.simple_save(session,
                                        path,
                                        inputs={"start": in_placeholder},
                                        outputs={"output": out})
        return path
Ejemplo n.º 4
0
 def _v1_single_metagraph_saved_model(self, use_resource):
   export_graph = ops.Graph()
   with export_graph.as_default():
     start = array_ops.placeholder(
         shape=None, dtype=dtypes.float32, name="start")
     if use_resource:
       distractor = variables.RefVariable(-1., name="distractor")
       v = resource_variable_ops.ResourceVariable(3., name="v")
     else:
       # "distractor" gets saved in the checkpoint and so used in the restore
       # function, but not in the pruned function for the signature. This tests
       # node naming: it needs to be consistent (and ideally always the same as
       # the node in the original GraphDef) for the resource manager to find
       # the right variable.
       distractor = variables.RefVariable(-1., name="distractor")
       v = variables.RefVariable(3., name="v")
     local_variable = variables.VariableV1(
         1.,
         collections=[ops.GraphKeys.LOCAL_VARIABLES],
         trainable=False,
         use_resource=True)
     output = array_ops.identity(start * v * local_variable, name="output")
     with session_lib.Session() as session:
       session.run([v.initializer, distractor.initializer,
                    local_variable.initializer])
       path = os.path.join(self.get_temp_dir(), "saved_model", str(ops.uid()))
       simple_save.simple_save(
           session,
           path,
           inputs={"start": start},
           outputs={"output": output},
           legacy_init_op=local_variable.initializer)
   return path
Ejemplo n.º 5
0
    def test_resave_signature(self):
        # Tests that signatures saved using TF1 can be resaved with TF2.
        # See b/211666001 for context.
        export_graph = ops.Graph()
        with export_graph.as_default():
            a = array_ops.placeholder(shape=[None, 1],
                                      dtype=dtypes.float32,
                                      name="input_2")
            b = array_ops.placeholder(shape=[None, 2],
                                      dtype=dtypes.float32,
                                      name="input_1")
            c = array_ops.identity(a)
            with session_lib.Session() as session:
                path = os.path.join(self.get_temp_dir(), "saved_model",
                                    str(ops.uid()))
                simple_save.simple_save(session,
                                        path,
                                        inputs={
                                            "a": a,
                                            "b": b
                                        },
                                        outputs={"c": c})
        imported = load.load(path)
        path2 = os.path.join(self.get_temp_dir(), "saved_model",
                             str(ops.uid()))
        save.save(imported, path2, imported.signatures)

        imported2 = load.load(path2)
        self.assertEqual(
            imported2.signatures["serving_default"](
                a=constant_op.constant([5.]),
                b=constant_op.constant([1., 3.]))["c"].numpy(), 5.)
Ejemplo n.º 6
0
    def test_load_and_restore_partitioned_variables(self):
        export_graph = ops.Graph()
        with export_graph.as_default():
            partitioned_var = variable_scope.get_variable(
                "a",
                shape=[6],
                initializer=init_ops.constant_initializer(13),
                partitioner=partitioned_variables.fixed_size_partitioner(2),
                use_resource=True)
            x = array_ops.placeholder(shape=[], dtype=dtypes.float32)
            y = x * partitioned_var
            with session_lib.Session() as session:
                session.run(variables.global_variables_initializer())
                path = os.path.join(self.get_temp_dir(), "saved_model",
                                    str(ops.uid()))
                simple_save.simple_save(session,
                                        path,
                                        inputs={"x": x},
                                        outputs={"y": y})

                # Create a name-based checkpoint with different values.
                session.run(partitioned_var.assign([[5, 4, 3], [2, 1, 0]]))
                ckpt_path = os.path.join(self.get_temp_dir(), "restore_ckpt")
                saver.Saver().save(session, ckpt_path)

        imported = load.load(path)
        self.assertAllClose(self.evaluate(imported.variables),
                            [[13, 13, 13], [13, 13, 13]])

        self.evaluate(imported.restore(ckpt_path))
        self.assertAllClose(self.evaluate(imported.variables),
                            [[5, 4, 3], [2, 1, 0]])
        self.assertAllClose(
            self.evaluate(imported.signatures["serving_default"](
                constant_op.constant(2.))), {"y": [10, 8, 6, 4, 2, 0]})
Ejemplo n.º 7
0
 def _v1_asset_saved_model(self):
     export_graph = ops.Graph()
     vocab_path = os.path.join(self.get_temp_dir(), "vocab.txt")
     with open(vocab_path, "w") as f:
         f.write("alpha\nbeta\ngamma\n")
     with export_graph.as_default():
         initializer = lookup_ops.TextFileInitializer(
             vocab_path,
             key_dtype=dtypes.string,
             key_index=lookup_ops.TextFileIndex.WHOLE_LINE,
             value_dtype=dtypes.int64,
             value_index=lookup_ops.TextFileIndex.LINE_NUMBER)
         table = lookup_ops.HashTable(initializer, default_value=-1)
         start = array_ops.placeholder(shape=None,
                                       dtype=dtypes.string,
                                       name="in")
         output = table.lookup(start, name="out")
         with session_lib.Session() as session:
             session.run([table.initializer])
             path = os.path.join(self.get_temp_dir(), "saved_model",
                                 str(ops.uid()))
             simple_save.simple_save(session,
                                     path,
                                     inputs={"start": start},
                                     outputs={"output": output},
                                     legacy_init_op=table.initializer)
     file_io.delete_file(vocab_path)
     return path
Ejemplo n.º 8
0
    def _v1_nested_while_saved_model(self):
        export_graph = ops.Graph()
        with export_graph.as_default():

            def _inner_while(loop_iterations):
                _, output = control_flow_ops.while_loop(
                    lambda index, accum: index <= loop_iterations,
                    lambda index, accum: (index + 1, accum + index),
                    [constant_op.constant(0),
                     constant_op.constant(0)])
                return output

            loop_iterations = array_ops.placeholder(name="loop_iterations",
                                                    shape=[],
                                                    dtype=dtypes.int32)
            _, output = control_flow_ops.while_loop(
                lambda index, accum: index <= loop_iterations,
                lambda index, accum: (index + 1, accum + _inner_while(index)),
                [constant_op.constant(0),
                 constant_op.constant(0)])
            with session_lib.Session() as session:
                path = os.path.join(self.get_temp_dir(), "saved_model",
                                    str(ops.uid()))
                simple_save.simple_save(
                    session,
                    path,
                    inputs={"loop_iterations": loop_iterations},
                    outputs={"output": output})
        return path
Ejemplo n.º 9
0
 def _v1_asset_saved_model(self):
   export_graph = ops.Graph()
   vocab_path = os.path.join(self.get_temp_dir(), "vocab.txt")
   with open(vocab_path, "w") as f:
     f.write("alpha\nbeta\ngamma\n")
   with export_graph.as_default():
     initializer = lookup_ops.TextFileInitializer(
         vocab_path,
         key_dtype=dtypes.string,
         key_index=lookup_ops.TextFileIndex.WHOLE_LINE,
         value_dtype=dtypes.int64,
         value_index=lookup_ops.TextFileIndex.LINE_NUMBER)
     table = lookup_ops.HashTable(
         initializer, default_value=-1)
     start = array_ops.placeholder(
         shape=None, dtype=dtypes.string, name="in")
     output = table.lookup(start, name="out")
     with session_lib.Session() as session:
       session.run([table.initializer])
       path = os.path.join(self.get_temp_dir(), "saved_model", str(ops.uid()))
       simple_save.simple_save(
           session,
           path,
           inputs={"start": start},
           outputs={"output": output},
           legacy_init_op=table.initializer)
   file_io.delete_file(vocab_path)
   return path
Ejemplo n.º 10
0
 def _v1_single_metagraph_saved_model(self, use_resource):
     export_graph = ops.Graph()
     with export_graph.as_default():
         start = array_ops.placeholder(shape=[None],
                                       dtype=dtypes.float32,
                                       name="start")
         if use_resource:
             v = resource_variable_ops.ResourceVariable(3.)
         else:
             v = variables.RefVariable(3.)
         local_variable = variables.VariableV1(
             1.,
             collections=[ops.GraphKeys.LOCAL_VARIABLES],
             trainable=False,
             use_resource=True)
         output = array_ops.identity(start * v * local_variable,
                                     name="output")
         with session_lib.Session() as session:
             session.run([v.initializer, local_variable.initializer])
             path = os.path.join(self.get_temp_dir(), "saved_model",
                                 str(ops.uid()))
             simple_save.simple_save(
                 session,
                 path,
                 inputs={"start": start},
                 outputs={"output": output},
                 legacy_init_op=local_variable.initializer)
     return path
Ejemplo n.º 11
0
 def _model_with_ragged_input(self):
   """Generate a graph with a RaggedTensor input and serialize in V1 format."""
   export_graph = ops.Graph()
   with export_graph.as_default():
     x = ragged_factory_ops.placeholder(dtypes.float32, 1, [])
     y = x * 2
     with session_lib.Session() as sess:
       path = os.path.join(self.get_temp_dir(), "saved_model", str(ops.uid()))
       simple_save.simple_save(sess, path, inputs={"x": x}, outputs={"y": y})
   return path
Ejemplo n.º 12
0
 def save(self, directory="export_dir"):
     if os.path.isdir(directory):
         print("DIRECTORY {} IS NOT EMPTY. REMOVING IN 5 SEC...".format(
             directory))
         time.sleep(5)
         shutil.rmtree(directory)
     simple_save(self.session,
                 directory,
                 inputs={"inputs_validation": self.inputs_validation},
                 outputs={"logits_validation": self.logits_validation})
Ejemplo n.º 13
0
    def testSimpleSave(self):
        """Test simple_save that uses the default parameters."""
        export_dir = os.path.join(test.get_temp_dir(), "test_simple_save")

        # Force the test to run in graph mode.
        # This tests a deprecated v1 API that both requires a session and uses
        # functionality that does not work with eager tensors (such as
        # build_tensor_info as called by predict_signature_def).
        with ops.Graph().as_default():
            # Initialize input and output variables and save a prediction graph using
            # the default parameters.
            with self.session(graph=ops.Graph()) as sess:
                var_x = self._init_and_validate_variable("var_x", 1)
                var_y = self._init_and_validate_variable("var_y", 2)
                inputs = {"x": var_x}
                outputs = {"y": var_y}
                simple_save.simple_save(sess, export_dir, inputs, outputs)

            # Restore the graph with a valid tag and check the global variables and
            # signature def map.
            with self.session(graph=ops.Graph()) as sess:
                graph = loader.load(sess, [tag_constants.SERVING], export_dir)
                collection_vars = ops.get_collection(
                    ops.GraphKeys.GLOBAL_VARIABLES)

                # Check value and metadata of the saved variables.
                self.assertEqual(len(collection_vars), 2)
                self.assertEqual(1, collection_vars[0].eval())
                self.assertEqual(2, collection_vars[1].eval())
                self._check_variable_info(collection_vars[0], var_x)
                self._check_variable_info(collection_vars[1], var_y)

                # Check that the appropriate signature_def_map is created with the
                # default key and method name, and the specified inputs and outputs.
                signature_def_map = graph.signature_def
                self.assertEqual(1, len(signature_def_map))
                self.assertEqual(
                    signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY,
                    list(signature_def_map.keys())[0])

                signature_def = signature_def_map[
                    signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY]
                self.assertEqual(signature_constants.PREDICT_METHOD_NAME,
                                 signature_def.method_name)

                self.assertEqual(1, len(signature_def.inputs))
                self._check_tensor_info(signature_def.inputs["x"], var_x)
                self.assertEqual(1, len(signature_def.outputs))
                self._check_tensor_info(signature_def.outputs["y"], var_y)
Ejemplo n.º 14
0
 def _model_with_sparse_output(self):
   """Generate a graph with a SparseTensor output and serialize in V1 format"""
   export_graph = ops.Graph()
   with export_graph.as_default():
     in_placeholder = array_ops.placeholder(dtype=dtypes.int64, shape=[1])
     out_sparse_tensor = sparse_tensor.SparseTensor(
         indices=[[0]], values=in_placeholder, dense_shape=[1]) * 2
     with session_lib.Session() as session:
       path = os.path.join(self.get_temp_dir(), "saved_model", str(ops.uid()))
       simple_save.simple_save(
           session,
           path,
           inputs={"start": in_placeholder},
           outputs={"output": out_sparse_tensor})
   return path
Ejemplo n.º 15
0
 def _v1_cond_saved_model(self):
   export_graph = ops.Graph()
   with export_graph.as_default():
     branch_selector = array_ops.placeholder(
         name="branch_selector", shape=[], dtype=dtypes.bool)
     output = control_flow_ops.cond(
         branch_selector,
         lambda: array_ops.ones([]),
         lambda: array_ops.zeros([]))
     with session_lib.Session() as session:
       path = os.path.join(self.get_temp_dir(), "saved_model", str(ops.uid()))
       simple_save.simple_save(
           session,
           path,
           inputs={"branch_selector": branch_selector},
           outputs={"output": output})
   return path
Ejemplo n.º 16
0
 def _v1_cond_saved_model(self):
   export_graph = ops.Graph()
   with export_graph.as_default():
     branch_selector = array_ops.placeholder(
         name="branch_selector", shape=[], dtype=dtypes.bool)
     output = control_flow_ops.cond(
         branch_selector,
         lambda: array_ops.ones([]),
         lambda: array_ops.zeros([]))
     with session_lib.Session() as session:
       path = os.path.join(self.get_temp_dir(), "saved_model", str(ops.uid()))
       simple_save.simple_save(
           session,
           path,
           inputs={"branch_selector": branch_selector},
           outputs={"output": output})
   return path
Ejemplo n.º 17
0
def main():
    # create instance of config
    config = Configuration()
    config.dir_model = sys.argv[1]
    target_dir = sys.argv[2]

    # build model
    model = CodePoSModel(config)
    model.build()
    model.restore_session(config.dir_model)

    # fully export model
    simple_save(model.sess,
                target_dir,
                inputs=model.get_input_dict(),
                outputs=model.get_output_dict())
    with open(os.path.join(target_dir, 'configuration.json'), 'w') as _:
        json.dumps(config.__dict__)
Ejemplo n.º 18
0
  def testSimpleSave(self):
    """Test simple_save that uses the default parameters."""
    export_dir = os.path.join(test.get_temp_dir(),
                              "test_simple_save")

    # Initialize input and output variables and save a prediction graph using
    # the default parameters.
    with self.session(graph=ops.Graph()) as sess:
      var_x = self._init_and_validate_variable(sess, "var_x", 1)
      var_y = self._init_and_validate_variable(sess, "var_y", 2)
      inputs = {"x": var_x}
      outputs = {"y": var_y}
      simple_save.simple_save(sess, export_dir, inputs, outputs)

    # Restore the graph with a valid tag and check the global variables and
    # signature def map.
    with self.session(graph=ops.Graph()) as sess:
      graph = loader.load(sess, [tag_constants.SERVING], export_dir)
      collection_vars = ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)

      # Check value and metadata of the saved variables.
      self.assertEqual(len(collection_vars), 2)
      self.assertEqual(1, collection_vars[0].eval())
      self.assertEqual(2, collection_vars[1].eval())
      self._check_variable_info(collection_vars[0], var_x)
      self._check_variable_info(collection_vars[1], var_y)

      # Check that the appropriate signature_def_map is created with the
      # default key and method name, and the specified inputs and outputs.
      signature_def_map = graph.signature_def
      self.assertEqual(1, len(signature_def_map))
      self.assertEqual(signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY,
                       list(signature_def_map.keys())[0])

      signature_def = signature_def_map[
          signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY]
      self.assertEqual(signature_constants.PREDICT_METHOD_NAME,
                       signature_def.method_name)

      self.assertEqual(1, len(signature_def.inputs))
      self._check_tensor_info(signature_def.inputs["x"], var_x)
      self.assertEqual(1, len(signature_def.outputs))
      self._check_tensor_info(signature_def.outputs["y"], var_y)
Ejemplo n.º 19
0
def main(_):
    # Build the inference graph.
    g = tf.Graph()
    with g.as_default():
        model = ShowAndTellModel(ModelConfig(), mode='inference')
        restore_fn = model.build_graph_from_config(FLAGS.checkpoint_path)

    # Create the vocabulary.
    vocab = Vocabulary(FLAGS.vocab_file)

    filenames = []
    for file_pattern in FLAGS.input_files.split(","):
        filenames.extend(tf.gfile.Glob(file_pattern))
    tf.logging.info("Running caption generation on %d files matching %s",
                    len(filenames), FLAGS.input_files)

    with tf.Session(graph=g) as sess:
        # Load the model from checkpoint.
        restore_fn(sess)

        # Prepare the caption generator. Here we are implicitly using the default
        # beam search parameters. See caption_generator.py for a description of the
        # available beam search parameters.
        # generator = CaptionGenerator(model, vocab)

        for filename in filenames:
            with tf.gfile.GFile(filename, "rb") as f:
                image = f.read()

                sess.run(model.sentence.initializer)
                beamed = sess.run(fetches=["loopx/Exit_1:0"],
                                  feed_dict={"image_bytes:0": image})
                #print(beamed)
                sentence = [vocab.id_to_word(w) for w in beamed[0]]
                sentence = " ".join(sentence)
                print(sentence[3:-4])
                #continue
                simple_save.simple_save(sess, "./saved_model/0010",
                                        {"image": model.image_feed},
                                        {"caption": model.output})
                break
Ejemplo n.º 20
0
def test_task(endpoint: str, bucket: str, model_file: str, examples_file: str) -> str:
    """Connects to served model and tests example MNIST images."""

    from minio import Minio
    from pathlib import Path
    from retrying import retry
    from tensorflow.python.keras.backend import get_session
    from tensorflow.python.keras.saving import load_model
    from tensorflow.python.saved_model.simple_save import simple_save
    import numpy as np
    import requests

    mclient = Minio(
        endpoint,
        access_key=Path('/secrets/accesskey').read_text(),
        secret_key=Path('/secrets/secretkey').read_text(),
        secure=False,
    )

    print('Downloading model')

    mclient.fget_object(bucket, model_file, '/models/model.h5')
    mclient.fget_object(bucket, examples_file, '/models/examples.npz')

    print('Downloaded model, converting it to serving format')

    with get_session() as sess:
        model = load_model('/models/model.h5')
        simple_save(
            sess,
            '/output/mnist/1/',
            inputs={'input_image': model.input},
            outputs={t.name: t for t in model.outputs},
        )

    model_url = 'http://localhost:9001/v1/models/mnist'

    @retry(stop_max_delay=30 * 1000)
    def wait_for_model():
        requests.get(f'{model_url}/versions/1').raise_for_status()

    wait_for_model()

    response = requests.get(f'{model_url}/metadata')
    response.raise_for_status()
    assert response.json() == {
        'model_spec': {'name': 'mnist', 'signature_name': '', 'version': '1'},
        'metadata': {
            'signature_def': {
                'signature_def': {
                    'serving_default': {
                        'inputs': {
                            'input_image': {
                                'dtype': 'DT_FLOAT',
                                'tensor_shape': {
                                    'dim': [
                                        {'size': '-1', 'name': ''},
                                        {'size': '28', 'name': ''},
                                        {'size': '28', 'name': ''},
                                        {'size': '1', 'name': ''},
                                    ],
                                    'unknown_rank': False,
                                },
                                'name': 'conv2d_input:0',
                            }
                        },
                        'outputs': {
                            'dense_1/Softmax:0': {
                                'dtype': 'DT_FLOAT',
                                'tensor_shape': {
                                    'dim': [{'size': '-1', 'name': ''}, {'size': '10', 'name': ''}],
                                    'unknown_rank': False,
                                },
                                'name': 'dense_1/Softmax:0',
                            }
                        },
                        'method_name': 'tensorflow/serving/predict',
                    }
                }
            }
        },
    }

    examples = np.load('/models/examples.npz')
    assert examples['X'].shape == (10, 28, 28, 1)
    assert examples['y'].shape == (10, 10)

    response = requests.post(
        f'{model_url}:predict', json={'instances': examples['X'].tolist()}
    )
    response.raise_for_status()

    predicted = np.argmax(response.json()['predictions'], axis=1).tolist()
    actual = np.argmax(examples['y'], axis=1).tolist()
    accuracy = sum(1 for (p, a) in zip(predicted, actual) if p == a) / len(predicted)

    if accuracy >= 0.8:
        print(f'Got accuracy of {accuracy:0.2f} in mnist model')
    else:
        raise Exception(f'Low accuracy in mnist model: {accuracy}')
Ejemplo n.º 21
0
def main(save_dir, distant_dir, walltime):
    """
    Atari games have two kinds of inputs,
    ram: size (128,)
    human: size (screen_height,screen_width,3)
    Since the preprocessing stacks several frames (RL_config.num_recent_obs) together,
    the input size should be modified accordingly.
    """
    game = "SpaceInvaders-v0"
    print("Playing game {}".format(game.split("-")[0]))
    env = gym.make(game)
    checkpoint = os.path.join(save_dir, game + "-dqn.ckpt")
    time0 = time.time()

    num_actions = env.action_space.n
    print("Number of actions is {}".format(num_actions))
    rl_conf = RL_config()
    rl_model = Prioritized_replay(env, rl_conf)
    rl_model.initialize_replay()
    obs_shape = rl_model.obs_shape
    print("Shape of the input:{}".format(obs_shape))
    if len(obs_shape) == 3:
        obs_h, obs_w, obs_c = rl_model.obs_shape
        obs_ph = tf.placeholder(tf.float32,
                                shape=(None, obs_h, obs_w, obs_c),
                                name="obs_ph")
        obs_dim = 3
    elif obs_shape < 3:
        if len(obs_shape) == 1:
            obs_h = obs_shape[0]
            obs_c = 1
        else:
            obs_h, obs_c = obs_shape
        obs_ph = tf.placeholder(tf.float32,
                                shape=(None, obs_h, obs_c),
                                name="obs_ph")
        obs_dim = 2
    else:
        print("obs_shape inconsistent with the model: {}".format(obs_shape))
        sys.exit()
    indexed_action_ph = tf.placeholder(tf.int32,
                                       shape=(None, 2),
                                       name="indexed_action_ph")
    y_ph = tf.placeholder(tf.float32, shape=(None, 1), name="y_ph")
    #ph for weights of importance sampling (IS)
    is_weight_ph = tf.placeholder(tf.float32, shape=None, name="is_weight_ph")

    #3 cnn layers
    if obs_dim == 3:
        w1 = tf.get_variable("cnn_w1",[8,8,obs_c,32],dtype=tf.float32,\
                initializer=tf.contrib.layers.xavier_initializer())
        w2 = tf.get_variable("cnn_w2",[4,4,32,64],dtype=tf.float32,\
                initializer=tf.contrib.layers.xavier_initializer())
        w3 = tf.get_variable("cnn_w3",[3,3,64,64],dtype=tf.float32,\
                initializer=tf.contrib.layers.xavier_initializer())
    else:
        w1 = tf.get_variable("cnn_w1",[8,obs_c,32],dtype=tf.float32,\
                initializer=tf.contrib.layers.xavier_initializer())
        w2 = tf.get_variable("cnn_w2",[4,32,64],dtype=tf.float32,\
                initializer=tf.contrib.layers.xavier_initializer())
        w3 = tf.get_variable("cnn_w3",[3,64,64],dtype=tf.float32,\
                initializer=tf.contrib.layers.xavier_initializer())

    tf.add_to_collection(tf.GraphKeys.REGULARIZATION_LOSSES,
                         1 / 2 * tf.norm(w1))
    if obs_dim == 3:
        z1 = tf.nn.conv2d(obs_ph, w1, strides=[1, 4, 4, 1], padding="VALID")
    else:
        z1 = tf.nn.conv1d(obs_ph, w1, strides=[1, 4, 1], padding="VALID")
    b1 = tf.get_variable("cnn_b1",z1.get_shape().as_list()[1:],dtype=tf.float32,\
            initializer=tf.zeros_initializer())
    a1 = tf.nn.relu(z1 + b1)

    tf.add_to_collection(tf.GraphKeys.REGULARIZATION_LOSSES,
                         1 / 2 * tf.norm(w2))
    if obs_dim == 3:
        z2 = tf.nn.conv2d(a1, w2, strides=[1, 2, 2, 1], padding="VALID")
    else:
        z2 = tf.nn.conv1d(a1, w2, strides=[1, 2, 1], padding="VALID")
    b2 = tf.get_variable("cnn_b2",z2.get_shape().as_list()[1:],dtype=tf.float32,\
            initializer=tf.zeros_initializer())
    a2 = tf.nn.relu(z2 + b2)

    tf.add_to_collection(tf.GraphKeys.REGULARIZATION_LOSSES,
                         1 / 2 * tf.norm(w3))
    if obs_dim == 3:
        z3 = tf.nn.conv2d(a2, w3, strides=[1, 1, 1, 1], padding="VALID")
    else:
        z3 = tf.nn.conv1d(a2, w3, strides=[1, 1, 1], padding="VALID")
    b3 = tf.get_variable("cnn_b3",z3.get_shape().as_list()[1:],dtype=tf.float32,\
            initializer=tf.zeros_initializer())
    a3 = tf.nn.relu(z3 + b3)

    #fully connected relu
    a3_flat = tf.contrib.layers.flatten(a3)
    w4 = tf.get_variable("fc_w4",[a3_flat.get_shape().as_list()[1],512],dtype=tf.float32,\
            initializer=tf.contrib.layers.xavier_initializer())
    tf.add_to_collection(tf.GraphKeys.REGULARIZATION_LOSSES,
                         1 / 2 * tf.norm(w4))
    b4 = tf.get_variable("fc_b4",[1,512],dtype=tf.float32,\
            initializer=tf.zeros_initializer())
    a4 = tf.nn.relu(tf.matmul(a3_flat, w4) + b4)

    #linear output layer
    wo = tf.get_variable("wo",[512,num_actions],dtype=tf.float32,\
            initializer=tf.contrib.layers.xavier_initializer())
    tf.add_to_collection(tf.GraphKeys.REGULARIZATION_LOSSES,
                         1 / 2 * tf.norm(wo))
    bo = tf.get_variable("bo",[1,num_actions],dtype=tf.float32,\
            initializer=tf.zeros_initializer())
    #q is an array with the size of num_actions
    q = tf.add(tf.matmul(a4, wo), bo, name="q")
    print("Shape of q: {}".format(q.get_shape().as_list()))

    #huber loss
    preds = tf.reshape(tf.gather_nd(q, indexed_action_ph), [-1, 1],
                       name="preds")
    loss = tf.losses.huber_loss(y_ph, preds, weights=is_weight_ph)
    reg_loss = loss + tf.reduce_sum(
        tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES))

    #Adam Optimizer
    train_op = tf.train.AdamOptimizer(
        learning_rate=rl_conf.learning_rate).minimize(reg_loss)

    variable_set = [w1, b1, w2, b2, w3, b3, w4, b4, wo, bo]
    with tf.Session() as sess:
        saver = tf.train.Saver()
        if tf.train.checkpoint_exists(checkpoint):
            saver.restore(sess, checkpoint)
            print("Restore model from checkpoint.")
        else:
            sess.run(tf.global_variables_initializer())
            print("Files saved to {}".format(save_dir))
        #Save model state
        print("Save model to save_dir:")
        inputs_dict = {
            "obs_ph": obs_ph,
            "indexed_action_ph": indexed_action_ph,
            "y_ph": y_ph,
            "is_weight_ph": is_weight_ph
        }
        outputs_dict = {"preds": preds}
        model_dir = os.path.join(save_dir, "model")
        if os.path.isdir(model_dir):
            model_dir = os.path.join(save_dir, "model-0")
            os.makedirs(model_dir)
        simple_save(sess, model_dir, inputs_dict, outputs_dict)

        q_file = os.path.join(save_dir, "q_values.txt")
        reward_file = os.path.join(save_dir, "rewards.txt")

        #compute priorites for initial replay
        batch_idx = 0
        priorities = np.zeros((len(rl_model.replay), 1))
        variable_values_q1 = sess.run(variable_set)
        var_feed_dict_list_q1 = [
            (key, val) for key, val in zip(variable_set, variable_values_q1)
        ]
        while batch_idx < len(rl_model.replay) / rl_conf.batch:
            if (batch_idx + 1) * rl_conf.batch > len(rl_model.replay):
                samples = rl_model.replay[batch_idx * rl_conf.batch:]
            else:
                samples = rl_model.replay[batch_idx *
                                          rl_conf.batch:(batch_idx + 1) *
                                          rl_conf.batch]
            sample_size = len(samples)
            obs_samples = np.array([sample[0] for sample in samples])
            obs_next_samples = np.array([sample[3] for sample in samples])
            rew_samples = np.array([sample[2] for sample in samples]).reshape(
                (sample_size, 1))
            done_samples = np.array([sample[4] for sample in samples]).reshape(
                (sample_size, 1))
            action_samples = np.array([sample[1]
                                       for sample in samples]).astype(int)

            q_targets = sess.run(q, feed_dict={obs_ph: obs_next_samples})

            q_for_priorities = sess.run(q, feed_dict={obs_ph: obs_samples})

            q_targets_selected_actions = np.amax(q_targets,
                                                 axis=-1,
                                                 keepdims=True)
            labels = np.where(
                done_samples, rew_samples,
                rew_samples + rl_conf.gamma * q_targets_selected_actions)

            if (batch_idx + 1) * rl_conf.batch > len(rl_model.replay):
                q_priorities_selected_actions = q_for_priorities[
                    np.arange(len(action_samples)), action_samples].reshape(
                        (sample_size, 1))
                priorities[batch_idx * rl_conf.batch:] = np.absolute(
                    labels - q_priorities_selected_actions)**(
                        rl_conf.alpha_prioritized_replay)
            else:
                q_priorities_selected_actions = q_for_priorities[
                    np.arange(len(action_samples)), action_samples].reshape(
                        (sample_size, 1))
                priorities[batch_idx * rl_conf.batch:(batch_idx + 1) *
                           rl_conf.batch] = np.absolute(
                               labels - q_priorities_selected_actions)**(
                                   rl_conf.alpha_prioritized_replay)
            batch_idx += 1
        rl_model.initialize_priority_list(priorities.flatten())

        update_q1_steps = 0
        for episode in np.arange(1, rl_conf.max_episodes + 1):
            #Make sure the running time does not exceed walltime
            start = time.time()
            if start - time0 > float(walltime) * 0.9:
                saver.save(sess, checkpoint)
                print("Running time limit achieves.")
                sys.exit()
            #metrics
            total_rew = 0.0
            average_q = 0.0

            s0 = preprocess_frame(env.reset(), rl_conf.shape_of_frame)
            zero_pad = np.zeros(s0.shape)
            if rl_conf.num_recent_obs == 1:
                sequence = []
            else:
                sequence = [zero_pad] * (rl_conf.num_recent_obs - 1)
            sequence.append(s0)
            #epsilon schedule
            epsilon = max(
                1.0 - float(episode) / rl_conf.final_exploration_episodes, 0.1)
            #beta_prioritized_replay schedule
            beta_prioritized_replay = min(
                rl_conf.beta_prioritized_replay +
                float(episode) / rl_conf.final_exploration_episodes, 1.0)
            for step in np.arange(1, rl_conf.max_steps):
                #whether to update variables of q1
                if update_q1_steps % rl_conf.steps_for_updating_q1 == 0:
                    variable_values_q1 = sess.run(variable_set)
                    var_feed_dict_list_q1 = [
                        (key, val)
                        for key, val in zip(variable_set, variable_values_q1)
                    ]
                #whether to recompute priority list
                if len(
                        rl_model.replay
                ) >= rl_conf.replay_capacity and update_q1_steps % rl_model.frequency_to_update_priorities_globally == 0:
                    rl_model.compute_priority_list(rl_model.priority_list)
                if len(
                        rl_model.replay
                ) >= rl_conf.replay_capacity and rl_model.group_rbs == []:
                    rl_model.compute_priority_list(rl_model.priority_list)
                update_q1_steps += 1
                #double dqn
                samples, sample_p_idxes = rl_model.sample_from_replay()
                obs_samples = np.array([sample[0] for sample in samples])
                obs_next_samples = np.array([sample[3] for sample in samples])
                rew_samples = np.array([sample[2]
                                        for sample in samples]).reshape(
                                            (rl_conf.batch, 1))
                done_samples = np.array([sample[4]
                                         for sample in samples]).reshape(
                                             (rl_conf.batch, 1))
                action_samples = np.array([sample[1]
                                           for sample in samples]).astype(int)
                indexed_action_samples = np.array(
                    [[idx, sample[1]]
                     for idx, sample in enumerate(samples)]).astype(int)

                q_for_selecting_actions = sess.run(
                    q, feed_dict={obs_ph: obs_next_samples})
                selected_actions = np.argmax(q_for_selecting_actions, axis=-1)

                q_for_priorities = sess.run(q, feed_dict={obs_ph: obs_samples})

                q_targets = sess.run(
                    q,
                    feed_dict=dict([(obs_ph, obs_next_samples)] +
                                   var_feed_dict_list_q1))
                q_targets_selected_actions = q_targets[
                    np.arange(len(selected_actions)),
                    selected_actions].reshape((rl_conf.batch, 1))
                labels = np.where(
                    done_samples, rew_samples,
                    rew_samples + rl_conf.gamma * q_targets_selected_actions)

                #update priorities for replay
                if sample_p_idxes is not None:
                    priorities = np.absolute(labels - q_for_priorities[
                        np.arange(len(action_samples)),
                        action_samples].reshape((rl_conf.batch, 1)))**(
                            rl_conf.alpha_prioritized_replay)
                    p_indexed_priorities = np.concatenate(
                        (sample_p_idxes.reshape(
                            (rl_conf.batch, 1)), priorities),
                        axis=1)
                    rl_model.update_priority_after_sampling(
                        p_indexed_priorities)
                    weights = (len(rl_model.replay) *
                               priorities)**(-beta_prioritized_replay)
                    weights = weights / np.amax(weights)
                else:
                    weights = 1.0

                sess.run(train_op,
                         feed_dict={
                             obs_ph: obs_samples,
                             y_ph: labels,
                             indexed_action_ph: indexed_action_samples,
                             is_weight_ph: weights
                         })

                #update replay
                if step % rl_conf.action_repeat == 1:
                    #calculate q1 and get optimal action
                    if step == 1:
                        obs_input = preprocess_obs(
                            sequence[:rl_conf.num_recent_obs])
                    q_values_for_new_sample = sess.run(
                        q, feed_dict={obs_ph: np.array([obs_input])})
                    max_action_for_new_sample = np.argmax(
                        q_values_for_new_sample, axis=-1)
                    average_q += q_values_for_new_sample[0][
                        max_action_for_new_sample]
                    #epsilon greedy algo
                    dice = np.random.uniform()
                    if dice < epsilon:
                        action = np.random.randint(num_actions)
                    else:
                        action = max_action_for_new_sample
                obs, rew, done, _ = env.step(action)
                total_rew += rew
                sequence.append(preprocess_frame(obs, rl_conf.shape_of_frame))
                last_obs_input = obs_input
                obs_input = preprocess_obs(sequence[step:step +
                                                    rl_conf.num_recent_obs])
                #calculate priority for new sample
                if done:
                    priority = np.absolute(
                        rew - q_values_for_new_sample[0][action])**(
                            rl_conf.alpha_prioritized_replay)
                else:
                    q_target_for_new_sample = sess.run(
                        q,
                        feed_dict=dict([(obs_ph, np.array([obs_input]))] +
                                       var_feed_dict_list_q1))
                    priority = np.absolute(rew + rl_conf.gamma*q_target_for_new_sample[0][max_action_for_new_sample] - q_values_for_new_sample[0][action]\
                            )**(rl_conf.alpha_prioritized_replay)
                if np.isscalar(priority):
                    priority = np.array([priority])
                else:
                    priority = priority.flatten()
                rl_model.update_replay(
                    np.array([(last_obs_input, action, rew, obs_input, done)]),
                    priority)
                if done:
                    with open(q_file, "a+") as out:
                        out.write(str(average_q * 4 / step) + "\n")
                    with open(reward_file, "a+") as out:
                        out.write(str(total_rew) + "\n")
                    break
            if episode % rl_conf.episodes_to_save == 0:
                saver.save(sess, checkpoint)
                save_file_to_distant_dir = """cp {} {}""".format(
                    os.path.join(save_dir, "*"), distant_dir)
                subprocess.call(save_file_to_distant_dir, shell=True)
            print("Time taken to complete this episode: {}s.".format(
                time.time() - start))

            #Evaluate the performance of the agent with total rewards
            if episode % rl_conf.episodes_to_validate == 0:
                total_rew_eval = []
                epsilon = 0.05
                first_actions = {}
                for _ in np.arange(rl_conf.evaluation_trials):
                    s0 = preprocess_frame(env.reset(), rl_conf.shape_of_frame)
                    zero_pad = np.zeros(s0.shape)
                    if rl_conf.num_recent_obs == 1:
                        sequence = []
                    else:
                        sequence = [zero_pad] * (rl_conf.num_recent_obs - 1)
                    sequence.append(s0)

                    rew_eval = 0
                    for step in np.arange(1, rl_conf.max_steps):
                        if step % rl_conf.action_repeat == 1:
                            #calculate q1 and get optimal action
                            if step == 1:
                                obs_input = preprocess_obs(
                                    sequence[:rl_conf.num_recent_obs])
                            q_values_for_evaluation = sess.run(
                                q, feed_dict={obs_ph: np.array([obs_input])})
                            max_action_for_evaluation = np.argmax(
                                q_values_for_evaluation, axis=-1)
                            #epsilon greedy algo
                            dice = np.random.uniform()
                            if dice < epsilon:
                                action = np.random.randint(num_actions)
                            else:
                                action = max_action_for_evaluation
                        obs, rew, done, _ = env.step(action)
                        rew_eval += rew
                        sequence.append(
                            preprocess_frame(obs, rl_conf.shape_of_frame))
                        last_obs_input = obs_input
                        obs_input = preprocess_obs(
                            sequence[step:step + rl_conf.num_recent_obs])
                        if done:
                            total_rew_eval.append(rew_eval)
                            break
                print("Evaluating the agent at episode-{}:".format(episode))
                total_rew_eval = np.array(total_rew_eval)
                print("Total rewards of trials have max-{}, average-{}, std-{}.".format(np.amax(total_rew_eval),\
                        np.mean(total_rew_eval),np.std(total_rew_eval)))
Ejemplo n.º 22
0
    decoder = tf.matmul(noise_input,
                        weights['decoder_h1']) + biases['decoder_b1']
    decoder = tf.nn.tanh(decoder)
    decoder = tf.matmul(decoder,
                        weights['decoder_out']) + biases['decoder_out']
    decoder = tf.nn.sigmoid(decoder)

    # Building a manifold of generated digits
    n = 20
    x_axis = np.linspace(-3, 3, n)
    y_axis = np.linspace(-3, 3, n)

    canvas = np.empty((28 * n, 28 * n))
    for i, yi in enumerate(x_axis):
        for j, xi in enumerate(y_axis):
            z_mu = np.array([[xi, yi]] * batch_size)
            x_mean = sess.run(decoder, feed_dict={noise_input: z_mu})
            canvas[(n - i - 1) * 28:(n - i) * 28, j * 28:(j + 1) * 28] = \
            x_mean[0].reshape(28, 28)

    plt.figure(figsize=(8, 10))
    Xi, Yi = np.meshgrid(x_axis, y_axis)
    plt.imshow(canvas, origin="upper", cmap="gray")
    plt.show()

    print("Saving the model")
    simple_save(sess,
                export_dir='./saved_variational_autoencoder',
                inputs={"inp": input_image},
                outputs={"out": decoder_export})
 def save_model(sess,path_to_model,X,Y):
     try:
         shutil.rmtree(path_to_model)
     except:
         pass
     simple_save(sess, path_to_model, inputs={"myInput": X}, outputs={"myOutput": Y})
Ejemplo n.º 24
0
def main(mode: str, user_feedback=None, detections=None):
    """
    Trains a model or uses an existent model to make predictions.

    Args:
        mode: Defines whether a model should be trained
        or an existent model should be used for making predictions.
        user_feedback: Used in live-train-mode to train the model based on user input.
        detections: Used in detection-mode to give a list of detections
        for which predictions should be generated.
    """
    gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=0.1)

    accept_prob_model_dir = os.path.join(model_dir, 'accept_prob_predictor')

    if not os.path.exists(accept_prob_model_dir):
        os.mkdir(accept_prob_model_dir)

    existent_checkpoints = os.listdir(accept_prob_model_dir)
    existent_checkpoints.sort(key=int)

    x = tf.placeholder(tf.float32, [None, 1005])
    y = prob_model(x)
    _y = tf.placeholder(tf.float32, [None, 1])
    loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=y, labels=_y))

    saver = tf.train.Saver()

    while True:
        actual_checkpoint_dir = ''
        if len(existent_checkpoints) == 0:
            new_checkpoint_dir = os.path.join(accept_prob_model_dir, '1')
            break
        actual_checkpoint = existent_checkpoints[len(existent_checkpoints) - 1]
        actual_checkpoint_dir = os.path.join(accept_prob_model_dir, actual_checkpoint)
        if len(os.listdir(actual_checkpoint_dir)) > 0:
            new_checkpoint_dir = os.path.join(accept_prob_model_dir,
                                              str(int(actual_checkpoint) + 1))
            break

        existent_checkpoints.remove(actual_checkpoint)
        shutil.rmtree(actual_checkpoint_dir)

    if (FLAGS and FLAGS.mode == 'train') or mode == 'train':
        # initial train mode with fixed training data
        if not user_feedback:
            iterations = FLAGS.iterations
            learning_rate = FLAGS.learning_rate
            base = FLAGS.path_to_training_data
            path_to_train_data = '{}_{}'.format(base, 'train.tfrecords')
            train_features, train_labels = parse_tf_records(path_to_train_data)

        # live train mode with user feedback
        else:
            iterations = 1
            learning_rate = 0.001
            batch_size = 64
            train_features, train_labels = tf.train.shuffle_batch([user_feedback['x'],
                                                                   user_feedback['y_']],
                                                                  batch_size=batch_size,
                                                                  capacity=50000,
                                                                  min_after_dequeue=0,
                                                                  allow_smaller_final_batch=True)

        test_feat = []
        test_lbl = []
        example = tf.train.Example()
        for record in tf.python_io.tf_record_iterator(path_to_test_data):
            example.ParseFromString(record)
            f = example.features.feature
            test_feat.append(np.asarray(f['test/feature'].float_list.value))
            test_lbl.append(f['test/label'].float_list.value[0])
        test_feat = np.reshape(test_feat, (-1, 1005))
        test_lbl = np.reshape(test_lbl, (-1, 1))

        init_op = tf.group(tf.global_variables_initializer(),
                           tf.local_variables_initializer())

        optimizer = tf.train.GradientDescentOptimizer(learning_rate=learning_rate)
        train_op = optimizer.minimize(loss=loss, global_step=tf.train.get_global_step())

        best_acc = 0.0
        best_acc_ann = 0.0
        best_acc_ver = 0.0
        early_stopping_counter = 0

        with tf.Session(config=tf.ConfigProto(gpu_options=gpu_options)) as sess:

            sess.run(init_op)

            if len(existent_checkpoints) > 0:
                saver.restore(sess, os.path.join(actual_checkpoint_dir, 'prob_predictor.ckpt'))

            coord = tf.train.Coordinator()
            threads = tf.train.start_queue_runners(coord=coord)

            for batch_index in range(iterations):
                try:
                    feat, lbl = sess.run([train_features, train_labels])
                except OutOfRangeError:
                    print('No more Training Data available')
                    break

                train_op.run(feed_dict={x: np.reshape(feat, (-1, 1005)),
                                        _y: np.reshape(lbl, (-1, 1))})

                if batch_index % 100 == 0 or user_feedback:
                    y_test = y.eval(feed_dict={x: test_feat, _y: test_lbl})
                    acc, acc_ann, acc_ver = evaluate_prediction_record(y_test, test_lbl)
                    print('step {},\t Annotation acc.:{}\tVerification acc.:{}'.format(batch_index,
                                                                                       acc_ann,
                                                                                       acc_ver))

                    if user_feedback:
                        if os.path.exists(new_checkpoint_dir):
                            shutil.rmtree(new_checkpoint_dir)
                        simple_save(sess,
                                    new_checkpoint_dir,
                                    inputs={'inputs': x},
                                    outputs={'outputs': y})
                        saver.save(sess, os.path.join(new_checkpoint_dir, 'prob_predictor.ckpt'))

                    if acc_ann + acc_ver > best_acc_ann + best_acc_ver:
                        if os.path.exists(new_checkpoint_dir):
                            shutil.rmtree(new_checkpoint_dir)

                        best_acc = acc
                        best_acc_ann = acc_ann
                        best_acc_ver = acc_ver
                        early_stopping_counter = 0
                        simple_save(sess,
                                    new_checkpoint_dir,
                                    inputs={'inputs': x},
                                    outputs={'outputs': y})
                        saver.save(sess, os.path.join(new_checkpoint_dir, 'prob_predictor.ckpt'))
                    # elif early_stopping_counter == 50:
                    #     print('Stopped early at batch {}/{}'.format(batch_index, iterations))
                    #     break
                    else:
                        early_stopping_counter += 1

            print('Accuracy:\t{}'.format(best_acc))
            print('Accuracy Ann:\t{}'.format(best_acc_ann))
            print('Accuracy Ver:\t{}'.format(best_acc_ver))

            # Stop the threads
            coord.request_stop()

            # Wait for threads to stop
            coord.join(threads)
            sess.close()

    elif mode == 'predict' and len(existent_checkpoints) > 0:
        feature_data = []
        for key in detections:
            for i, _ in enumerate(detections[key]):
                feature_data.append(compute_feature_vector(detections[key][i]))

        with tf.Session(config=tf.ConfigProto(gpu_options=gpu_options)) as sess:
            saver.restore(sess, os.path.join(actual_checkpoint_dir, 'prob_predictor.ckpt'))

            result = y.eval(feed_dict={
                x: np.reshape(feature_data, (-1, 1005))})
            prediction = tf.round(result).eval()

        for i, val in enumerate(prediction):
            print('{}: {}'.format(i, val))
        return prediction
Ejemplo n.º 25
0
def main():
    logging.getLogger().setLevel(logging.DEBUG)
    # Reference variable setup
    sess = tf.Session()
    K.set_session(sess)
    model_version = lib.get_batch_name()
    K.set_learning_phase(0)

    # Data setup

    logging.info('Loading data')

    dataframe = pandas.read_csv("iris_with_header.csv")
    numerical_cols = ['sepal_length', 'sepal_width']
    input_data = list()
    for col in numerical_cols:
        input_data.append(dataframe[col].values)

    logging.info('OHE-ing response variable')
    encoder = LabelEncoder()
    encoder.fit(dataframe.values[:, 4])
    encoded_Y = encoder.transform(dataframe.values[:, 4])
    one_hot_labels = np_utils.to_categorical(encoded_Y)

    # Model setup
    logging.info('Creating model')

    input_layers = list()

    for col in numerical_cols:
        logging.info('Creating input for {}'.format(col))

        if len(dataframe[col].shape) > 1:
            shape = dataframe[col].shape[1]
        else:
            shape = 1

        logging.info('Inferring variable {} has shape: {}'.format(col, shape))

        input_layers.append(Input(shape=(shape, ),
                                  name='{}_input'.format(col)))

    layers = Concatenate()(input_layers)
    layers = Dense(32)(layers)
    layers = Dense(3)(layers)

    model = Model(input_layers, layers)

    model.compile(loss=losses.categorical_crossentropy, optimizer='adam')

    model.fit(input_data, one_hot_labels)

    tf_inputs = list()
    tf_examples = list()
    # Register input placholders
    for col in numerical_cols:
        logging.info('Creating tf placeholder for col: {}'.format(col))

        if len(dataframe[col].shape) > 1:
            shape = dataframe[col].shape[1]
        else:
            shape = 1
        logging.info('Inferring variable {} has shape: {}'.format(col, shape))

        serialized_tf_example = tf.placeholder(tf.string, name=col)
        tf_examples.append(serialized_tf_example)

        # TODO Better type lookup based on numpy types
        feature_configs = {
            col: tf.FixedLenFeature(shape=shape, dtype=tf.float32),
        }
        tf_example = tf.parse_example(serialized_tf_example, feature_configs)
        tf_inputs.append(tf.identity(tf_example[col], name=col))

    # Generate output tensor by feeding inputs into model
    y = model(tf_inputs)

    # Generate classification signature definition
    labels = encoder.classes_
    values, indices = tf.nn.top_k(y, len(labels))
    OHE_index_to_string_lookup_table = tf.contrib.lookup.index_to_string_table_from_tensor(
        tf.constant(labels))
    prediction_classes = OHE_index_to_string_lookup_table.lookup(
        tf.to_int64(indices))

    # Save model
    output_path = './' + model_version
    logging.info('Saving model to {}'.format(output_path))
    simple_save_inputs = dict(zip(numerical_cols, tf_inputs))
    logging.info('Inputs to simple save: {}'.format(simple_save_inputs))
    simple_save(sess, output_path, inputs=simple_save_inputs, outputs={'y': y})

    # Can now get signature w/ https://www.tensorflow.org/guide/saved_model#cli_to_inspect_and_execute_savedmodel

    # TODO Create tensorflow-serving signature

    # TODO validate signatures

    # TODO Serialize tensorflow-serving elements

    # TODO Save the graph

    # TODO Return path to serialized graph

    pass
Ejemplo n.º 26
0
        # Reshape data to get 28 seq of 28 elements
        batch_x = batch_x.reshape((batch_size, timesteps, num_input))
        # Run optimization op (backprop)
        sess.run(train_op, feed_dict={X: batch_x, Y: batch_y})
        if step % display_step == 0 or step == 1:
            # Calculate batch loss and accuracy
            loss, acc = sess.run([loss_op, accuracy],
                                 feed_dict={
                                     X: batch_x,
                                     Y: batch_y
                                 })
            print("Step " + str(step) + ", Minibatch Loss= " + \
                  "{:.4f}".format(loss) + ", Training Accuracy= " + \
                  "{:.3f}".format(acc))

    print("Optimization Finished!")

    # Calculate accuracy for 128 mnist test images
    test_len = 128
    test_data = mnist.test.images[:test_len].reshape(
        (-1, timesteps, num_input))
    test_label = mnist.test.labels[:test_len]
    print("Testing Accuracy:", \
        sess.run(accuracy, feed_dict={X: test_data, Y: test_label}))

    print("Saving the model")
    simple_save(sess,
                export_dir='./saved_recurrent_network',
                inputs={"images": X},
                outputs={"out": prediction})
Ejemplo n.º 27
0
def train(Xl, yl, Xt, yt):
    """ Main train
    :param Xl training examples tensor, shape [num_examples, num_inputs, num_timesteps]
    :param yl one-hot encoded training labels tensor, shape [num_examples, num_classes]
    :param Xt test examples tensor, shape [num_examples, num_inputs, num_timesteps]
    :param yt one-hot encoded test labels tensor, shape [num_examples, num_classes]
    """
    batch_size = 64
    num_timesteps = Xl.shape[
        2]  # number of rows (each row in the image is considered as a timestep)
    num_inputs = Xl.shape[1]  # length of each row
    num_hidden = 128
    num_classes = yl.shape[1]
    training_steps = 500  # TODO temporary to make faster, 10000 to actually train
    learning_rate = 0.001
    num_examples = Xl.shape[0]
    with tf.Session(graph=tf.Graph()) as sess:
        X, lstm_out = lstm_and_dense_layer(num_timesteps,
                                           num_inputs,
                                           num_classes=num_classes,
                                           num_hidden=num_hidden)
        y = tf.placeholder(tf.float32, shape=(None, num_classes))

        loss_op = tf.reduce_mean(
            tf.nn.softmax_cross_entropy_with_logits_v2(logits=lstm_out,
                                                       labels=y))
        optimizer = tf.train.GradientDescentOptimizer(
            learning_rate=learning_rate)
        train_op = optimizer.minimize(loss_op)

        prediction = tf.nn.softmax(lstm_out, name='prediction')
        correct_pred = tf.equal(tf.argmax(prediction, 1), tf.argmax(y, 1))
        accuracy_op = tf.reduce_mean(tf.cast(correct_pred, tf.float32))

        sess.run(variables.global_variables_initializer())

        epoch = -1
        batch_start = 0
        batch_end = batch_size

        def next_batch():
            nonlocal epoch, batch_start, batch_end, Xl, yl
            if batch_end > num_examples or epoch == -1:
                epoch += 1
                batch_start = 0
                batch_end = batch_size
                perm0 = np.arange(num_examples)
                np.random.shuffle(perm0)
                Xl = Xl[perm0]
                yl = yl[perm0]
            Xi_ = Xl[batch_start:batch_end, :, :]
            yi_ = yl[batch_start:batch_end, :]
            batch_start = batch_end
            batch_end = batch_start + batch_size
            return {X: Xi_, y: yi_}

        for step in range(training_steps + 1):
            batch = next_batch()
            sess.run(train_op, feed_dict=batch)

            if step % 100 == 0:
                loss_, acc_ = sess.run((loss_op, accuracy_op), feed_dict=batch)
                print("epoch", epoch, "step", step, "loss",
                      "{:.4f}".format(loss_), "acc", "{:.2f}".format(acc_))

        print("Optimization Finished!")
        print("Testing Accuracy:",
              sess.run(accuracy_op, feed_dict={
                  X: Xt,
                  y: yt
              }))

        save_model = False
        if save_model:
            print("Saving the model")
            simple_save(sess,
                        export_dir='./lstm',
                        inputs={"images": X},
                        outputs={"out": prediction})
Ejemplo n.º 28
0
        # Reshape data to get 28 seq of 28 elements
        batch_x = batch_x.reshape((batch_size, timesteps, num_input))
        # Run optimization op (backprop)
        sess.run(train_op, feed_dict={X: batch_x, Y: batch_y})
        if step % display_step == 0 or step == 1:
            # Calculate batch loss and accuracy
            loss, acc = sess.run([loss_op, accuracy],
                                 feed_dict={
                                     X: batch_x,
                                     Y: batch_y
                                 })
            print("Step " + str(step) + ", Minibatch Loss= " + \
                  "{:.4f}".format(loss) + ", Training Accuracy= " + \
                  "{:.3f}".format(acc))

    print("Optimization Finished!")

    # Calculate accuracy for 128 mnist test images
    test_len = 128
    test_data = mnist.test.images[:test_len].reshape(
        (-1, timesteps, num_input))
    test_label = mnist.test.labels[:test_len]
    print("Testing Accuracy:", \
        sess.run(accuracy, feed_dict={X: test_data, Y: test_label}))

    print("Saving the model")
    simple_save(sess,
                export_dir='./saved_bidirectional_rnn',
                inputs={"inp": X},
                outputs={"out": prediction})
Ejemplo n.º 29
0
def amazon_attribute_train(generator: rmc_att_topic.generator,
                           discriminator: rmc_att_topic.discriminator,
                           oracle_loader: RealDataAmazonLoader, config, args):
    batch_size = config['batch_size']
    num_sentences = config['num_sentences']
    vocab_size = config['vocabulary_size']
    seq_len = config['seq_len']
    dataset = config['dataset']
    npre_epochs = config['npre_epochs']
    n_topic_pre_epochs = config['n_topic_pre_epochs']
    nadv_steps = config['nadv_steps']
    temper = config['temperature']
    adapt = config['adapt']

    # changed to match resources path
    data_dir = resources_path(config['data_dir'], "Amazon_Attribute")
    log_dir = resources_path(config['log_dir'])
    sample_dir = resources_path(config['sample_dir'])

    # filename
    oracle_file = os.path.join(sample_dir, 'oracle_{}.txt'.format(dataset))
    gen_file = os.path.join(sample_dir, 'generator.txt')
    gen_text_file = os.path.join(sample_dir, 'generator_text.txt')
    gen_text_file_print = os.path.join(sample_dir, 'gen_text_file_print.txt')
    json_file = os.path.join(sample_dir, 'json_file.txt')
    json_file_validation = os.path.join(sample_dir, 'json_file_validation.txt')
    csv_file = os.path.join(log_dir, 'experiment-log-rmcgan.csv')
    data_file = os.path.join(data_dir, '{}.txt'.format(dataset))

    test_file = os.path.join(data_dir, 'test.csv')

    # create necessary directories
    if not os.path.exists(data_dir):
        os.makedirs(data_dir)
    if not os.path.exists(sample_dir):
        os.makedirs(sample_dir)
    if not os.path.exists(log_dir):
        os.makedirs(log_dir)

    # placeholder definitions
    x_real = tf.placeholder(tf.int32, [batch_size, seq_len],
                            name="x_real")  # tokens of oracle sequences
    x_user = tf.placeholder(tf.int32, [batch_size], name="x_user")
    x_product = tf.placeholder(tf.int32, [batch_size], name="x_product")
    x_rating = tf.placeholder(tf.int32, [batch_size], name="x_rating")

    temperature = tf.Variable(1., trainable=False, name='temperature')

    x_real_onehot = tf.one_hot(x_real,
                               vocab_size)  # batch_size x seq_len x vocab_size
    assert x_real_onehot.get_shape().as_list() == [
        batch_size, seq_len, vocab_size
    ]

    # generator and discriminator outputs
    generator_obj = generator(x_real=x_real,
                              temperature=temperature,
                              x_user=x_user,
                              x_product=x_product,
                              x_rating=x_rating)
    discriminator_real = discriminator(
        x_onehot=x_real_onehot)  # , with_out=False)
    discriminator_fake = discriminator(
        x_onehot=generator_obj.gen_x_onehot_adv)  # , with_out=False)

    # GAN / Divergence type
    log_pg, g_loss, d_loss = get_losses(generator_obj, discriminator_real,
                                        discriminator_fake, config)

    # Global step
    global_step = tf.Variable(0, trainable=False)
    global_step_op = global_step.assign_add(1)

    # Train ops
    g_pretrain_op, g_train_op, d_train_op, d_topic_pretrain_op = get_train_ops(
        config, generator_obj.pretrain_loss, g_loss, d_loss, None, log_pg,
        temperature, global_step)

    # Record wall clock time
    time_diff = tf.placeholder(tf.float32)
    Wall_clock_time = tf.Variable(0., trainable=False)
    update_Wall_op = Wall_clock_time.assign_add(time_diff)

    # Temperature placeholder
    temp_var = tf.placeholder(tf.float32)
    update_temperature_op = temperature.assign(temp_var)

    # Loss summaries
    loss_summaries = [
        tf.summary.scalar('adv_loss/discriminator/total', d_loss),
        tf.summary.scalar('adv_loss/generator/total_g_loss', g_loss),
        tf.summary.scalar('adv_loss/log_pg', log_pg),
        tf.summary.scalar('adv_loss/Wall_clock_time', Wall_clock_time),
        tf.summary.scalar('adv_loss/temperature', temperature),
    ]
    loss_summary_op = tf.summary.merge(loss_summaries)

    # Metric Summaries
    config['bleu_amazon'] = True
    config['bleu_amazon_validation'] = True
    metrics_pl, metric_summary_op = get_metric_summary_op(config)

    # Summaries
    gen_pretrain_loss_summary = CustomSummary(name='pretrain_loss',
                                              scope='generator')
    gen_sentences_summary = CustomSummary(name='generated_sentences',
                                          scope='generator',
                                          summary_type=tf.summary.text,
                                          item_type=tf.string)
    topic_discr_pretrain_summary = CustomSummary(name='pretrain_loss',
                                                 scope='topic_discriminator')
    topic_discr_accuracy_summary = CustomSummary(name='pretrain_accuracy',
                                                 scope='topic_discriminator')
    run_information = CustomSummary(name='run_information',
                                    scope='info',
                                    summary_type=tf.summary.text,
                                    item_type=tf.string)
    custom_summaries = [
        gen_pretrain_loss_summary, gen_sentences_summary,
        topic_discr_pretrain_summary, topic_discr_accuracy_summary,
        run_information
    ]

    # To save the trained model
    saver = tf.train.Saver()
    # ------------- initial the graph --------------
    with init_sess() as sess:
        variables_dict = get_parameters_division()

        log = open(csv_file, 'w')
        sum_writer = tf.summary.FileWriter(os.path.join(log_dir, 'summary'),
                                           sess.graph)
        for custom_summary in custom_summaries:
            custom_summary.set_file_writer(sum_writer, sess)
        run_information.write_summary(str(args), 0)
        print("Information stored in the summary!")

        metrics = get_metrics(config, oracle_loader, sess, json_file,
                              json_file_validation, generator_obj)

        gc.collect()
        # Check if there is a pretrained generator saved
        model_dir = "PretrainGenerator"
        model_path = resources_path(
            os.path.join("checkpoint_folder", model_dir))
        try:
            new_saver = tf.train.import_meta_graph(
                os.path.join(model_path, "model.ckpt.meta"))
            new_saver.restore(sess, os.path.join(model_path, "model.ckpt"))
            print("Used saved model for generator pretrain")
        except OSError:
            print('Start pre-training...')
            # pre-training
            # Pre-train the generator using MLE for one epoch

            progress = tqdm(range(npre_epochs))
            for epoch in progress:
                g_pretrain_loss_np = generator_obj.pretrain_epoch(
                    oracle_loader, sess, g_pretrain_op=g_pretrain_op)
                gen_pretrain_loss_summary.write_summary(
                    g_pretrain_loss_np, epoch)

                # Test
                ntest_pre = 40
                if np.mod(epoch, ntest_pre) == 0:
                    generator_obj.generated_num = 200
                    json_object = generator_obj.generate_samples(
                        sess, oracle_loader, dataset="train")
                    write_json(json_file, json_object)
                    json_object = generator_obj.generate_samples(
                        sess, oracle_loader, dataset="validation")
                    write_json(json_file_validation, json_object)

                    # take sentences from saved files
                    sent = generator.get_sentences(json_object)
                    sent = take_sentences_attribute(json_object)
                    gen_sentences_summary.write_summary(sent, epoch)

                    # write summaries
                    scores = [metric.get_score() for metric in metrics]
                    metrics_summary_str = sess.run(metric_summary_op,
                                                   feed_dict=dict(
                                                       zip(metrics_pl,
                                                           scores)))
                    sum_writer.add_summary(metrics_summary_str, epoch)

                    msg = 'pre_gen_epoch:' + str(
                        epoch) + ', g_pre_loss: %.4f' % g_pretrain_loss_np
                    metric_names = [metric.get_name() for metric in metrics]
                    for (name, score) in zip(metric_names, scores):
                        msg += ', ' + name + ': %.4f' % score
                    tqdm.write(msg)
                    log.write(msg)
                    log.write('\n')

                    gc.collect()

        print('Start adversarial training...')
        progress = tqdm(range(nadv_steps))
        for adv_epoch in progress:
            gc.collect()
            niter = sess.run(global_step)

            t0 = time.time()
            # Adversarial training
            for _ in range(config['gsteps']):
                user, product, rating, sentence = oracle_loader.random_batch(
                    dataset="train")
                feed_dict = {
                    generator_obj.x_user: user,
                    generator_obj.x_product: product,
                    generator_obj.x_rating: rating
                }
                sess.run(g_train_op, feed_dict=feed_dict)
            for _ in range(config['dsteps']):
                user, product, rating, sentence = oracle_loader.random_batch(
                    dataset="train")
                n = np.zeros((batch_size, seq_len))
                for ind, el in enumerate(sentence):
                    n[ind] = el
                feed_dict = {
                    generator_obj.x_user: user,
                    generator_obj.x_product: product,
                    generator_obj.x_rating: rating,
                    x_real: n
                }
                sess.run(d_train_op, feed_dict=feed_dict)

            t1 = time.time()
            sess.run(update_Wall_op, feed_dict={time_diff: t1 - t0})

            # temperature
            temp_var_np = get_fixed_temperature(temper, niter, nadv_steps,
                                                adapt)
            sess.run(update_temperature_op, feed_dict={temp_var: temp_var_np})

            user, product, rating, sentence = oracle_loader.random_batch(
                dataset="train")
            n = np.zeros((batch_size, seq_len))
            for ind, el in enumerate(sentence):
                n[ind] = el
            feed_dict = {
                generator_obj.x_user: user,
                generator_obj.x_product: product,
                generator_obj.x_rating: rating,
                x_real: n
            }
            g_loss_np, d_loss_np, loss_summary_str = sess.run(
                [g_loss, d_loss, loss_summary_op], feed_dict=feed_dict)
            sum_writer.add_summary(loss_summary_str, niter)

            sess.run(global_step_op)

            progress.set_description('g_loss: %4.4f, d_loss: %4.4f' %
                                     (g_loss_np, d_loss_np))

            # Test
            if np.mod(adv_epoch, 300) == 0 or adv_epoch == nadv_steps - 1:
                generator_obj.generated_num = generator_obj.batch_size * 10
                json_object = generator_obj.generate_samples(sess,
                                                             oracle_loader,
                                                             dataset="train")
                write_json(json_file, json_object)
                json_object = generator_obj.generate_samples(
                    sess, oracle_loader, dataset="validation")
                write_json(json_file_validation, json_object)

                # take sentences from saved files
                sent = take_sentences_attribute(json_object)
                gen_sentences_summary.write_summary(sent,
                                                    adv_epoch + npre_epochs)

                # write summaries
                scores = [metric.get_score() for metric in metrics]
                metrics_summary_str = sess.run(metric_summary_op,
                                               feed_dict=dict(
                                                   zip(metrics_pl, scores)))
                sum_writer.add_summary(metrics_summary_str,
                                       adv_epoch + npre_epochs)
                # tqdm.write("in {} seconds".format(time.time() - t))

                msg = 'pre_gen_epoch:' + str(
                    adv_epoch) + ', g_pre_loss: %.4f' % g_pretrain_loss_np
                metric_names = [metric.get_name() for metric in metrics]
                for (name, score) in zip(metric_names, scores):
                    msg += ', ' + name + ': %.4f' % score
                tqdm.write(msg)
                log.write(msg)
                log.write('\n')

                gc.collect()

        sum_writer.close()

        save_model = True
        if save_model:
            model_dir = datetime.datetime.now().strftime("%Y%m%d%H%M%S")
            model_path = os.path.join(resources_path("trained_models"),
                                      model_dir)
            simple_save(sess,
                        model_path,
                        inputs={
                            "x_user": x_user,
                            "x_rating": x_rating,
                            "x_product": x_product
                        },
                        outputs={"gen_x": generator_obj.gen_x})
            # save_path = saver.save(sess, os.path.join(model_path, "model.ckpt"))
            print("Model saved in path: %s" % model_path)
Ejemplo n.º 30
0
def customer_reviews_train(generator: ReviewGenerator,
                           discriminator_positive: ReviewDiscriminator,
                           discriminator_negative: ReviewDiscriminator,
                           oracle_loader: RealDataCustomerReviewsLoader,
                           config, args):
    batch_size = config['batch_size']
    num_sentences = config['num_sentences']
    vocab_size = config['vocabulary_size']
    seq_len = config['seq_len']
    dataset = config['dataset']
    npre_epochs = config['npre_epochs']
    nadv_steps = config['nadv_steps']
    temper = config['temperature']
    adapt = config['adapt']

    # changed to match resources path
    data_dir = resources_path(config['data_dir'], "Amazon_Attribute")
    log_dir = resources_path(config['log_dir'])
    sample_dir = resources_path(config['sample_dir'])

    # filename
    json_file = os.path.join(sample_dir, 'json_file.txt')
    csv_file = os.path.join(log_dir, 'experiment-log-rmcgan.csv')

    # create necessary directories
    if not os.path.exists(data_dir):
        os.makedirs(data_dir)
    if not os.path.exists(sample_dir):
        os.makedirs(sample_dir)
    if not os.path.exists(log_dir):
        os.makedirs(log_dir)

    # placeholder definitions
    x_real = tf.placeholder(tf.int32, [batch_size, seq_len], name="x_real")
    x_pos = tf.placeholder(tf.int32, [batch_size, seq_len], name="x_pos")
    x_neg = tf.placeholder(tf.int32, [batch_size, seq_len], name="x_neg")
    x_sentiment = tf.placeholder(tf.int32, [batch_size], name="x_sentiment")

    temperature = tf.Variable(1., trainable=False, name='temperature')

    x_real_pos_onehot = tf.one_hot(
        x_pos, vocab_size)  # batch_size x seq_len x vocab_size
    x_real_neg_onehot = tf.one_hot(
        x_neg, vocab_size)  # batch_size x seq_len x vocab_size
    assert x_real_pos_onehot.get_shape().as_list() == [
        batch_size, seq_len, vocab_size
    ]

    # generator and discriminator outputs
    generator_obj = generator(x_real=x_real,
                              temperature=temperature,
                              x_sentiment=x_sentiment)
    # discriminator for positive sentences
    discriminator_positive_real_pos = discriminator_positive(
        x_onehot=x_real_pos_onehot)
    discriminator_positive_real_neg = discriminator_positive(
        x_onehot=x_real_neg_onehot)
    discriminator_positive_fake = discriminator_positive(
        x_onehot=generator_obj.gen_x_onehot_adv)
    # discriminator for negative sentences
    discriminator_negative_real_pos = discriminator_negative(
        x_onehot=x_real_pos_onehot)
    discriminator_negative_real_neg = discriminator_negative(
        x_onehot=x_real_neg_onehot)
    discriminator_negative_fake = discriminator_negative(
        x_onehot=generator_obj.gen_x_onehot_adv)

    # GAN / Divergence type

    log_pg, g_loss, d_loss = get_losses(
        generator_obj, discriminator_positive_real_pos,
        discriminator_positive_real_neg, discriminator_positive_fake,
        discriminator_negative_real_pos, discriminator_negative_real_neg,
        discriminator_negative_fake)

    # Global step
    global_step = tf.Variable(0, trainable=False)
    global_step_op = global_step.assign_add(1)

    # Train ops
    g_pretrain_op, g_train_op, d_train_op, d_topic_pretrain_op = get_train_ops(
        config, generator_obj.pretrain_loss, g_loss, d_loss, None, log_pg,
        temperature, global_step)

    # Record wall clock time
    time_diff = tf.placeholder(tf.float32)
    Wall_clock_time = tf.Variable(0., trainable=False)
    update_Wall_op = Wall_clock_time.assign_add(time_diff)

    # Temperature placeholder
    temp_var = tf.placeholder(tf.float32)
    update_temperature_op = temperature.assign(temp_var)

    # Loss summaries
    loss_summaries = [
        tf.summary.scalar('adv_loss/discriminator/total', d_loss),
        tf.summary.scalar('adv_loss/generator/total_g_loss', g_loss),
        tf.summary.scalar('adv_loss/log_pg', log_pg),
        tf.summary.scalar('adv_loss/Wall_clock_time', Wall_clock_time),
        tf.summary.scalar('adv_loss/temperature', temperature),
    ]
    loss_summary_op = tf.summary.merge(loss_summaries)

    # Metric Summaries
    metrics_pl, metric_summary_op = get_metric_summary_op(config)

    # Summaries
    gen_pretrain_loss_summary = CustomSummary(name='pretrain_loss',
                                              scope='generator')
    gen_sentences_summary = CustomSummary(name='generated_sentences',
                                          scope='generator',
                                          summary_type=tf.summary.text,
                                          item_type=tf.string)
    run_information = CustomSummary(name='run_information',
                                    scope='info',
                                    summary_type=tf.summary.text,
                                    item_type=tf.string)
    custom_summaries = [
        gen_pretrain_loss_summary, gen_sentences_summary, run_information
    ]

    # To save the trained model
    # ------------- initial the graph --------------
    with init_sess() as sess:

        # count parameters

        log = open(csv_file, 'w')
        summary_dir = os.path.join(log_dir, 'summary', str(time.time()))
        if not os.path.exists(summary_dir):
            os.makedirs(summary_dir)
        sum_writer = tf.summary.FileWriter(summary_dir, sess.graph)
        for custom_summary in custom_summaries:
            custom_summary.set_file_writer(sum_writer, sess)

        run_information.write_summary(str(args), 0)
        print("Information stored in the summary!")

        def get_metrics():
            # set up evaluation metric
            metrics = []
            if config['nll_gen']:
                nll_gen = NllReview(oracle_loader,
                                    generator_obj,
                                    sess,
                                    name='nll_gen_review')
                metrics.append(nll_gen)
            if config['KL']:
                KL_div = KL_divergence(oracle_loader,
                                       json_file,
                                       name='KL_divergence')
                metrics.append(KL_div)
            if config['jaccard_similarity']:
                Jaccard_Sim = JaccardSimilarity(oracle_loader,
                                                json_file,
                                                name='jaccard_similarity')
                metrics.append(Jaccard_Sim)
            if config['jaccard_diversity']:
                Jaccard_Sim = JaccardDiversity(oracle_loader,
                                               json_file,
                                               name='jaccard_diversity')
                metrics.append(Jaccard_Sim)

            return metrics

        metrics = get_metrics()
        generator_obj.generated_num = 200  #num_sentences

        gc.collect()
        # Check if there is a pretrained generator saved
        model_dir = "PretrainGenerator"
        model_path = resources_path(
            os.path.join("checkpoint_folder", model_dir))
        try:
            new_saver = tf.train.import_meta_graph(
                os.path.join(model_path, "model.ckpt.meta"))
            new_saver.restore(sess, os.path.join(model_path, "model.ckpt"))
            print("Used saved model for generator pretrain")
        except OSError:
            print('Start pre-training...')
            # pre-training
            # Pre-train the generator using MLE for one epoch

            progress = tqdm(range(npre_epochs))
            for epoch in progress:
                oracle_loader.reset_pointer()
                g_pretrain_loss_np = generator_obj.pretrain_epoch(
                    oracle_loader, sess, g_pretrain_op=g_pretrain_op)
                gen_pretrain_loss_summary.write_summary(
                    g_pretrain_loss_np, epoch)
                msg = 'pre_gen_epoch:' + str(
                    epoch) + ', g_pre_loss: %.4f' % g_pretrain_loss_np
                progress.set_description(msg)

                # Test
                ntest_pre = 30
                if np.mod(epoch, ntest_pre) == 0 or epoch == npre_epochs - 1:
                    json_object = generator_obj.generate_json(
                        oracle_loader, sess)
                    write_json(json_file, json_object)

                    # take sentences from saved files
                    sent = take_sentences_json(json_object)
                    gen_sentences_summary.write_summary(sent, epoch)

                    # write summaries
                    scores = [metric.get_score() for metric in metrics]
                    metrics_summary_str = sess.run(metric_summary_op,
                                                   feed_dict=dict(
                                                       zip(metrics_pl,
                                                           scores)))
                    sum_writer.add_summary(metrics_summary_str, epoch)

                    msg = 'pre_gen_epoch:' + str(
                        epoch) + ', g_pre_loss: %.4f' % g_pretrain_loss_np
                    metric_names = [metric.get_name() for metric in metrics]
                    for (name, score) in zip(metric_names, scores):
                        msg += ', ' + name + ': %.4f' % score
                    tqdm.write(msg)
                    log.write(msg)
                    log.write('\n')

                    gc.collect()

        gc.collect()

        print('Start adversarial training...')
        progress = tqdm(range(nadv_steps))
        for adv_epoch in progress:
            gc.collect()
            niter = sess.run(global_step)

            t0 = time.time()
            # Adversarial training
            for _ in range(config['gsteps']):
                sentiment, sentence = oracle_loader.random_batch()
                n = np.zeros((generator_obj.batch_size, generator_obj.seq_len))
                for ind, el in enumerate(sentence):
                    n[ind] = el
                sess.run(g_pretrain_op,
                         feed_dict={
                             generator_obj.x_real: n,
                             generator_obj.x_sentiment: sentiment
                         })
            for _ in range(config['dsteps']):
                sentiment, sentence, pos, neg = oracle_loader.get_positive_negative_batch(
                )
                n1 = np.zeros(
                    (generator_obj.batch_size, generator_obj.seq_len))
                n2 = np.zeros(
                    (generator_obj.batch_size, generator_obj.seq_len))
                n3 = np.zeros(
                    (generator_obj.batch_size, generator_obj.seq_len))
                for ind, (s, p, n) in enumerate(zip(sentence, pos, neg)):
                    n1[ind] = s
                    n2[ind] = p[0]
                    n3[ind] = n[0]
                feed_dict = {
                    x_real: n1,
                    x_pos: n2,
                    x_neg: n3,
                    x_sentiment: sentiment
                }
                sess.run(d_train_op, feed_dict=feed_dict)

            t1 = time.time()
            sess.run(update_Wall_op, feed_dict={time_diff: t1 - t0})

            # temperature
            temp_var_np = get_fixed_temperature(temper, niter, nadv_steps,
                                                adapt)
            sess.run(update_temperature_op, feed_dict={temp_var: temp_var_np})

            sentiment, sentence, pos, neg = oracle_loader.get_positive_negative_batch(
            )
            n1 = np.zeros((generator_obj.batch_size, generator_obj.seq_len))
            n2 = np.zeros((generator_obj.batch_size, generator_obj.seq_len))
            n3 = np.zeros((generator_obj.batch_size, generator_obj.seq_len))
            for ind, (s, p, n) in enumerate(zip(sentence, pos, neg)):
                n1[ind] = s
                n2[ind] = p[0]
                n3[ind] = n[0]
            feed_dict = {
                x_real: n1,
                x_pos: n2,
                x_neg: n3,
                x_sentiment: sentiment
            }
            g_loss_np, d_loss_np, loss_summary_str = sess.run(
                [g_loss, d_loss, loss_summary_op], feed_dict=feed_dict)
            sum_writer.add_summary(loss_summary_str, niter)

            sess.run(global_step_op)

            progress.set_description('g_loss: %4.4f, d_loss: %4.4f' %
                                     (g_loss_np, d_loss_np))

            # Test
            # print("N_iter: {}, test every {} epochs".format(niter, config['ntest']))
            if np.mod(adv_epoch, 100) == 0 or adv_epoch == nadv_steps - 1:
                json_object = generator_obj.generate_json(oracle_loader, sess)
                write_json(json_file, json_object)

                # take sentences from saved files
                sent = take_sentences_json(json_object)
                gen_sentences_summary.write_summary(
                    sent, niter + config['npre_epochs'])

                # write summaries
                scores = [metric.get_score() for metric in metrics]
                metrics_summary_str = sess.run(metric_summary_op,
                                               feed_dict=dict(
                                                   zip(metrics_pl, scores)))
                sum_writer.add_summary(metrics_summary_str,
                                       niter + config['npre_epochs'])

                msg = 'adv_step: ' + str(niter)
                metric_names = [metric.get_name() for metric in metrics]
                for (name, score) in zip(metric_names, scores):
                    msg += ', ' + name + ': %.4f' % score
                tqdm.write(msg)
                log.write(msg)
                log.write('\n')

                gc.collect()

        sum_writer.close()

        save_model = False
        if save_model:
            model_dir = datetime.datetime.now().strftime("%Y%m%d%H%M%S")
            model_path = os.path.join(resources_path("trained_models"),
                                      model_dir)
            simple_save(sess,
                        model_path,
                        inputs={"x_topic": x_topic},
                        outputs={"gen_x": x_fake})
            # save_path = saver.save(sess, os.path.join(model_path, "model.ckpt"))
            print("Model saved in path: %s" % model_path)
Ejemplo n.º 31
0
def real_topic_train(generator_obj, discriminator_obj, topic_discriminator_obj,
                     oracle_loader: RealDataTopicLoader, config, args):
    batch_size = config['batch_size']
    num_sentences = config['num_sentences']
    vocab_size = config['vocab_size']
    seq_len = config['seq_len']
    dataset = config['dataset']
    npre_epochs = config['npre_epochs']
    n_topic_pre_epochs = config['n_topic_pre_epochs']
    nadv_steps = config['nadv_steps']
    temper = config['temperature']
    adapt = config['adapt']

    # changed to match resources path
    data_dir = resources_path(config['data_dir'])
    log_dir = resources_path(config['log_dir'])
    sample_dir = resources_path(config['sample_dir'])

    # filename
    oracle_file = os.path.join(sample_dir, 'oracle_{}.txt'.format(dataset))
    gen_file = os.path.join(sample_dir, 'generator.txt')
    gen_text_file = os.path.join(sample_dir, 'generator_text.txt')
    json_file = os.path.join(sample_dir, 'json_file.txt')
    csv_file = os.path.join(log_dir, 'experiment-log-rmcgan.csv')
    data_file = os.path.join(data_dir, '{}.txt'.format(dataset))
    if dataset == 'image_coco':
        test_file = os.path.join(data_dir, 'testdata/test_coco.txt')
    elif dataset == 'emnlp_news':
        test_file = os.path.join(data_dir, 'testdata/test_emnlp.txt')
    else:
        raise NotImplementedError('Unknown dataset!')

    # create necessary directories
    if not os.path.exists(data_dir):
        os.makedirs(data_dir)
    if not os.path.exists(sample_dir):
        os.makedirs(sample_dir)
    if not os.path.exists(log_dir):
        os.makedirs(log_dir)

    # placeholder definitions
    x_real = tf.placeholder(tf.int32, [batch_size, seq_len],
                            name="x_real")  # tokens of oracle sequences
    x_topic = tf.placeholder(tf.float32,
                             [batch_size, oracle_loader.vocab_size + 1],
                             name="x_topic")  # todo stessa cosa del +1
    x_topic_random = tf.placeholder(tf.float32,
                                    [batch_size, oracle_loader.vocab_size + 1],
                                    name="x_topic_random")

    temperature = tf.Variable(1., trainable=False, name='temperature')

    x_real_onehot = tf.one_hot(x_real,
                               vocab_size)  # batch_size x seq_len x vocab_size
    assert x_real_onehot.get_shape().as_list() == [
        batch_size, seq_len, vocab_size
    ]

    # generator and discriminator outputs
    generator = generator_obj(x_real=x_real,
                              temperature=temperature,
                              x_topic=x_topic)
    d_real = discriminator_obj(x_onehot=x_real_onehot)
    d_fake = discriminator_obj(x_onehot=generator.gen_x_onehot_adv)
    if not args.no_topic:
        d_topic_real_pos = topic_discriminator_obj(x_onehot=x_real_onehot,
                                                   x_topic=x_topic)
        d_topic_real_neg = topic_discriminator_obj(x_onehot=x_real_onehot,
                                                   x_topic=x_topic_random)
        d_topic_fake = topic_discriminator_obj(
            x_onehot=generator.gen_x_onehot_adv, x_topic=x_topic)
    else:
        d_topic_real_pos = None
        d_topic_real_neg = None
        d_topic_fake = None

    # GAN / Divergence type
    losses = get_losses(generator, d_real, d_fake, d_topic_real_pos,
                        d_topic_real_neg, d_topic_fake, config)
    if not args.no_topic:
        d_topic_loss = losses['d_topic_loss_real_pos'] + losses[
            'd_topic_loss_real_neg']  # only from real data for pretrain
        d_topic_accuracy = get_accuracy(d_topic_real_pos.logits,
                                        d_topic_real_neg.logits)
    else:
        d_topic_loss = None
        d_topic_accuracy = None

    # Global step
    global_step = tf.Variable(0, trainable=False)
    global_step_op = global_step.assign_add(1)

    # Train ops
    g_pretrain_op, g_train_op, d_train_op, d_topic_pretrain_op = get_train_ops(
        config, generator.pretrain_loss, losses['g_loss'], losses['d_loss'],
        d_topic_loss, losses['log_pg'], temperature, global_step)
    generator.g_pretrain_op = g_pretrain_op

    # Record wall clock time
    time_diff = tf.placeholder(tf.float32)
    Wall_clock_time = tf.Variable(0., trainable=False)
    update_Wall_op = Wall_clock_time.assign_add(time_diff)

    # Temperature placeholder
    temp_var = tf.placeholder(tf.float32)
    update_temperature_op = temperature.assign(temp_var)

    # Loss summaries
    loss_summaries = [
        tf.summary.scalar('adv_loss/discriminator/classic/d_loss_real',
                          losses['d_loss_real']),
        tf.summary.scalar('adv_loss/discriminator/classic/d_loss_fake',
                          losses['d_loss_fake']),
        tf.summary.scalar('adv_loss/discriminator/total', losses['d_loss']),
        tf.summary.scalar('adv_loss/generator/g_sentence_loss',
                          losses['g_sentence_loss']),
        tf.summary.scalar('adv_loss/generator/total_g_loss', losses['g_loss']),
        tf.summary.scalar('adv_loss/log_pg', losses['log_pg']),
        tf.summary.scalar('adv_loss/Wall_clock_time', Wall_clock_time),
        tf.summary.scalar('adv_loss/temperature', temperature),
    ]
    if not args.no_topic:
        loss_summaries += [
            tf.summary.scalar(
                'adv_loss/discriminator/topic_discriminator/d_topic_loss_real_pos',
                losses['d_topic_loss_real_pos']),
            tf.summary.scalar(
                'adv_loss/discriminator/topic_discriminator/d_topic_loss_fake',
                losses['d_topic_loss_fake']),
            tf.summary.scalar('adv_loss/generator/g_topic_loss',
                              losses['g_topic_loss'])
        ]

    loss_summary_op = tf.summary.merge(loss_summaries)

    # Metric Summaries
    metrics_pl, metric_summary_op = get_metric_summary_op(config)

    # Summaries
    gen_pretrain_loss_summary = CustomSummary(name='pretrain_loss',
                                              scope='generator')
    gen_sentences_summary = CustomSummary(name='generated_sentences',
                                          scope='generator',
                                          summary_type=tf.summary.text,
                                          item_type=tf.string)
    topic_discr_pretrain_summary = CustomSummary(name='pretrain_loss',
                                                 scope='topic_discriminator')
    topic_discr_accuracy_summary = CustomSummary(name='pretrain_accuracy',
                                                 scope='topic_discriminator')
    run_information = CustomSummary(name='run_information',
                                    scope='info',
                                    summary_type=tf.summary.text,
                                    item_type=tf.string)
    custom_summaries = [
        gen_pretrain_loss_summary, gen_sentences_summary,
        topic_discr_pretrain_summary, topic_discr_accuracy_summary,
        run_information
    ]

    # To save the trained model
    saver = tf.compat.v1.train.Saver()

    # ------------- initial the graph --------------
    with init_sess() as sess:
        variables_dict = get_parameters_division()

        print("Total paramter number: {}".format(
            np.sum([
                np.prod(v.get_shape().as_list())
                for v in tf.trainable_variables()
            ])))
        log = open(csv_file, 'w')

        now = datetime.datetime.now()
        additional_text = now.strftime(
            "%Y-%m-%d_%H-%M") + "_" + args.summary_name
        summary_dir = os.path.join(log_dir, 'summary', additional_text)
        if not os.path.exists(summary_dir):
            os.makedirs(summary_dir)
        sum_writer = tf.compat.v1.summary.FileWriter(os.path.join(summary_dir),
                                                     sess.graph)
        for custom_summary in custom_summaries:
            custom_summary.set_file_writer(sum_writer, sess)

        run_information.write_summary(str(args), 0)
        print("Information stored in the summary!")

        # generate oracle data and create batches
        oracle_loader.create_batches(oracle_file)

        metrics = get_metrics(config, oracle_loader, test_file, gen_text_file,
                              generator.pretrain_loss, x_real, x_topic, sess,
                              json_file)

        gc.collect()

        # Check if there is a pretrained generator saved
        model_dir = "PretrainGenerator"
        model_path = resources_path(
            os.path.join("checkpoint_folder", model_dir))
        try:
            new_saver = tf.train.import_meta_graph(
                os.path.join(model_path, "model.ckpt.meta"))
            new_saver.restore(sess, os.path.join(model_path, "model.ckpt"))
            print("Used saved model for generator pretrain")
        except OSError:
            print('Start pre-training...')
            progress = tqdm(range(npre_epochs), dynamic_ncols=True)
            for epoch in progress:
                # pre-training
                g_pretrain_loss_np = generator.pretrain_epoch(
                    sess, oracle_loader)
                gen_pretrain_loss_summary.write_summary(
                    g_pretrain_loss_np, epoch)

                # Test
                ntest_pre = 30
                if np.mod(epoch, ntest_pre) == 0:
                    json_object = generator.generate_samples_topic(
                        sess, oracle_loader, num_sentences)
                    write_json(json_file, json_object)

                    # take sentences from saved files
                    sent = generator.get_sentences(json_object)
                    gen_sentences_summary.write_summary(sent, epoch)

                    # write summaries
                    t = time.time()
                    scores = [metric.get_score() for metric in metrics]
                    metrics_summary_str = sess.run(metric_summary_op,
                                                   feed_dict=dict(
                                                       zip(metrics_pl,
                                                           scores)))
                    sum_writer.add_summary(metrics_summary_str, epoch)

                    msg = 'pre_gen_epoch:' + str(
                        epoch) + ', g_pre_loss: %.4f' % g_pretrain_loss_np
                    metric_names = [metric.get_name() for metric in metrics]
                    for (name, score) in zip(metric_names, scores):
                        score = score * 1e5 if 'Earth' in name else score
                        msg += ', ' + name + ': %.4f' % score
                    progress.set_description(
                        msg + " in {:.2f} sec".format(time.time() - t))
                    log.write(msg)
                    log.write('\n')

                    gc.collect()

        if not args.no_topic:
            gc.collect()
            print('Start Topic Discriminator pre-training...')
            progress = tqdm(range(n_topic_pre_epochs))
            for epoch in progress:
                # pre-training and write loss
                # Pre-train the generator using MLE for one epoch
                supervised_g_losses = []
                supervised_accuracy = []
                oracle_loader.reset_pointer()

                for it in range(oracle_loader.num_batch):
                    text_batch, topic_batch = oracle_loader.next_batch(
                        only_text=False)
                    _, topic_loss, accuracy = sess.run(
                        [d_topic_pretrain_op, d_topic_loss, d_topic_accuracy],
                        feed_dict={
                            x_real: text_batch,
                            x_topic: topic_batch,
                            x_topic_random: oracle_loader.random_topic()
                        })
                    supervised_g_losses.append(topic_loss)
                    supervised_accuracy.append(accuracy)

                d_topic_pretrain_loss = np.mean(supervised_g_losses)
                accuracy_mean = np.mean(supervised_accuracy)
                topic_discr_pretrain_summary.write_summary(
                    d_topic_pretrain_loss, epoch)
                topic_discr_accuracy_summary.write_summary(
                    accuracy_mean, epoch)
                progress.set_description(
                    'topic_loss: %4.4f, accuracy: %4.4f' %
                    (d_topic_pretrain_loss, accuracy_mean))

        print('Start adversarial training...')
        progress = tqdm(range(nadv_steps))
        for adv_epoch in progress:
            gc.collect()
            niter = sess.run(global_step)

            t0 = time.time()
            # Adversarial training
            for _ in range(config['gsteps']):
                text_batch, topic_batch = oracle_loader.random_batch(
                    only_text=False)
                sess.run(g_train_op,
                         feed_dict={
                             x_real: text_batch,
                             x_topic: topic_batch
                         })
            for _ in range(config['dsteps']):
                # normal + topic discriminator together
                text_batch, topic_batch = oracle_loader.random_batch(
                    only_text=False)
                sess.run(d_train_op,
                         feed_dict={
                             x_real: text_batch,
                             x_topic: topic_batch,
                             x_topic_random: oracle_loader.random_topic()
                         })

            t1 = time.time()
            sess.run(update_Wall_op, feed_dict={time_diff: t1 - t0})

            # temperature
            temp_var_np = get_fixed_temperature(temper, niter, nadv_steps,
                                                adapt)
            sess.run(update_temperature_op, feed_dict={temp_var: temp_var_np})

            text_batch, topic_batch = oracle_loader.random_batch(
                only_text=False)
            feed = {
                x_real: text_batch,
                x_topic: topic_batch,
                x_topic_random: oracle_loader.random_topic()
            }
            g_loss_np, d_loss_np, loss_summary_str = sess.run(
                [losses['g_loss'], losses['d_loss'], loss_summary_op],
                feed_dict=feed)
            sum_writer.add_summary(loss_summary_str, niter)

            sess.run(global_step_op)

            progress.set_description('g_loss: %4.4f, d_loss: %4.4f' %
                                     (g_loss_np, d_loss_np))

            # Test
            if np.mod(adv_epoch, 400) == 0 or adv_epoch == nadv_steps - 1:
                json_object = generator.generate_samples_topic(
                    sess, oracle_loader, num_sentences)
                write_json(json_file, json_object)

                # take sentences from saved files
                sent = generator.get_sentences(json_object)
                gen_sentences_summary.write_summary(sent, adv_epoch)

                # write summaries
                scores = [metric.get_score() for metric in metrics]
                metrics_summary_str = sess.run(metric_summary_op,
                                               feed_dict=dict(
                                                   zip(metrics_pl, scores)))
                sum_writer.add_summary(metrics_summary_str,
                                       niter + config['npre_epochs'])

                msg = 'adv_step: ' + str(niter)
                metric_names = [metric.get_name() for metric in metrics]
                for (name, score) in zip(metric_names, scores):
                    msg += ', ' + name + ': %.4f' % score
                tqdm.write(msg)
                log.write(msg)
                log.write('\n')

        sum_writer.close()

        save_model = False
        if save_model:
            model_dir = datetime.datetime.now().strftime("%Y%m%d%H%M%S")
            model_path = os.path.join(resources_path("trained_models"),
                                      model_dir)
            simple_save(sess,
                        model_path,
                        inputs={"x_topic": x_topic},
                        outputs={"gen_x": generator.gen_x})
            # save_path = saver.save(sess, os.path.join(model_path, "model.ckpt"))
            print("Model saved in path: %s" % model_path)
Ejemplo n.º 32
0
def real_topic_train_NoDiscr(generator, oracle_loader: RealDataTopicLoader,
                             config, args):
    batch_size = config['batch_size']
    num_sentences = config['num_sentences']
    vocab_size = config['vocab_size']
    seq_len = config['seq_len']
    dataset = config['dataset']
    npre_epochs = config['npre_epochs']
    nadv_steps = config['nadv_steps']
    temper = config['temperature']
    adapt = config['adapt']

    # changed to match resources path
    data_dir = resources_path(config['data_dir'])
    log_dir = resources_path(config['log_dir'])
    sample_dir = resources_path(config['sample_dir'])

    # filename
    oracle_file = os.path.join(sample_dir, 'oracle_{}.txt'.format(dataset))
    gen_file = os.path.join(sample_dir, 'generator.txt')
    gen_text_file = os.path.join(sample_dir, 'generator_text.txt')
    gen_text_file_print = os.path.join(sample_dir, 'gen_text_file_print.txt')
    json_file = os.path.join(sample_dir,
                             'json_file{}.txt'.format(args.json_file))
    csv_file = os.path.join(log_dir, 'experiment-log-rmcgan.csv')
    data_file = os.path.join(data_dir, '{}.txt'.format(dataset))
    if dataset == 'image_coco':
        test_file = os.path.join(data_dir, 'testdata/test_coco.txt')
    elif dataset == 'emnlp_news':
        test_file = os.path.join(data_dir, 'testdata/test_emnlp.txt')
    else:
        raise NotImplementedError('Unknown dataset!')

    # create necessary directories
    if not os.path.exists(data_dir):
        os.makedirs(data_dir)
    if not os.path.exists(sample_dir):
        os.makedirs(sample_dir)
    if not os.path.exists(log_dir):
        os.makedirs(log_dir)

    # placeholder definitions
    x_real = placeholder(tf.int32, [batch_size, seq_len],
                         name="x_real")  # tokens of oracle sequences
    x_topic = placeholder(tf.float32,
                          [batch_size, oracle_loader.vocab_size + 1],
                          name="x_topic")  # todo stessa cosa del +1
    x_topic_random = placeholder(tf.float32,
                                 [batch_size, oracle_loader.vocab_size + 1],
                                 name="x_topic_random")

    temperature = tf.Variable(1., trainable=False, name='temperature')

    x_real_onehot = tf.one_hot(x_real,
                               vocab_size)  # batch_size x seq_len x vocab_size
    assert x_real_onehot.get_shape().as_list() == [
        batch_size, seq_len, vocab_size
    ]

    # generator and discriminator outputs
    x_fake_onehot_appr, x_fake, g_pretrain_loss, gen_o, \
    lambda_values_returned, gen_x_no_lambda = generator(x_real=x_real,
                                                        temperature=temperature,
                                                        x_topic=x_topic)

    # A function to calculate the gradients and get training operations
    def get_train_ops(config, g_pretrain_loss):
        gpre_lr = config['gpre_lr']

        g_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES,
                                   scope='generator')
        grad_clip = 5.0  # keep the same with the previous setting

        # generator pre-training
        pretrain_opt = tf.train.AdamOptimizer(gpre_lr,
                                              beta1=0.9,
                                              beta2=0.999,
                                              name="gen_pretrain_adam")
        pretrain_grad, _ = tf.clip_by_global_norm(
            tf.gradients(g_pretrain_loss, g_vars, name="gradients_g_pretrain"),
            grad_clip,
            name="g_pretrain_clipping")  # gradient clipping
        g_pretrain_op = pretrain_opt.apply_gradients(zip(
            pretrain_grad, g_vars))

        return g_pretrain_op

    # Train ops
    g_pretrain_op = get_train_ops(config, g_pretrain_loss)

    # Metric Summaries
    metrics_pl, metric_summary_op = get_metric_summary_op(config)

    # Summaries
    gen_pretrain_loss_summary = CustomSummary(name='pretrain_loss',
                                              scope='generator')
    gen_sentences_summary = CustomSummary(name='generated_sentences',
                                          scope='generator',
                                          summary_type=tf.summary.text,
                                          item_type=tf.string)
    run_information = CustomSummary(name='run_information',
                                    scope='info',
                                    summary_type=tf.summary.text,
                                    item_type=tf.string)
    custom_summaries = [
        gen_pretrain_loss_summary, gen_sentences_summary, run_information
    ]

    # ------------- initial the graph --------------
    with init_sess() as sess:
        variables_dict = get_parameters_division()

        log = open(csv_file, 'w')
        sum_writer = tf.summary.FileWriter(os.path.join(log_dir, 'summary'),
                                           sess.graph)
        for custom_summary in custom_summaries:
            custom_summary.set_file_writer(sum_writer, sess)

        run_information.write_summary(str(args), 0)
        print("Information stored in the summary!")

        oracle_loader.create_batches(oracle_file)

        metrics = get_metrics(config, oracle_loader, test_file, gen_text_file,
                              g_pretrain_loss, x_real, x_topic, sess,
                              json_file)

        gc.collect()
        # Check if there is a pretrained generator saved
        model_dir = "PretrainGenerator"
        model_path = resources_path(
            os.path.join("checkpoint_folder", model_dir))
        try:
            new_saver = tf.train.import_meta_graph(
                os.path.join(model_path, "model.ckpt.meta"))
            new_saver.restore(sess, os.path.join(model_path, "model.ckpt"))
            print("Used saved model for generator pretrain")
        except OSError:
            print('Start pre-training...')

        progress = tqdm(range(npre_epochs))
        for epoch in progress:
            # pre-training
            g_pretrain_loss_np = pre_train_epoch(sess, g_pretrain_op,
                                                 g_pretrain_loss, x_real,
                                                 oracle_loader, x_topic)
            gen_pretrain_loss_summary.write_summary(g_pretrain_loss_np, epoch)
            progress.set_description(
                "Pretrain_loss: {}".format(g_pretrain_loss_np))

            # Test
            ntest_pre = 40
            if np.mod(epoch, ntest_pre) == 0:
                json_object = generate_sentences(sess,
                                                 x_fake,
                                                 batch_size,
                                                 num_sentences,
                                                 oracle_loader=oracle_loader,
                                                 x_topic=x_topic)
                write_json(json_file, json_object)

                with open(gen_text_file, 'w') as outfile:
                    i = 0
                    for sent in json_object['sentences']:
                        if i < 200:
                            outfile.write(sent['generated_sentence'] + "\n")
                        else:
                            break

                # take sentences from saved files
                sent = take_sentences_json(json_object,
                                           first_elem='generated_sentence',
                                           second_elem='real_starting')
                gen_sentences_summary.write_summary(sent, epoch)

                # write summaries
                scores = [metric.get_score() for metric in metrics]
                metrics_summary_str = sess.run(metric_summary_op,
                                               feed_dict=dict(
                                                   zip(metrics_pl, scores)))
                sum_writer.add_summary(metrics_summary_str, epoch)

                msg = 'pre_gen_epoch:' + str(
                    epoch) + ', g_pre_loss: %.4f' % g_pretrain_loss_np
                metric_names = [metric.get_name() for metric in metrics]
                for (name, score) in zip(metric_names, scores):
                    msg += ', ' + name + ': %.4f' % score
                tqdm.write(msg)
                log.write(msg)
                log.write('\n')

                gc.collect()

        gc.collect()
        sum_writer.close()

        save_model = True
        if save_model:
            model_dir = datetime.datetime.now().strftime("%Y%m%d%H%M%S")
            model_path = os.path.join(resources_path("trained_models"),
                                      model_dir)
            simple_save(sess,
                        model_path,
                        inputs={"x_topic": x_topic},
                        outputs={"gen_x": x_fake})
            # save_path = saver.save(sess, os.path.join(model_path, "model.ckpt"))
            print("Model saved in path: %s" % model_path)
Ejemplo n.º 33
0
def test_task(
        model_file: InputBinaryFile(str),
        examples_file: InputBinaryFile(str),
        confusion_matrix: OutputTextFile(str),
        results: OutputTextFile(str),
):
    """Connects to served model and tests example MNIST images."""

    import time
    import json

    import numpy as np
    import requests
    from tensorflow.python.keras.backend import get_session
    from tensorflow.python.keras.saving import load_model
    from tensorflow.python.saved_model.simple_save import simple_save

    with get_session() as sess:
        model = load_model(model_file)
        simple_save(
            sess,
            '/output/mnist/1/',
            inputs={'input_image': model.input},
            outputs={t.name: t
                     for t in model.outputs},
        )

    model_url = 'http://localhost:9001/v1/models/mnist'

    for _ in range(60):
        try:
            requests.get(f'{model_url}/versions/1').raise_for_status()
            break
        except requests.RequestException:
            time.sleep(5)
    else:
        raise Exception("Waited too long for sidecar to come up!")

    response = requests.get(f'{model_url}/metadata')
    response.raise_for_status()
    assert response.json() == {
        'model_spec': {
            'name': 'mnist',
            'signature_name': '',
            'version': '1'
        },
        'metadata': {
            'signature_def': {
                'signature_def': {
                    'serving_default': {
                        'inputs': {
                            'input_image': {
                                'dtype': 'DT_FLOAT',
                                'tensor_shape': {
                                    'dim': [
                                        {
                                            'size': '-1',
                                            'name': ''
                                        },
                                        {
                                            'size': '28',
                                            'name': ''
                                        },
                                        {
                                            'size': '28',
                                            'name': ''
                                        },
                                        {
                                            'size': '1',
                                            'name': ''
                                        },
                                    ],
                                    'unknown_rank':
                                    False,
                                },
                                'name': 'conv2d_input:0',
                            }
                        },
                        'outputs': {
                            'dense_1/Softmax:0': {
                                'dtype': 'DT_FLOAT',
                                'tensor_shape': {
                                    'dim': [{
                                        'size': '-1',
                                        'name': ''
                                    }, {
                                        'size': '10',
                                        'name': ''
                                    }],
                                    'unknown_rank':
                                    False,
                                },
                                'name': 'dense_1/Softmax:0',
                            }
                        },
                        'method_name': 'tensorflow/serving/predict',
                    }
                }
            }
        },
    }

    examples = np.load(examples_file)
    assert examples['val_x'].shape == (100, 28, 28, 1)
    assert examples['val_y'].shape == (100, 10)

    response = requests.post(f'{model_url}:predict',
                             json={'instances': examples['val_x'].tolist()})
    response.raise_for_status()

    predicted = np.argmax(response.json()['predictions'], axis=1).tolist()
    actual = np.argmax(examples['val_y'], axis=1).tolist()
    zipped = list(zip(predicted, actual))
    accuracy = sum(1 for (p, a) in zipped if p == a) / len(predicted)

    print(f"Accuracy: {accuracy:0.2f}")
Ejemplo n.º 34
0
                     y: batch_y,
                     seqlen: batch_seqlen
                 })
        if step % display_step == 0 or step == 1:
            # Calculate batch accuracy & loss
            acc, loss = sess.run([accuracy, cost],
                                 feed_dict={
                                     x: batch_x,
                                     y: batch_y,
                                     seqlen: batch_seqlen
                                 })
            print("Step " + str(step*batch_size) + ", Minibatch Loss= " + \
                  "{:.6f}".format(loss) + ", Training Accuracy= " + \
                  "{:.5f}".format(acc))

    print("Optimization Finished!")

    # Calculate accuracy
    test_data = testset.data
    test_label = testset.labels
    test_seqlen = testset.seqlen
    print("Testing Accuracy:", \
        sess.run(accuracy, feed_dict={x: test_data, y: test_label,
                                      seqlen: test_seqlen}))

    print("Saving the model")
    simple_save(sess,
                export_dir='./saved_dynamic_rnn',
                inputs={"inp": x},
                outputs={"out": prediction})
Ejemplo n.º 35
0
                    [train_op, cost, weighted_action_loss],
                    feed_dict={
                        food: x_train[i:i + step],
                        individual_values: individual_values_train[i:i + step],
                        reward: reward_train[i:i + step],
                        actions_performed: actions_train[i:i + step],
                        actions_target: actions_target_train,
                        next_pred_reward: next_y_pred_v,
                        keep_prob: 0.99
                    })

                print("cost_train: " + str(cost_train) + " reward_loss_v: " +
                      str(reward_loss_v) + " weighted_action_loss_v: " +
                      str(weighted_action_loss_v))

        shutil.rmtree('model', ignore_errors=True)

        simple_save(sess,
                    "model",
                    inputs={"input": food},
                    outputs={
                        "action_pred": action_pred,
                        "reward_pred": reward_pred
                    })
        save_path = saver.save(sess, "model_tmp/model.ckpt")
        r.flushall()
        #action_pred_v = sess.run(
        #    [action_pred],
        #    feed_dict={food: x_train, individual_values: individual_values_train, keep_prob: 1.0})
        #print(action_pred_v)