def main(argv): # print(argv) # print(FLAGS) current_time = datetime.datetime.now().strftime("%Y%m%d-%H%M%S") out_dir = os.path.join(FLAGS.output_dir, 'Deterministic', current_time) ########################## # Hyperparmeters & Model # ########################## input_shape = dict(medium=(256, 256, 3), realworld=(512, 512, 3))[FLAGS.level] hparams = dict(dropout_rate=FLAGS.dropout_rate, num_base_filters=FLAGS.num_base_filters, learning_rate=FLAGS.learning_rate, l2_reg=FLAGS.l2_reg, input_shape=input_shape) classifier = VGGDrop(**hparams) # classifier.summary() ############# # Load Task # ############# dtask = bdlb.load( benchmark="diabetic_retinopathy_diagnosis", level=FLAGS.level, batch_size=FLAGS.batch_size, download_and_prepare=False, # do not download data from this script ) ds_train, ds_validation, ds_test = dtask.datasets ################# # Training Loop # ################# history = classifier.fit( ds_train, epochs=FLAGS.num_epochs, validation_data=ds_validation, class_weight=dtask.class_weight(), callbacks=[ tfk.callbacks.TensorBoard( log_dir=os.path.join(out_dir, "tensorboard"), update_freq="epoch", write_graph=True, histogram_freq=1, ), tfk.callbacks.ModelCheckpoint( filepath=os.path.join( out_dir, "checkpoints", "weights-{epoch}.ckpt", ), verbose=1, save_weights_only=True, ) ], ) plotting.tfk_history(history, output_dir=os.path.join(out_dir, "history")) ############## # Evaluation # ############## additional_metrics = [] try: import sail.metrics additional_metrics.append(('ECE', sail.metrics.GPleissCalibrationError())) except ImportError: import warnings warnings.warn('Could not import SAIL metrics.') dtask.evaluate(functools.partial(predict, model=classifier, type=FLAGS.uncertainty), dataset=ds_test, output_dir=os.path.join(out_dir, 'evaluation'), additional_metrics=additional_metrics)
def main(argv): print(argv) print(FLAGS) ########################## # Hyperparmeters & Model # ########################## input_shape = dict(medium=(256, 256, 3), realworld=(512, 512, 3))[FLAGS.level] hparams = dict(dropout_rate=FLAGS.dropout_rate, num_base_filters=FLAGS.num_base_filters, learning_rate=FLAGS.learning_rate, l2_reg=FLAGS.l2_reg, input_shape=input_shape) classifier = VGGDrop(**hparams) classifier.summary() ############# # Load Task # ############# dtask = bdlb.load( benchmark="diabetic_retinopathy_diagnosis", level=FLAGS.level, batch_size=FLAGS.batch_size, download_and_prepare=False, # do not download data from this script ) ds_train, ds_validation, ds_test = dtask.datasets ################# # Training Loop # ################# history = classifier.fit( ds_train, epochs=FLAGS.num_epochs, validation_data=ds_validation, class_weight=dtask.class_weight(), callbacks=[ tfk.callbacks.TensorBoard( log_dir=os.path.join(FLAGS.output_dir, "tensorboard"), update_freq="epoch", write_graph=True, histogram_freq=1, ), tfk.callbacks.ModelCheckpoint( filepath=os.path.join( FLAGS.output_dir, "checkpoints", "weights-{epoch}.ckpt", ), verbose=1, save_weights_only=True, ) ], ) plotting.tfk_history(history, output_dir=os.path.join(FLAGS.output_dir, "history")) ############## # Evaluation # ############## dtask.evaluate(functools.partial(predict, model=classifier, type=FLAGS.uncertainty), dataset=ds_test, output_dir=FLAGS.output_dir)