def build(cls, params, sub_cls=None, controller=None): loaders = {} splits = try_get_attr(params, f'{cls.prefix_name()}_splits', ('train', )) shuffles = try_get_attr(params, f'{cls.prefix_name()}_shuffles', (True, )) belongs = try_get_attr(params, f'{cls.prefix_name()}_belongs', ('train', )) loader_kwargs = load_func_kwargs(params, py_data.DataLoader.__init__, cls.prefix_name()) init_kwargs = load_func_kwargs(params, cls.__init__, cls.prefix_name()) cls_name = getattr(params, f'{cls.prefix_name()}_cls', None) assert cls_name is not None sub_cls = cls.load_cls(cls_name) for idx, belong in enumerate(belongs): init_kwargs['split'] = splits[idx] dataset = sub_cls(**init_kwargs) if controller is not None: dataset.controller = controller loader_kwargs.update({ 'shuffle': shuffles[idx], 'dataset': dataset, 'collate_fn': dataset.collate_fn }) loaders[belong] = py_data.DataLoader(**loader_kwargs) return loaders
def build(cls, params, sub_cls=None, controller=None): params = to_namespace(params) writer_kwargs = load_func_kwargs(params, SummaryWriter.__init__, cls.prefix_name()) init_kwargs = load_func_kwargs(params, cls.__init__, cls.prefix_name()) init_kwargs.update({'writer_kwargs': writer_kwargs}) module = cls(**init_kwargs) controller.register_module(module) return module
def build(cls, params, sub_cls=None, controller=None): opt_type = try_get_attr(params, f'{cls.prefix_name()}_type', None, check=False) opt_cls = getattr(optim, opt_type, None) assert opt_cls is not None optim_kwargs = load_func_kwargs(params, opt_cls.__init__, cls.prefix_name()) init_kwargs = load_func_kwargs(params, cls.__init__, cls.prefix_name()) init_kwargs.update({'optim_kwargs': optim_kwargs}) return cls(**init_kwargs)
def build(cls, params, sub_cls=None, controller=None): layer_cls_names = try_get_attr(params, f'{cls.prefix_name()}_layer_names') layer_clses = [ Layer.load_cls(layer_cls_name) for layer_cls_name in layer_cls_names ] layer_params = cls.collect_layer_params(layer_cls_names) layer_args = { name: try_get_attr(params, f'{cls.prefix_name()}_layer_{name}s') for name in layer_params.keys() } layer_args = { key: value for key, value in layer_args.items() if value is not None } layers = list() for idx, layer_cls in enumerate(layer_clses): layer_kwargs = { name: layer_arg[idx] for name, layer_arg in layer_args.items() } layer_kwargs = load_func_kwargs(layer_kwargs, layer_cls.__init__) layers.append(layer_cls(**layer_kwargs)) kwargs = load_func_params(params, cls.__init__, cls.prefix_name()) kwargs[f'{cls.prefix_name()}_layers'] = layers return cls.default_build(kwargs, controller=controller)
def default_build(cls, params, controller=None): init_kwargs = load_func_kwargs(params, cls.__init__, cls.prefix_name()) module = cls(**init_kwargs) if controller is not None: module.controller = controller # if controller is not None: # controller.register_module(module) return module
def forward(self, img_feats, q_feats): for layer in self.layers: layer_kwargs = load_func_kwargs( { 'img_feats': img_feats, 'q_feats': q_feats }, layer.forward) img_feats = layer(**layer_kwargs) return img_feats
def init_args(cls, params, sub_cls=None): cls.default_init_args(params) try_set_attr(params, f'{cls.prefix_name()}_name', 'logger_group') try_set_attr( params, f'{cls.prefix_name()}_logger_dir', to_path(params.root_dir).joinpath(f'loggers/{params.proj_name}')) logger_cls = cls.load_cls( try_get_attr(params, f'{cls.prefix_name()}_logger_cls', check=False)) if logger_cls is not None: logger_cls.init_args(params) setattr( params, f'{cls.prefix_name()}_logger_kwargs', load_func_kwargs(params, logger_cls.__init__, cls.prefix_name()))
def get_input(cls, sample): kwargs = load_func_kwargs(sample, cls.forward) return kwargs