def init_resnet(hparams, model): """Init resnet weights from a TF model if provided.""" if not hparams['widget_encoder_checkpoint']: return reader = tf.train.NewCheckpointReader(hparams['widget_encoder_checkpoint']) # Initialize model weights. init_set = input_utils.input_fn(hparams['train_files'], 1, hparams['vocab_size'], hparams['phrase_vocab_size'], hparams['max_pixel_pos'], hparams['max_dom_pos'], epoches=1, buffer_size=1) init_features = next(iter(init_set)) init_target = model.compute_targets(init_features) model([init_features, init_target[0]], training=True) weight_value_tuples = [] for layer in model._encoder._pixel_layers: # pylint: disable=protected-access for param in layer.weights: sublayer, varname = param.name.replace(':0', '').split('/')[-2:] var_name = 'encoder/{}/{}'.format(sublayer, varname) if reader.has_tensor(var_name): logging.info('Found pretrained weights: %s, %s, %s', var_name, param.shape, reader.get_tensor(var_name).shape) weight_value_tuples.append( (param, reader.get_tensor(var_name))) logging.info('Load pretrained %s weights', len(weight_value_tuples)) tf.keras.backend.batch_set_value(weight_value_tuples)
def main(argv=None): del argv hparams = create_hparams(FLAGS.experiment) if hparams['distribution_strategy'] == 'multi_worker_mirrored': strategy = tf.distribute.experimental.MultiWorkerMirroredStrategy() elif hparams['distribution_strategy'] == 'mirrored': strategy = tf.distribute.MirroredStrategy() else: raise ValueError( 'Only `multi_worker_mirrored` is supported strategy ' 'in Keras MNIST example at this time. Strategy passed ' 'in is %s' % hparams['distribution_strategy']) # Create and compile the model under Distribution strategy scope. # `fit`, `evaluate` and `predict` will be distributed based on the strategy # model was compiled with. with strategy.scope(): # Build the train and eval datasets from the MNIST data. train_set = input_utils.input_fn( hparams['train_files'], hparams['batch_size'], hparams['vocab_size'], hparams['phrase_vocab_size'], hparams['max_pixel_pos'], hparams['max_dom_pos'], epoches=1, buffer_size=hparams['train_buffer_size']) dev_set = input_utils.input_fn(hparams['eval_files'], hparams['eval_batch_size'], hparams['vocab_size'], hparams['phrase_vocab_size'], hparams['max_pixel_pos'], hparams['max_dom_pos'], epoches=100, buffer_size=hparams['eval_buffer_size']) model = WidgetCaptionModel(hparams) lr_schedule = optimizer.LearningRateSchedule( hparams['learning_rate_constant'], hparams['hidden_size'], hparams['learning_rate_warmup_steps']) opt = tf.keras.optimizers.Adam( lr_schedule, hparams['optimizer_adam_beta1'], hparams['optimizer_adam_beta2'], epsilon=hparams['optimizer_adam_epsilon']) model.compile(optimizer=opt) init_resnet(hparams, model) callbacks = [tf.keras.callbacks.TerminateOnNaN()] if FLAGS.model_dir: tensorboard_callback = TensorBoardCallBack(log_dir=FLAGS.model_dir) callbacks.append(tensorboard_callback) if FLAGS.ckpt_filepath: model_checkpoint_callback = tf.keras.callbacks.ModelCheckpoint( filepath=FLAGS.ckpt_filepath, save_weights_only=True) callbacks.append(model_checkpoint_callback) # Train the model with the train dataset. history = model.fit(x=train_set, epochs=hparams['train_epoches'], validation_data=dev_set, validation_steps=10, callbacks=callbacks) logging.info('Training ends successfully. `model.fit()` result: %s', history.history)