Example #1
0
def train_agent(iterations, modeldir, logdir):
    """Train and convert the model."""
    summary_writer = tensorboard.SummaryWriter(logdir)

    rng = random.PRNGKey(0)
    rng, init_rng = random.split(rng)
    policygradient = PolicyGradient()
    params = policygradient.init(
        init_rng, jnp.ones([1, common.BOARD_SIZE,
                            common.BOARD_SIZE]))['params']
    optimizer = create_optimizer(model_params=params,
                                 learning_rate=LEARNING_RATE)

    # Main training loop
    progress_bar = tf.keras.utils.Progbar(iterations)
    for i in range(iterations):
        predict_fn = functools.partial(run_inference, optimizer.target)
        board_log, action_log, result_log = common.play_game(predict_fn)
        rewards = common.compute_rewards(result_log)
        summary_writer.scalar('game_length', len(board_log), i)
        optimizer = train_step(optimizer, board_log, action_log, rewards)

        summary_writer.flush()
        progress_bar.add(1)

    summary_writer.close()

    # Convert to tflite model
    model = PolicyGradient()
    jax_predict_fn = lambda input: model.apply({'params': optimizer.target},
                                               input)

    tf_predict = tf.function(
        jax2tf.convert(jax_predict_fn, enable_xla=False),
        input_signature=[
            tf.TensorSpec(shape=[1, common.BOARD_SIZE, common.BOARD_SIZE],
                          dtype=tf.float32,
                          name='input')
        ],
        autograph=False)

    converter = tf.lite.TFLiteConverter.from_concrete_functions(
        [tf_predict.get_concrete_function()], tf_predict)

    tflite_model = converter.convert()

    # Save the model
    with open(os.path.join(modeldir, 'planestrike.tflite'), 'wb') as f:
        f.write(tflite_model)

    print('TFLite model generated!')
Example #2
0
def train_agent(iterations, modeldir, logdir):
  """Train and convert the model."""

  model = tf.keras.models.Sequential([
      tf.keras.layers.Flatten(
          input_shape=(common.BOARD_SIZE, common.BOARD_SIZE)),
      tf.keras.layers.Dense(2 * common.BOARD_SIZE**2, activation='relu'),
      tf.keras.layers.Dense(common.BOARD_SIZE**2, activation='relu'),
      tf.keras.layers.Dense(common.BOARD_SIZE**2, activation='softmax')
  ])

  sgd = tf.keras.optimizers.SGD(learning_rate=LEARNING_RATE)

  model.compile(loss='sparse_categorical_crossentropy', optimizer=sgd)

  summary_writer = tf.summary.create_file_writer(logdir)

  def predict_fn(board):
    return model.predict(board)

  # Main training loop
  progress_bar = tf.keras.utils.Progbar(iterations)
  for i in range(iterations):
    board_log, action_log, result_log = common.play_game(predict_fn)
    with summary_writer.as_default():
      tf.summary.scalar('game_length', len(action_log), step=i)
    rewards = common.compute_rewards(result_log)

    model.fit(
        x=board_log,
        y=action_log,
        batch_size=1,
        verbose=0,
        epochs=1,
        sample_weight=rewards)

    summary_writer.flush()
    progress_bar.add(1)

  summary_writer.close()

  # Convert to tflite model
  converter = tf.lite.TFLiteConverter.from_keras_model(model)
  tflite_model = converter.convert()

  # Save the model
  with open(os.path.join(modeldir, 'planestrike.tflite'), 'wb') as f:
    f.write(tflite_model)

  print('TFLite model generated!')