예제 #1
0
def main(argv):
  del argv  # unused arg
  if FLAGS.num_cores > 1 or not FLAGS.use_gpu:
    raise ValueError('Only single GPU is currently supported.')
  tf.enable_v2_behavior()
  tf.io.gfile.makedirs(FLAGS.output_dir)
  tf.random.set_seed(FLAGS.seed)

  dataset_train, ds_info = utils.load_dataset(tfds.Split.TRAIN, with_info=True)
  dataset_size = ds_info.splits['train'].num_examples
  dataset_train = dataset_train.repeat().shuffle(10 * FLAGS.batch_size).batch(
      FLAGS.batch_size)
  test_batch_size = 100
  validation_steps = ds_info.splits['test'].num_examples // test_batch_size
  dataset_test = utils.load_dataset(tfds.Split.TEST)
  dataset_test = dataset_test.repeat().batch(test_batch_size)

  model = resnet_v1(input_shape=ds_info.features['image'].shape,
                    depth=20,
                    num_classes=ds_info.features['label'].num_classes,
                    batch_norm=FLAGS.batch_norm,
                    prior_stddev=FLAGS.prior_stddev,
                    dataset_size=dataset_size)
  kl, elbo = get_metrics(model, dataset_size)

  model.compile(
      tf.keras.optimizers.Adam(FLAGS.init_learning_rate),
      loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
      metrics=[
          tf.keras.metrics.SparseCategoricalCrossentropy(
              name='negative_log_likelihood',
              from_logits=True),
          tf.keras.metrics.SparseCategoricalAccuracy(name='accuracy'),
          elbo,
          kl])
  logging.info('Model input shape: %s', model.input_shape)
  logging.info('Model output shape: %s', model.output_shape)
  logging.info('Model number of weights: %s', model.count_params())

  tensorboard_cb = tf.keras.callbacks.TensorBoard(log_dir=FLAGS.output_dir,
                                                  write_graph=False)
  lr_scheduler = utils.make_lr_scheduler(FLAGS.init_learning_rate)
  model.fit(dataset_train,
            steps_per_epoch=dataset_size // FLAGS.batch_size,
            epochs=FLAGS.train_epochs,
            validation_data=dataset_test,
            validation_steps=validation_steps,
            callbacks=[tensorboard_cb, lr_scheduler])

  logging.info('Saving model to output_dir.')
  model_filename = FLAGS.output_dir + '/model.ckpt'
  model.save_weights(model_filename)
예제 #2
0
def main(argv):
    del argv  # unused arg
    tf.io.gfile.makedirs(FLAGS.output_dir)
    tf.random.set_seed(FLAGS.seed)

    dataset_train, ds_info = utils.load_dataset(tfds.Split.TRAIN,
                                                with_info=True)
    dataset_size = ds_info.splits['train'].num_examples
    dataset_train = dataset_train.repeat().shuffle(
        10 * FLAGS.batch_size).batch(FLAGS.batch_size)
    test_batch_size = 100
    validation_steps = ds_info.splits['test'].num_examples // test_batch_size
    dataset_test = utils.load_dataset(tfds.Split.TEST)
    dataset_test = dataset_test.repeat().batch(test_batch_size)

    model = resnet_v1(input_shape=ds_info.features['image'].shape,
                      depth=20,
                      num_classes=ds_info.features['label'].num_classes,
                      batch_norm=FLAGS.batch_norm,
                      prior_stddev=FLAGS.prior_stddev,
                      dataset_size=dataset_size)
    negative_log_likelihood, accuracy, log_marginal, kl, elbo = get_metrics(
        model, dataset_size)

    model.compile(tf.keras.optimizers.Adam(FLAGS.init_learning_rate),
                  loss=negative_log_likelihood,
                  metrics=[elbo, log_marginal, kl, accuracy])
    logging.info('Model input shape: %s', model.input_shape)
    logging.info('Model output shape: %s', model.output_shape)
    logging.info('Model number of weights: %s', model.count_params())

    tensorboard_cb = tf.keras.callbacks.TensorBoard(log_dir=FLAGS.output_dir,
                                                    write_graph=False)
    lr_scheduler = utils.make_lr_scheduler(FLAGS.init_learning_rate)
    model.fit(dataset_train,
              steps_per_epoch=dataset_size // FLAGS.batch_size,
              epochs=FLAGS.train_epochs,
              validation_data=dataset_test,
              validation_steps=validation_steps,
              callbacks=[tensorboard_cb, lr_scheduler])

    logging.info('Saving model to output_dir.')
    model_filename = FLAGS.output_dir + '/model.ckpt'
    model.save_weights(model_filename)
예제 #3
0
    cfg = options.get_arguments()

    EXPERIMENT = f"{cfg.model}_{cfg.experiment}"
    MODEL_PATH = f"models/{EXPERIMENT}"
    LOG_PATH = f"logs/{EXPERIMENT}"

    utils.make_folder(MODEL_PATH)
    utils.make_folder(LOG_PATH)

    criterions = utils.define_losses()
    dataloaders = utils.make_data_novel(cfg)

    model = utils.build_structure_generator(cfg).to(cfg.device)
    optimizer = utils.make_optimizer(cfg, model)
    scheduler = utils.make_lr_scheduler(cfg, optimizer)

    logger = utils.make_logger(LOG_PATH)
    writer = utils.make_summary_writer(EXPERIMENT)

    def on_after_epoch(model, df_hist, images, epoch, saveEpoch):
        utils.save_best_model(MODEL_PATH, model, df_hist)
        utils.checkpoint_model(MODEL_PATH, model, epoch, saveEpoch)
        utils.log_hist(logger, df_hist)
        utils.write_on_board_losses_stg2(writer, df_hist)
        utils.write_on_board_images_stg2(writer, images, epoch)

    if cfg.lrSched is not None:
        def on_after_batch(iteration):
            utils.write_on_board_lr(writer, scheduler.get_lr(), iteration)
            scheduler.step(iteration)