def receive_test_event(data: dict, context: dict) -> bool: """Entrypoint for Cloud Function. Args: data: dict containing base64-encoded proto message. context: dict containing event metadata. Returns: True if message should be ack-ed, else False. """ logging.set_verbosity(logging.INFO) dataset = DATASET project = PROJECT or google.auth.default()[1] try: message_bytes = base64.b64decode(data['data']) event = metrics_pb2.TestCompletedEvent() event.ParseFromString(message_bytes) except Exception as e: logging.fatal( 'Failed to parse PubSub message. Will ack message to prevent ' 'more crashes.', exc_info=e) return True alert_handler = (alerts.AlertHandler(project, event.benchmark_id, event.debug_info, level='ERROR')) logging.get_absl_logger().addHandler(alert_handler) metric_store = bigquery_client.BigQueryMetricStore( project=project, dataset=dataset, ) try: logging.info('Processing test event: %s', str(event)) job_row, metric_rows = process_proto_message(event, metric_store, context.event_id) metric_store.insert_status_and_metrics(job_row, metric_rows) except Exception as e: logging.fatal( 'Encountered exception while attempting to process message.', exc_info=e) if alert_handler.has_errors: logging.info('Alerts: %s', str(alert_handler._records)) if SEND_EMAIL_ALERTS: _send_email(project, *alert_handler.generate_email_content) else: logging.info('E-mail alerts disabled.') else: logging.info('No alerts found.') return True
def main(_): redirect_logs = flags.FLAGS.redirect_logs cache_paths = rconst.Paths( data_dir=flags.FLAGS.data_dir, cache_id=flags.FLAGS.cache_id) log_file_name = "data_gen_proc_{}.log".format(cache_paths.cache_id) log_file = os.path.join(cache_paths.data_dir, log_file_name) if log_file.startswith("gs://") and redirect_logs: fallback_log_file = os.path.join(tempfile.gettempdir(), log_file_name) print("Unable to log to {}. Falling back to {}" .format(log_file, fallback_log_file)) log_file = fallback_log_file # This server is generally run in a subprocess. if redirect_logs: print("Redirecting stdout and stderr to {}".format(log_file)) log_stream = open(log_file, "wt") # Note: not tf.gfile.Open(). stdout = log_stream stderr = log_stream try: if redirect_logs: absl_logging.get_absl_logger().addHandler( hdlr=logging.StreamHandler(stream=stdout)) sys.stdout = stdout sys.stderr = stderr print("Logs redirected.") try: log_msg("sys.argv: {}".format(" ".join(sys.argv))) if flags.FLAGS.seed is not None: np.random.seed(flags.FLAGS.seed) _generation_loop( num_workers=flags.FLAGS.num_workers, cache_paths=cache_paths, num_readers=flags.FLAGS.num_readers, num_neg=flags.FLAGS.num_neg, num_train_positives=flags.FLAGS.num_train_positives, num_items=flags.FLAGS.num_items, spillover=flags.FLAGS.spillover, epochs_per_cycle=flags.FLAGS.epochs_per_cycle, train_batch_size=flags.FLAGS.train_batch_size, eval_batch_size=flags.FLAGS.eval_batch_size, ) except KeyboardInterrupt: log_msg("KeyboardInterrupt registered.") except: traceback.print_exc() raise finally: log_msg("Shutting down generation subprocess.") sys.stdout.flush() sys.stderr.flush() if redirect_logs: log_stream.close()
def main(_): redirect_logs = flags.FLAGS.redirect_logs cache_paths = rconst.Paths(data_dir=flags.FLAGS.data_dir, cache_id=flags.FLAGS.cache_id) log_file_name = "data_gen_proc_{}.log".format(cache_paths.cache_id) log_file = os.path.join(cache_paths.data_dir, log_file_name) if log_file.startswith("gs://") and redirect_logs: fallback_log_file = os.path.join(tempfile.gettempdir(), log_file_name) print("Unable to log to {}. Falling back to {}".format( log_file, fallback_log_file)) log_file = fallback_log_file # This server is generally run in a subprocess. if redirect_logs: print("Redirecting stdout and stderr to {}".format(log_file)) log_stream = open(log_file, "wt") # Note: not tf.gfile.Open(). stdout = log_stream stderr = log_stream try: if redirect_logs: absl_logging.get_absl_logger().addHandler( hdlr=logging.StreamHandler(stream=stdout)) sys.stdout = stdout sys.stderr = stderr print("Logs redirected.") try: log_msg("sys.argv: {}".format(" ".join(sys.argv))) if flags.FLAGS.seed is not None: np.random.seed(flags.FLAGS.seed) _generation_loop( num_workers=flags.FLAGS.num_workers, cache_paths=cache_paths, num_readers=flags.FLAGS.num_readers, num_neg=flags.FLAGS.num_neg, num_train_positives=flags.FLAGS.num_train_positives, num_items=flags.FLAGS.num_items, spillover=flags.FLAGS.spillover, epochs_per_cycle=flags.FLAGS.epochs_per_cycle, train_batch_size=flags.FLAGS.train_batch_size, eval_batch_size=flags.FLAGS.eval_batch_size, ) except KeyboardInterrupt: log_msg("KeyboardInterrupt registered.") except: traceback.print_exc() raise finally: log_msg("Shutting down generation subprocess.") sys.stdout.flush() sys.stderr.flush() if redirect_logs: log_stream.close()
def __init__(self, model, flags, metrics=metrics.Metrics()): self.model_wrapper = ModelWrapper(model) self.model_checkpoint_path = flags.save_dir self.lr = flags.learning_rate self.hidden_dim = flags.hidden_dim self.write_summary = flags.write_summary self.batch_size = flags.batch_size self.metrics = metrics logging.get_absl_logger().addHandler(logging_base.StreamHandler()) return
def __init__(self, model, flags): self.model_wrapper = ModelWrapper(model) self.model_checkpoint_path = os.path.join(flags.save_dir, flags.data_set) self.lr = flags.learning_rate self.write_summary = flags.write_summary self.rnn_dim = flags.rnn_dim self.dropout = flags.dropout_rate self.batch_size = flags.batch_size self.metric_class = HawkesMetrics() logging.get_absl_logger().addHandler(logging_base.StreamHandler())
def __init__(self, model, flags): self.model_wrapper = ModelWrapper(model) self.model_checkpoint_path = flags.save_dir self.lr = flags.learning_rate self.write_summary = flags.write_summary self.rnn_dim = flags.rnn_dim self.dropout = flags.dropout self.batch_size = flags.batch_size logging.get_absl_logger().addHandler(logging_base.StreamHandler()) return
def setup_logger(print_logs: bool, save_logs: bool, save_path: Path, run_id: str): native_logging.root.removeHandler(logging._absl_handler) logging._warn_preinit_stderr = False formatter = native_logging.Formatter(fmt='%(asctime)s %(message)s', datefmt='%Y-%d-%m %H:%M:%S') handlers = [] if save_logs: write_mode = 'a' if save_path.exists() else 'w' save_path.mkdir(parents=True, exist_ok=True) log_file = save_path / f"{run_id}.log" stream = tf.io.gfile.GFile(str(log_file), write_mode) log_handler = native_logging.StreamHandler(stream) print('Saving logs in {}'.format(save_path)) handlers.append(log_handler) if print_logs or not save_logs: log_handler = native_logging.StreamHandler(sys.stdout) handlers.append(log_handler) logger = logging.get_absl_logger() logger.propagate = False for log_handler in handlers: log_handler.setFormatter(formatter) log_handler.setLevel(logging.INFO) logger.addHandler(log_handler) return logger
def test_logger_and_handler(self): absl_logger = std_logging.getLogger('absl') self.assertIs(absl_logger, logging.get_absl_logger()) self.assertTrue(isinstance(absl_logger, logging.ABSLLogger)) self.assertTrue( isinstance(logging.get_absl_handler().python_handler.formatter, logging.PythonFormatter))
def setUp(self): self._logger = logging.get_absl_logger() self._handler = alerts.AlertHandler( project_id='my-project-id', benchmark_id='benchmark-id', debug_info=None, ) self._logger.addHandler(self._handler)
def _test_register_frame_to_skip(): """Test skipping frames for line number reporting.""" def _getline(): def _getline_inner(): return logging.get_absl_logger().findCaller()[1] return _getline_inner() # Check register_frame_to_skip function to see if log frame skipping works. line1 = _getline() line2 = _getline() logging.get_absl_logger().register_frame_to_skip(__file__, '_getline') line3 = _getline() # Both should be line number of the _getline_inner() call. assert (line1 == line2), (line1, line2) # line3 should be a line number in this function. assert (line2 != line3), (line2, line3)
def test_find_log_dir_with_nothing(self): with mock.patch.object(os.path, 'exists'), \ mock.patch.object(os.path, 'isdir'), \ mock.patch.object(logging.get_absl_logger(), 'fatal') as mock_fatal: os.path.exists.return_value = False os.path.isdir.return_value = False log_dir = logging.find_log_dir() mock_fatal.assert_called() self.assertEqual(None, log_dir)
def test_call_only_default_args(self): expected_arg_two = 3 @phase_descriptor.PhaseOptions() def phase(arg_one=1, arg_two=2): self.assertEqual(arg_one, 1) # We are changing the arg with the with_args statement when called. self.assertEqual(arg_two, expected_arg_two) self._test_state.running_phase_state = ( test_state.PhaseState.from_descriptor(phase, self._test_state, logging.get_absl_logger())) phase.with_args(arg_two=expected_arg_two)(self._test_state)
def __init__(self, tag): """Create a context object for recording time. Args: tag (str): the summary tag for the the time. """ self._tag = tag caller = logging.get_absl_logger().findCaller() # token is a string of filename:lineno:tag token = caller[0] + ':' + str(caller[1]) + ':' + tag if token not in _contexts: _contexts[token] = {'time': 0., 'n': 0} self._counter = _contexts[token]
def warning_once(msg, *args): """Generate warning message once. Note that the current implementation resembles that of the ``log_every_n()``` function in ``logging`` but reduces the calling stack by one to ensure the multiple warning once messages generated at difference places can be displayed correctly. Args: msg: str, the message to be logged. *args: The args to be substitued into the msg. """ caller = logging.get_absl_logger().findCaller() count = logging._get_next_log_count_per_token(caller) logging.log_if(logging.WARNING, msg, not (count % (1 << 62)), *args)
def test_call_only_default_args_and_plug(self): expected_arg_one = 5 self._test_state.plug_manager.initialize_plugs([ExtraPlug]) @plugs.plug(custom_plug=ExtraPlug) def phase(custom_plug, arg_one=1, arg_two=2): self.assertIsInstance(custom_plug, ExtraPlug) # We are changing the arg with the with_args statement when called. self.assertEqual(arg_one, expected_arg_one) self.assertEqual(arg_two, 2) self._test_state.running_phase_state = ( test_state.PhaseState.from_descriptor(phase, self._test_state, logging.get_absl_logger())) phase.with_args(arg_one=expected_arg_one)(self._test_state)
def set_logging(is_debug, config): absl_logger = logging.get_absl_logger() # create formatter and add it to the handlers formatter = _logging.Formatter("[ %(asctime)-15s %(levelname)s %(filename)15s:%(lineno)-4d " \ " %(process)-5d ] %(message)s") log_dir = config["solver"]["saver"]["model_path"] if not os.path.exists(log_dir): os.makedirs(log_dir) logging.get_absl_handler().use_absl_log_file(program_name='delta', log_dir=log_dir) fh = _logging.StreamHandler() fh.setLevel(_logging.NOTSET) fh.setFormatter(formatter) absl_logger.addHandler(fh) if is_debug: logging.set_verbosity(_logging.DEBUG) else: logging.set_verbosity(_logging.NOTSET) logging.info("Also save log file to directory: {}".format(log_dir))
def train_step(self, data): X, y = data training = tf.constant(True) cis, cit, ans, ner, pos, qpre = X qit, qis = y target_vocab_size = self.vocab.get_vocab_size("target") unk_index = self.vocab.get_token_id(self.vocab._unk_token, "source") start_index = self.vocab.get_token_id(self.vocab._start_token, "source") end_index = self.vocab.get_token_id(self.vocab._end_token, "source") y_true = prep_y_true(cis, qit, qis, target_vocab_size, unk_index, start_index, end_index) with tf.GradientTape() as tape: output_dict = self(X, y, training) # shape: () loss = self.compiled_loss(y_true, output_dict['ypred']) if (self._train_counter % 10 == 0 and (logging.get_absl_logger().getEffectiveLevel() == logging.converter.STANDARD_DEBUG)): samples = 3 tf.py_function(self.debug, [ cis[:samples], qit[:samples], ans[:samples], qpre[:samples], output_dict["attentive_weights"][:samples], output_dict["selective_weights"][:samples] ], [], name="Debug") gradients = tape.gradient(loss, self.trainable_variables) self.optimizer.apply_gradients( zip(gradients, self.trainable_variables)) self.compiled_metrics.update_state(y_true, output_dict['ypred']) self._train_counter.assign_add(1) return {m.name: m.result() for m in self.metrics}
import torch from absl import logging as absl_logging from datasets import load_dataset, load_metric from torch.utils.data import DataLoader import pycuda.autoinit # noqa: F401 import pycuda.driver as cuda import tensorrt as trt import transformers from accelerate import Accelerator from transformers import AutoTokenizer, EvalPrediction, default_data_collator, set_seed from transformers.trainer_pt_utils import nested_concat, nested_truncate from utils_qa import postprocess_qa_predictions TRT_LOGGER = trt.Logger(trt.Logger.WARNING) absl_logger = absl_logging.get_absl_logger() absl_logger.setLevel(logging.WARNING) logger = logging.getLogger(__name__) parser = argparse.ArgumentParser() # Required parameters parser.add_argument( "--onnx_model_path", default=None, type=str, required=True, help="Path to ONNX model: ", )
def init_mllogger(): global mllogger mllogger = mllog.MLLogger(logging.get_absl_logger(), jax.host_id())
def main(_): # get logger save_path = FLAGS.save_dir if FLAGS.save_logs: if not tf.gfile.Exists(os.path.join(save_path, 'train.log')): tf.gfile.MakeDirs(save_path) write_mode = 'w' else: write_mode = 'a' stream = tf.gfile.Open(os.path.join(save_path, 'train.log'), write_mode) log_handler = native_logging.StreamHandler(stream) print('Saving logs in {}'.format(save_path)) else: log_handler = native_logging.StreamHandler(sys.stdout) formatter = native_logging.Formatter( '%(asctime)s %(levelname)-8s %(message)s') log_handler.setFormatter(formatter) log_handler.setLevel(logging.INFO) logger = logging.get_absl_logger() logger.addHandler(log_handler) # set up tf.summary train_log_dir = save_path + '/train' valid_log_dir = save_path + '/valid' train_summary_writer = tf.summary.create_file_writer(train_log_dir) valid_summary_writer = tf.summary.create_file_writer(valid_log_dir) # load data dataset_path = os.path.join(FLAGS.data_dir, FLAGS.dataset) dataset = DatasetClass(dataset_path, FLAGS.debug) sizes = dataset.get_shape() train_examples_reversed = dataset.get_examples('train') valid_examples = dataset.get_examples('valid') test_examples = dataset.get_examples('test') filters = dataset.get_filters() logging.info('\t Dataset shape: %s', (str(sizes))) # save config config_path = os.path.join(save_path, 'config.json') if FLAGS.save_logs and not tf.gfile.Exists(config_path): with tf.gfile.Open(config_path, 'w') as fjson: json.dump(train_utils.get_config_dict(CONFIG), fjson) # create and build model tf.keras.backend.set_floatx(FLAGS.dtype) model = getattr(models, FLAGS.model)(sizes, FLAGS) model.build(input_shape=(1, 3 + sum(FLAGS.node_batch_per_level))) trainable_params = train_utils.count_params(model) trainer = CFTrainer(sizes, FLAGS) logging.info('\t Total number of trainable parameters %s', (trainable_params)) # restore or create checkpoint if FLAGS.save_model: ckpt = tf.train.Checkpoint(step=tf.Variable(0), optimizer=trainer.optimizer, net=model) manager = tf.train.CheckpointManager(ckpt, save_path, max_to_keep=1) if manager.latest_checkpoint: ckpt.restore(manager.latest_checkpoint) logging.info('\t Restored from %s', (manager.latest_checkpoint)) else: logging.info('\t Initializing from scratch.') else: logging.info('\t Initializing from scratch.') # train model logging.info('\t Start training') early_stopping_counter = 0 best_mr = None best_epoch = None best_weights = None if FLAGS.save_model: epoch = ckpt.step else: epoch = 0 if int(epoch) < FLAGS.max_epochs: while int(epoch) < FLAGS.max_epochs: if FLAGS.save_model: epoch.assign_add(1) else: epoch += 1 # Train step start = time.perf_counter() train_batch = train_examples_reversed.batch(FLAGS.batch_size) train_loss = trainer.train_step(model, train_batch).numpy() end = time.perf_counter() execution_time = (end - start) logging.info('\t Epoch %i | train loss: %.4f | total time: %.4f', int(epoch), train_loss, execution_time) with train_summary_writer.as_default(): tf.summary.scalar('loss', train_loss, step=epoch) if FLAGS.save_model and int(epoch) % FLAGS.checkpoint == 0: save_path = manager.save() logging.info('\t Saved checkpoint for epoch %i: %s', int(epoch), save_path) if int(epoch) % FLAGS.valid == 0: # compute valid loss valid_batch = valid_examples.batch(FLAGS.batch_size) valid_loss = trainer.valid_step(model, valid_batch).numpy() logging.info('\t Epoch %i | average valid loss: %.4f', int(epoch), valid_loss) with valid_summary_writer.as_default(): tf.summary.scalar('loss', valid_loss, step=epoch) # compute validation metrics valid = train_utils.ranks_to_metrics_dict( model.eval(valid_examples, filters, batch_size=500)) logging.info( train_utils.format_partial_metrics(valid, split='valid')) with valid_summary_writer.as_default(): tf.summary.scalar('mrs', valid['MR'], step=epoch) tf.summary.scalar('mrrs', valid['MRR'], step=epoch) tf.summary.scalar('hits@[1]', valid['hits@[1,3,10]'][1], step=epoch) tf.summary.scalar('hits@[3]', valid['hits@[1,3,10]'][3], step=epoch) tf.summary.scalar('hits@[10]', valid['hits@[1,3,10]'][10], step=epoch) # tree eval logging.info('\t Building tree...') model.build_tree() logging.info('\t Tree build finished.') stats = model.tree.stats() logging.info( train_utils.format_tree_stats(FLAGS.nodes_per_level, *stats)) for k in FLAGS.top_k: logging.info('\t k is: %i', int(k)) valid_tree = train_utils.ranks_to_metrics_dict( model.top_k_tree_eval(valid_examples, filters, k=k)) logging.info( train_utils.format_partial_metrics(valid_tree, split='valid')) with valid_summary_writer.as_default(): tf.summary.scalar('mrs top {}'.format(k), valid_tree['MR'], step=epoch) tf.summary.scalar('mrrs top {}'.format(k), valid_tree['MRR'], step=epoch) tf.summary.scalar('hits@[1] top {}'.format(k), valid_tree['hits@[1,3,10]'][1], step=epoch) tf.summary.scalar('hits@[3] top {}'.format(k), valid_tree['hits@[1,3,10]'][3], step=epoch) tf.summary.scalar('hits@[10] top {}'.format(k), valid_tree['hits@[1,3,10]'][10], step=epoch) # early stopping valid_mr = valid['MR'] if not best_mr or valid_mr < best_mr: best_mr = valid_mr early_stopping_counter = 0 best_epoch = int(epoch) best_weights = copy.copy(model.get_weights()) else: early_stopping_counter += 1 if early_stopping_counter == FLAGS.patience: logging.info('\t Early stopping') break logging.info('\t Optimization finished') logging.info('\t Evaluating best model from epoch %s', best_epoch) model.set_weights(best_weights) if FLAGS.save_model: model.save_weights(os.path.join(save_path, 'best_model.ckpt')) # validation metrics valid = train_utils.ranks_to_metrics_dict( model.eval(valid_examples, filters, batch_size=500)) logging.info(train_utils.format_partial_metrics(valid, split='valid')) # test metrics test = train_utils.ranks_to_metrics_dict( model.eval(test_examples, filters, batch_size=500)) logging.info(train_utils.format_partial_metrics(test, split='test')) # tree eval logging.info('\t Building tree...') model.build_tree() logging.info('\t Tree build finished.') stats = model.tree.stats() logging.info( train_utils.format_tree_stats(FLAGS.nodes_per_level, *stats)) for k in FLAGS.top_k: logging.info('\t k is: %i', int(k)) # valid tree eval valid_tree = train_utils.ranks_to_metrics_dict( model.top_k_tree_eval(valid_examples, filters, k=k)) logging.info( train_utils.format_partial_metrics(valid_tree, split='valid')) # test tree eval test_tree = train_utils.ranks_to_metrics_dict( model.top_k_tree_eval(test_examples, filters, k=k)) logging.info( train_utils.format_partial_metrics(test_tree, split='tree')) else: logging.info('\t Training completed')
def test_get_absl_logger(self): self.assertIsInstance( logging.get_absl_logger(), logging.ABSLLogger)
# TODO: look into Python 3.10 installation issue with conda SUPPORTED_PYTHON_VERSION: t.List[str] = ["3.7", "3.8", "3.9"] NVIDIA_REPO_URL: str = ( "https://developer.download.nvidia.com/compute/cuda/repos/{}/x86_64") NVIDIA_ML_REPO_URL: str = ( "https://developer.download.nvidia.com/compute/machine-learning/repos/{}/x86_64" ) HTTP_RETRY_ATTEMPTS: int = 2 HTTP_RETRY_WAIT_SECS: int = 20 # setup some default logging.get_absl_handler().setFormatter(ColoredFormatter()) log = logging.get_absl_logger() if not os.path.exists("./.env"): log.warning(f"Make sure to create .env file at {os.getcwd()}") else: load_dotenv() if os.geteuid() == 0: # We only use docker_client when running as root. log.debug("Creating an instance of DockerClient") docker_client = DockerClient(base_url="unix://var/run/docker.sock", timeout=3600, tls=True) class LogsMixin(object):
def _getline_inner(): return logging.get_absl_logger().findCaller()[1]
def init_mllogger(): global mllogger mllogger = mllog.MLLogger(logging.get_absl_logger(), jax.host_id(), full=FLAGS.mlperf_logs)
import tensorflow as tf import tensorflow_datasets as tfds import numpy as np import os import matplotlib.pyplot as plt from tensorflow.keras.utils import to_categorical from absl import flags, logging from tqdm.auto import tqdm from PIL import Image from tf_rlib import datasets FLAGS = flags.FLAGS LOGGER = logging.get_absl_logger() class Cifar10Semi(datasets.Dataset): """ use tfrecords could speed-up 5%~15% comparing to numpy. because tfrecords format only use 50% space than numpy. """ def __init__(self, labels_persentage): super(Cifar10Semi, self).__init__() self.labels_persentage = labels_persentage self.dtype = np.float16 if FLAGS.amp else np.float32 save_path = '/ws_data/tmp/cifar10_labels{}'.format(labels_persentage) if not os.path.exists(save_path): os.makedirs(save_path) if FLAGS.amp: self.train_file = os.path.join(save_path, 'train16.tfrecords') self.valid_file = os.path.join(save_path, 'valid16.tfrecords') else: self.train_file = os.path.join(save_path, 'train32.tfrecords') self.valid_file = os.path.join(save_path, 'valid32.tfrecords')
def main(_): # get logger if FLAGS.save_logs: if not os.path.exists(os.path.join(FLAGS.save_dir, 'train.log')): os.makedirs(FLAGS.save_dir) write_mode = 'w' else: write_mode = 'a' stream = open(os.path.join(FLAGS.save_dir, 'train.log'), write_mode) log_handler = native_logging.StreamHandler(stream) print('Saving logs in {}'.format(FLAGS.save_dir)) else: log_handler = native_logging.StreamHandler(sys.stdout) formatter = native_logging.Formatter( '%(asctime)s %(levelname)-8s %(message)s') log_handler.setFormatter(formatter) log_handler.setLevel(logging.INFO) logger = logging.get_absl_logger() logger.addHandler(log_handler) # load data dataset_path = os.path.join(FLAGS.data_dir, FLAGS.dataset) dataset = DatasetFn(dataset_path, FLAGS.debug) sizes = dataset.get_shape() train_examples_reversed = dataset.get_examples('train') valid_examples = dataset.get_examples('valid') test_examples = dataset.get_examples('test') filters = dataset.get_filters() logging.info('\t Dataset shape: %s', (str(sizes))) # save config config_path = os.path.join(FLAGS.save_dir, 'config.json') if FLAGS.save_logs: with open(config_path, 'w') as fjson: json.dump(train_utils.get_config_dict(), fjson) # create and build model tf.keras.backend.set_floatx(FLAGS.dtype) model = getattr(models, FLAGS.model)(sizes, FLAGS) model.build(input_shape=(1, 3)) trainable_params = train_utils.count_params(model) trainer = KGTrainer(sizes, FLAGS) logging.info('\t Total number of trainable parameters %s', (trainable_params)) # restore or create checkpoint if FLAGS.save_model: ckpt = tf.train.Checkpoint(step=tf.Variable(0), optimizer=trainer.optimizer, net=model) manager = tf.train.CheckpointManager(ckpt, FLAGS.save_dir, max_to_keep=1) if manager.latest_checkpoint: ckpt.restore(manager.latest_checkpoint) logging.info('\t Restored from %s', (manager.latest_checkpoint)) else: logging.info('\t Initializing from scratch.') else: logging.info('\t Initializing from scratch.') # train model logging.info('\t Start training') early_stopping_counter = 0 best_mrr = None best_epoch = None best_weights = None if FLAGS.save_model: epoch = ckpt.step else: epoch = 0 if int(epoch) < FLAGS.max_epochs: while int(epoch) < FLAGS.max_epochs: if FLAGS.save_model: epoch.assign_add(1) else: epoch += 1 # Train step start = time.perf_counter() train_batch = train_examples_reversed.batch(FLAGS.batch_size) train_loss = trainer.train_step(model, train_batch).numpy() end = time.perf_counter() execution_time = (end - start) logging.info('\t Epoch %i | train loss: %.4f | total time: %.4f', int(epoch), train_loss, execution_time) if FLAGS.save_model and int(epoch) % FLAGS.checkpoint == 0: save_path = manager.save() logging.info('\t Saved checkpoint for epoch %i: %s', int(epoch), save_path) if int(epoch) % FLAGS.valid == 0: # compute valid loss valid_batch = valid_examples.batch(FLAGS.batch_size) valid_loss = trainer.valid_step(model, valid_batch).numpy() logging.info('\t Epoch %i | average valid loss: %.4f', int(epoch), valid_loss) # compute validation metrics valid = train_utils.avg_both( *model.eval(valid_examples, filters)) logging.info(train_utils.format_metrics(valid, split='valid')) # early stopping valid_mrr = valid['MRR'] if not best_mrr or valid_mrr > best_mrr: best_mrr = valid_mrr early_stopping_counter = 0 best_epoch = int(epoch) best_weights = copy.copy(model.get_weights()) else: early_stopping_counter += 1 if early_stopping_counter == FLAGS.patience: logging.info('\t Early stopping') break logging.info('\t Optimization finished') logging.info('\t Evaluating best model from epoch %s', best_epoch) model.set_weights(best_weights) if FLAGS.save_model: model.save_weights(os.path.join(FLAGS.save_dir, 'best_model.ckpt')) # validation metrics valid = train_utils.avg_both(*model.eval(valid_examples, filters)) logging.info(train_utils.format_metrics(valid, split='valid')) # test metrics test = train_utils.avg_both(*model.eval(test_examples, filters)) logging.info(train_utils.format_metrics(test, split='test')) else: logging.info('\t Training completed')