예제 #1
0
def main(argv):

  # print(argv)
  # print(FLAGS)

  current_time = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
  out_dir = os.path.join(FLAGS.output_dir, 'Deterministic', current_time)

  ##########################
  # Hyperparmeters & Model #
  ##########################
  input_shape = dict(medium=(256, 256, 3), realworld=(512, 512, 3))[FLAGS.level]

  hparams = dict(dropout_rate=FLAGS.dropout_rate,
                 num_base_filters=FLAGS.num_base_filters,
                 learning_rate=FLAGS.learning_rate,
                 l2_reg=FLAGS.l2_reg,
                 input_shape=input_shape)
  classifier = VGGDrop(**hparams)
  # classifier.summary()

  #############
  # Load Task #
  #############
  dtask = bdlb.load(
      benchmark="diabetic_retinopathy_diagnosis",
      level=FLAGS.level,
      batch_size=FLAGS.batch_size,
      download_and_prepare=False,  # do not download data from this script
  )
  ds_train, ds_validation, ds_test = dtask.datasets

  #################
  # Training Loop #
  #################
  history = classifier.fit(
      ds_train,
      epochs=FLAGS.num_epochs,
      validation_data=ds_validation,
      class_weight=dtask.class_weight(),
      callbacks=[
          tfk.callbacks.TensorBoard(
              log_dir=os.path.join(out_dir, "tensorboard"),
              update_freq="epoch",
              write_graph=True,
              histogram_freq=1,
          ),
          tfk.callbacks.ModelCheckpoint(
              filepath=os.path.join(
                  out_dir,
                  "checkpoints",
                  "weights-{epoch}.ckpt",
              ),
              verbose=1,
              save_weights_only=True,
          )
      ],
  )
  plotting.tfk_history(history,
                       output_dir=os.path.join(out_dir, "history"))

  ##############
  # Evaluation #
  ##############
  additional_metrics = []
  try:
    import sail.metrics
    additional_metrics.append(('ECE', sail.metrics.GPleissCalibrationError()))
  except ImportError:
    import warnings
    warnings.warn('Could not import SAIL metrics.')
  dtask.evaluate(functools.partial(predict, model=classifier, type=FLAGS.uncertainty),
                 dataset=ds_test,
                 output_dir=os.path.join(out_dir, 'evaluation'),
                 additional_metrics=additional_metrics)
예제 #2
0
def main(argv):

    print(argv)
    print(FLAGS)

    ##########################
    # Hyperparmeters & Model #
    ##########################
    input_shape = dict(medium=(256, 256, 3),
                       realworld=(512, 512, 3))[FLAGS.level]

    hparams = dict(dropout_rate=FLAGS.dropout_rate,
                   num_base_filters=FLAGS.num_base_filters,
                   learning_rate=FLAGS.learning_rate,
                   l2_reg=FLAGS.l2_reg,
                   input_shape=input_shape)
    classifier = VGGDrop(**hparams)
    classifier.summary()

    #############
    # Load Task #
    #############
    dtask = bdlb.load(
        benchmark="diabetic_retinopathy_diagnosis",
        level=FLAGS.level,
        batch_size=FLAGS.batch_size,
        download_and_prepare=False,  # do not download data from this script
    )
    ds_train, ds_validation, ds_test = dtask.datasets

    #################
    # Training Loop #
    #################
    history = classifier.fit(
        ds_train,
        epochs=FLAGS.num_epochs,
        validation_data=ds_validation,
        class_weight=dtask.class_weight(),
        callbacks=[
            tfk.callbacks.TensorBoard(
                log_dir=os.path.join(FLAGS.output_dir, "tensorboard"),
                update_freq="epoch",
                write_graph=True,
                histogram_freq=1,
            ),
            tfk.callbacks.ModelCheckpoint(
                filepath=os.path.join(
                    FLAGS.output_dir,
                    "checkpoints",
                    "weights-{epoch}.ckpt",
                ),
                verbose=1,
                save_weights_only=True,
            )
        ],
    )
    plotting.tfk_history(history,
                         output_dir=os.path.join(FLAGS.output_dir, "history"))

    ##############
    # Evaluation #
    ##############
    dtask.evaluate(functools.partial(predict,
                                     model=classifier,
                                     type=FLAGS.uncertainty),
                   dataset=ds_test,
                   output_dir=FLAGS.output_dir)