class Module(object): """Base Module to adapter tf Module.""" def __init__(self): self.name = '' self.data_format = General.data_format self._modules = Config() self._parameters = OrderedDict() self._weights_buffer = OrderedDict() self._init_configs() def _init_configs(self): self._training = True self._trainable = True self.weight_file = None self.from_weight_type = None self._is_load_pretrained = False self._is_adaptive_weight = False self.exclude_weight_prefix = None def add_module(self, name, model): """Add models into self._models.""" setattr(self, str(name), model) def build(self): """Build model or params.""" pass def named_modules(self): """Return names spaces.""" self._apply_names() _modules = [] for module in self.children(): _modules.append((module.name, module)) _modules.extend(module.named_modules()) return _modules def named_children(self): """Return names children.""" return [(name, module) for name, module in self._modules.items()] def children(self): """Get child models of current Module.""" for model in self._modules.values(): yield model def load_checkpoint(self, weight_file): """Load weight state dict from last checkpoint file.""" if not weight_file: return logging.info("Load checkpoint form file ({}).".format(weight_file)) # model_file = tf.train.latest_checkpoint(weight_file) reader = tf.train.NewCheckpointReader(weight_file) variables = reader.get_variable_to_shape_map() states = {v: reader.get_tensor(v) for v in variables} self.load_checkpoint_from_numpy(states) def load_checkpoint_from_numpy(self, states): """Load checkpoint from numpy.""" states = self._exclude_checkpoint_by_prefix(states) for name, module in self.named_modules(): child_state = [(k, v) for k, v in states.items() if k.startswith(module.name + '/')] for k, v in child_state: module.set_weights(k, v) def _exclude_checkpoint_by_prefix(self, states): if self.exclude_weight_prefix: if not isinstance(self.exclude_weight_prefix, list): self.exclude_weight_prefix = [self.exclude_weight_prefix] for prefix in self.exclude_weight_prefix: states = {k: v for k, v in states.items() if not k.startswith(prefix)} return states def set_weights(self, name, value): """Set weights into weights buffer.""" self._weights_buffer[name] = value @property def training(self): """Get training flag.""" return self._training @training.setter def training(self, value): """Set training flag.""" self._training = value for module in self.children(): module.training = value @property def is_adaptive_weight(self): """Get _is_adaptive_weight flag.""" return self._is_adaptive_weight @is_adaptive_weight.setter def is_adaptive_weight(self, value): """Set _is_adaptive_weight flag.""" self._is_adaptive_weight = value for module in self.children(): module.is_adaptive_weight = value def freeze(self): """Set training flag.""" self._trainable = False for module in self.children(): module.freeze() def __setattr__(self, key, value): """Set name to modules.""" super().__setattr__(key, value) if isinstance(value, Module): self._modules[key] = value def set_parameters(self, name, value): """Set Parameters.""" self._parameters[name] = value setattr(self, name, value) return self.name def get_weights(self, name=None): """Get weights by name.""" if self._weights_buffer: return self._weights_buffer return tf.get_default_graph().get_tensor_by_name('{}:0'.format(name)) def get_all_weights(self): """Get all weights.""" all_weights = OrderedDict() for child in self.children(): all_weights.update(child._weights_buffer) if isinstance(child, Module): all_weights.update(child.get_all_weights()) return all_weights def get_weight_ops(self, name): """Get weight ops.""" all_weight = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES) weight_ops = [t for t in all_weight if not t.name.startswith(name)] return weight_ops def call(self, inputs, *args, **kwarg): """Call inputs.""" output = inputs for model in self.children(): output = model(output) return output def adaptive_weight(self, inputs): """Adaptive weight.""" return {} def _apply_names(self, parent_name=''): """Apply names spaces.""" for scope_name, module in self._modules.items(): scope_name = '{}.{}'.format(parent_name, scope_name) if parent_name else scope_name module.name = module.name or scope_name + '/' + module.__class__.__name__ module._apply_names(scope_name) def _apply_parameters(self): """Apply names spaces.""" for name, params in self._parameters.items(): setattr(self, name, tf.Variable(params, name='{}.{}'.format(self.name, name) if self.name else name)) def __call__(self, inputs, *args, **kwargs): """Call call function.""" self.build() self._apply_parameters() self._apply_names() for module in self.children(): module._is_load_pretrained = True out = self.call(inputs, *args, **kwargs) self._apply_weights(inputs) return out def _apply_weights(self, inputs): if not self._weights_buffer: return variables = tf.get_collection(tf.GraphKeys.VARIABLES) if self.is_adaptive_weight: self._weights_buffer.update(self.adaptive_weight(inputs)) values = [(var, self._weights_buffer.get(var.name.replace(':0', ''))) for var in variables if var.name.replace(':0', '') in self._weights_buffer] for v, weight in values: v._initializer_op = state_ops.assign(v, weight) self._weights_buffer.clear() def modules(self): """Get the current modules.""" if self._modules.values(): return self._modules.values() else: return [self]
class Module(object): """Base Module to adapter tf Module.""" data_format = 'channels_first' def __init__(self): self.parent_scope_name = '' self._scope_name = '' self._modules = Config() self._training = True self.enable_scope_name = enable_scope_name self.data_format = General.data_format self.pretrained_model_file = None self._is_load_pretrained = False self.load_pretrained_type = None self._trainable = True self.pretrained_prefix = None def add_module(self, name, model): """Add models into self._models.""" setattr(self, str(name), model) def named_modules(self): """Return names spaces.""" _names_modules = [] for model in self.children(): if isinstance(model, Module): _names_modules.append(((model._scope_name, model))) child_modules = model.named_modules() _names_modules.extend(child_modules) return _names_modules def named_children(self): """Return names children.""" return [(name, module) for name, module in self._modules.items()] def children(self): """Get child models of current Module.""" for model in self._modules.values(): if isinstance(model, Module): model._scope_name = "{}.{}".format( self._scope_name, model.parent_scope_name) if self._scope_name else model.parent_scope_name yield model def pretrained(self, pretrained_model_file=None): """Load Pretrained weights.""" if self._is_load_pretrained: return [] assign_vars = [] checkpoint_path = pretrained_model_file or self.pretrained_model_file if not checkpoint_path: return pretrained_prefix = self.pretrained_prefix or {self._scope_name: self._scope_name} if self.load_pretrained_type == 'pytorch': assign_vars = assign_pytorch_weights(checkpoint_path, pretrained_prefix) else: tf.train.init_from_checkpoint(checkpoint_path, pretrained_prefix) self._is_load_pretrained = True return assign_vars @property def training(self): """Get training flag.""" return self._training @training.setter def training(self, value): """Set training flag.""" self._training = value for module in self.children(): module.training = value @property def freeze(self): """Get training flag.""" return self.freeze @freeze.setter def freeze(self, value): """Set training flag.""" self._trainable = not value for module in self.children(): module.freeze = value def __setattr__(self, key, value): """Set name to modules.""" self.__dict__[key] = value # super().__setattr__(key, value) if isinstance(value, Module): if self.enable_scope_name: value.parent_scope_name = key self._modules[key] = value def __getattribute__(self, name): """Get modules by name.""" value = object.__getattribute__(self, name) if isinstance(value, Module) and self.enable_scope_name: value._scope_name = "{}.{}".format( self._scope_name, value.parent_scope_name) if self._scope_name else value.parent_scope_name return value def set_parameters(self, name, value): """Set Parameters.""" with tf.variable_scope('', reuse=tf.AUTO_REUSE): setattr(self, name, tf.get_variable(name, initializer=value)) def get_weights(self, name): """Get weights by name.""" return tf.get_default_graph().get_tensor_by_name('{}:0'.format(name)) def get_weight_ops(self, name): """Get weight ops.""" all_weight = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES) weight_ops = [t for t in all_weight if not t.name.startswith(name)] return weight_ops def call(self, inputs, *args, **kwarg): """Call inputs.""" output = inputs for model in self.children(): output = model(output) return output def __call__(self, inputs, *args, **kwargs): """Call call function.""" return self.call(inputs, *args, **kwargs) def modules(self): """Get the current modules.""" if self._modules.values(): return self._modules.values() else: return [self]