示例#1
0
import time
import logging
import numpy as np
from baseline.utils import export, optional_params, register, listify
import math


__all__ = []
exporter = export(__all__)

BASELINE_LR_SCHEDULERS = {}
logger = logging.getLogger('baseline')


@exporter
class LearningRateScheduler(object):

    def __init__(self, **kwargs):
        self.lr = kwargs.get('lr', kwargs.get('eta', 1.0))

    @staticmethod
    def _identity(x):
        return x

    def __str__(self):
        return "{}(eta={})".format(type(self).__name__, self.lr)


@exporter
class WarmupLearningRateScheduler(LearningRateScheduler):
    def __init__(self, warmup_steps=16000, **kwargs):
示例#2
0
import os
import shutil
import datetime
from tensorflow.python.framework.errors_impl import NotFoundError
import mead.exporters
from mead.exporters import register_exporter
from baseline.tf.embeddings import *
from baseline.tf.tfy import transition_mask
from baseline.utils import (export, read_json, ls_props, Offsets, write_json)
from collections import namedtuple

FIELD_NAME = 'text/tokens'
ASSET_FILE_NAME = 'model.assets'

__all__ = []
exporter = export(__all__)

SignatureOutput = namedtuple("SignatureOutput", ("classes", "scores"))


@exporter
class TensorFlowExporter(mead.exporters.Exporter):
    def __init__(self, task):
        super(TensorFlowExporter, self).__init__(task)

    def _run(self, sess, basename):
        pass

    def _restore_checkpoint(self, sess, basename):
        saver = tf.train.Saver()
        sess.run(tf.tables_initializer())