def train_agent(iterations, modeldir, logdir): """Train and convert the model.""" summary_writer = tensorboard.SummaryWriter(logdir) rng = random.PRNGKey(0) rng, init_rng = random.split(rng) policygradient = PolicyGradient() params = policygradient.init( init_rng, jnp.ones([1, common.BOARD_SIZE, common.BOARD_SIZE]))['params'] optimizer = create_optimizer(model_params=params, learning_rate=LEARNING_RATE) # Main training loop progress_bar = tf.keras.utils.Progbar(iterations) for i in range(iterations): predict_fn = functools.partial(run_inference, optimizer.target) board_log, action_log, result_log = common.play_game(predict_fn) rewards = common.compute_rewards(result_log) summary_writer.scalar('game_length', len(board_log), i) optimizer = train_step(optimizer, board_log, action_log, rewards) summary_writer.flush() progress_bar.add(1) summary_writer.close() # Convert to tflite model model = PolicyGradient() jax_predict_fn = lambda input: model.apply({'params': optimizer.target}, input) tf_predict = tf.function( jax2tf.convert(jax_predict_fn, enable_xla=False), input_signature=[ tf.TensorSpec(shape=[1, common.BOARD_SIZE, common.BOARD_SIZE], dtype=tf.float32, name='input') ], autograph=False) converter = tf.lite.TFLiteConverter.from_concrete_functions( [tf_predict.get_concrete_function()], tf_predict) tflite_model = converter.convert() # Save the model with open(os.path.join(modeldir, 'planestrike.tflite'), 'wb') as f: f.write(tflite_model) print('TFLite model generated!')
def train_agent(iterations, modeldir, logdir): """Train and convert the model.""" model = tf.keras.models.Sequential([ tf.keras.layers.Flatten( input_shape=(common.BOARD_SIZE, common.BOARD_SIZE)), tf.keras.layers.Dense(2 * common.BOARD_SIZE**2, activation='relu'), tf.keras.layers.Dense(common.BOARD_SIZE**2, activation='relu'), tf.keras.layers.Dense(common.BOARD_SIZE**2, activation='softmax') ]) sgd = tf.keras.optimizers.SGD(learning_rate=LEARNING_RATE) model.compile(loss='sparse_categorical_crossentropy', optimizer=sgd) summary_writer = tf.summary.create_file_writer(logdir) def predict_fn(board): return model.predict(board) # Main training loop progress_bar = tf.keras.utils.Progbar(iterations) for i in range(iterations): board_log, action_log, result_log = common.play_game(predict_fn) with summary_writer.as_default(): tf.summary.scalar('game_length', len(action_log), step=i) rewards = common.compute_rewards(result_log) model.fit( x=board_log, y=action_log, batch_size=1, verbose=0, epochs=1, sample_weight=rewards) summary_writer.flush() progress_bar.add(1) summary_writer.close() # Convert to tflite model converter = tf.lite.TFLiteConverter.from_keras_model(model) tflite_model = converter.convert() # Save the model with open(os.path.join(modeldir, 'planestrike.tflite'), 'wb') as f: f.write(tflite_model) print('TFLite model generated!')