def create_tf_parameters(self, name='test_tf_param'): with tf.variable_scope(name): a = tf.get_variable(shape=[3, 4], dtype=tf.float32, name='var_1') b = tf.get_variable(shape=[3, 4], dtype=tf.bool, name='var_2') conf = DictConfig(required_key_dict=Foo.required_key_dict, config_dict=dict(var1=1, var2=0.01)) param = ParametersWithTensorflowVariable( tf_var_list=[a, b], rest_parameters=dict(var3='sss'), name=name, source_config=conf, require_snapshot=True, to_ph_parameter_dict=dict( var1=tf.placeholder(shape=(), dtype=tf.int32))) return param, locals()
def create_ph(self, name): with tf.variable_scope(name): a = tf.get_variable(shape=[3, 4], dtype=tf.float32, name='var_1') conf = DictConfig(required_key_dict=Foo.required_key_dict, config_dict=dict(var1=1, var2=0.01)) param = ParametersWithTensorflowVariable( tf_var_list=[a], rest_parameters=dict(var3='sss'), name=name, source_config=conf, require_snapshot=True, to_ph_parameter_dict=dict( var1=tf.placeholder(shape=(), dtype=tf.int32))) param.init() a = PlaceholderInput(parameters=param, inputs=None) return a, locals()
def create_dict_config(self): a = DictConfig(required_key_dict=Foo.required_key_dict, config_dict=dict(var1=1, var2=0.1), cls_name='Foo') return a, locals()