Exemplo n.º 1
0
 def __init__(self):
     self.name = ''
     self.data_format = General.data_format
     self._modules = Config()
     self._parameters = OrderedDict()
     self._weights_buffer = OrderedDict()
     self._init_configs()
Exemplo n.º 2
0
 def __init__(self, config=None):
     """Initialize."""
     self.is_multi_opt = False
     if config is not None:
         self.config = Config(config)
     raw_config = self.config.to_dict()
     raw_config.type = self.config.type
     map_dict = OptimMappingDict
     self.map_config = ConfigBackendMapping(
         map_dict.type_mapping_dict,
         map_dict.params_mapping_dict).backend_mapping(raw_config)
     self.optim_cls = ClassFactory.get_cls(ClassType.OPTIMIZER,
                                           self.map_config.type)
Exemplo n.º 3
0
 def __init__(self, **desc):
     """Initialize."""
     super(SimpleCnn, self).__init__()
     desc = Config(**desc)
     self.num_class = desc.num_class
     self.fp16 = desc.get('fp16', False)
     self.channels = desc.channels
     self.conv1 = ops.Conv2d(3, 32, padding=1, kernel_size=3)
     self.pool1 = ops.MaxPool2d(2, stride=2)
     self.blocks = self._blocks(self.channels, desc.blocks)
     self.pool2 = ops.MaxPool2d(2, stride=2)
     self.conv2 = ops.Conv2d(self.channels, 64, padding=1, kernel_size=3)
     self.global_conv = ops.Conv2d(64, 64, kernel_size=8, padding=0)
     self.view = ops.View()
     self.fc = ops.Linear(64, self.num_class)
Exemplo n.º 4
0
 def module_arch_params(self):
     """Get Arch params."""
     return Config({
         k.split('.')[-1]: v
         for k, v in ArchParams._values.items()
         if '.'.join(k.split('.')[:-1]) == self.name
     })
Exemplo n.º 5
0
def _set_config(args, step_name, step_type):
    """Fully train."""
    # general
    General.step_name = step_name
    if hasattr(args, "general"):
        General.from_dict(args.general)
    # pipeline
    PipelineConfig.steps = [step_name]
    # pipestep
    PipeStepConfig.type = step_type
    # model
    if hasattr(args, "model"):
        if hasattr(args.model, "model_desc"):
            args.model.model_desc = Config(args.model.model_desc)
        PipeStepConfig.model.from_dict(args.model)
    # dataset
    if hasattr(args, "dataset"):
        PipeStepConfig.dataset.from_dict(args.dataset)
    # trainer
    if hasattr(args, "trainer"):
        TrainerConfig.from_dict(args.trainer)
    # evaluator
    if hasattr(args, "evaluator"):
        # PipeStepConfig.evaluator._type_name = args.evaluator
        if "HostEvaluator" in args.evaluator:
            PipeStepConfig.evaluator_enable = True
            PipeStepConfig.evaluator.host_evaluator_enable = True
        if "DeviceEvaluator" in args.evaluator:
            PipeStepConfig.evaluator_enable = True
            PipeStepConfig.evaluator.device_evaluator_enable = True
Exemplo n.º 6
0
    def __init__(self, **cfg):
        """Initialize method."""
        cfg = Config(cfg)
        self.use_cuda = cfg.use_cuda
        self.use_distributed = cfg.use_distributed
        self.SR_lr = cfg.SR_lr
        self.cyc_lr = cfg.cyc_lr
        super(CycleSRModel, self).__init__(cfg)
        self.max_norm = cfg.grad_clip
        self.loss_names.append("G_SR")
        self.loss_names.append("SR")
        self.loss_SR = 0
        self.loss_G_SR = 0
        self.SR_lam = cfg.SR_lam
        self.cycleSR_lam = cfg.cycleSR_lam
        logging.info("Now we are using CycleGan with SR")

        self.G_SR = None
        self.HR = None
        self.LR = None
        self.SR = None
        # add model names
        self.model_names.append("SR")
        self.netSR = define_SR(cfg.VDSR, self.use_cuda, self.use_distributed)
        self.criterionSR = torch.nn.MSELoss().cuda()
        # initialize optimizers
        self.optimizer_SR = torch.optim.Adam(self.netSR.parameters(),
                                             lr=cfg.SR_lr,
                                             betas=(0.5, 0.999))
Exemplo n.º 7
0
 def __init__(self, config=None):
     """Initialize."""
     # register pytorch optim as default
     if config:
         self.config = Config(config)
         raw_config = deepcopy(self.config)
     else:
         self.config = LrScheduler.config
         raw_config = self.config.to_dict()
     raw_config.type = self.config.type
     map_dict = LrSchedulerMappingDict()
     self.map_config = ConfigBackendMapping(
         map_dict.type_mapping_dict,
         map_dict.params_mapping_dict).backend_mapping(raw_config)
     self._cls = ClassFactory.get_cls(ClassType.LR_SCHEDULER,
                                      self.map_config.type)
Exemplo n.º 8
0
def run_pipeline(load_special_lib_func=None):
    """Run pipeline."""
    args = _parse_args()
    _resume(args)
    _set_backend(args)
    _append_env()
    if load_special_lib_func:
        load_special_lib_func(args.config_file)
    config = Config(args.config_file)
    # load general
    if config.get("general"):
        General.from_dict(config.get("general"), skip_check=False)
    os.environ['TF_CPP_MIN_LOG_LEVEL'] = str(General.TF_CPP_MIN_LOG_LEVEL)
    if General.requires and not verify_requires(General.requires):
        return
    dict_args = vars(args)
    dict_args = _check_parse(dict_args)
    config = _modify_config(dict_args, config)
    _backup_config(args)
    _change_process_name()
    vega.run(config)
Exemplo n.º 9
0
def _set_backend(args):
    backend = args.backend
    device = args.device
    if backend:
        if args.backend in ["pytorch", "p"]:
            backend = "pytorch"
        elif args.backend in ["tensorflow", "t"]:
            backend = "tensorflow"
        elif args.backend in ["mindspore", "m"]:
            backend = "mindspore"
    else:
        config = Config(args.config_file)
        if "general" in config and "backend" in config["general"]:
            backend = config["general"]["backend"]
    if not device:
        config = Config(args.config_file)
        if "general" in config and "device_category" in config["general"]:
            device = config["general"]["device_category"]
    if backend:
        General.backend = backend
    if device:
        General.device_category = device
    vega.set_backend(General.backend, General.device_category)
Exemplo n.º 10
0
 def _create_examples(self, lines, set_type):
     """Create examples for the training, dev and test sets."""
     examples = []
     for (i, line) in enumerate(lines):
         if i == 0:
             continue
         guid = "%s-%s" % (set_type, i)
         text_a = line[3]
         text_b = line[4]
         label = None if set_type == "test" else line[0]
         examples.append(
             Config(
                 dict(guid=guid, text_a=text_a, text_b=text_b,
                      label=label)))
     return examples
Exemplo n.º 11
0
 def __new__(cls, *args, **kwargs):
     """Record params."""
     desc = {}
     params_sig = sig(cls.__init__).parameters
     param_names = list(params_sig.keys())
     if len(param_names) > len(args):
         # not dynamic parameter for connections
         for idx, arg in enumerate(args):
             arg_name = param_names[idx + 1]
             desc[arg_name] = arg
     if kwargs:
         desc.update(kwargs)
     instance = super(Serializable, cls).__new__(cls)
     instance.desc = Config(desc) if desc else {}
     return instance
Exemplo n.º 12
0
class LrScheduler(object):
    """Register and call LrScheduler class."""

    config = LrSchedulerConfig()

    def __init__(self, config=None):
        """Initialize."""
        # register pytorch optim as default
        if config:
            self.config = Config(config)
            raw_config = deepcopy(self.config)
        else:
            self.config = LrScheduler.config
            raw_config = self.config.to_dict()
        raw_config.type = self.config.type
        map_dict = LrSchedulerMappingDict()
        self.map_config = ConfigBackendMapping(
            map_dict.type_mapping_dict,
            map_dict.params_mapping_dict).backend_mapping(raw_config)
        self._cls = ClassFactory.get_cls(ClassType.LR_SCHEDULER,
                                         self.map_config.type)

    def __call__(self, optimizer=None, epochs=None, steps=None):
        """Call lr scheduler class."""
        params = self.map_config.get("params", {})
        logging.debug("Call LrScheduler. name={}, params={}".format(
            self._cls.__name__, params))

        setattr(self._cls, "by_epoch", True)
        if hasattr(self.config, "by_epoch"):
            setattr(self._cls, "by_epoch", self.config.by_epoch)

        try:
            if params:
                return self._cls(optimizer, **params)
            else:
                return self._cls(optimizer)
        except Exception as ex:
            logging.error(
                "Failed to call LrScheduler name={}, params={}".format(
                    self._cls.__name__, params))
            raise ex
Exemplo n.º 13
0
    def _decode_hps(hps):
        """Decode hps: `trainer.optim.lr : 0.1` to dict format.

        And convert to `vega.common.config import Config` object
        This Config will be override in Trainer or Datasets class
        The override priority is: input hps > user configuration >  default configuration
        :param hps: hyper params
        :return: dict
        """
        hps_dict = {}
        if hps is None:
            return None
        if isinstance(hps, tuple):
            return hps
        for hp_name, value in hps.items():
            hp_dict = {}
            for key in list(reversed(hp_name.split('.'))):
                if hp_dict:
                    hp_dict = {key: hp_dict}
                else:
                    hp_dict = {key: value}
            # update cfg with hps
            hps_dict = update_dict(hps_dict, hp_dict, [])
        return Config(hps_dict)
Exemplo n.º 14
0
 def _init_hps(self, hps=None):
     """Load hps from file."""
     # load config
     if hps is not None:
         pass
     elif self.config.hps_file is not None:
         desc_file = self.config.hps_file.replace("{local_base_path}",
                                                  self.local_base_path)
         hps = Config(desc_file)
         if "trainer" in hps:
             if "epochs" in hps["trainer"]:
                 hps["trainer"].pop("epochs")
             if "checkpoint_path" in hps["trainer"]:
                 hps["trainer"].pop("checkpoint_path")
     elif self.config.hps_folder is not None:
         folder = self.config.hps_folder.replace("{local_base_path}",
                                                 self.local_base_path)
         pattern = FileOps.join_path(folder, "hps_*.json")
         desc_file = glob.glob(pattern)[0]
         hps = Config(desc_file)
         if "trainer" in hps:
             if "epochs" in hps["trainer"]:
                 hps["trainer"].pop("epochs")
             if "checkpoint_path" in hps["trainer"]:
                 hps["trainer"].pop("checkpoint_path")
     # merge config
     if not self.hps:
         self.hps = hps
     elif hps:
         hps.from_dict(self.hps)
         self.hps = hps
     # set config
     if self.hps and self.hps.get('trainer'):
         self.config.from_dict(self.hps.get('trainer'))
         self.load_checkpoint = self.config.load_checkpoint
     self.epochs = self.config.epochs
Exemplo n.º 15
0
    def convert_examples_to_features(self, examples, label_list,
                                     max_seq_length, tokenizer):
        """Load a data file into a list of `InputBatch`s."""
        label_map = {label: i for i, label in enumerate(label_list)}
        features = []
        for (ex_index, example) in enumerate(examples):
            tokens_a = tokenizer.tokenize(example.text_a)

            tokens_b = None
            if example.text_b:
                tokens_b = tokenizer.tokenize(example.text_b)
                # Modifies `tokens_a` and `tokens_b` in place so that the total
                # length is less than the specified length.
                # Account for [CLS], [SEP], [SEP] with "- 3"
                _truncate_seq_pair(tokens_a, tokens_b, max_seq_length - 3)
            else:
                # Account for [CLS] and [SEP] with "- 2"
                if len(tokens_a) > max_seq_length - 2:
                    tokens_a = tokens_a[:(max_seq_length - 2)]

            # The convention in BERT is:
            # (a) For sequence pairs:
            #  tokens:   [CLS] is this jack ##son ##ville ? [SEP] no it is not . [SEP]
            #  type_ids: 0   0  0    0    0     0       0 0    1  1  1  1   1 1
            # (b) For single sequences:
            #  tokens:   [CLS] the dog is hairy . [SEP]
            #  type_ids: 0   0   0   0  0     0 0
            #
            # Where "type_ids" are used to indicate whether this is the first
            # sequence or the second sequence. The embedding vectors for `type=0` and
            # `type=1` were learned during pre-training and are added to the wordpiece
            # embedding vector (and position vector). This is not *strictly* necessary
            # since the [SEP] token unambigiously separates the sequences, but it makes
            # it easier for the model to learn the concept of sequences.
            #
            # For classification tasks, the first vector (corresponding to [CLS]) is
            # used as as the "sentence vector". Note that this only makes sense because
            # the entire model is fine-tuned.
            tokens = ["[CLS]"] + tokens_a + ["[SEP]"]
            segment_ids = [0] * len(tokens)

            if tokens_b:
                tokens += tokens_b + ["[SEP]"]
                segment_ids += [1] * (len(tokens_b) + 1)

            input_ids = tokenizer.convert_tokens_to_ids(tokens)

            # The mask has 1 for real tokens and 0 for padding tokens. Only real
            # tokens are attended to.
            input_mask = [1] * len(input_ids)

            # Zero-pad up to the sequence length.
            padding = [0] * (max_seq_length - len(input_ids))
            input_ids += padding
            input_mask += padding
            segment_ids += padding

            assert len(input_ids) == max_seq_length
            assert len(input_mask) == max_seq_length
            assert len(segment_ids) == max_seq_length

            label_id = label_map[example.label]
            if ex_index < 5:
                logging.info("*** Example ***")
                logging.info("guid: %s" % (example.guid))
                logging.info("tokens: %s" % " ".join([str(x) for x in tokens]))
                logging.info("input_ids: %s" %
                             " ".join([str(x) for x in input_ids]))
                logging.info("input_mask: %s" %
                             " ".join([str(x) for x in input_mask]))
                logging.info("segment_ids: %s" %
                             " ".join([str(x) for x in segment_ids]))
                logging.info("label: %s (id = %d)" % (example.label, label_id))

            features.append(
                Config(
                    dict(input_ids=input_ids,
                         input_mask=input_mask,
                         segment_ids=segment_ids,
                         label_id=label_id)))
        return features
Exemplo n.º 16
0
def _parse_args(sections, desc):
    parser = argment_parser(desc)
    parser.add_argument("-backend",
                        "--general.backend",
                        default="pytorch",
                        type=str,
                        help="pytorch|tensorflow|mindspore")
    if "cluster" in sections:
        parser.add_argument("-devices_per_trainer",
                            "--general.worker.devices_per_trainer",
                            default=None,
                            type=int)
        parser.add_argument("-master_ip",
                            "--general.cluster.master_ip",
                            default=None,
                            type=str)
        parser.add_argument("-slaves",
                            "--general.cluster.slaves",
                            default=[],
                            action='store',
                            dest='general.cluster.slaves',
                            type=str,
                            nargs='*',
                            help="slave IP list")
    parser.add_argument("-dataset",
                        "--dataset.type",
                        required=True,
                        type=str,
                        help="dataset name.")
    parser.add_argument("-data_path",
                        "--dataset.common.data_path",
                        type=str,
                        help="dataset path.")
    parser.add_argument("-batch_size",
                        "--dataset.common.batch_size",
                        default=256,
                        type=int)
    if "model" in sections:
        parser.add_argument("-model_desc", "--model.model_desc", type=str)
        parser.add_argument("-model_file",
                            "--model.pretrained_model_file",
                            type=str)
    if "trainer" in sections:
        parser.add_argument("-epochs", "--trainer.epochs", type=int)
    if "fine_tune" in sections:
        parser.add_argument(
            "-task_type",
            "--task_type",
            default="classification",
            type=str,
            help="classification|detection|segmentation|super_resolution")
        parser.add_argument("-num_classes", "--trainer.num_classes", type=int)
    parser.add_argument(
        "-evaluator",
        "--evaluator",
        default=[],
        action='store',
        dest='evaluator',
        type=str,
        nargs='*',
        help="evaluator list, eg. -evaluator HostEvaluator DeviceEvaluator")
    args = vars(parser.parse_args())
    args = {key: value for key, value in args.items() if args[key]}
    tree = Config(build_tree(args))
    return tree
Exemplo n.º 17
0
    def backend_mapping(self, config):
        """Map config to specific backend.

        :param config: original config from config file
        :type config: Config or dict
        :return: config after mapping to backend
        :rtype: Config
        """
        origin_config = Config(copy.deepcopy(config))
        type = origin_config.type

        if type not in self.type_mapping_dict:
            return config
        params = origin_config.get('params', {})
        backend_config = Config()
        backend_config.type = self.type_mapping_dict[type][self.backend_type]
        backend_config.params = Config()

        mapping_params = self.params_mapping_dict.get(type, {})
        for key, value in params.items():
            if key in mapping_params:
                mapping_key = mapping_params[key][self.backend_type]
            else:
                mapping_key = None
            if mapping_key is not None:
                if isinstance(value, dict) and 'type' in value:
                    backend_config.params[mapping_key] = self.backend_mapping(
                        value)
                else:
                    backend_config.params[mapping_key] = value

        return Config(backend_config)
Exemplo n.º 18
0
class Optimizer(object):
    """Register and call Optimizer class."""

    config = OptimConfig()

    def __new__(cls, *args, **kwargs):
        """Create optimizer or multi-optimizer class."""
        if isinstance(cls.config.to_dict, list):
            t_cls = ClassFactory.get_cls(ClassType.OPTIMIZER,
                                         'MultiOptimizers')
            return super().__new__(t_cls)
        return super().__new__(cls)

    def __init__(self, config=None):
        """Initialize."""
        self.is_multi_opt = False
        if config is not None:
            self.config = Config(config)
        raw_config = self.config.to_dict()
        raw_config.type = self.config.type
        map_dict = OptimMappingDict
        self.map_config = ConfigBackendMapping(
            map_dict.type_mapping_dict,
            map_dict.params_mapping_dict).backend_mapping(raw_config)
        self.optim_cls = ClassFactory.get_cls(ClassType.OPTIMIZER,
                                              self.map_config.type)

    def __call__(self, model=None, distributed=False, **kwargs):
        """Call Optimizer class.

        :param model: model, used in torch case
        :param distributed: use distributed
        :return: optimizer
        """
        params = self.map_config.get("params", {})
        logging.debug("Call Optimizer. name={}, params={}".format(
            self.optim_cls.__name__, params))
        optimizer = None
        try:
            if vega.is_torch_backend():
                learnable_params = [
                    param for param in model.parameters()
                    if param.requires_grad
                ]
                optimizer = self.optim_cls(learnable_params, **params)
                if distributed:
                    optimizer = self.set_distributed(optimizer, model)
            elif vega.is_tf_backend():
                optimizer = dynamic_optimizer(self.optim_cls, **params)
            elif vega.is_ms_backend():
                if "dynamic_lr" in kwargs:
                    params.update({"learning_rate": kwargs["dynamic_lr"]})
                learnable_params = [
                    param for param in model.trainable_params()
                    if param.requires_grad
                ]
                optimizer = self.optim_cls(learnable_params, **params)
            return optimizer
        except Exception as ex:
            logging.error("Failed to call Optimizer name={}, params={}".format(
                self.optim_cls.__name__, params))
            raise ex

    @classmethod
    def set_distributed(cls, optimizer, model=None):
        """Set distributed optimizer."""
        if vega.is_torch_backend():
            optimizer = hvd.DistributedOptimizer(
                optimizer,
                named_parameters=model.named_parameters(),
                compression=hvd.Compression.none)
        elif vega.is_tf_backend():
            optim_class = hvd.DistributedOptimizer if vega.is_gpu_device(
            ) else NPUDistributedOptimizer
            optimizer = dynamic_distributed_optimizer(optim_class, optimizer)
        return optimizer
Exemplo n.º 19
0
 def set_arch_params(self, value):
     """Set Arch params."""
     ArchParams._values = Config(value)
Exemplo n.º 20
0
class Module(object):
    """Base Module to adapter tf Module."""
    def __init__(self):
        self.name = ''
        self.data_format = General.data_format
        self._modules = Config()
        self._parameters = OrderedDict()
        self._weights_buffer = OrderedDict()
        self._init_configs()

    def _init_configs(self):
        self._training = True
        self._trainable = True
        self.weight_file = None
        self.from_weight_type = None
        self._is_load_pretrained = False
        self.exclude_weight_prefix = None
        self._pre_hooks = OrderedDict()
        self._call_hooks = OrderedDict()

    def add_module(self, name, model):
        """Add models into self._models."""
        setattr(self, str(name), model)

    def build(self):
        """Build model or params."""
        pass

    def register_forward_pre_hook(self, hook):
        """Register pre hook."""
        self._pre_hooks[hook.__name__] = hook

    def register_forward_hook(self, hook):
        """Register call hook."""
        self._call_hooks[hook.__name__] = hook

    def named_modules(self):
        """Return names spaces."""
        self._apply_names()
        _modules = []
        for module in self.children():
            _modules.append((module.name, module))
            _modules.extend(module.named_modules())
        return _modules

    def named_children(self):
        """Return names children."""
        return [(name, module) for name, module in self._modules.items()]

    def children(self):
        """Get child models of current Module."""
        for model in self._modules.values():
            yield model

    def load_checkpoint(self, weight_file):
        """Load weight state dict from last checkpoint file."""
        if not weight_file:
            return
        logging.info("Load checkpoint form file ({}).".format(weight_file))
        # model_file = tf.train.latest_checkpoint(weight_file)
        reader = tf.train.NewCheckpointReader(weight_file)
        variables = reader.get_variable_to_shape_map()
        states = {v: reader.get_tensor(v) for v in variables}
        self.load_checkpoint_from_numpy(states)

    def load_checkpoint_from_numpy(self, states):
        """Load checkpoint from numpy."""
        states = self._exclude_checkpoint_by_prefix(states)
        for name, module in self.named_modules():
            child_state = [(k, v) for k, v in states.items()
                           if k.startswith(module.name + '/')]
            for k, v in child_state:
                module.set_weights(k, v)

    def _exclude_checkpoint_by_prefix(self, states):
        if self.exclude_weight_prefix:
            if not isinstance(self.exclude_weight_prefix, list):
                self.exclude_weight_prefix = [self.exclude_weight_prefix]
            for prefix in self.exclude_weight_prefix:
                states = {
                    k: v
                    for k, v in states.items() if not k.startswith(prefix)
                }
        return states

    def set_weights(self, name, value):
        """Set weights into weights buffer."""
        self._weights_buffer[name] = value

    @property
    def training(self):
        """Get training flag."""
        return self._training

    @training.setter
    def training(self, value):
        """Set training flag."""
        self._training = value
        for module in self.children():
            module.training = value

    def freeze(self):
        """Set training flag."""
        self._trainable = False
        for module in self.children():
            module.freeze()

    def __setattr__(self, key, value):
        """Set name to modules."""
        super().__setattr__(key, value)
        if isinstance(value, Module):
            self._modules[key] = value

    def set_parameters(self, name, value):
        """Set Parameters."""
        self._parameters[name] = value
        setattr(self, name, value)
        return self.name

    def get_weights(self, name=None):
        """Get weights by name."""
        if name is None:
            return self._weights_buffer
        else:
            return tf.get_default_graph().get_tensor_by_name(
                '{}:0'.format(name))

    def get_all_weights(self):
        """Get all weights."""
        all_weights = OrderedDict()
        for child in self.children():
            all_weights.update(child._weights_buffer)
            if isinstance(child, Module):
                all_weights.update(child.get_all_weights())
        return all_weights

    def get_weight_ops(self, name):
        """Get weight ops."""
        all_weight = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES)
        weight_ops = [t for t in all_weight if not t.name.startswith(name)]
        return weight_ops

    def call(self, inputs, *args, **kwarg):
        """Call inputs."""
        output = inputs
        for model in self.children():
            output = model(output)
        return output

    def _apply_names(self, parent_name=''):
        """Apply names spaces."""
        for scope_name, module in self._modules.items():
            scope_name = '{}.{}'.format(
                parent_name, scope_name) if parent_name else scope_name
            module.name = module.name or scope_name + '/' + module.__class__.__name__
            module._apply_names(scope_name)

    def _apply_parameters(self):
        """Apply names spaces."""
        for name, params in self._parameters.items():
            setattr(
                self, name,
                tf.Variable(params,
                            name='{}.{}'.format(self.name, name)
                            if self.name else name))

    def __call__(self, inputs, *args, **kwargs):
        """Call call function."""
        self.build()
        self._apply_parameters()
        self._apply_names()
        for module in self.children():
            module._is_load_pretrained = True
        if self.training:
            for hook_name, hook in self._pre_hooks.items():
                inputs = hook(self, inputs) or inputs
        out = self.call(inputs, *args, **kwargs)
        if self.training:
            for hook_name, hook in self._call_hooks.items():
                out = hook(self, inputs, out) or out
        self._apply_weights()
        return out

    def _apply_weights(self):
        if not self._weights_buffer:
            return
        variables = tf.get_collection(tf.GraphKeys.VARIABLES)
        if isinstance(self, Conv2d):
            self._weights_buffer = {
                k.replace('/weights', '/kernel'): v
                for k, v in self._weights_buffer.items()
            }
        values = [(var, self._weights_buffer.get(var.name.replace(':0', '')))
                  for var in variables
                  if var.name.replace(':0', '') in self._weights_buffer]
        for v, weight in values:
            v._initializer_op = state_ops.assign(v, weight)
        self._weights_buffer.clear()

    def modules(self):
        """Get the current modules."""
        if self._modules.values():
            return self._modules.values()
        else:
            return [self]
Exemplo n.º 21
0
 def __init__(self, config=None):
     """Initialize."""
     self.is_multi_opt = True
     if config is not None:
         self.config = Config(config)
     self._opts = OrderedDict()