def run_imagenet(flags_obj): """Run ResNet ImageNet training and eval loop. Args: flags_obj: An object containing parsed flag values. Returns: Dict of results of the run. Contains the keys `eval_results` and `train_hooks`. `eval_results` contains accuracy (top_1) and accuracy_top_5. `train_hooks` is a list the instances of hooks used during training. """ if flags_obj.use_fs: replace_relu_with_fs() input_function = (flags_obj.use_synthetic_data and get_synth_input_fn( flags_core.get_tf_dtype(flags_obj)) or input_fn) result = resnet_run_loop.resnet_main( flags_obj, imagenet_model_fn, input_function, DATASET_NAME, shape=[DEFAULT_IMAGE_SIZE, DEFAULT_IMAGE_SIZE, NUM_CHANNELS]) return result
def run_cifar(flags_obj): """Run ResNet CIFAR-10 training and eval loop. Args: flags_obj: An object containing parsed flag values. Returns: Dictionary of results. Including final accuracy. """ if flags_obj.use_fs: replace_relu_with_fs() print_spikes = flags_obj.print_spikes if flags_obj.image_bytes_as_serving_input: tf.compat.v1.logging.fatal( '--image_bytes_as_serving_input cannot be set to True for CIFAR. ' 'This flag is only applicable to ImageNet.') return input_function = (flags_obj.use_synthetic_data and get_synth_input_fn( flags_core.get_tf_dtype(flags_obj)) or input_fn) result = resnet_run_loop.resnet_main(flags_obj, cifar10_model_fn, input_function, DATASET_NAME, shape=[HEIGHT, WIDTH, NUM_CHANNELS]) return result