示例#1
0
def main():
    epoch_train_steps = int(FLAGS.train_count / FLAGS.batch_count)
    num_train_steps = epoch_train_steps * float(FLAGS.epoch_count)
    print("Epoch train steps %d" % epoch_train_steps)
    print("Total train steps %d" % num_train_steps)
    run_config = tf.estimator.RunConfig(save_checkpoints_secs=1300, keep_checkpoint_max=2)
    model_fn = build_model_fn(num_train_steps)
    estimator = tf.estimator.Estimator(model_fn, model_dir=FLAGS.output_dir, config=run_config)

    if FLAGS.do_train:
        train_input_fn = data.build_input_fn(True,  "train_X", "train_Y",True, FLAGS)
        test_input_fn = data.build_input_fn(False, "test_X", "test_Y", False,FLAGS)

        train_spec = tf.estimator.TrainSpec(input_fn=train_input_fn, max_steps=num_train_steps)
        eval_spec = tf.estimator.EvalSpec(input_fn=test_input_fn, throttle_secs=3, steps=None)

        tf.estimator.train_and_evaluate(estimator, train_spec, eval_spec)

    if FLAGS.do_eval:
        test_input_fn = data.build_input_fn(False, "test_X", "test_Y", False, FLAGS)
        estimator.evaluate(test_input_fn)
    if FLAGS.do_predict:
        test_input_fn = data.build_input_fn(False, "test_X", "test_Y", False, FLAGS)
        predict_labels(estimator, test_input_fn)
示例#2
0
def main(argv):
    if len(argv) > 1:
        raise app.UsageError('Too many command-line arguments.')

    # Enable training summary.
    if FLAGS.train_summary_steps > 0:
        tf.config.set_soft_device_placement(True)

    builder = tfds.builder(FLAGS.dataset, data_dir=FLAGS.data_dir)
    builder.download_and_prepare()
    num_train_examples = builder.info.splits[FLAGS.train_split].num_examples
    num_eval_examples = builder.info.splits[FLAGS.eval_split].num_examples
    num_classes = builder.info.features['label'].num_classes

    train_steps = model_util.get_train_steps(num_train_examples)
    eval_steps = int(math.ceil(num_eval_examples / FLAGS.eval_batch_size))
    epoch_steps = int(round(num_train_examples / FLAGS.train_batch_size))

    resnet.BATCH_NORM_DECAY = FLAGS.batch_norm_decay
    model = resnet.resnet_v1(resnet_depth=FLAGS.resnet_depth,
                             width_multiplier=FLAGS.width_multiplier,
                             cifar_stem=FLAGS.image_size <= 32)

    checkpoint_steps = (FLAGS.checkpoint_steps
                        or (FLAGS.checkpoint_epochs * epoch_steps))

    cluster = None
    if FLAGS.use_tpu and FLAGS.master is None:
        if FLAGS.tpu_name:
            cluster = tf.distribute.cluster_resolver.TPUClusterResolver(
                FLAGS.tpu_name, zone=FLAGS.tpu_zone, project=FLAGS.gcp_project)
        else:
            cluster = tf.distribute.cluster_resolver.TPUClusterResolver()
            tf.config.experimental_connect_to_cluster(cluster)
            tf.tpu.experimental.initialize_tpu_system(cluster)

    default_eval_mode = tf.estimator.tpu.InputPipelineConfig.PER_HOST_V1
    sliced_eval_mode = tf.estimator.tpu.InputPipelineConfig.SLICED
    run_config = tf.estimator.tpu.RunConfig(
        tpu_config=tf.estimator.tpu.TPUConfig(
            iterations_per_loop=checkpoint_steps,
            eval_training_input_configuration=sliced_eval_mode
            if FLAGS.use_tpu else default_eval_mode),
        model_dir=FLAGS.model_dir,
        save_summary_steps=checkpoint_steps,
        save_checkpoints_steps=checkpoint_steps,
        keep_checkpoint_max=FLAGS.keep_checkpoint_max,
        master=FLAGS.master,
        cluster=cluster)
    estimator = tf.estimator.tpu.TPUEstimator(
        model_lib.build_model_fn(model, num_classes, num_train_examples),
        config=run_config,
        train_batch_size=FLAGS.train_batch_size,
        eval_batch_size=FLAGS.eval_batch_size,
        use_tpu=FLAGS.use_tpu)

    if FLAGS.mode == 'eval':
        for ckpt in tf.train.checkpoints_iterator(run_config.model_dir,
                                                  min_interval_secs=15):
            try:
                result = perform_evaluation(estimator=estimator,
                                            input_fn=data_lib.build_input_fn(
                                                builder, False),
                                            eval_steps=eval_steps,
                                            model=model,
                                            num_classes=num_classes,
                                            checkpoint_path=ckpt)
            except tf.errors.NotFoundError:
                continue
            if result['global_step'] >= train_steps:
                return
    else:
        estimator.train(data_lib.build_input_fn(builder, True),
                        max_steps=train_steps)
        if FLAGS.mode == 'train_then_eval':
            perform_evaluation(estimator=estimator,
                               input_fn=data_lib.build_input_fn(
                                   builder, False),
                               eval_steps=eval_steps,
                               model=model,
                               num_classes=num_classes)
示例#3
0
def main(argv):
    if len(argv) > 1:
        raise app.UsageError('Too many command-line arguments.')

    # need to import here because we need the above flags
    from mobilenetv3.mobilenet_v3 import V3_LARGE_MINIMALISTIC, V3_SMALL_MINIMALISTIC, mobilenet_func

    # Enable training summary.
    if FLAGS.train_summary_steps > 0:
        tf.config.set_soft_device_placement(True)

    builder = tfds.builder(FLAGS.dataset, data_dir=FLAGS.data_dir)
    builder.download_and_prepare()
    num_train_examples = builder.info.splits[FLAGS.train_split].num_examples
    num_eval_examples = builder.info.splits[FLAGS.eval_split].num_examples
    num_classes = builder.info.features['label'].num_classes

    train_steps = model_util.get_train_steps(num_train_examples)
    eval_steps = int(math.ceil(num_eval_examples / FLAGS.eval_batch_size))
    epoch_steps = int(round(num_train_examples / FLAGS.train_batch_size))

    if FLAGS.backbone == "resnet":
        resnet.BATCH_NORM_DECAY = FLAGS.batch_norm_decay
        model = resnet.resnet_v1(resnet_depth=FLAGS.resnet_depth,
                                 width_multiplier=FLAGS.width_multiplier,
                                 cifar_stem=FLAGS.image_size <= 32,
                                 conv1_stride1=(FLAGS.image_size == 112)
                                 or (FLAGS.image_size == 84))
    elif FLAGS.backbone == "mobilenet_v3_large_minimalistic":
        model = mobilenet_func(conv_defs=V3_LARGE_MINIMALISTIC,
                               depth_multiplier=FLAGS.width_multiplier)
    elif FLAGS.backbone == "mobilenet_v3_small_minimalistic":
        model = mobilenet_func(conv_defs=V3_SMALL_MINIMALISTIC,
                               depth_multiplier=FLAGS.width_multiplier)
    else:
        raise ValueError("wrong backbone:" + FLAGS.backbone)

    checkpoint_steps = (FLAGS.checkpoint_steps
                        or (FLAGS.checkpoint_epochs * epoch_steps))

    cluster = None
    if FLAGS.use_tpu and FLAGS.master is None:
        if FLAGS.tpu_name:
            cluster = tf.distribute.cluster_resolver.TPUClusterResolver(
                FLAGS.tpu_name, zone=FLAGS.tpu_zone, project=FLAGS.gcp_project)
        else:
            cluster = tf.distribute.cluster_resolver.TPUClusterResolver()
            tf.config.experimental_connect_to_cluster(cluster)
            tf.tpu.experimental.initialize_tpu_system(cluster)

    config = tf.ConfigProto(gpu_options=tf.GPUOptions(allow_growth=True))

    if FLAGS.use_fp16:
        # This works, but no loss rescaling
        config.graph_options.rewrite_options.auto_mixed_precision = 1

    if FLAGS.use_tpu:
        sliced_eval_mode = tf.estimator.tpu.InputPipelineConfig.SLICED
    else:
        sliced_eval_mode = tf.estimator.tpu.InputPipelineConfig.PER_HOST_V1

    run_config = tf.estimator.tpu.RunConfig(
        tpu_config=tf.estimator.tpu.TPUConfig(
            iterations_per_loop=checkpoint_steps,
            eval_training_input_configuration=sliced_eval_mode),
        model_dir=FLAGS.model_dir,
        save_summary_steps=checkpoint_steps,
        save_checkpoints_steps=checkpoint_steps,
        keep_checkpoint_max=FLAGS.keep_checkpoint_max,
        master=FLAGS.master,
        cluster=cluster,
        session_config=config)
    estimator = tf.estimator.tpu.TPUEstimator(
        model_lib.build_model_fn(model, num_classes, num_train_examples),
        config=run_config,
        train_batch_size=FLAGS.train_batch_size,
        eval_batch_size=FLAGS.eval_batch_size,
        use_tpu=FLAGS.use_tpu,
        model_dir=FLAGS.model_dir)

    if FLAGS.mode == 'eval':
        for ckpt in tf.train.checkpoints_iterator(run_config.model_dir,
                                                  min_interval_secs=15):
            try:
                result = perform_evaluation(estimator=estimator,
                                            input_fn=data_lib.build_input_fn(
                                                builder, False),
                                            eval_steps=eval_steps,
                                            model=model,
                                            num_classes=num_classes,
                                            checkpoint_path=ckpt)
            except tf.errors.NotFoundError:
                continue
            if result['global_step'] >= train_steps:
                return
    else:
        profile_hook = tf.estimator.ProfilerHook(save_steps=100000000,
                                                 save_secs=None,
                                                 output_dir=FLAGS.model_dir,
                                                 show_dataflow=True,
                                                 show_memory=True)
        estimator.train(data_lib.build_input_fn(builder, True),
                        max_steps=train_steps,
                        hooks=[profile_hook])
        if FLAGS.mode == 'train_then_eval':
            perform_evaluation(estimator=estimator,
                               input_fn=data_lib.build_input_fn(
                                   builder, False),
                               eval_steps=eval_steps,
                               model=model,
                               num_classes=num_classes)
示例#4
0
def main(argv):
    #boostx: todo: due to VSC latest code (4-12-2020) coverting
    # "--variable_schema='(?!global_step|(?:.*/|^)LARSOptimizer|head)'"
    #to "--variable_schema=\'(?!global_step|(?:.*/|^)LARSOptimizer|head)\'"
    #I have to use following code to get a workaround: //@audit workaround
    FLAGS.variable_schema = '(?!global_step|(?:.*/|^)LARSOptimizer|head)'

    if len(argv) > 1:
        raise app.UsageError('Too many command-line arguments.')

    # Enable training summary.
    if FLAGS.train_summary_steps > 0:
        tf.config.set_soft_device_placement(True)

#//@follow-up Estmator Evaluation (3)
    builder = tfds.builder(FLAGS.dataset, data_dir=FLAGS.data_dir)
    builder.download_and_prepare()
    num_train_examples = builder.info.splits[FLAGS.train_split].num_examples
    num_eval_examples = builder.info.splits[FLAGS.eval_split].num_examples
    num_classes = builder.info.features['label'].num_classes

    train_steps = model_util.get_train_steps(num_train_examples)
    eval_steps = int(math.ceil(num_eval_examples / FLAGS.eval_batch_size))
    epoch_steps = int(round(num_train_examples / FLAGS.train_batch_size))

    resnet.BATCH_NORM_DECAY = FLAGS.batch_norm_decay
    model = resnet.resnet_v1(resnet_depth=FLAGS.resnet_depth,
                             width_multiplier=FLAGS.width_multiplier,
                             cifar_stem=FLAGS.image_size <= 32)

    checkpoint_steps = (FLAGS.checkpoint_steps
                        or (FLAGS.checkpoint_epochs * epoch_steps))

    master = FLAGS.master
    if FLAGS.use_tpu and master is None:
        if FLAGS.tpu_name:
            cluster = tf.distribute.cluster_resolver.TPUClusterResolver(
                FLAGS.tpu_name, zone=FLAGS.tpu_zone, project=FLAGS.gcp_project)
        else:
            cluster = tf.distribute.cluster_resolver.TPUClusterResolver()
        tf.config.experimental_connect_to_cluster(cluster)
        tf.tpu.experimental.initialize_tpu_system(cluster)
        master = cluster.master()

    run_config = tf.estimator.tpu.RunConfig(
        tpu_config=tf.estimator.tpu.TPUConfig(
            iterations_per_loop=checkpoint_steps),
        model_dir=FLAGS.model_dir,
        save_summary_steps=checkpoint_steps,
        save_checkpoints_steps=checkpoint_steps,
        keep_checkpoint_max=FLAGS.keep_checkpoint_max,
        master=master)
    estimator = tf.estimator.tpu.TPUEstimator(  #//@follow-up Estmator Evaluation (4)
        model_lib.build_model_fn(
            model, num_classes,
            num_train_examples),  #//@follow-up Estmator Evaluation (5)
        config=run_config,
        train_batch_size=FLAGS.train_batch_size,
        eval_batch_size=FLAGS.eval_batch_size,
        use_tpu=FLAGS.use_tpu)

    if FLAGS.mode == 'eval':
        for ckpt in tf.train.checkpoints_iterator(run_config.model_dir,
                                                  min_interval_secs=15):
            try:
                result = perform_evaluation(estimator=estimator,
                                            input_fn=data_lib.build_input_fn(
                                                builder, False),
                                            eval_steps=eval_steps,
                                            model=model,
                                            num_classes=num_classes,
                                            checkpoint_path=ckpt)
            except tf.errors.NotFoundError:
                continue
            if result['global_step'] >= train_steps:
                return
    else:
        estimator.train(data_lib.build_input_fn(builder, True),
                        max_steps=train_steps)
        if FLAGS.mode == 'train_then_eval':  #//@follow-up Estmator Evaluation (10)
            perform_evaluation(estimator=estimator,
                               input_fn=data_lib.build_input_fn(
                                   builder, False),
                               eval_steps=eval_steps,
                               model=model,
                               num_classes=num_classes)