コード例 #1
0
    def build(cls, params, sub_cls=None, controller=None):
        loaders = {}
        splits = try_get_attr(params, f'{cls.prefix_name()}_splits',
                              ('train', ))
        shuffles = try_get_attr(params, f'{cls.prefix_name()}_shuffles',
                                (True, ))
        belongs = try_get_attr(params, f'{cls.prefix_name()}_belongs',
                               ('train', ))

        loader_kwargs = load_func_kwargs(params, py_data.DataLoader.__init__,
                                         cls.prefix_name())
        init_kwargs = load_func_kwargs(params, cls.__init__, cls.prefix_name())
        cls_name = getattr(params, f'{cls.prefix_name()}_cls', None)
        assert cls_name is not None
        sub_cls = cls.load_cls(cls_name)
        for idx, belong in enumerate(belongs):
            init_kwargs['split'] = splits[idx]
            dataset = sub_cls(**init_kwargs)
            if controller is not None:
                dataset.controller = controller
            loader_kwargs.update({
                'shuffle': shuffles[idx],
                'dataset': dataset,
                'collate_fn': dataset.collate_fn
            })
            loaders[belong] = py_data.DataLoader(**loader_kwargs)
        return loaders
コード例 #2
0
ファイル: experiment.py プロジェクト: code4paper/ra_gcn
 def build(cls, params, sub_cls=None, controller=None):
     params = to_namespace(params)
     writer_kwargs = load_func_kwargs(params, SummaryWriter.__init__,
                                      cls.prefix_name())
     init_kwargs = load_func_kwargs(params, cls.__init__, cls.prefix_name())
     init_kwargs.update({'writer_kwargs': writer_kwargs})
     module = cls(**init_kwargs)
     controller.register_module(module)
     return module
コード例 #3
0
 def build(cls, params, sub_cls=None, controller=None):
     opt_type = try_get_attr(params,
                             f'{cls.prefix_name()}_type',
                             None,
                             check=False)
     opt_cls = getattr(optim, opt_type, None)
     assert opt_cls is not None
     optim_kwargs = load_func_kwargs(params, opt_cls.__init__,
                                     cls.prefix_name())
     init_kwargs = load_func_kwargs(params, cls.__init__, cls.prefix_name())
     init_kwargs.update({'optim_kwargs': optim_kwargs})
     return cls(**init_kwargs)
コード例 #4
0
ファイル: base.py プロジェクト: code4paper/ra_gcn
 def build(cls, params, sub_cls=None, controller=None):
     layer_cls_names = try_get_attr(params,
                                    f'{cls.prefix_name()}_layer_names')
     layer_clses = [
         Layer.load_cls(layer_cls_name)
         for layer_cls_name in layer_cls_names
     ]
     layer_params = cls.collect_layer_params(layer_cls_names)
     layer_args = {
         name: try_get_attr(params, f'{cls.prefix_name()}_layer_{name}s')
         for name in layer_params.keys()
     }
     layer_args = {
         key: value
         for key, value in layer_args.items() if value is not None
     }
     layers = list()
     for idx, layer_cls in enumerate(layer_clses):
         layer_kwargs = {
             name: layer_arg[idx]
             for name, layer_arg in layer_args.items()
         }
         layer_kwargs = load_func_kwargs(layer_kwargs, layer_cls.__init__)
         layers.append(layer_cls(**layer_kwargs))
     kwargs = load_func_params(params, cls.__init__, cls.prefix_name())
     kwargs[f'{cls.prefix_name()}_layers'] = layers
     return cls.default_build(kwargs, controller=controller)
コード例 #5
0
 def default_build(cls, params, controller=None):
     init_kwargs = load_func_kwargs(params, cls.__init__, cls.prefix_name())
     module = cls(**init_kwargs)
     if controller is not None:
         module.controller = controller
     # if controller is not None:
     #     controller.register_module(module)
     return module
コード例 #6
0
 def forward(self, img_feats, q_feats):
     for layer in self.layers:
         layer_kwargs = load_func_kwargs(
             {
                 'img_feats': img_feats,
                 'q_feats': q_feats
             }, layer.forward)
         img_feats = layer(**layer_kwargs)
     return img_feats
コード例 #7
0
ファイル: base.py プロジェクト: code4paper/ra_gcn
 def init_args(cls, params, sub_cls=None):
     cls.default_init_args(params)
     try_set_attr(params, f'{cls.prefix_name()}_name', 'logger_group')
     try_set_attr(
         params, f'{cls.prefix_name()}_logger_dir',
         to_path(params.root_dir).joinpath(f'loggers/{params.proj_name}'))
     logger_cls = cls.load_cls(
         try_get_attr(params,
                      f'{cls.prefix_name()}_logger_cls',
                      check=False))
     if logger_cls is not None:
         logger_cls.init_args(params)
         setattr(
             params, f'{cls.prefix_name()}_logger_kwargs',
             load_func_kwargs(params, logger_cls.__init__,
                              cls.prefix_name()))
コード例 #8
0
ファイル: base.py プロジェクト: code4paper/ra_gcn
 def get_input(cls, sample):
     kwargs = load_func_kwargs(sample, cls.forward)
     return kwargs