def test_count_parameters_empty(self): module = snt.Module() snt.allow_empty_variables(module) # No variables. self.assertEqual(0, parameter_overview.count_parameters(module)) # Single variable. module.var = tf.Variable([0, 1]) self.assertEqual(2, parameter_overview.count_parameters(module))
def test_get_parameter_overview_empty(self): module = snt.Module() snt.allow_empty_variables(module) # No variables. self.assertEqual(EMPTY_PARAMETER_OVERVIEW, parameter_overview.get_parameter_overview(module)) module.conv = snt.Conv2D(output_channels=2, kernel_shape=3) # Variables not yet created (happens in the first forward pass). self.assertEqual(EMPTY_PARAMETER_OVERVIEW, parameter_overview.get_parameter_overview(module))
def to_sonnet_module(transformation: types.TensorTransformation) -> snt.Module: """Convert a tensor transformation to a Sonnet Module.""" if isinstance(transformation, snt.Module): return transformation # Use snt.Sequential to convert any tensor transformation to a snt.Module. module = snt.Sequential([transformation]) # Wrap the module to allow it to return an empty variable tuple. return snt.allow_empty_variables(module)
def build_agent_model(n_actions, agent_size, batch_size, kpt_encoder_type, kpt_cnn_channels): # agent model_dict = { "agent_net": RecurrentQNet(agent_size, n_actions, batch_size) } # keypoint encoder if kpt_encoder_type == "cnn": model_dict["kpt_encoder"] = KptConvEncoder(kpt_cnn_channels, agent_size) elif kpt_encoder_type == "gnn": model_dict["kpt_encoder"] = snt.allow_empty_variables(KptGnnEncoder()) model_dict["node_enc"] = NodeEncoder(output_dim=agent_size) model_dict["pos_net"] = PositionalEncoder(kpt_cnn_channels) return model_dict
def to_sonnet_module(transformation: types.TensorValuedCallable) -> snt.Module: """Convert a tensor transformation to a Sonnet Module. Args: transformation: A Callable that takes one or more (nested) Tensors, and returns one or more (nested) Tensors. Returns: A Sonnet Module that wraps the transformation. """ if isinstance(transformation, snt.Module): return transformation module = TransformationWrapper(transformation) # Wrap the module to allow it to return an empty variable tuple. return snt.allow_empty_variables(module)
_IMAGE_SIZE = 224 _BATCH_SIZE = 1 _NUM_CLASSES = 100 # Learning rate for the sonnet optimizer _LEARNING_RATE = 0.01 _NUM_EPOCHS = 200 _MOMENTUM = 0.9 train_generator = MSASLDataLoader(ANNOTATION_FILE_PATH_TRAIN, FRAMES_DIR_PATH, 1, height=224, width=224, color_mode='rgb', shuffle=True, frames_threshold=28) data_shape = train_generator.get_data_dim() print('DATA SHAPE (frames per sample, height, width, color_channels)' + str(data_shape)) print('DATA LEN (number of batches)' + str(train_generator.batch_size)) with tf.name_scope('RGB'): i3d_model = i3d.InceptionI3d(num_classes=_NUM_CLASSES, spatial_squeeze=True, final_endpoint='Logits') snt.allow_empty_variables(i3d_model) X_train, y_train = train_generator[0] print(train_generator[1]) print("Y_Train: ") print(y_train.shape) print(y_train) X_train = tf.cast(X_train, tf.float32) print('X SHAPE ' + str(X_train.shape)) print("=====================VARIABLES===================") for variable in tf.compat.v1.global_variables(): print(variable.name) if variable.name.split('/')[0] == 'RGB': print(variable.name) print(tf.compat.v1.get_default_graph().get_name_scope()) print("-------------------------------------------------")