def receive_test_event(data: dict, context: dict) -> bool:
    """Entrypoint for Cloud Function.

  Args:
    data: dict containing base64-encoded proto message.
    context: dict containing event metadata.

  Returns:
    True if message should be ack-ed, else False.
  """
    logging.set_verbosity(logging.INFO)

    dataset = DATASET
    project = PROJECT or google.auth.default()[1]

    try:
        message_bytes = base64.b64decode(data['data'])
        event = metrics_pb2.TestCompletedEvent()
        event.ParseFromString(message_bytes)
    except Exception as e:
        logging.fatal(
            'Failed to parse PubSub message. Will ack message to prevent '
            'more crashes.',
            exc_info=e)
        return True

    alert_handler = (alerts.AlertHandler(project,
                                         event.benchmark_id,
                                         event.debug_info,
                                         level='ERROR'))
    logging.get_absl_logger().addHandler(alert_handler)

    metric_store = bigquery_client.BigQueryMetricStore(
        project=project,
        dataset=dataset,
    )
    try:
        logging.info('Processing test event: %s', str(event))
        job_row, metric_rows = process_proto_message(event, metric_store,
                                                     context.event_id)
        metric_store.insert_status_and_metrics(job_row, metric_rows)
    except Exception as e:
        logging.fatal(
            'Encountered exception while attempting to process message.',
            exc_info=e)

    if alert_handler.has_errors:
        logging.info('Alerts: %s', str(alert_handler._records))
        if SEND_EMAIL_ALERTS:
            _send_email(project, *alert_handler.generate_email_content)
        else:
            logging.info('E-mail alerts disabled.')
    else:
        logging.info('No alerts found.')

    return True
def main(_):
  redirect_logs = flags.FLAGS.redirect_logs
  cache_paths = rconst.Paths(
      data_dir=flags.FLAGS.data_dir, cache_id=flags.FLAGS.cache_id)


  log_file_name = "data_gen_proc_{}.log".format(cache_paths.cache_id)
  log_file = os.path.join(cache_paths.data_dir, log_file_name)
  if log_file.startswith("gs://") and redirect_logs:
    fallback_log_file = os.path.join(tempfile.gettempdir(), log_file_name)
    print("Unable to log to {}. Falling back to {}"
          .format(log_file, fallback_log_file))
    log_file = fallback_log_file

  # This server is generally run in a subprocess.
  if redirect_logs:
    print("Redirecting stdout and stderr to {}".format(log_file))
    log_stream = open(log_file, "wt")  # Note: not tf.gfile.Open().
    stdout = log_stream
    stderr = log_stream
  try:
    if redirect_logs:
      absl_logging.get_absl_logger().addHandler(
          hdlr=logging.StreamHandler(stream=stdout))
      sys.stdout = stdout
      sys.stderr = stderr
      print("Logs redirected.")
    try:
      log_msg("sys.argv: {}".format(" ".join(sys.argv)))

      if flags.FLAGS.seed is not None:
        np.random.seed(flags.FLAGS.seed)

      _generation_loop(
          num_workers=flags.FLAGS.num_workers,
          cache_paths=cache_paths,
          num_readers=flags.FLAGS.num_readers,
          num_neg=flags.FLAGS.num_neg,
          num_train_positives=flags.FLAGS.num_train_positives,
          num_items=flags.FLAGS.num_items,
          spillover=flags.FLAGS.spillover,
          epochs_per_cycle=flags.FLAGS.epochs_per_cycle,
          train_batch_size=flags.FLAGS.train_batch_size,
          eval_batch_size=flags.FLAGS.eval_batch_size,
      )
    except KeyboardInterrupt:
      log_msg("KeyboardInterrupt registered.")
    except:
      traceback.print_exc()
      raise
  finally:
    log_msg("Shutting down generation subprocess.")
    sys.stdout.flush()
    sys.stderr.flush()
    if redirect_logs:
      log_stream.close()
Exemple #3
0
def main(_):
    redirect_logs = flags.FLAGS.redirect_logs
    cache_paths = rconst.Paths(data_dir=flags.FLAGS.data_dir,
                               cache_id=flags.FLAGS.cache_id)

    log_file_name = "data_gen_proc_{}.log".format(cache_paths.cache_id)
    log_file = os.path.join(cache_paths.data_dir, log_file_name)
    if log_file.startswith("gs://") and redirect_logs:
        fallback_log_file = os.path.join(tempfile.gettempdir(), log_file_name)
        print("Unable to log to {}. Falling back to {}".format(
            log_file, fallback_log_file))
        log_file = fallback_log_file

    # This server is generally run in a subprocess.
    if redirect_logs:
        print("Redirecting stdout and stderr to {}".format(log_file))
        log_stream = open(log_file, "wt")  # Note: not tf.gfile.Open().
        stdout = log_stream
        stderr = log_stream
    try:
        if redirect_logs:
            absl_logging.get_absl_logger().addHandler(
                hdlr=logging.StreamHandler(stream=stdout))
            sys.stdout = stdout
            sys.stderr = stderr
            print("Logs redirected.")
        try:
            log_msg("sys.argv: {}".format(" ".join(sys.argv)))

            if flags.FLAGS.seed is not None:
                np.random.seed(flags.FLAGS.seed)

            _generation_loop(
                num_workers=flags.FLAGS.num_workers,
                cache_paths=cache_paths,
                num_readers=flags.FLAGS.num_readers,
                num_neg=flags.FLAGS.num_neg,
                num_train_positives=flags.FLAGS.num_train_positives,
                num_items=flags.FLAGS.num_items,
                spillover=flags.FLAGS.spillover,
                epochs_per_cycle=flags.FLAGS.epochs_per_cycle,
                train_batch_size=flags.FLAGS.train_batch_size,
                eval_batch_size=flags.FLAGS.eval_batch_size,
            )
        except KeyboardInterrupt:
            log_msg("KeyboardInterrupt registered.")
        except:
            traceback.print_exc()
            raise
    finally:
        log_msg("Shutting down generation subprocess.")
        sys.stdout.flush()
        sys.stderr.flush()
        if redirect_logs:
            log_stream.close()
Exemple #4
0
    def __init__(self, model, flags, metrics=metrics.Metrics()):
        self.model_wrapper = ModelWrapper(model)
        self.model_checkpoint_path = flags.save_dir
        self.lr = flags.learning_rate
        self.hidden_dim = flags.hidden_dim
        self.write_summary = flags.write_summary
        self.batch_size = flags.batch_size
        self.metrics = metrics

        logging.get_absl_logger().addHandler(logging_base.StreamHandler())

        return
    def __init__(self,
                 model,
                 flags):
        self.model_wrapper = ModelWrapper(model)
        self.model_checkpoint_path = os.path.join(flags.save_dir, flags.data_set)
        self.lr = flags.learning_rate
        self.write_summary = flags.write_summary
        self.rnn_dim = flags.rnn_dim
        self.dropout = flags.dropout_rate
        self.batch_size = flags.batch_size
        self.metric_class = HawkesMetrics()

        logging.get_absl_logger().addHandler(logging_base.StreamHandler())
    def __init__(self,
                 model,
                 flags):
        self.model_wrapper = ModelWrapper(model)
        self.model_checkpoint_path = flags.save_dir
        self.lr = flags.learning_rate
        self.write_summary = flags.write_summary
        self.rnn_dim = flags.rnn_dim
        self.dropout = flags.dropout
        self.batch_size = flags.batch_size

        logging.get_absl_logger().addHandler(logging_base.StreamHandler())

        return
Exemple #7
0
def setup_logger(print_logs: bool, save_logs: bool, save_path: Path,
                 run_id: str):
    native_logging.root.removeHandler(logging._absl_handler)
    logging._warn_preinit_stderr = False
    formatter = native_logging.Formatter(fmt='%(asctime)s %(message)s',
                                         datefmt='%Y-%d-%m %H:%M:%S')
    handlers = []
    if save_logs:
        write_mode = 'a' if save_path.exists() else 'w'
        save_path.mkdir(parents=True, exist_ok=True)
        log_file = save_path / f"{run_id}.log"
        stream = tf.io.gfile.GFile(str(log_file), write_mode)
        log_handler = native_logging.StreamHandler(stream)
        print('Saving logs in {}'.format(save_path))
        handlers.append(log_handler)
    if print_logs or not save_logs:
        log_handler = native_logging.StreamHandler(sys.stdout)
        handlers.append(log_handler)
    logger = logging.get_absl_logger()
    logger.propagate = False
    for log_handler in handlers:
        log_handler.setFormatter(formatter)
        log_handler.setLevel(logging.INFO)
        logger.addHandler(log_handler)
    return logger
Exemple #8
0
 def test_logger_and_handler(self):
   absl_logger = std_logging.getLogger('absl')
   self.assertIs(absl_logger, logging.get_absl_logger())
   self.assertTrue(isinstance(absl_logger, logging.ABSLLogger))
   self.assertTrue(
       isinstance(logging.get_absl_handler().python_handler.formatter,
                  logging.PythonFormatter))
 def setUp(self):
     self._logger = logging.get_absl_logger()
     self._handler = alerts.AlertHandler(
         project_id='my-project-id',
         benchmark_id='benchmark-id',
         debug_info=None,
     )
     self._logger.addHandler(self._handler)
Exemple #10
0
def _test_register_frame_to_skip():
    """Test skipping frames for line number reporting."""
    def _getline():
        def _getline_inner():
            return logging.get_absl_logger().findCaller()[1]

        return _getline_inner()

    # Check register_frame_to_skip function to see if log frame skipping works.
    line1 = _getline()
    line2 = _getline()
    logging.get_absl_logger().register_frame_to_skip(__file__, '_getline')
    line3 = _getline()
    # Both should be line number of the _getline_inner() call.
    assert (line1 == line2), (line1, line2)
    # line3 should be a line number in this function.
    assert (line2 != line3), (line2, line3)
Exemple #11
0
 def test_find_log_dir_with_nothing(self):
   with mock.patch.object(os.path, 'exists'), \
       mock.patch.object(os.path, 'isdir'), \
       mock.patch.object(logging.get_absl_logger(), 'fatal') as mock_fatal:
     os.path.exists.return_value = False
     os.path.isdir.return_value = False
     log_dir = logging.find_log_dir()
     mock_fatal.assert_called()
     self.assertEqual(None, log_dir)
Exemple #12
0
    def test_call_only_default_args(self):
        expected_arg_two = 3

        @phase_descriptor.PhaseOptions()
        def phase(arg_one=1, arg_two=2):
            self.assertEqual(arg_one, 1)
            # We are changing the arg with the with_args statement when called.
            self.assertEqual(arg_two, expected_arg_two)

        self._test_state.running_phase_state = (
            test_state.PhaseState.from_descriptor(phase, self._test_state,
                                                  logging.get_absl_logger()))
        phase.with_args(arg_two=expected_arg_two)(self._test_state)
Exemple #13
0
    def __init__(self, tag):
        """Create a context object for recording time.

        Args:
            tag (str): the summary tag for the the time.
        """
        self._tag = tag
        caller = logging.get_absl_logger().findCaller()
        # token is a string of filename:lineno:tag
        token = caller[0] + ':' + str(caller[1]) + ':' + tag
        if token not in _contexts:
            _contexts[token] = {'time': 0., 'n': 0}
        self._counter = _contexts[token]
Exemple #14
0
def warning_once(msg, *args):
    """Generate warning message once.

    Note that the current implementation resembles that of the ``log_every_n()```
    function in ``logging`` but reduces the calling stack by one to ensure
    the multiple warning once messages generated at difference places can be
    displayed correctly.

    Args:
        msg: str, the message to be logged.
        *args: The args to be substitued into the msg.
    """
    caller = logging.get_absl_logger().findCaller()
    count = logging._get_next_log_count_per_token(caller)
    logging.log_if(logging.WARNING, msg, not (count % (1 << 62)), *args)
Exemple #15
0
    def test_call_only_default_args_and_plug(self):
        expected_arg_one = 5
        self._test_state.plug_manager.initialize_plugs([ExtraPlug])

        @plugs.plug(custom_plug=ExtraPlug)
        def phase(custom_plug, arg_one=1, arg_two=2):
            self.assertIsInstance(custom_plug, ExtraPlug)
            # We are changing the arg with the with_args statement when called.
            self.assertEqual(arg_one, expected_arg_one)
            self.assertEqual(arg_two, 2)

        self._test_state.running_phase_state = (
            test_state.PhaseState.from_descriptor(phase, self._test_state,
                                                  logging.get_absl_logger()))
        phase.with_args(arg_one=expected_arg_one)(self._test_state)
Exemple #16
0
def set_logging(is_debug, config):
    absl_logger = logging.get_absl_logger()
    # create formatter and add it to the handlers
    formatter = _logging.Formatter("[ %(asctime)-15s %(levelname)s %(filename)15s:%(lineno)-4d " \
                " %(process)-5d ]  %(message)s")

    log_dir = config["solver"]["saver"]["model_path"]
    if not os.path.exists(log_dir):
        os.makedirs(log_dir)

    logging.get_absl_handler().use_absl_log_file(program_name='delta',
                                                 log_dir=log_dir)

    fh = _logging.StreamHandler()
    fh.setLevel(_logging.NOTSET)
    fh.setFormatter(formatter)
    absl_logger.addHandler(fh)

    if is_debug:
        logging.set_verbosity(_logging.DEBUG)
    else:
        logging.set_verbosity(_logging.NOTSET)

    logging.info("Also save log file to directory: {}".format(log_dir))
    def train_step(self, data):
        X, y = data
        training = tf.constant(True)
        cis, cit, ans, ner, pos, qpre = X
        qit, qis = y

        target_vocab_size = self.vocab.get_vocab_size("target")
        unk_index = self.vocab.get_token_id(self.vocab._unk_token, "source")
        start_index = self.vocab.get_token_id(self.vocab._start_token,
                                              "source")
        end_index = self.vocab.get_token_id(self.vocab._end_token, "source")

        y_true = prep_y_true(cis, qit, qis, target_vocab_size, unk_index,
                             start_index, end_index)
        with tf.GradientTape() as tape:
            output_dict = self(X, y, training)
            # shape: ()
            loss = self.compiled_loss(y_true, output_dict['ypred'])
            if (self._train_counter % 10 == 0
                    and (logging.get_absl_logger().getEffectiveLevel()
                         == logging.converter.STANDARD_DEBUG)):
                samples = 3
                tf.py_function(self.debug, [
                    cis[:samples], qit[:samples], ans[:samples],
                    qpre[:samples], output_dict["attentive_weights"][:samples],
                    output_dict["selective_weights"][:samples]
                ], [],
                               name="Debug")
            gradients = tape.gradient(loss, self.trainable_variables)

            self.optimizer.apply_gradients(
                zip(gradients, self.trainable_variables))

        self.compiled_metrics.update_state(y_true, output_dict['ypred'])
        self._train_counter.assign_add(1)
        return {m.name: m.result() for m in self.metrics}
import torch
from absl import logging as absl_logging
from datasets import load_dataset, load_metric
from torch.utils.data import DataLoader

import pycuda.autoinit  # noqa: F401
import pycuda.driver as cuda
import tensorrt as trt
import transformers
from accelerate import Accelerator
from transformers import AutoTokenizer, EvalPrediction, default_data_collator, set_seed
from transformers.trainer_pt_utils import nested_concat, nested_truncate
from utils_qa import postprocess_qa_predictions

TRT_LOGGER = trt.Logger(trt.Logger.WARNING)
absl_logger = absl_logging.get_absl_logger()
absl_logger.setLevel(logging.WARNING)

logger = logging.getLogger(__name__)

parser = argparse.ArgumentParser()

# Required parameters
parser.add_argument(
    "--onnx_model_path",
    default=None,
    type=str,
    required=True,
    help="Path to ONNX model: ",
)
def init_mllogger():
    global mllogger
    mllogger = mllog.MLLogger(logging.get_absl_logger(), jax.host_id())
def main(_):
    # get logger
    save_path = FLAGS.save_dir
    if FLAGS.save_logs:
        if not tf.gfile.Exists(os.path.join(save_path, 'train.log')):
            tf.gfile.MakeDirs(save_path)
            write_mode = 'w'
        else:
            write_mode = 'a'
        stream = tf.gfile.Open(os.path.join(save_path, 'train.log'),
                               write_mode)
        log_handler = native_logging.StreamHandler(stream)
        print('Saving logs in {}'.format(save_path))
    else:
        log_handler = native_logging.StreamHandler(sys.stdout)
    formatter = native_logging.Formatter(
        '%(asctime)s %(levelname)-8s %(message)s')
    log_handler.setFormatter(formatter)
    log_handler.setLevel(logging.INFO)
    logger = logging.get_absl_logger()
    logger.addHandler(log_handler)

    # set up tf.summary
    train_log_dir = save_path + '/train'
    valid_log_dir = save_path + '/valid'
    train_summary_writer = tf.summary.create_file_writer(train_log_dir)
    valid_summary_writer = tf.summary.create_file_writer(valid_log_dir)

    # load data
    dataset_path = os.path.join(FLAGS.data_dir, FLAGS.dataset)
    dataset = DatasetClass(dataset_path, FLAGS.debug)
    sizes = dataset.get_shape()
    train_examples_reversed = dataset.get_examples('train')
    valid_examples = dataset.get_examples('valid')
    test_examples = dataset.get_examples('test')
    filters = dataset.get_filters()
    logging.info('\t Dataset shape: %s', (str(sizes)))

    # save config
    config_path = os.path.join(save_path, 'config.json')
    if FLAGS.save_logs and not tf.gfile.Exists(config_path):
        with tf.gfile.Open(config_path, 'w') as fjson:
            json.dump(train_utils.get_config_dict(CONFIG), fjson)

    # create and build model
    tf.keras.backend.set_floatx(FLAGS.dtype)
    model = getattr(models, FLAGS.model)(sizes, FLAGS)
    model.build(input_shape=(1, 3 + sum(FLAGS.node_batch_per_level)))
    trainable_params = train_utils.count_params(model)
    trainer = CFTrainer(sizes, FLAGS)
    logging.info('\t Total number of trainable parameters %s',
                 (trainable_params))

    # restore or create checkpoint
    if FLAGS.save_model:
        ckpt = tf.train.Checkpoint(step=tf.Variable(0),
                                   optimizer=trainer.optimizer,
                                   net=model)
        manager = tf.train.CheckpointManager(ckpt, save_path, max_to_keep=1)
        if manager.latest_checkpoint:
            ckpt.restore(manager.latest_checkpoint)
            logging.info('\t Restored from %s', (manager.latest_checkpoint))
        else:
            logging.info('\t Initializing from scratch.')
    else:
        logging.info('\t Initializing from scratch.')

    # train model
    logging.info('\t Start training')
    early_stopping_counter = 0
    best_mr = None
    best_epoch = None
    best_weights = None
    if FLAGS.save_model:
        epoch = ckpt.step
    else:
        epoch = 0

    if int(epoch) < FLAGS.max_epochs:
        while int(epoch) < FLAGS.max_epochs:
            if FLAGS.save_model:
                epoch.assign_add(1)
            else:
                epoch += 1

            # Train step
            start = time.perf_counter()
            train_batch = train_examples_reversed.batch(FLAGS.batch_size)
            train_loss = trainer.train_step(model, train_batch).numpy()
            end = time.perf_counter()
            execution_time = (end - start)
            logging.info('\t Epoch %i | train loss: %.4f | total time: %.4f',
                         int(epoch), train_loss, execution_time)
            with train_summary_writer.as_default():
                tf.summary.scalar('loss', train_loss, step=epoch)

            if FLAGS.save_model and int(epoch) % FLAGS.checkpoint == 0:
                save_path = manager.save()
                logging.info('\t Saved checkpoint for epoch %i: %s',
                             int(epoch), save_path)

            if int(epoch) % FLAGS.valid == 0:
                # compute valid loss
                valid_batch = valid_examples.batch(FLAGS.batch_size)
                valid_loss = trainer.valid_step(model, valid_batch).numpy()
                logging.info('\t Epoch %i | average valid loss: %.4f',
                             int(epoch), valid_loss)
                with valid_summary_writer.as_default():
                    tf.summary.scalar('loss', valid_loss, step=epoch)

                # compute validation metrics
                valid = train_utils.ranks_to_metrics_dict(
                    model.eval(valid_examples, filters, batch_size=500))
                logging.info(
                    train_utils.format_partial_metrics(valid, split='valid'))
                with valid_summary_writer.as_default():
                    tf.summary.scalar('mrs', valid['MR'], step=epoch)
                    tf.summary.scalar('mrrs', valid['MRR'], step=epoch)
                    tf.summary.scalar('hits@[1]',
                                      valid['hits@[1,3,10]'][1],
                                      step=epoch)
                    tf.summary.scalar('hits@[3]',
                                      valid['hits@[1,3,10]'][3],
                                      step=epoch)
                    tf.summary.scalar('hits@[10]',
                                      valid['hits@[1,3,10]'][10],
                                      step=epoch)

                # tree eval
                logging.info('\t Building tree...')
                model.build_tree()
                logging.info('\t Tree build finished.')
                stats = model.tree.stats()
                logging.info(
                    train_utils.format_tree_stats(FLAGS.nodes_per_level,
                                                  *stats))
                for k in FLAGS.top_k:
                    logging.info('\t k is: %i', int(k))
                    valid_tree = train_utils.ranks_to_metrics_dict(
                        model.top_k_tree_eval(valid_examples, filters, k=k))
                    logging.info(
                        train_utils.format_partial_metrics(valid_tree,
                                                           split='valid'))
                    with valid_summary_writer.as_default():
                        tf.summary.scalar('mrs top {}'.format(k),
                                          valid_tree['MR'],
                                          step=epoch)
                        tf.summary.scalar('mrrs top {}'.format(k),
                                          valid_tree['MRR'],
                                          step=epoch)
                        tf.summary.scalar('hits@[1] top {}'.format(k),
                                          valid_tree['hits@[1,3,10]'][1],
                                          step=epoch)
                        tf.summary.scalar('hits@[3] top {}'.format(k),
                                          valid_tree['hits@[1,3,10]'][3],
                                          step=epoch)
                        tf.summary.scalar('hits@[10] top {}'.format(k),
                                          valid_tree['hits@[1,3,10]'][10],
                                          step=epoch)

                # early stopping
                valid_mr = valid['MR']
                if not best_mr or valid_mr < best_mr:
                    best_mr = valid_mr
                    early_stopping_counter = 0
                    best_epoch = int(epoch)
                    best_weights = copy.copy(model.get_weights())
                else:
                    early_stopping_counter += 1
                    if early_stopping_counter == FLAGS.patience:
                        logging.info('\t Early stopping')
                        break

        logging.info('\t Optimization finished')
        logging.info('\t Evaluating best model from epoch %s', best_epoch)
        model.set_weights(best_weights)
        if FLAGS.save_model:
            model.save_weights(os.path.join(save_path, 'best_model.ckpt'))

        # validation metrics
        valid = train_utils.ranks_to_metrics_dict(
            model.eval(valid_examples, filters, batch_size=500))
        logging.info(train_utils.format_partial_metrics(valid, split='valid'))

        # test metrics
        test = train_utils.ranks_to_metrics_dict(
            model.eval(test_examples, filters, batch_size=500))
        logging.info(train_utils.format_partial_metrics(test, split='test'))

        # tree eval
        logging.info('\t Building tree...')
        model.build_tree()
        logging.info('\t Tree build finished.')
        stats = model.tree.stats()
        logging.info(
            train_utils.format_tree_stats(FLAGS.nodes_per_level, *stats))

        for k in FLAGS.top_k:
            logging.info('\t k is: %i', int(k))

            # valid tree eval
            valid_tree = train_utils.ranks_to_metrics_dict(
                model.top_k_tree_eval(valid_examples, filters, k=k))
            logging.info(
                train_utils.format_partial_metrics(valid_tree, split='valid'))
            # test tree eval
            test_tree = train_utils.ranks_to_metrics_dict(
                model.top_k_tree_eval(test_examples, filters, k=k))
            logging.info(
                train_utils.format_partial_metrics(test_tree, split='tree'))

    else:
        logging.info('\t Training completed')
Exemple #21
0
 def test_get_absl_logger(self):
   self.assertIsInstance(
       logging.get_absl_logger(), logging.ABSLLogger)
Exemple #22
0
# TODO: look into Python 3.10 installation issue with conda
SUPPORTED_PYTHON_VERSION: t.List[str] = ["3.7", "3.8", "3.9"]

NVIDIA_REPO_URL: str = (
    "https://developer.download.nvidia.com/compute/cuda/repos/{}/x86_64")
NVIDIA_ML_REPO_URL: str = (
    "https://developer.download.nvidia.com/compute/machine-learning/repos/{}/x86_64"
)

HTTP_RETRY_ATTEMPTS: int = 2
HTTP_RETRY_WAIT_SECS: int = 20

# setup some default
logging.get_absl_handler().setFormatter(ColoredFormatter())
log = logging.get_absl_logger()

if not os.path.exists("./.env"):
    log.warning(f"Make sure to create .env file at {os.getcwd()}")
else:
    load_dotenv()

if os.geteuid() == 0:
    # We only use docker_client when running as root.
    log.debug("Creating an instance of DockerClient")
    docker_client = DockerClient(base_url="unix://var/run/docker.sock",
                                 timeout=3600,
                                 tls=True)


class LogsMixin(object):
Exemple #23
0
 def _getline_inner():
     return logging.get_absl_logger().findCaller()[1]
Exemple #24
0
def init_mllogger():
    global mllogger
    mllogger = mllog.MLLogger(logging.get_absl_logger(),
                              jax.host_id(),
                              full=FLAGS.mlperf_logs)
Exemple #25
0
import tensorflow as tf
import tensorflow_datasets as tfds
import numpy as np
import os
import matplotlib.pyplot as plt
from tensorflow.keras.utils import to_categorical
from absl import flags, logging
from tqdm.auto import tqdm
from PIL import Image
from tf_rlib import datasets

FLAGS = flags.FLAGS
LOGGER = logging.get_absl_logger()


class Cifar10Semi(datasets.Dataset):
    """ use tfrecords could speed-up 5%~15% comparing to numpy. because tfrecords format only use 50% space than numpy.
    """
    def __init__(self, labels_persentage):
        super(Cifar10Semi, self).__init__()
        self.labels_persentage = labels_persentage
        self.dtype = np.float16 if FLAGS.amp else np.float32
        save_path = '/ws_data/tmp/cifar10_labels{}'.format(labels_persentage)
        if not os.path.exists(save_path):
            os.makedirs(save_path)
        if FLAGS.amp:
            self.train_file = os.path.join(save_path, 'train16.tfrecords')
            self.valid_file = os.path.join(save_path, 'valid16.tfrecords')
        else:
            self.train_file = os.path.join(save_path, 'train32.tfrecords')
            self.valid_file = os.path.join(save_path, 'valid32.tfrecords')
Exemple #26
0
def main(_):
    # get logger
    if FLAGS.save_logs:
        if not os.path.exists(os.path.join(FLAGS.save_dir, 'train.log')):
            os.makedirs(FLAGS.save_dir)
            write_mode = 'w'
        else:
            write_mode = 'a'
        stream = open(os.path.join(FLAGS.save_dir, 'train.log'), write_mode)
        log_handler = native_logging.StreamHandler(stream)
        print('Saving logs in {}'.format(FLAGS.save_dir))
    else:
        log_handler = native_logging.StreamHandler(sys.stdout)
    formatter = native_logging.Formatter(
        '%(asctime)s %(levelname)-8s %(message)s')
    log_handler.setFormatter(formatter)
    log_handler.setLevel(logging.INFO)
    logger = logging.get_absl_logger()
    logger.addHandler(log_handler)

    # load data
    dataset_path = os.path.join(FLAGS.data_dir, FLAGS.dataset)
    dataset = DatasetFn(dataset_path, FLAGS.debug)
    sizes = dataset.get_shape()
    train_examples_reversed = dataset.get_examples('train')
    valid_examples = dataset.get_examples('valid')
    test_examples = dataset.get_examples('test')
    filters = dataset.get_filters()
    logging.info('\t Dataset shape: %s', (str(sizes)))

    # save config
    config_path = os.path.join(FLAGS.save_dir, 'config.json')
    if FLAGS.save_logs:
        with open(config_path, 'w') as fjson:
            json.dump(train_utils.get_config_dict(), fjson)

    # create and build model
    tf.keras.backend.set_floatx(FLAGS.dtype)
    model = getattr(models, FLAGS.model)(sizes, FLAGS)
    model.build(input_shape=(1, 3))
    trainable_params = train_utils.count_params(model)
    trainer = KGTrainer(sizes, FLAGS)
    logging.info('\t Total number of trainable parameters %s',
                 (trainable_params))

    # restore or create checkpoint
    if FLAGS.save_model:
        ckpt = tf.train.Checkpoint(step=tf.Variable(0),
                                   optimizer=trainer.optimizer,
                                   net=model)
        manager = tf.train.CheckpointManager(ckpt,
                                             FLAGS.save_dir,
                                             max_to_keep=1)
        if manager.latest_checkpoint:
            ckpt.restore(manager.latest_checkpoint)
            logging.info('\t Restored from %s', (manager.latest_checkpoint))
        else:
            logging.info('\t Initializing from scratch.')
    else:
        logging.info('\t Initializing from scratch.')

    # train model
    logging.info('\t Start training')
    early_stopping_counter = 0
    best_mrr = None
    best_epoch = None
    best_weights = None
    if FLAGS.save_model:
        epoch = ckpt.step
    else:
        epoch = 0

    if int(epoch) < FLAGS.max_epochs:
        while int(epoch) < FLAGS.max_epochs:
            if FLAGS.save_model:
                epoch.assign_add(1)
            else:
                epoch += 1

            # Train step
            start = time.perf_counter()
            train_batch = train_examples_reversed.batch(FLAGS.batch_size)
            train_loss = trainer.train_step(model, train_batch).numpy()
            end = time.perf_counter()
            execution_time = (end - start)
            logging.info('\t Epoch %i | train loss: %.4f | total time: %.4f',
                         int(epoch), train_loss, execution_time)

            if FLAGS.save_model and int(epoch) % FLAGS.checkpoint == 0:
                save_path = manager.save()
                logging.info('\t Saved checkpoint for epoch %i: %s',
                             int(epoch), save_path)

            if int(epoch) % FLAGS.valid == 0:
                # compute valid loss
                valid_batch = valid_examples.batch(FLAGS.batch_size)
                valid_loss = trainer.valid_step(model, valid_batch).numpy()
                logging.info('\t Epoch %i | average valid loss: %.4f',
                             int(epoch), valid_loss)

                # compute validation metrics
                valid = train_utils.avg_both(
                    *model.eval(valid_examples, filters))
                logging.info(train_utils.format_metrics(valid, split='valid'))

                # early stopping
                valid_mrr = valid['MRR']
                if not best_mrr or valid_mrr > best_mrr:
                    best_mrr = valid_mrr
                    early_stopping_counter = 0
                    best_epoch = int(epoch)
                    best_weights = copy.copy(model.get_weights())
                else:
                    early_stopping_counter += 1
                    if early_stopping_counter == FLAGS.patience:
                        logging.info('\t Early stopping')
                        break

        logging.info('\t Optimization finished')
        logging.info('\t Evaluating best model from epoch %s', best_epoch)
        model.set_weights(best_weights)
        if FLAGS.save_model:
            model.save_weights(os.path.join(FLAGS.save_dir, 'best_model.ckpt'))

        # validation metrics
        valid = train_utils.avg_both(*model.eval(valid_examples, filters))
        logging.info(train_utils.format_metrics(valid, split='valid'))

        # test metrics
        test = train_utils.avg_both(*model.eval(test_examples, filters))
        logging.info(train_utils.format_metrics(test, split='test'))
    else:
        logging.info('\t Training completed')