Beispiel #1
0
def configurable_module(module):
    if not issubclass(module, nn.Module):
        raise ValueError(
            "this decorator can only be used on flax.nn.Module class.")

    def wrapper(**kwargs):
        return module.partial(**kwargs)

    wrapper.__name__ = module.__name__

    return gin.configurable(wrapper)
Beispiel #2
0
def register_module_with_gin(module, module_name=None):
    """
    Register all the callables in a single module with gin.

    A useful way to add gin configurability to a codebase without explicilty using the @gin.configurable decorator.
    """
    module_name = module.__name__ if module_name is None else module_name

    for attr in dir(module):
        if callable(getattr(module, attr)):
            setattr(module, attr, gin.configurable(getattr(module, attr), module=module_name))
Beispiel #3
0
import gin

from tensorflow.keras import applications
from tensorflow.compat.v1.keras.layers import BatchNormalization

from thin.models import resnet

EfficientNetB0 = gin.configurable(applications.EfficientNetB0,
                                  module='tf.keras.applications')
EfficientNetB1 = gin.configurable(applications.EfficientNetB1,
                                  module='tf.keras.applications')
EfficientNetB2 = gin.configurable(applications.EfficientNetB2,
                                  module='tf.keras.applications')
EfficientNetB3 = gin.configurable(applications.EfficientNetB3,
                                  module='tf.keras.applications')
EfficientNetB4 = gin.configurable(applications.EfficientNetB4,
                                  module='tf.keras.applications')
EfficientNetB5 = gin.configurable(applications.EfficientNetB5,
                                  module='tf.keras.applications')
EfficientNetB6 = gin.configurable(applications.EfficientNetB6,
                                  module='tf.keras.applications')
EfficientNetB7 = gin.configurable(applications.EfficientNetB7,
                                  module='tf.keras.applications')

ResNet18 = gin.configurable(resnet.ResNet18, module='thin.models')
ResNet34 = gin.configurable(resnet.ResNet34, module='thin.models')
ResNet50 = gin.configurable(resnet.ResNet50, module='thin.models')
ResUNet18 = gin.configurable(resnet.ResUNet18, module='thin.models')
ResUNet34 = gin.configurable(resnet.ResUNet34, module='thin.models')
ResUNet50 = gin.configurable(resnet.ResUNet50, module='thin.models')
Beispiel #4
0
# See the License for the specific language governing permissions and
# limitations under the License.

# Lint as: python3
"""Module with all the global configurable models for training."""

from ddsp.training.models.autoencoder import Autoencoder
from ddsp.training.models.vae import VAE
from ddsp.training.models.inverse_synthesis import InverseSynthesis
from ddsp.training.models.midi_autoencoder import MidiAutoencoder
from ddsp.training.models.midi_autoencoder import ZMidiAutoencoder
from ddsp.training.models.model import Model
import gin


_configurable = lambda cls: gin.configurable(cls, module=__name__)

Autoencoder = _configurable(Autoencoder)
VAE = _configurable(VAE)
InverseSynthesis = _configurable(InverseSynthesis)
MidiAutoencoder = _configurable(MidiAutoencoder)
ZMidiAutoencoder = _configurable(ZMidiAutoencoder)



@gin.configurable
def get_model(model=gin.REQUIRED):
  """Gin configurable function get a 'global' model for use in ddsp_run.py.

  Convenience for using the same model in train(), evaluate(), and sample().
  Args:
Beispiel #5
0
# limitations under the License.

# Lint as: python3
"""Tests for ddsp.dags.py."""

from absl.testing import parameterized
from ddsp import dags
import gin
import tensorflow as tf

# Make dense layers configurable for this test.
gin.external_configurable(tf.keras.layers.Dense, 'tf.keras.layers.Dense')

# Make dag_layers configurable for this test.
gin.enter_interactive_mode()
gin.configurable(dags.DAGLayer)


class DAGLayerTest(parameterized.TestCase, tf.test.TestCase):
    def setUp(self):
        """Create some dummy input data for the chain."""
        super().setUp()
        # Create inputs.
        self.n_batch = 4
        self.x_dims = 5
        self.z_dims = 2
        self.x = tf.ones([self.n_batch, self.x_dims])
        self.inputs = {'test_data': self.x}
        self.gin_config_kwarg_modules = f"""
    import ddsp
    dags.run_dag.verbose = True