Example #1
0
def define_cifar_flags():
    common.define_keras_flags(dynamic_loss_scale=False)

    flags_core.set_defaults(data_dir='/tmp/cifar10_data/cifar-10-batches-bin',
                            model_dir='/tmp/cifar10_model',
                            epochs_between_evals=10,
                            batch_size=128)
Example #2
0
def define_imagenet_keras_flags():
    common.define_keras_flags(model=True,
                              optimizer=True,
                              pretrained_filepath=True)
    common.define_pruning_flags()
    flags_core.set_defaults()
    flags.adopt_module_key_flags(common)
Example #3
0
        strategy,
        runnable.train,
        runnable.evaluate,
        global_step=runnable.global_step,
        steps_per_loop=steps_per_loop,
        train_steps=per_epoch_steps * train_epochs,
        checkpoint_manager=checkpoint_manager,
        summary_interval=summary_interval,
        eval_steps=eval_steps,
        eval_interval=eval_interval)

    time_callback.on_train_begin()
    resnet_controller.train(evaluate=not flags_obj.skip_eval)
    time_callback.on_train_end()

    stats = build_stats(runnable, time_callback)
    return stats


def main(_):
    model_helpers.apply_clean(flags.FLAGS)
    with logger.benchmark_context(flags.FLAGS):
        stats = run(flags.FLAGS)
    logging.info('Run stats:\n%s', stats)


if __name__ == '__main__':
    logging.set_verbosity(logging.INFO)
    common.define_keras_flags()
    app.run(main)
Example #4
0
 def setUpClass(cls):  # pylint: disable=invalid-name
   super(CtlImagenetTest, cls).setUpClass()
   common.define_keras_flags()