def get_strategy(params):
    if params.type == 'gpu':
        logging.info('Creating GPU strategy')
        return tf.distribute.OneDeviceStrategy(device='/gpu:0')

    if params.type == 'cpu':
        logging.info('Creating CPU strategy')
        return tf.distribute.OneDeviceStrategy(device='/cpu:0')

    if params.type == 'multi_gpu':
        logging.info('Creating Multi GPU strategy')
        return tf.distribute.MirroredStrategy()

    if params.type == 'tpu':
        logging.info('Creating TPU strategy')

        tpu_name = params.name

        if tpu_name == '':
            if 'TPU_NAME' not in os.environ:
                raise AssertionError(
                    'Failed to fetch TPU name, please set ENV VAR `TPU_NAME` or specify TPU name in config '
                )  # noqa: E501

            tpu_name = os.environ['TPU_NAME']
            logging.warning(
                'Using {} as TPU name from ENV VAR `TPU_NAME`'.format(
                    tpu_name))

        else:
            if 'TPU_NAME' in os.environ:
                tpu_name = os.environ['TPU_NAME']

                logging.warning(
                    'Changed TPU name from {} to {} (overided with ENV VAR `TPU_NAME`)'  # noqa: E501
                    .format(params.name, tpu_name))

        logging.info(
            'Configuring TPU: {} with correct tensorflow version'.format(
                tpu_name))

        c = Client(tpu_name)
        c.configure_tpu_version(tf.__version__, restart_type='ifNeeded')

        logging.info(
            'Done Configuring TPU: {} with tensorflow version: {}'.format(
                tpu_name, tf.__version__))

        resolver = tf.distribute.cluster_resolver.TPUClusterResolver.connect(
            tpu_name)

        return tf.distribute.TPUStrategy(resolver)

    raise ValueError('Unsupported strategy requested')
示例#2
0
def set_colab_runtime_version(version=tf.__version__):
    # noinspection PyUnresolvedReferences
    from cloud_tpu_client import Client
    tpu_address = os.environ.get("COLAB_TPU_ADDR")
    client = Client(tpu="grpc://" + tpu_address)
    client.configure_tpu_version(version, restart_type='ifNeeded')
    client.wait_for_healthy()
示例#3
0
    def connect_tpu(self, tpu_name):
        self.tpus.append(tpu_name)

        c = Client(tpu_name)
        c.wait_for_healthy(interval=5)
        c.configure_tpu_version(self.tpu_version, restart_type='ifNeeded')
        c.wait_for_healthy(interval=5)

        tpu_resolver = tf.distribute.cluster_resolver.TPUClusterResolver(
            tpu=tpu_name, job_name=tpu_name)
        self.tpu_resolvers[tpu_name] = tpu_resolver

        if len(self.tpu_resolvers) == 1:
            tf.config.experimental_connect_to_cluster(tpu_resolver)
        else:
            # TODO: might want to check health of nodes in self.tpus before redefining the cluster
            tf.config.experimental_connect_to_cluster(
                tf.distribute.cluster_resolver.UnionResolver(
                    *self.tpu_resolvers.values()))
        topology = tf.tpu.experimental.initialize_tpu_system(tpu_resolver)
        self.strategies[tpu_name] = tf.distribute.TPUStrategy(tpu_resolver)

        self.topologies[tpu_name] = topology
        self.total_tpu_count = len(self.tpus)
示例#4
0
import os

parser = argparse.ArgumentParser()
parser.add_argument('--target-version',
                    type=str,
                    required=True,
                    help='target TPU Runtime version')

args = parser.parse_args()

tpu_config_env = os.environ.get('TPU_CONFIG')

if not tpu_config_env:
    tf.logging.info('Missing TPU_CONFIG, use CPU/GPU for training.')
    exit()

tpu_node = json.loads(tpu_config_env)

c = Client(tpu_node['tpu_node_name'])
# Wait for the TPU node to be up and Healthy with default runtime
c.wait_for_healthy()
# Change runtime to target_version
c.configure_tpu_version(args.target_version, restart_type='ifNeeded')
c.wait_for_healthy()

# Get the TPU_IP_ADDRESS
# This will be used in the subsequent steps to set XRT_TPUC_CONFIG
endpoints = c.network_endpoints()
for endpoint in endpoints:
    print(endpoint['ipAddress'])
def main(argv):
    import tensorflow as tf  # need to be here to have the env variables defined
    tf.get_logger().propagate = False

    # masking error related to cache
    logger.getLogger('googleapiclient.discovery_cache').setLevel(logger.ERROR)

    # set level of verbosity
    if FLAGS.verbosity_level == 'DEBUG':
        logging.set_verbosity(logging.DEBUG)
        print('logging.DEBUG')
    elif FLAGS.verbosity_level == 'INFO':
        logging.set_verbosity(logging.INFO)
    elif FLAGS.verbosity_level == 'WARNING':
        logging.set_verbosity(logging.WARNING)
    elif FLAGS.verbosity_level == 'ERROR':
        logging.set_verbosity(logging.ERROR)
    elif FLAGS.verbosity_level == 'FATAL':
        logging.set_verbosity(logging.FATAL)
    else:
        logging.set_verbosity(logging.INFO)

    # set level of verbosity for Tensorflow
    if FLAGS.verbosity_level == 'VERBOSE':
        tf.debugging.set_log_device_placement(True)
        tf.autograph.set_verbosity(10, alsologtostdout=False)

    # logger.getLogger('googleapiclient.discovery_cache').setLevel(logging.ERROR)

    # fmt = "[%(levelname)s %(asctime)s %(filename)s:%(lineno)s] %(message)s"
    fmt = "[%(levelname)s] %(message)s"
    formatter = logger.Formatter(fmt)
    logging.get_absl_handler().setFormatter(formatter)
    logging.get_absl_handler().python_handler.stream = sys.stdout
    logging.set_stderrthreshold(logging.WARNING)

    # level_log = 'INFO'

    # # Instantiates a client
    # client = google.cloud.logging.Client()
    #
    # # Connects the logger to the root logging handler; by default this captures
    # # all logs at INFO level and higher
    # client.setup_logging(log_level=FLAGS.verbosity)
    #
    # print('loggerDict:', logger.root.manager.loggerDict.keys())
    #
    # for i in logger.root.manager.loggerDict.keys():
    #     if i=='tensorflow':
    #        #print('-> propagate False')
    #         logger.getLogger(i).propagate = False  # needed
    #     elif i=='google.auth':
    #         logger.getLogger(i).propagate = False  # needed
    #     elif i=='google_auth_httplib2':
    #         logger.getLogger(i).propagate = False  # needed
    #     elif i=='pyasn1':
    #         logger.getLogger(i).propagate = False  # needed
    #     elif i=='sklearn':
    #         logger.getLogger(i).propagate = False  # needed
    #     elif i=='google.cloud':
    #         logger.getLogger(i).propagate = False  # needed
    #     else:
    #         logger.getLogger(i).propagate = True # needed
    #     handler = logger.getLogger(i).handlers
    #     if handler != []:
    #         #print("logger's name=", i,handler)
    #         for h in handler:
    #             #print('    -> ', h)
    #             if h.__class__ == logger.StreamHandler:
    #                 #print('    -> name=', h.__class__)
    #                 h.setStream(sys.stdout)
    #                 h.setLevel(level_log)
    #                 #print('    --> handlers =', h)
    #
    root_logger = logger.getLogger()
    # root_logger.handlers=[handler for handler in root_logger.handlers if isinstance(handler, (CloudLoggingHandler, ContainerEngineHandler, logging.ABSLHandler))]
    #
    for handler in root_logger.handlers:
        print("----- handler ", handler)
        print("---------class ", handler.__class__)

    #     if handler.__class__ == CloudLoggingHandler:
    #         handler.setStream(sys.stdout)
    #         handler.setLevel(level_log)
    #     if handler.__class__ == logging.ABSLHandler:
    #         handler.python_handler.stream = sys.stdout
    #         handler.setLevel(level_log)
    # #        handler.handler.setStream(sys.stdout)
    #
    # for handler in root_logger.handlers:
    #     print("----- handler ", handler)
    #
    # # Instantiates a client
    # #client = google.cloud.logging.Client()
    #
    # # Connects the logger to the root logging handler; by default this captures
    # # all logs at INFO level and higher
    # #client.setup_logging()
    #
    # # redirect abseil logging messages to the stdout stream
    # #logging.get_absl_handler().python_handler.stream = sys.stdout
    #
    # # some test
    # #tf.get_logger().addHandler(logger.StreamHandler(sys.stdout))
    # #tf.get_logger().disabled = True
    # #tf.autograph.set_verbosity(5 ,alsologtostdout=True)
    #
    # ## DEBUG
    # #fmt = "[%(levelname)s %(asctime)s %(filename)s:%(lineno)s] %(message)s"
    # #formatter = logger.Formatter(fmt)
    # #logging.get_absl_handler().setFormatter(formatter)
    #
    # # set level of verbosity
    # #logging.set_verbosity(logging.DEBUG)
    #
    # print(' 0 print --- ')
    # logging.info(' 1 logging:')
    # logging.info(' 2 logging:')
    #
    # print(' 3 print --- ')
    # logging.debug(' 4 logging-test-debug')
    # logging.info(' 5 logging-test-info')
    # logging.warning(' 6 logging-test-warning')
    # logging.error(' 7 logging test-error')
    # print(' 8 print --- ')
    # #_=BertTokenizer.from_pretrained('bert-base-uncased')
    # print(' 9 print --- ')
    # _= tf.distribute.MirroredStrategy()
    # print('10 print --- ')
    # ## DEBUG

    print('logging.get_verbosity()', logging.get_verbosity())

    # print flags
    abseil_flags = [
        'logtostderr', 'alsologtostderr', 'log_dir', 'v', 'verbosity',
        'stderrthreshold', 'showprefixforinfo', 'run_with_pdb',
        'pdb_post_mortem', 'run_with_profiling', 'profile_file',
        'use_cprofile_for_profiling', 'only_check_args', 'flagfile', 'undefok'
    ]
    logging.info('-- Custom flags:')
    for name in list(FLAGS):
        if name not in abseil_flags:
            logging.info('custom flags: {:40} with value: {:50}'.format(
                name, str(FLAGS[name].value)))
    logging.info('\n-- Abseil flags:')
    for name in list(FLAGS):
        if name in abseil_flags:
            logging.info('abseil flags: {:40} with value: {:50}'.format(
                name, str(FLAGS[name].value)))

    if os.environ.get('LOG_FILE_TO_WRITE') is not None:
        logging.info('os.environ[LOG_FILE_TO_WRITE]: {}'.format(
            os.environ['LOG_FILE_TO_WRITE']))
        # split_path = os.environ['LOG_FILE_TO_WRITE'].split('/')
        # logging.get_absl_handler().use_absl_log_file(split_path[-1], '/'.join(split_path[:-1]))

    # fmt = "[%(levelname)s %(asctime)s %(filename)s:%(lineno)s] %(message)s"
    # formatter = logger.Formatter(fmt)
    # logging.get_absl_handler().setFormatter(formatter)

    # set level of verbosity
    # logging.set_verbosity(FLAGS.verbosity)
    # logging.set_stderrthreshold(FLAGS.verbosity)

    logging.info(tf.__version__)
    logging.info(tf.keras.__version__)
    logging.info(list(FLAGS))
    logging.debug('flags: \n {}'.format(FLAGS))
    logging.debug('env variables: \n{}'.format(os.environ))
    logging.debug('current dir: {}'.format(os.getcwd()))
    logging.debug('__package__: {}'.format(__package__))
    logging.debug('__name__: {}'.format(__name__))
    logging.debug('__file__: {}'.format(__file__))

    # only for HP tuning!
    if os.environ.get('CLOUD_ML_HP_METRIC_TAG') is not None:
        logging.info('this is a hyper parameters job !')

        # setup the hp flag
        FLAGS.is_hyperparameter_tuning = True
        logging.info('FLAGS.is_hyperparameter_tuning: {}'.format(
            FLAGS.is_hyperparameter_tuning))

        logging.info('os.environ[CLOUD_ML_HP_METRIC_TAG]: {}'.format(
            os.environ['CLOUD_ML_HP_METRIC_TAG']))
        logging.info('os.environ[CLOUD_ML_HP_METRIC_FILE]: {}'.format(
            os.environ['CLOUD_ML_HP_METRIC_FILE']))
        logging.info('os.environ[CLOUD_ML_TRIAL_ID]: {}'.format(
            os.environ['CLOUD_ML_TRIAL_ID']))

        # variable name for hyper parameter tuning
        metric_accuracy = os.environ['CLOUD_ML_HP_METRIC_TAG']
        logging.info('metric accuracy name: {}'.format(metric_accuracy))
    else:
        metric_accuracy = 'NotDefined'

    if os.environ.get('TF_CONFIG') is not None:
        logging.info('os.environ[TF_CONFIG]: {}'.format(
            os.environ['TF_CONFIG']))
    else:
        logging.error('os.environ[TF_CONFIG] doesn\'t exist !')

    if FLAGS.use_tpu:
        # Check or update the TensorFlow on the TPU cluster to match the one of the VM
        logging.info(
            'setting up TPU: check that TensorFlow version is the same on the VM and on the TPU cluster'
        )
        client_tpu = Client()

        # define TPU strategy before any ops
        client_tpu.configure_tpu_version(tf.__version__,
                                         restart_type='ifNeeded')
        logging.info('setting up TPU: cluster resolver')
        tpu_cluster_resolver = tf.distribute.cluster_resolver.TPUClusterResolver(
        )
        logging.info('setting up TPU: \n {}'.format(tpu_cluster_resolver))
        logging.info('running on TPU: \n {}'.format(
            tpu_cluster_resolver.cluster_spec().as_dict()['worker']))
        tf.config.experimental_connect_to_cluster(tpu_cluster_resolver)
        tf.tpu.experimental.initialize_tpu_system(tpu_cluster_resolver)
        strategy = tf.distribute.experimental.TPUStrategy(tpu_cluster_resolver)
    else:
        strategy = tf.distribute.MirroredStrategy()
        print('do nothing')
    logging.info('Number of devices: {}'.format(strategy.num_replicas_in_sync))

    # choose language's model and tokenizer
    MODELS = [(TFBertModel, BertTokenizer, 'bert-base-multilingual-uncased')]
    model_index = 0  # BERT
    # model_class = MODELS[model_index][0]  # i.e TFBertModel
    # tokenizer_class = MODELS[model_index][1]  # i.e BertTokenizer
    pretrained_weights = MODELS[model_index][
        2]  # 'i.e bert-base-multilingual-uncased'

    # download   pre trained model:
    if FLAGS.pretrained_model_dir:
        # download pre trained model from a bucket
        logging.info('downloading pretrained model!')
        search = re.search('gs://(.*?)/(.*)', FLAGS.pretrained_model_dir)
        if search is not None:
            bucket_name = search.group(1)
            blob_name = search.group(2)
            local_path = '.'
            mu.download_blob(bucket_name, blob_name, local_path)
            pretrained_model_dir = local_path + '/' + blob_name
        else:
            pretrained_model_dir = FLAGS.pretrained_model_dir
    else:
        # download pre trained model from internet
        pretrained_model_dir = '.'

    # some check
    logging.info('Batch size:            {:6}/{:6}'.format(
        FLAGS.batch_size_train, FLAGS.batch_size_eval))
    logging.info('Step per epoch:        {:6}/{:6}'.format(
        FLAGS.steps_per_epoch_train, FLAGS.steps_per_epoch_eval))
    logging.info('Total number of batch: {:6}/{:6}'.format(
        FLAGS.steps_per_epoch_train * (FLAGS.epochs + 1),
        FLAGS.steps_per_epoch_eval * 1))

    # with tf.summary.create_file_writer(FLAGS.output_dir,
    #                                   filename_suffix='.oup',
    #                                   name='test').as_default():
    #    tf.summary.scalar('metric_accuracy', 1.0, step=1)
    # print('-- 00001')
    #  read TFRecords files, shuffle, map and batch size
    train_dataset = tf_bert.build_dataset(FLAGS.input_train_tfrecords,
                                          FLAGS.batch_size_train, 2048)
    valid_dataset = tf_bert.build_dataset(FLAGS.input_eval_tfrecords,
                                          FLAGS.batch_size_eval, 2048)

    # set repeat
    train_dataset = train_dataset.repeat(FLAGS.epochs + 1)
    valid_dataset = valid_dataset.repeat(2)

    # reset all variables used by Keras
    tf.keras.backend.clear_session()

    # create and compile the Keras model in the context of strategy.scope
    with strategy.scope():
        logging.debug('pretrained_model_dir={}'.format(pretrained_model_dir))
        model = tf_bert.create_model(pretrained_weights,
                                     pretrained_model_dir=pretrained_model_dir,
                                     num_labels=FLAGS.num_classes,
                                     learning_rate=FLAGS.learning_rate,
                                     epsilon=FLAGS.epsilon)
    # train the model
    tf_bert.train_and_evaluate(model,
                               num_epochs=FLAGS.epochs,
                               steps_per_epoch=FLAGS.steps_per_epoch_train,
                               train_data=train_dataset,
                               validation_steps=FLAGS.steps_per_epoch_eval,
                               eval_data=valid_dataset,
                               output_dir=FLAGS.output_dir,
                               n_steps_history=FLAGS.n_steps_history,
                               FLAGS=FLAGS,
                               decay_type=FLAGS.decay_type,
                               learning_rate=FLAGS.learning_rate,
                               s=FLAGS.decay_learning_rate,
                               n_batch_decay=FLAGS.n_batch_decay,
                               metric_accuracy=metric_accuracy)
import tensorflow as tf
from cloud_tpu_client import Client
import logging
# from tensorflow.python.eager import context

# tf.get_logger().setLevel(logging.INFO)
tf.debugging.set_log_device_placement(True)


@tf.function
def red_sum(a, b, c, d):
    return tf.reduce_sum(a) + tf.reduce_sum(b) + tf.reduce_sum(
        c) + tf.reduce_sum(d)


c = Client("kindiana-nettest1")
c.configure_tpu_version(tf.__version__, restart_type='ifNeeded')
resolver1 = tf.distribute.cluster_resolver.TPUClusterResolver(
    tpu="kindiana-nettest1", job_name='tpu0')
tf.config.experimental_connect_to_cluster(resolver1)
tf.tpu.experimental.initialize_tpu_system(resolver1)

# Create the tensors before benchmarking
# looks like ~2GB tensors are the biggest you can send
with tf.device('/job:tpu0/replica:0/task:0/device:TPU:0'):
    tpu0_0 = tf.Variable(tf.fill([256, 1024, 1024], 1))
    tpu0_1 = tf.Variable(tf.fill([256, 1024, 1024], 1))
    tpu0_2 = tf.Variable(tf.fill([256, 1024, 1024], 1))
    tpu0_3 = tf.Variable(tf.fill([256, 1024, 1024], 1))
with tf.device('/job:tpu0/replica:0/task:0/device:CPU:0'):
    tpu0_cpu1 = tf.Variable(tf.fill([256, 1024, 1024], 1))
#  Copyright 2018 Google LLC
#
# * Licensed under the Apache License, Version 2.0 (the "License");
# * you may not use this file except in compliance with the License.
# * You may obtain a copy of the License at
# *
# *      http://www.apache.org/licenses/LICENSE-2.0
# *
# * Unless required by applicable law or agreed to in writing, software
# * distributed under the License is distributed on an "AS IS" BASIS,
# * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# * See the License for the specific language governing permissions and
# * limitations under the License.
# 
# Use with "pytorch-nightly" TPU version only
from cloud_tpu_client import Client
import argparse
parser = argparse.ArgumentParser()
parser.add_argument('--tpu-name', type=str, required=True, help='Name of the TPU Instance')
parser.add_argument('--target-version', type=str, required=True, help='Target TPU Runtime version')
args = parser.parse_args()
c = Client(args.tpu_name)
c.configure_tpu_version(args.target_version)
c.wait_for_healthy()
示例#8
0
def get_colab_runtime_version():
    # noinspection PyUnresolvedReferences
    from cloud_tpu_client import Client
    tpu_address = os.environ.get("COLAB_TPU_ADDR")
    client = Client(tpu="grpc://" + tpu_address)
    return client.runtime_version()
示例#9
0
    parser.add_argument('--saved_weights',
                        type=str,
                        help='Load the pre-trained weights')

    parser.add_argument('--batch_size',
                        type=int,
                        default=128,
                        help='batch_size')

    in_args = parser.parse_args()

    return in_args


Client().configure_tpu_version(tf.__version__, restart_type='ifNeeded')
try:
    tpu = tf.distribute.cluster_resolver.TPUClusterResolver()  # TPU detection
    print('Running on TPU ', tpu.cluster_spec().as_dict()['worker'])
except ValueError:
    raise BaseException('ERROR: Not connected to a TPU runtime;')

tf.config.experimental_connect_to_cluster(tpu)
tf.tpu.experimental.initialize_tpu_system(tpu)
tpu_strategy = tf.distribute.experimental.TPUStrategy(tpu)

args = get_input_args()
iden = 'efficientnetb2'  # @param ['efficientnetb2','densenet201','resnet50v2']
DATA_DIM = 75  # @param
IMG_DIM = 256  # @param
NB_CHANNEL = 3  # @param