def copy_variable_ref_to_graph(input_graph, output_graph, var_ref, init_value, scope=''): if scope != '': new_name = (scope + '/' + var_ref.name[:var_ref.name.index(':')]) else: new_name = var_ref.name[:var_ref.name.index(':')] collections = [] for name, collection in input_graph._collections.items(): if var_ref in collection: if (name == ops.GraphKeys.GLOBAL_VARIABLES or name == ops.GraphKeys.TRAINABLE_VARIABLES or scope == ''): collections.append(name) else: collections.append(scope + '/' + name) trainable = (var_ref in input_graph.get_collection( ops.GraphKeys.TRAINABLE_VARIABLES)) with output_graph.as_default(): new_var = Variable(init_value, trainable, name=new_name, collections=collections, validate_shape=False) new_var.set_shape(init_value.shape) return new_var
def add_variable_to_graph(output_graph, var_name, init_value, trainable=True, collections=[], scope=''): if scope != '': new_name = scope + '/' + var_name else: new_name = var_name with output_graph.as_default(): new_var = Variable( init_value, trainable, name=new_name, collections=collections, validate_shape=False) new_var.set_shape(init_value.shape) return new_var