def init_resnet(hparams, model):
    """Init resnet weights from a TF model if provided."""
    if not hparams['widget_encoder_checkpoint']:
        return

    reader = tf.train.NewCheckpointReader(hparams['widget_encoder_checkpoint'])

    # Initialize model weights.
    init_set = input_utils.input_fn(hparams['train_files'],
                                    1,
                                    hparams['vocab_size'],
                                    hparams['phrase_vocab_size'],
                                    hparams['max_pixel_pos'],
                                    hparams['max_dom_pos'],
                                    epoches=1,
                                    buffer_size=1)
    init_features = next(iter(init_set))
    init_target = model.compute_targets(init_features)
    model([init_features, init_target[0]], training=True)

    weight_value_tuples = []
    for layer in model._encoder._pixel_layers:  # pylint: disable=protected-access
        for param in layer.weights:
            sublayer, varname = param.name.replace(':0', '').split('/')[-2:]
            var_name = 'encoder/{}/{}'.format(sublayer, varname)
            if reader.has_tensor(var_name):
                logging.info('Found pretrained weights: %s, %s, %s', var_name,
                             param.shape,
                             reader.get_tensor(var_name).shape)
                weight_value_tuples.append(
                    (param, reader.get_tensor(var_name)))
    logging.info('Load pretrained %s weights', len(weight_value_tuples))
    tf.keras.backend.batch_set_value(weight_value_tuples)
def main(argv=None):
    del argv

    hparams = create_hparams(FLAGS.experiment)

    if hparams['distribution_strategy'] == 'multi_worker_mirrored':
        strategy = tf.distribute.experimental.MultiWorkerMirroredStrategy()
    elif hparams['distribution_strategy'] == 'mirrored':
        strategy = tf.distribute.MirroredStrategy()
    else:
        raise ValueError(
            'Only `multi_worker_mirrored` is supported strategy '
            'in Keras MNIST example at this time. Strategy passed '
            'in is %s' % hparams['distribution_strategy'])

    # Create and compile the model under Distribution strategy scope.
    # `fit`, `evaluate` and `predict` will be distributed based on the strategy
    # model was compiled with.
    with strategy.scope():
        # Build the train and eval datasets from the MNIST data.
        train_set = input_utils.input_fn(
            hparams['train_files'],
            hparams['batch_size'],
            hparams['vocab_size'],
            hparams['phrase_vocab_size'],
            hparams['max_pixel_pos'],
            hparams['max_dom_pos'],
            epoches=1,
            buffer_size=hparams['train_buffer_size'])

        dev_set = input_utils.input_fn(hparams['eval_files'],
                                       hparams['eval_batch_size'],
                                       hparams['vocab_size'],
                                       hparams['phrase_vocab_size'],
                                       hparams['max_pixel_pos'],
                                       hparams['max_dom_pos'],
                                       epoches=100,
                                       buffer_size=hparams['eval_buffer_size'])

        model = WidgetCaptionModel(hparams)
        lr_schedule = optimizer.LearningRateSchedule(
            hparams['learning_rate_constant'], hparams['hidden_size'],
            hparams['learning_rate_warmup_steps'])
        opt = tf.keras.optimizers.Adam(
            lr_schedule,
            hparams['optimizer_adam_beta1'],
            hparams['optimizer_adam_beta2'],
            epsilon=hparams['optimizer_adam_epsilon'])
        model.compile(optimizer=opt)

        init_resnet(hparams, model)

    callbacks = [tf.keras.callbacks.TerminateOnNaN()]
    if FLAGS.model_dir:
        tensorboard_callback = TensorBoardCallBack(log_dir=FLAGS.model_dir)
        callbacks.append(tensorboard_callback)
    if FLAGS.ckpt_filepath:
        model_checkpoint_callback = tf.keras.callbacks.ModelCheckpoint(
            filepath=FLAGS.ckpt_filepath, save_weights_only=True)
        callbacks.append(model_checkpoint_callback)

    # Train the model with the train dataset.
    history = model.fit(x=train_set,
                        epochs=hparams['train_epoches'],
                        validation_data=dev_set,
                        validation_steps=10,
                        callbacks=callbacks)

    logging.info('Training ends successfully. `model.fit()` result: %s',
                 history.history)