Exemplo n.º 1
0
def main(argv):
    if len(argv) > 1:
        raise app.UsageError('Too many command-line arguments.')

    if FLAGS.create_hub:
        app.run(create_module_from_checkpoints)
    else:

        if not os.path.exists(FLAGS.model_dir):
            os.makedirs(FLAGS.model_dir)
            print("Created directory: {0}".format(os.path.abspath(FLAGS.model_dir)))    
                
        # Enable training summary.
        if FLAGS.train_summary_steps > 0:
            tf.config.set_soft_device_placement(True)

        # Choose dataset. 
        if FLAGS.dataset == "chest_xray":
            # Not really a builder, but it's compatible
            # TODO config
            #data_path = FLAGS.local_tmp_folder
            data_path = FLAGS.data_dir
            data_split = FLAGS.train_data_split
            print(f"***********************************************************************************")
            print("")
            print(f"DANGER WARNING ON SPLIT -> XRAY Data split:{data_split} SHOULD BE 0.9")
            print("")
            print(f"***********************************************************************************")

            builder, info = chest_xray.XRayDataSet(data_path, config=None, train=True, return_tf_dataset=False, split=data_split)
            build_input_fn = partial(data_lib.build_chest_xray_fn, FLAGS.use_multi_gpus, data_path)
            num_train_examples = info.get('num_examples')
            num_classes = info.get('num_classes')
            num_eval_examples = info.get('num_eval_examples')
            print(f"num_train_examples:{num_train_examples}, num_eval_examples:{num_eval_examples}")
        else:
            #builder = tfds.builder(FLAGS.dataset, data_dir=FLAGS.data_dir)
            builder.download_and_prepare()
            num_train_examples = builder.info.splits[FLAGS.train_split].num_examples
            num_eval_examples = builder.info.splits[FLAGS.eval_split].num_examples
            num_classes = builder.info.features['label'].num_classes
            build_input_fn = data_lib.build_input_fn

        train_steps = model_util.get_train_steps(num_train_examples)
        eval_steps = int(math.ceil(num_eval_examples / FLAGS.eval_batch_size))
        epoch_steps = int(round(num_train_examples / FLAGS.train_batch_size))

        resnet.BATCH_NORM_DECAY = FLAGS.batch_norm_decay
        model = resnet.resnet_v1(
            resnet_depth=FLAGS.resnet_depth,
            width_multiplier=FLAGS.width_multiplier,
            cifar_stem=FLAGS.image_size <= 32)

        checkpoint_steps = (
            FLAGS.checkpoint_steps or (FLAGS.checkpoint_epochs * epoch_steps))

        cluster = None
        if FLAGS.use_tpu and FLAGS.master is None:
            if FLAGS.tpu_name:
                cluster = tf.distribute.cluster_resolver.TPUClusterResolver(
                    FLAGS.tpu_name, zone=FLAGS.tpu_zone, project=FLAGS.gcp_project)
            else:
                cluster = tf.distribute.cluster_resolver.TPUClusterResolver()
                tf.config.experimental_connect_to_cluster(cluster)
                tf.tpu.experimental.initialize_tpu_system(cluster)

        strategy = tf.distribute.MirroredStrategy() if not FLAGS.use_tpu and FLAGS.use_multi_gpus else None # Multi GPU?
        print("use_multi_gpus: {0}".format(FLAGS.use_multi_gpus))
        print("use MirroredStrategy: {0}".format(not FLAGS.use_tpu and FLAGS.use_multi_gpus))
        default_eval_mode = tf.estimator.tpu.InputPipelineConfig.PER_HOST_V1
        sliced_eval_mode = tf.estimator.tpu.InputPipelineConfig.SLICED
        run_config = tf.estimator.tpu.RunConfig(
            tpu_config=tf.estimator.tpu.TPUConfig(
                iterations_per_loop=checkpoint_steps,
                eval_training_input_configuration=sliced_eval_mode
                if FLAGS.use_tpu else default_eval_mode),
            model_dir=FLAGS.model_dir,
            train_distribute=strategy, 
            eval_distribute=strategy,
            save_summary_steps=checkpoint_steps,
            save_checkpoints_steps=checkpoint_steps,
            keep_checkpoint_max=FLAGS.keep_checkpoint_max,
            master=FLAGS.master,
            cluster=cluster)
        estimator = tf.estimator.tpu.TPUEstimator(
            model_lib.build_model_fn(model, num_classes, num_train_examples, FLAGS.train_batch_size),
            config=run_config,
            train_batch_size=FLAGS.train_batch_size,
            eval_batch_size=FLAGS.eval_batch_size,
            use_tpu=FLAGS.use_tpu)
            

        # save flags for this experiment

        pickle.dump(FLAGS.flag_values_dict(), open(os.path.join(FLAGS.model_dir, 'experiment_flags.p'), "wb"))
        FLAGS.append_flags_into_file(os.path.join(FLAGS.model_dir, 'experiment_flags.txt'))


        # Train/Eval
        if FLAGS.mode == 'eval':
            for ckpt in tf.train.checkpoints_iterator(
                    run_config.model_dir, min_interval_secs=15):
                try:
                    result = perform_evaluation(
                        estimator=estimator,
                        input_fn=build_input_fn(builder, False),
                        eval_steps=eval_steps,
                        model=model,
                        num_classes=num_classes,
                        checkpoint_path=ckpt)
                except tf.errors.NotFoundError:
                    continue
                if result['global_step'] >= train_steps:
                    return
        else:
            estimator.train(
                build_input_fn(builder, True), max_steps=train_steps)
            if FLAGS.mode == 'train_then_eval':
                perform_evaluation(
                    estimator=estimator,
                    input_fn=build_input_fn(builder, False),
                    eval_steps=eval_steps,
                    model=model,
                    num_classes=num_classes)
        # Save the Hub in all case
        # app.run(create_module_from_checkpoints)
        create_module_from_checkpoints(argv)

        # Compute SSL metric:
        if FLAGS.compute_ssl_metric:
            compute_ssl_metric()
Exemplo n.º 2
0
def main(argv):
    if len(argv) > 1:
        raise app.UsageError('Too many command-line arguments.')

    # need to import here because we need the above flags
    from mobilenetv3.mobilenet_v3 import V3_LARGE_MINIMALISTIC, V3_SMALL_MINIMALISTIC, mobilenet_func

    # Enable training summary.
    if FLAGS.train_summary_steps > 0:
        tf.config.set_soft_device_placement(True)

    builder = tfds.builder(FLAGS.dataset, data_dir=FLAGS.data_dir)
    builder.download_and_prepare()
    num_train_examples = builder.info.splits[FLAGS.train_split].num_examples
    num_eval_examples = builder.info.splits[FLAGS.eval_split].num_examples
    num_classes = builder.info.features['label'].num_classes

    train_steps = model_util.get_train_steps(num_train_examples)
    eval_steps = int(math.ceil(num_eval_examples / FLAGS.eval_batch_size))
    epoch_steps = int(round(num_train_examples / FLAGS.train_batch_size))

    if FLAGS.backbone == "resnet":
        resnet.BATCH_NORM_DECAY = FLAGS.batch_norm_decay
        model = resnet.resnet_v1(resnet_depth=FLAGS.resnet_depth,
                                 width_multiplier=FLAGS.width_multiplier,
                                 cifar_stem=FLAGS.image_size <= 32,
                                 conv1_stride1=(FLAGS.image_size == 112)
                                 or (FLAGS.image_size == 84))
    elif FLAGS.backbone == "mobilenet_v3_large_minimalistic":
        model = mobilenet_func(conv_defs=V3_LARGE_MINIMALISTIC,
                               depth_multiplier=FLAGS.width_multiplier)
    elif FLAGS.backbone == "mobilenet_v3_small_minimalistic":
        model = mobilenet_func(conv_defs=V3_SMALL_MINIMALISTIC,
                               depth_multiplier=FLAGS.width_multiplier)
    else:
        raise ValueError("wrong backbone:" + FLAGS.backbone)

    checkpoint_steps = (FLAGS.checkpoint_steps
                        or (FLAGS.checkpoint_epochs * epoch_steps))

    cluster = None
    if FLAGS.use_tpu and FLAGS.master is None:
        if FLAGS.tpu_name:
            cluster = tf.distribute.cluster_resolver.TPUClusterResolver(
                FLAGS.tpu_name, zone=FLAGS.tpu_zone, project=FLAGS.gcp_project)
        else:
            cluster = tf.distribute.cluster_resolver.TPUClusterResolver()
            tf.config.experimental_connect_to_cluster(cluster)
            tf.tpu.experimental.initialize_tpu_system(cluster)

    config = tf.ConfigProto(gpu_options=tf.GPUOptions(allow_growth=True))

    if FLAGS.use_fp16:
        # This works, but no loss rescaling
        config.graph_options.rewrite_options.auto_mixed_precision = 1

    if FLAGS.use_tpu:
        sliced_eval_mode = tf.estimator.tpu.InputPipelineConfig.SLICED
    else:
        sliced_eval_mode = tf.estimator.tpu.InputPipelineConfig.PER_HOST_V1

    run_config = tf.estimator.tpu.RunConfig(
        tpu_config=tf.estimator.tpu.TPUConfig(
            iterations_per_loop=checkpoint_steps,
            eval_training_input_configuration=sliced_eval_mode),
        model_dir=FLAGS.model_dir,
        save_summary_steps=checkpoint_steps,
        save_checkpoints_steps=checkpoint_steps,
        keep_checkpoint_max=FLAGS.keep_checkpoint_max,
        master=FLAGS.master,
        cluster=cluster,
        session_config=config)
    estimator = tf.estimator.tpu.TPUEstimator(
        model_lib.build_model_fn(model, num_classes, num_train_examples),
        config=run_config,
        train_batch_size=FLAGS.train_batch_size,
        eval_batch_size=FLAGS.eval_batch_size,
        use_tpu=FLAGS.use_tpu,
        model_dir=FLAGS.model_dir)

    if FLAGS.mode == 'eval':
        for ckpt in tf.train.checkpoints_iterator(run_config.model_dir,
                                                  min_interval_secs=15):
            try:
                result = perform_evaluation(estimator=estimator,
                                            input_fn=data_lib.build_input_fn(
                                                builder, False),
                                            eval_steps=eval_steps,
                                            model=model,
                                            num_classes=num_classes,
                                            checkpoint_path=ckpt)
            except tf.errors.NotFoundError:
                continue
            if result['global_step'] >= train_steps:
                return
    else:
        profile_hook = tf.estimator.ProfilerHook(save_steps=100000000,
                                                 save_secs=None,
                                                 output_dir=FLAGS.model_dir,
                                                 show_dataflow=True,
                                                 show_memory=True)
        estimator.train(data_lib.build_input_fn(builder, True),
                        max_steps=train_steps,
                        hooks=[profile_hook])
        if FLAGS.mode == 'train_then_eval':
            perform_evaluation(estimator=estimator,
                               input_fn=data_lib.build_input_fn(
                                   builder, False),
                               eval_steps=eval_steps,
                               model=model,
                               num_classes=num_classes)
Exemplo n.º 3
0
def main(argv):
    if len(argv) > 1:
        raise app.UsageError('Too many command-line arguments.')

    # Enable training summary.
    if FLAGS.train_summary_steps > 0:
        tf.config.set_soft_device_placement(True)

    builder = tfds.builder(FLAGS.dataset, data_dir=FLAGS.data_dir)
    builder.download_and_prepare()
    num_train_examples = builder.info.splits[FLAGS.train_split].num_examples
    num_eval_examples = builder.info.splits[FLAGS.eval_split].num_examples
    num_classes = builder.info.features['label'].num_classes

    train_steps = model_util.get_train_steps(num_train_examples)
    eval_steps = int(math.ceil(num_eval_examples / FLAGS.eval_batch_size))
    epoch_steps = int(round(num_train_examples / FLAGS.train_batch_size))

    resnet.BATCH_NORM_DECAY = FLAGS.batch_norm_decay
    model = resnet.resnet_v1(resnet_depth=FLAGS.resnet_depth,
                             width_multiplier=FLAGS.width_multiplier,
                             cifar_stem=FLAGS.image_size <= 32)

    checkpoint_steps = (FLAGS.checkpoint_steps
                        or (FLAGS.checkpoint_epochs * epoch_steps))

    cluster = None
    if FLAGS.use_tpu and FLAGS.master is None:
        if FLAGS.tpu_name:
            cluster = tf.distribute.cluster_resolver.TPUClusterResolver(
                FLAGS.tpu_name, zone=FLAGS.tpu_zone, project=FLAGS.gcp_project)
        else:
            cluster = tf.distribute.cluster_resolver.TPUClusterResolver()
            tf.config.experimental_connect_to_cluster(cluster)
            tf.tpu.experimental.initialize_tpu_system(cluster)

    default_eval_mode = tf.estimator.tpu.InputPipelineConfig.PER_HOST_V1
    sliced_eval_mode = tf.estimator.tpu.InputPipelineConfig.SLICED
    run_config = tf.estimator.tpu.RunConfig(
        tpu_config=tf.estimator.tpu.TPUConfig(
            iterations_per_loop=checkpoint_steps,
            eval_training_input_configuration=sliced_eval_mode
            if FLAGS.use_tpu else default_eval_mode),
        model_dir=FLAGS.model_dir,
        save_summary_steps=checkpoint_steps,
        save_checkpoints_steps=checkpoint_steps,
        keep_checkpoint_max=FLAGS.keep_checkpoint_max,
        master=FLAGS.master,
        cluster=cluster)
    estimator = tf.estimator.tpu.TPUEstimator(
        model_lib.build_model_fn(model, num_classes, num_train_examples),
        config=run_config,
        train_batch_size=FLAGS.train_batch_size,
        eval_batch_size=FLAGS.eval_batch_size,
        use_tpu=FLAGS.use_tpu)

    if FLAGS.mode == 'eval':
        for ckpt in tf.train.checkpoints_iterator(run_config.model_dir,
                                                  min_interval_secs=15):
            try:
                result = perform_evaluation(estimator=estimator,
                                            input_fn=data_lib.build_input_fn(
                                                builder, False),
                                            eval_steps=eval_steps,
                                            model=model,
                                            num_classes=num_classes,
                                            checkpoint_path=ckpt)
            except tf.errors.NotFoundError:
                continue
            if result['global_step'] >= train_steps:
                return
    else:
        estimator.train(data_lib.build_input_fn(builder, True),
                        max_steps=train_steps)
        if FLAGS.mode == 'train_then_eval':
            perform_evaluation(estimator=estimator,
                               input_fn=data_lib.build_input_fn(
                                   builder, False),
                               eval_steps=eval_steps,
                               model=model,
                               num_classes=num_classes)
Exemplo n.º 4
0
def main(argv):
    #boostx: todo: due to VSC latest code (4-12-2020) coverting
    # "--variable_schema='(?!global_step|(?:.*/|^)LARSOptimizer|head)'"
    #to "--variable_schema=\'(?!global_step|(?:.*/|^)LARSOptimizer|head)\'"
    #I have to use following code to get a workaround: //@audit workaround
    FLAGS.variable_schema = '(?!global_step|(?:.*/|^)LARSOptimizer|head)'

    if len(argv) > 1:
        raise app.UsageError('Too many command-line arguments.')

    # Enable training summary.
    if FLAGS.train_summary_steps > 0:
        tf.config.set_soft_device_placement(True)

#//@follow-up Estmator Evaluation (3)
    builder = tfds.builder(FLAGS.dataset, data_dir=FLAGS.data_dir)
    builder.download_and_prepare()
    num_train_examples = builder.info.splits[FLAGS.train_split].num_examples
    num_eval_examples = builder.info.splits[FLAGS.eval_split].num_examples
    num_classes = builder.info.features['label'].num_classes

    train_steps = model_util.get_train_steps(num_train_examples)
    eval_steps = int(math.ceil(num_eval_examples / FLAGS.eval_batch_size))
    epoch_steps = int(round(num_train_examples / FLAGS.train_batch_size))

    resnet.BATCH_NORM_DECAY = FLAGS.batch_norm_decay
    model = resnet.resnet_v1(resnet_depth=FLAGS.resnet_depth,
                             width_multiplier=FLAGS.width_multiplier,
                             cifar_stem=FLAGS.image_size <= 32)

    checkpoint_steps = (FLAGS.checkpoint_steps
                        or (FLAGS.checkpoint_epochs * epoch_steps))

    master = FLAGS.master
    if FLAGS.use_tpu and master is None:
        if FLAGS.tpu_name:
            cluster = tf.distribute.cluster_resolver.TPUClusterResolver(
                FLAGS.tpu_name, zone=FLAGS.tpu_zone, project=FLAGS.gcp_project)
        else:
            cluster = tf.distribute.cluster_resolver.TPUClusterResolver()
        tf.config.experimental_connect_to_cluster(cluster)
        tf.tpu.experimental.initialize_tpu_system(cluster)
        master = cluster.master()

    run_config = tf.estimator.tpu.RunConfig(
        tpu_config=tf.estimator.tpu.TPUConfig(
            iterations_per_loop=checkpoint_steps),
        model_dir=FLAGS.model_dir,
        save_summary_steps=checkpoint_steps,
        save_checkpoints_steps=checkpoint_steps,
        keep_checkpoint_max=FLAGS.keep_checkpoint_max,
        master=master)
    estimator = tf.estimator.tpu.TPUEstimator(  #//@follow-up Estmator Evaluation (4)
        model_lib.build_model_fn(
            model, num_classes,
            num_train_examples),  #//@follow-up Estmator Evaluation (5)
        config=run_config,
        train_batch_size=FLAGS.train_batch_size,
        eval_batch_size=FLAGS.eval_batch_size,
        use_tpu=FLAGS.use_tpu)

    if FLAGS.mode == 'eval':
        for ckpt in tf.train.checkpoints_iterator(run_config.model_dir,
                                                  min_interval_secs=15):
            try:
                result = perform_evaluation(estimator=estimator,
                                            input_fn=data_lib.build_input_fn(
                                                builder, False),
                                            eval_steps=eval_steps,
                                            model=model,
                                            num_classes=num_classes,
                                            checkpoint_path=ckpt)
            except tf.errors.NotFoundError:
                continue
            if result['global_step'] >= train_steps:
                return
    else:
        estimator.train(data_lib.build_input_fn(builder, True),
                        max_steps=train_steps)
        if FLAGS.mode == 'train_then_eval':  #//@follow-up Estmator Evaluation (10)
            perform_evaluation(estimator=estimator,
                               input_fn=data_lib.build_input_fn(
                                   builder, False),
                               eval_steps=eval_steps,
                               model=model,
                               num_classes=num_classes)