Example #1
0
class Module(object):
    """Base Module to adapter tf Module."""

    def __init__(self):
        self.name = ''
        self.data_format = General.data_format
        self._modules = Config()
        self._parameters = OrderedDict()
        self._weights_buffer = OrderedDict()
        self._init_configs()

    def _init_configs(self):
        self._training = True
        self._trainable = True
        self.weight_file = None
        self.from_weight_type = None
        self._is_load_pretrained = False
        self._is_adaptive_weight = False
        self.exclude_weight_prefix = None

    def add_module(self, name, model):
        """Add models into self._models."""
        setattr(self, str(name), model)

    def build(self):
        """Build model or params."""
        pass

    def named_modules(self):
        """Return names spaces."""
        self._apply_names()
        _modules = []
        for module in self.children():
            _modules.append((module.name, module))
            _modules.extend(module.named_modules())
        return _modules

    def named_children(self):
        """Return names children."""
        return [(name, module) for name, module in self._modules.items()]

    def children(self):
        """Get child models of current Module."""
        for model in self._modules.values():
            yield model

    def load_checkpoint(self, weight_file):
        """Load weight state dict from last checkpoint file."""
        if not weight_file:
            return
        logging.info("Load checkpoint form file ({}).".format(weight_file))
        # model_file = tf.train.latest_checkpoint(weight_file)
        reader = tf.train.NewCheckpointReader(weight_file)
        variables = reader.get_variable_to_shape_map()
        states = {v: reader.get_tensor(v) for v in variables}
        self.load_checkpoint_from_numpy(states)

    def load_checkpoint_from_numpy(self, states):
        """Load checkpoint from numpy."""
        states = self._exclude_checkpoint_by_prefix(states)
        for name, module in self.named_modules():
            child_state = [(k, v) for k, v in states.items() if k.startswith(module.name + '/')]
            for k, v in child_state:
                module.set_weights(k, v)

    def _exclude_checkpoint_by_prefix(self, states):
        if self.exclude_weight_prefix:
            if not isinstance(self.exclude_weight_prefix, list):
                self.exclude_weight_prefix = [self.exclude_weight_prefix]
            for prefix in self.exclude_weight_prefix:
                states = {k: v for k, v in states.items() if not k.startswith(prefix)}
        return states

    def set_weights(self, name, value):
        """Set weights into weights buffer."""
        self._weights_buffer[name] = value

    @property
    def training(self):
        """Get training flag."""
        return self._training

    @training.setter
    def training(self, value):
        """Set training flag."""
        self._training = value
        for module in self.children():
            module.training = value

    @property
    def is_adaptive_weight(self):
        """Get _is_adaptive_weight flag."""
        return self._is_adaptive_weight

    @is_adaptive_weight.setter
    def is_adaptive_weight(self, value):
        """Set _is_adaptive_weight flag."""
        self._is_adaptive_weight = value
        for module in self.children():
            module.is_adaptive_weight = value

    def freeze(self):
        """Set training flag."""
        self._trainable = False
        for module in self.children():
            module.freeze()

    def __setattr__(self, key, value):
        """Set name to modules."""
        super().__setattr__(key, value)
        if isinstance(value, Module):
            self._modules[key] = value

    def set_parameters(self, name, value):
        """Set Parameters."""
        self._parameters[name] = value
        setattr(self, name, value)
        return self.name

    def get_weights(self, name=None):
        """Get weights by name."""
        if self._weights_buffer:
            return self._weights_buffer
        return tf.get_default_graph().get_tensor_by_name('{}:0'.format(name))

    def get_all_weights(self):
        """Get all weights."""
        all_weights = OrderedDict()
        for child in self.children():
            all_weights.update(child._weights_buffer)
            if isinstance(child, Module):
                all_weights.update(child.get_all_weights())
        return all_weights

    def get_weight_ops(self, name):
        """Get weight ops."""
        all_weight = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES)
        weight_ops = [t for t in all_weight if not t.name.startswith(name)]
        return weight_ops

    def call(self, inputs, *args, **kwarg):
        """Call inputs."""
        output = inputs
        for model in self.children():
            output = model(output)
        return output

    def adaptive_weight(self, inputs):
        """Adaptive weight."""
        return {}

    def _apply_names(self, parent_name=''):
        """Apply names spaces."""
        for scope_name, module in self._modules.items():
            scope_name = '{}.{}'.format(parent_name, scope_name) if parent_name else scope_name
            module.name = module.name or scope_name + '/' + module.__class__.__name__
            module._apply_names(scope_name)

    def _apply_parameters(self):
        """Apply names spaces."""
        for name, params in self._parameters.items():
            setattr(self, name, tf.Variable(params, name='{}.{}'.format(self.name, name) if self.name else name))

    def __call__(self, inputs, *args, **kwargs):
        """Call call function."""
        self.build()
        self._apply_parameters()
        self._apply_names()
        for module in self.children():
            module._is_load_pretrained = True
        out = self.call(inputs, *args, **kwargs)
        self._apply_weights(inputs)
        return out

    def _apply_weights(self, inputs):
        if not self._weights_buffer:
            return
        variables = tf.get_collection(tf.GraphKeys.VARIABLES)
        if self.is_adaptive_weight:
            self._weights_buffer.update(self.adaptive_weight(inputs))
        values = [(var, self._weights_buffer.get(var.name.replace(':0', ''))) for var in variables if
                  var.name.replace(':0', '') in self._weights_buffer]
        for v, weight in values:
            v._initializer_op = state_ops.assign(v, weight)
        self._weights_buffer.clear()

    def modules(self):
        """Get the current modules."""
        if self._modules.values():
            return self._modules.values()
        else:
            return [self]
Example #2
0
class Module(object):
    """Base Module to adapter tf Module."""

    data_format = 'channels_first'

    def __init__(self):
        self.parent_scope_name = ''
        self._scope_name = ''
        self._modules = Config()
        self._training = True
        self.enable_scope_name = enable_scope_name
        self.data_format = General.data_format
        self.pretrained_model_file = None
        self._is_load_pretrained = False
        self.load_pretrained_type = None
        self._trainable = True
        self.pretrained_prefix = None

    def add_module(self, name, model):
        """Add models into self._models."""
        setattr(self, str(name), model)

    def named_modules(self):
        """Return names spaces."""
        _names_modules = []
        for model in self.children():
            if isinstance(model, Module):
                _names_modules.append(((model._scope_name, model)))
                child_modules = model.named_modules()
                _names_modules.extend(child_modules)
        return _names_modules

    def named_children(self):
        """Return names children."""
        return [(name, module) for name, module in self._modules.items()]

    def children(self):
        """Get child models of current Module."""
        for model in self._modules.values():
            if isinstance(model, Module):
                model._scope_name = "{}.{}".format(
                    self._scope_name, model.parent_scope_name) if self._scope_name else model.parent_scope_name
            yield model

    def pretrained(self, pretrained_model_file=None):
        """Load Pretrained weights."""
        if self._is_load_pretrained:
            return []
        assign_vars = []
        checkpoint_path = pretrained_model_file or self.pretrained_model_file
        if not checkpoint_path:
            return
        pretrained_prefix = self.pretrained_prefix or {self._scope_name: self._scope_name}
        if self.load_pretrained_type == 'pytorch':
            assign_vars = assign_pytorch_weights(checkpoint_path, pretrained_prefix)
        else:
            tf.train.init_from_checkpoint(checkpoint_path, pretrained_prefix)
        self._is_load_pretrained = True
        return assign_vars

    @property
    def training(self):
        """Get training flag."""
        return self._training

    @training.setter
    def training(self, value):
        """Set training flag."""
        self._training = value
        for module in self.children():
            module.training = value

    @property
    def freeze(self):
        """Get training flag."""
        return self.freeze

    @freeze.setter
    def freeze(self, value):
        """Set training flag."""
        self._trainable = not value
        for module in self.children():
            module.freeze = value

    def __setattr__(self, key, value):
        """Set name to modules."""
        self.__dict__[key] = value
        # super().__setattr__(key, value)
        if isinstance(value, Module):
            if self.enable_scope_name:
                value.parent_scope_name = key
            self._modules[key] = value

    def __getattribute__(self, name):
        """Get modules by name."""
        value = object.__getattribute__(self, name)
        if isinstance(value, Module) and self.enable_scope_name:
            value._scope_name = "{}.{}".format(
                self._scope_name, value.parent_scope_name) if self._scope_name else value.parent_scope_name
        return value

    def set_parameters(self, name, value):
        """Set Parameters."""
        with tf.variable_scope('', reuse=tf.AUTO_REUSE):
            setattr(self, name, tf.get_variable(name, initializer=value))

    def get_weights(self, name):
        """Get weights by name."""
        return tf.get_default_graph().get_tensor_by_name('{}:0'.format(name))

    def get_weight_ops(self, name):
        """Get weight ops."""
        all_weight = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES)
        weight_ops = [t for t in all_weight if not t.name.startswith(name)]
        return weight_ops

    def call(self, inputs, *args, **kwarg):
        """Call inputs."""
        output = inputs
        for model in self.children():
            output = model(output)
        return output

    def __call__(self, inputs, *args, **kwargs):
        """Call call function."""
        return self.call(inputs, *args, **kwargs)

    def modules(self):
        """Get the current modules."""
        if self._modules.values():
            return self._modules.values()
        else:
            return [self]