def configurable_module(module): if not issubclass(module, nn.Module): raise ValueError( "this decorator can only be used on flax.nn.Module class.") def wrapper(**kwargs): return module.partial(**kwargs) wrapper.__name__ = module.__name__ return gin.configurable(wrapper)
def register_module_with_gin(module, module_name=None): """ Register all the callables in a single module with gin. A useful way to add gin configurability to a codebase without explicilty using the @gin.configurable decorator. """ module_name = module.__name__ if module_name is None else module_name for attr in dir(module): if callable(getattr(module, attr)): setattr(module, attr, gin.configurable(getattr(module, attr), module=module_name))
import gin from tensorflow.keras import applications from tensorflow.compat.v1.keras.layers import BatchNormalization from thin.models import resnet EfficientNetB0 = gin.configurable(applications.EfficientNetB0, module='tf.keras.applications') EfficientNetB1 = gin.configurable(applications.EfficientNetB1, module='tf.keras.applications') EfficientNetB2 = gin.configurable(applications.EfficientNetB2, module='tf.keras.applications') EfficientNetB3 = gin.configurable(applications.EfficientNetB3, module='tf.keras.applications') EfficientNetB4 = gin.configurable(applications.EfficientNetB4, module='tf.keras.applications') EfficientNetB5 = gin.configurable(applications.EfficientNetB5, module='tf.keras.applications') EfficientNetB6 = gin.configurable(applications.EfficientNetB6, module='tf.keras.applications') EfficientNetB7 = gin.configurable(applications.EfficientNetB7, module='tf.keras.applications') ResNet18 = gin.configurable(resnet.ResNet18, module='thin.models') ResNet34 = gin.configurable(resnet.ResNet34, module='thin.models') ResNet50 = gin.configurable(resnet.ResNet50, module='thin.models') ResUNet18 = gin.configurable(resnet.ResUNet18, module='thin.models') ResUNet34 = gin.configurable(resnet.ResUNet34, module='thin.models') ResUNet50 = gin.configurable(resnet.ResUNet50, module='thin.models')
# See the License for the specific language governing permissions and # limitations under the License. # Lint as: python3 """Module with all the global configurable models for training.""" from ddsp.training.models.autoencoder import Autoencoder from ddsp.training.models.vae import VAE from ddsp.training.models.inverse_synthesis import InverseSynthesis from ddsp.training.models.midi_autoencoder import MidiAutoencoder from ddsp.training.models.midi_autoencoder import ZMidiAutoencoder from ddsp.training.models.model import Model import gin _configurable = lambda cls: gin.configurable(cls, module=__name__) Autoencoder = _configurable(Autoencoder) VAE = _configurable(VAE) InverseSynthesis = _configurable(InverseSynthesis) MidiAutoencoder = _configurable(MidiAutoencoder) ZMidiAutoencoder = _configurable(ZMidiAutoencoder) @gin.configurable def get_model(model=gin.REQUIRED): """Gin configurable function get a 'global' model for use in ddsp_run.py. Convenience for using the same model in train(), evaluate(), and sample(). Args:
# limitations under the License. # Lint as: python3 """Tests for ddsp.dags.py.""" from absl.testing import parameterized from ddsp import dags import gin import tensorflow as tf # Make dense layers configurable for this test. gin.external_configurable(tf.keras.layers.Dense, 'tf.keras.layers.Dense') # Make dag_layers configurable for this test. gin.enter_interactive_mode() gin.configurable(dags.DAGLayer) class DAGLayerTest(parameterized.TestCase, tf.test.TestCase): def setUp(self): """Create some dummy input data for the chain.""" super().setUp() # Create inputs. self.n_batch = 4 self.x_dims = 5 self.z_dims = 2 self.x = tf.ones([self.n_batch, self.x_dims]) self.inputs = {'test_data': self.x} self.gin_config_kwarg_modules = f""" import ddsp dags.run_dag.verbose = True