import time import logging import numpy as np from baseline.utils import export, optional_params, register, listify import math __all__ = [] exporter = export(__all__) BASELINE_LR_SCHEDULERS = {} logger = logging.getLogger('baseline') @exporter class LearningRateScheduler(object): def __init__(self, **kwargs): self.lr = kwargs.get('lr', kwargs.get('eta', 1.0)) @staticmethod def _identity(x): return x def __str__(self): return "{}(eta={})".format(type(self).__name__, self.lr) @exporter class WarmupLearningRateScheduler(LearningRateScheduler): def __init__(self, warmup_steps=16000, **kwargs):
import os import shutil import datetime from tensorflow.python.framework.errors_impl import NotFoundError import mead.exporters from mead.exporters import register_exporter from baseline.tf.embeddings import * from baseline.tf.tfy import transition_mask from baseline.utils import (export, read_json, ls_props, Offsets, write_json) from collections import namedtuple FIELD_NAME = 'text/tokens' ASSET_FILE_NAME = 'model.assets' __all__ = [] exporter = export(__all__) SignatureOutput = namedtuple("SignatureOutput", ("classes", "scores")) @exporter class TensorFlowExporter(mead.exporters.Exporter): def __init__(self, task): super(TensorFlowExporter, self).__init__(task) def _run(self, sess, basename): pass def _restore_checkpoint(self, sess, basename): saver = tf.train.Saver() sess.run(tf.tables_initializer())