def __init__( self, hiddens, layer_fn=nn.Linear, bias=True, norm_fn=None, activation_fn=nn.ReLU, dropout=None, layer_order=None, residual=False ): super().__init__() # hack to prevent cycle imports from catalyst.contrib.registry import Registry layer_fn = Registry.name2nn(layer_fn) activation_fn = Registry.name2nn(activation_fn) norm_fn = Registry.name2nn(norm_fn) dropout = Registry.name2nn(dropout) layer_order = layer_order or ["layer", "norm", "drop", "act"] if isinstance(dropout, float): dropout_fn = lambda: nn.Dropout(dropout) else: dropout_fn = dropout def _layer_fn(f_in, f_out, bias): return layer_fn(f_in, f_out, bias=bias) def _normalize_fn(f_in, f_out, bias): return norm_fn(f_out) if norm_fn is not None else None def _dropout_fn(f_in, f_out, bias): return dropout_fn() if dropout_fn is not None else None def _activation_fn(f_in, f_out, bias): return activation_fn() if activation_fn is not None else None name2fn = { "layer": _layer_fn, "norm": _normalize_fn, "drop": _dropout_fn, "act": _activation_fn, } net = [] for i, (f_in, f_out) in enumerate(pairwise(hiddens)): block = [] for key in layer_order: fn = name2fn[key](f_in, f_out, bias) if fn is not None: block.append((f"{key}", fn)) block = torch.nn.Sequential(OrderedDict(block)) if residual: block = ResidualWrapper(net=block) net.append((f"block_{i}", block)) self.net = torch.nn.Sequential(OrderedDict(net))
def prepare_for_trainer(cls, config): # hack to prevent cycle dependencies from catalyst.contrib.registry import Registry config_ = config.copy() actor_state_shape = ( config_["shared"]["history_len"], config_["shared"]["state_size"], ) actor_action_size = config_["shared"]["action_size"] n_step = config_["shared"]["n_step"] gamma = config_["shared"]["gamma"] history_len = config_["shared"]["history_len"] trainer_state_shape = (config_["shared"]["state_size"], ) trainer_action_shape = (config_["shared"]["action_size"], ) actor_fn = config_["actor"].pop("agent", None) actor = Registry.get_agent(agent=actor_fn, state_shape=actor_state_shape, action_size=actor_action_size, **config_["actor"]) critic_fn = config_["critic"].pop("agent", None) critic = Registry.get_agent(agent=critic_fn, state_shape=actor_state_shape, action_size=actor_action_size, **config_["critic"]) n_critics = config_["algorithm"].pop("n_critics", 2) critics = [ Registry.get_agent(agent=critic_fn, state_shape=actor_state_shape, action_size=actor_action_size, **config_["critic"]) for _ in range(n_critics - 1) ] algorithm = cls(**config_["algorithm"], actor=actor, critic=critic, critics=critics, n_step=n_step, gamma=gamma) kwargs = { "algorithm": algorithm, "state_shape": trainer_state_shape, "action_shape": trainer_action_shape, "n_step": n_step, "gamma": gamma, "history_len": history_len } return kwargs
def _init(self, critics, action_noise_std=0.2, action_noise_clip=0.5, values_range=(-10., 10.), critic_distribution=None, **kwargs): super()._init(**kwargs) # hack to prevent cycle dependencies from catalyst.contrib.registry import Registry self.n_atoms = self.critic.out_features self._loss_fn = self._base_loss self.action_noise_std = action_noise_std self.action_noise_clip = action_noise_clip critics = [x.to(self._device) for x in critics] critics_optimizer = [ Registry.get_optimizer(x, **self.critic_optimizer_params) for x in critics ] critics_scheduler = [ Registry.get_scheduler(x, **self.critic_scheduler_params) for x in critics_optimizer ] target_critics = [copy.deepcopy(x).to(self._device) for x in critics] self.critics = [self.critic] + critics self.critics_optimizer = [self.critic_optimizer] + critics_optimizer self.critics_scheduler = [self.critic_scheduler] + critics_scheduler self.target_critics = [self.target_critic] + target_critics if critic_distribution == "quantile": tau_min = 1 / (2 * self.n_atoms) tau_max = 1 - tau_min tau = torch.linspace(start=tau_min, end=tau_max, steps=self.n_atoms) self.tau = self._to_tensor(tau) self._loss_fn = self._quantile_loss elif critic_distribution == "categorical": self.v_min, self.v_max = values_range self.delta_z = (self.v_max - self.v_min) / (self.n_atoms - 1) z = torch.linspace(start=self.v_min, end=self.v_max, steps=self.n_atoms) self.z = self._to_tensor(z) self._loss_fn = self._categorical_loss
def prepare_callbacks(*, mode: str, stage: str = None, resume: str = None, out_prefix: str = None, **kwargs) -> Dict[str, Callback]: """ Runner callbacks method to handle different runs logic. :param args: console args :param mode: train/infer :param stage: training stage name :param resume: path to checkpoint (used for checkpoint callback) :param **kwargs: callbacks params :return: OrderedDict with callbacks """ callbacks = OrderedDict() for key, value in kwargs.items(): callback = Registry.get_callback(**value) callbacks[key] = callback for key, value in callbacks.items(): # @TODO: remove hack if resume is not None and hasattr(value, "resume"): value.resume = resume if out_prefix is not None and hasattr(value, "out_prefix"): value.out_prefix = out_prefix return callbacks
def __init__(self, action_size, layer_fn, activation_fn=nn.ReLU, squashing_fn=nn.Tanh, norm_fn=None, bias=False): super().__init__() # hack to prevent cycle imports from catalyst.contrib.registry import Registry activation_fn = Registry.name2nn(activation_fn) self.action_size = action_size self.coupling1 = CouplingLayer(action_size=action_size, layer_fn=layer_fn, activation_fn=activation_fn, norm_fn=None, bias=bias, parity="odd") self.coupling2 = CouplingLayer(action_size=action_size, layer_fn=layer_fn, activation_fn=activation_fn, norm_fn=None, bias=bias, parity="even") self.squashing_layer = SquashingLayer(squashing_fn)
def __init__(self, squashing_fn=nn.Tanh): """ Layer that squashes samples from some distribution to be bounded. """ super().__init__() # hack to prevent cycle imports from catalyst.contrib.registry import Registry self.squashing_fn = Registry.name2nn(squashing_fn)()
def main(args, unknown_args): args, config = parse_args_uargs(args, unknown_args, dump_config=True) os.makedirs(args.logdir, exist_ok=True) save_config(config=config, logdir=args.logdir) if args.expdir is not None: modules = prepare_modules( # noqa: F841 expdir=args.expdir, dump_dir=args.logdir) algorithm = Registry.get_fn("algorithm", args.algorithm) algorithm_kwargs = algorithm.prepare_for_trainer(config) redis_server = StrictRedis(port=config.get("redis", {}).get("port", 12000)) redis_prefix = config.get("redis", {}).get("prefix", "") pprint(config["trainer"]) pprint(algorithm_kwargs) trainer = Trainer(**config["trainer"], **algorithm_kwargs, logdir=args.logdir, redis_server=redis_server, redis_prefix=redis_prefix) pprint(trainer) def on_exit(): for p in trainer.get_processes(): p.terminate() atexit.register(on_exit) trainer.run()
def prepare_for_sampler(cls, config): # hack to prevent cycle dependencies from catalyst.contrib.registry import Registry config_ = config.copy() actor_state_shape = ( config_["shared"]["history_len"], config_["shared"]["state_size"], ) actor_action_size = config_["shared"]["action_size"] actor_fn = config_["actor"].pop("agent", None) actor = Registry.get_agent( agent=actor_fn, state_shape=actor_state_shape, action_size=actor_action_size, **config_["actor"] ) history_len = config_["shared"]["history_len"] kwargs = {"actor": actor, "history_len": history_len} return kwargs
def get_optimizer(self, stage: str, model) -> _Optimizer: fp16 = isinstance(model, Fp16Wrap) optimizer_params = (self.stages_config[stage].get( "optimizer_params", {})) optimizer = Registry.get_optimizer(model, **optimizer_params, fp16=fp16) return optimizer
def __init__(self, in_features, activation_fn="Tanh"): super().__init__() # hack to prevent cycle imports from catalyst.contrib.registry import Registry activation_fn = Registry.name2nn(activation_fn) self.attn = nn.Sequential( nn.Conv2d( in_features, 1, kernel_size=1, stride=1, padding=0, bias=False ), activation_fn() )
def prepare_model_stuff(*, model, criterion_params=None, optimizer_params=None, scheduler_params=None): fp16 = isinstance(model, Fp16Wrap) criterion_params = criterion_params or {} criterion = Registry.get_criterion(**criterion_params) optimizer_params = optimizer_params or {} optimizer = Registry.get_optimizer(model, **optimizer_params, fp16=fp16) scheduler_params = scheduler_params or {} scheduler = Registry.get_scheduler(optimizer, **scheduler_params) return criterion, optimizer, scheduler
def get_callbacks(self, stage: str) -> "List[Callback]": callbacks_params = (self.stages_config[stage].get( "callbacks_params", {})) callbacks = [] for key, value in callbacks_params.items(): callback = Registry.get_callback(**value) callbacks.append(callback) return callbacks
def main(args, unknown_args): args, config = parse_args_uargs(args, unknown_args, dump_config=True) os.makedirs(args.logdir, exist_ok=True) save_config(config=config, logdir=args.logdir) if args.expdir is not None: modules = prepare_modules( # noqa: F841 expdir=args.expdir, dump_dir=args.logdir) algorithm = Registry.get_fn("algorithm", args.algorithm) if args.environment is not None: # @TODO: remove this hack # come on, just refactor whole rl environment_fn = Registry.get_fn("environment", args.environment) env = environment_fn(**config["env"]) config["shared"]["observation_size"] = env.observation_shape[0] config["shared"]["action_size"] = env.action_shape[0] del env algorithm_kwargs = algorithm.prepare_for_trainer(config) redis_server = StrictRedis(port=config.get("redis", {}).get("port", 12000)) redis_prefix = config.get("redis", {}).get("prefix", "") pprint(config["trainer"]) pprint(algorithm_kwargs) trainer = Trainer(**config["trainer"], **algorithm_kwargs, logdir=args.logdir, redis_server=redis_server, redis_prefix=redis_prefix) pprint(trainer) def on_exit(): for p in trainer.get_processes(): p.terminate() atexit.register(on_exit) trainer.run()
def __init__(self, arch="resnet34", pretrained=True, frozen=True, pooling=None, pooling_kwargs=None, cut_layers=2): super().__init__() # hack to prevent cycle imports from catalyst.contrib.registry import Registry resnet = torchvision.models.__dict__[arch](pretrained=pretrained) modules = list(resnet.children())[:-cut_layers] # delete last layers if frozen: for module in modules: for param in module.parameters(): param.requires_grad = False if pooling is not None: pooling_kwargs = pooling_kwargs or {} pooling_layer_fn = Registry.name2nn(pooling) pooling_layer = pooling_layer_fn( in_features=resnet.fc.in_features, **pooling_kwargs) \ if "attn" in pooling.lower() \ else pooling_layer_fn(**pooling_kwargs) modules += [pooling_layer] out_features = pooling_layer.out_features( in_features=resnet.fc.in_features) else: out_features = resnet.fc.in_features flatten = Registry.name2nn("Flatten") modules += [flatten()] self.out_features = out_features self.encoder = nn.Sequential(*modules)
def __init__(self, grad_clip_params: Dict = None, fp16_grad_scale: float = 128.0, accumulation_steps: int = 1, optimizer_key: str = None, loss_key: str = None): """ @TODO: docs """ # hack to prevent cycle imports from catalyst.contrib.registry import Registry grad_clip_params = grad_clip_params or {} self.grad_clip_fn = Registry.get_grad_clip_fn(**grad_clip_params) self.fp16 = False self.fp16_grad_scale = fp16_grad_scale self.accumulation_steps = accumulation_steps self.optimizer_key = optimizer_key self.loss_key = loss_key self._optimizer_wd = 0 self._accumulation_counter = 0
def main(args, unknown_args): args, config = parse_args_uargs(args, unknown_args) set_global_seeds(args.seed) modules = prepare_modules(expdir=args.expdir) model = Registry.get_model(**config["model_params"]) datasource = modules["data"].DataSource() data_params = config.get("data_params", {}) or {} loaders = datasource.prepare_loaders(mode="infer", n_workers=args.workers, batch_size=args.batch_size, **data_params) runner = modules["model"].ModelRunner(model=model) callbacks_params = config.get("callbacks_params", {}) or {} callbacks = runner.prepare_callbacks(mode="infer", resume=args.resume, out_prefix=args.out_prefix, **callbacks_params) runner.infer(loaders=loaders, callbacks=callbacks, verbose=args.verbose)
def main(args, unknown_args): args, config = parse_args_uargs(args, unknown_args, dump_config=True) set_global_seeds(args.seed) assert args.baselogdir is not None or args.logdir is not None if args.logdir is None: modules_ = prepare_modules(expdir=args.expdir) logdir = modules_["model"].prepare_logdir(config=config) args.logdir = str(pathlib.Path(args.baselogdir).joinpath(logdir)) os.makedirs(args.logdir, exist_ok=True) save_config(config=config, logdir=args.logdir) modules = prepare_modules(expdir=args.expdir, dump_dir=args.logdir) model = Registry.get_model(**config["model_params"]) datasource = modules["data"].DataSource() runner = modules["model"].ModelRunner(model=model) runner.train_stages(datasource=datasource, args=args, stages_config=config["stages"], verbose=args.verbose)
def __init__( self, input_size=224, width_mult=1., pretrained=True, pooling=None, pooling_kwargs=None, ): super().__init__() # hack to prevent cycle imports from catalyst.contrib.registry import Registry net = MobileNetV2(input_size=input_size, width_mult=width_mult, pretrained=pretrained) self.encoder = list(net.encoder.children()) if pooling is not None: pooling_kwargs = pooling_kwargs or {} pooling_layer_fn = Registry.name2nn(pooling) pooling_layer = pooling_layer_fn( in_features=self.last_channel, **pooling_kwargs) \ if "attn" in pooling.lower() \ else pooling_layer_fn(**pooling_kwargs) self.encoder.append(pooling_layer) out_features = pooling_layer.out_features( in_features=net.output_channel) else: out_features = net.output_channel self.out_features = out_features # make it nn.Sequential self.encoder = nn.Sequential(*self.encoder) self._initialize_weights()
def get_model(self, stage: str) -> _Model: model = Registry.get_model(**self._config["model_params"]) model = self._preprocess_model_for_stage(stage, model) model = self._postprocess_model_for_stage(stage, model) return model
def main(args, unknown_args): args, config = parse_args_uargs(args, unknown_args) os.makedirs(args.logdir, exist_ok=True) save_config(config=config, logdir=args.logdir) if args.expdir is not None: modules = prepare_modules( # noqa: F841 expdir=args.expdir, dump_dir=args.logdir) algorithm = Registry.get_fn("algorithm", args.algorithm) environment = Registry.get_fn("environment", args.environment) processes = [] sampler_id = 0 def on_exit(): for p in processes: p.terminate() atexit.register(on_exit) params = dict(logdir=args.logdir, algorithm=algorithm, environment=environment, config=config, resume=args.resume, redis=args.redis) if args.debug: params_ = dict( vis=False, infer=False, action_noise=0.5, param_noise=0.5, action_noise_prob=args.action_noise_prob, param_noise_prob=args.param_noise_prob, id=sampler_id, ) run_sampler(**params, **params_) for i in range(args.vis): params_ = dict( vis=False, infer=False, action_noise_prob=0, param_noise_prob=0, id=sampler_id, ) p = mp.Process(target=run_sampler, kwargs=dict(**params, **params_)) p.start() processes.append(p) sampler_id += 1 for i in range(args.infer): params_ = dict( vis=False, infer=True, action_noise_prob=0, param_noise_prob=0, id=sampler_id, ) p = mp.Process(target=run_sampler, kwargs=dict(**params, **params_)) p.start() processes.append(p) sampler_id += 1 for i in range(1, args.train + 1): action_noise = args.max_action_noise * i / args.train \ if args.max_action_noise is not None \ else None param_noise = args.max_param_noise * i / args.train \ if args.max_param_noise is not None \ else None params_ = dict( vis=False, infer=False, action_noise=action_noise, param_noise=param_noise, action_noise_prob=args.action_noise_prob, param_noise_prob=args.param_noise_prob, id=sampler_id, ) p = mp.Process(target=run_sampler, kwargs=dict(**params, **params_)) p.start() processes.append(p) sampler_id += 1 for p in processes: p.join()
def __init__(self, actor, critic, gamma, n_step, actor_optimizer_params, critic_optimizer_params, actor_grad_clip_params=None, critic_grad_clip_params=None, actor_loss_params=None, critic_loss_params=None, actor_scheduler_params=None, critic_scheduler_params=None, resume=None, load_optimizer=True, actor_tau=1.0, critic_tau=1.0, min_action=-1.0, max_action=1.0, **kwargs): # hack to prevent cycle dependencies from catalyst.contrib.registry import Registry self._device = UtilsFactory.prepare_device() self.actor = actor.to(self._device) self.critic = critic.to(self._device) self.target_actor = copy.deepcopy(actor).to(self._device) self.target_critic = copy.deepcopy(critic).to(self._device) self.actor_optimizer = Registry.get_optimizer(self.actor, **actor_optimizer_params) self.critic_optimizer = Registry.get_optimizer( self.critic, **critic_optimizer_params) self.actor_optimizer_params = actor_optimizer_params self.critic_optimizer_params = critic_optimizer_params actor_scheduler_params = actor_scheduler_params or {} critic_scheduler_params = critic_scheduler_params or {} self.actor_scheduler = Registry.get_scheduler(self.actor_optimizer, **actor_scheduler_params) self.critic_scheduler = Registry.get_scheduler( self.critic_optimizer, **critic_scheduler_params) self.actor_scheduler_params = actor_scheduler_params self.critic_scheduler_params = critic_scheduler_params self.n_step = n_step self.gamma = gamma actor_grad_clip_params = actor_grad_clip_params or {} critic_grad_clip_params = critic_grad_clip_params or {} self.actor_grad_clip_fn = Registry.get_grad_clip_fn( **actor_grad_clip_params) self.critic_grad_clip_fn = Registry.get_grad_clip_fn( **critic_grad_clip_params) self.actor_grad_clip_params = actor_grad_clip_params self.critic_grad_clip_params = critic_grad_clip_params self.actor_criterion = Registry.get_criterion( **(actor_loss_params or {})) self.critic_criterion = Registry.get_criterion( **(critic_loss_params or {})) self.actor_loss_params = actor_loss_params self.critic_loss_params = critic_loss_params self.actor_tau = actor_tau self.critic_tau = critic_tau self.min_action = min_action self.max_action = max_action self._init(**kwargs) if resume is not None: self.load_checkpoint(resume, load_optimizer=load_optimizer)
def _init( self, critics, reward_scale=1.0, values_range=(-10., 10.), critic_distribution=None, **kwargs ): """ Parameters ---------- reward_scale: float, THE MOST IMPORTANT HYPERPARAMETER which controls the ratio between maximizing rewards and acting as randomly as possible use_regularization: bool, whether to use l2 regularization on policy network outputs, regularization can not be used with RealNVPActor mu_and_sigma_reg: float, coefficient for l2 regularization on mu and log_sigma policy_grad_estimator: str, "reinforce": may be used with arbitrary explicit policy "reparametrization_trick": may be used with reparametrizable policy, e.g. Gaussian, normalizing flow (Real NVP). """ super()._init(**kwargs) # hack to prevent cycle dependencies from catalyst.contrib.registry import Registry self.n_atoms = self.critic.out_features self._loss_fn = self._base_loss self.reward_scale = reward_scale # @TODO: policy regularization critics = [x.to(self._device) for x in critics] critics_optimizer = [ Registry.get_optimizer(x, **self.critic_optimizer_params) for x in critics ] critics_scheduler = [ Registry.get_scheduler(x, **self.critic_scheduler_params) for x in critics_optimizer ] target_critics = [copy.deepcopy(x).to(self._device) for x in critics] self.critics = [self.critic] + critics self.critics_optimizer = [self.critic_optimizer] + critics_optimizer self.critics_scheduler = [self.critic_scheduler] + critics_scheduler self.target_critics = [self.target_critic] + target_critics if critic_distribution == "quantile": tau_min = 1 / (2 * self.n_atoms) tau_max = 1 - tau_min tau = torch.linspace( start=tau_min, end=tau_max, steps=self.n_atoms ) self.tau = self._to_tensor(tau) self._loss_fn = self._quantile_loss elif critic_distribution == "categorical": self.v_min, self.v_max = values_range self.delta_z = (self.v_max - self.v_min) / (self.n_atoms - 1) z = torch.linspace( start=self.v_min, end=self.v_max, steps=self.n_atoms ) self.z = self._to_tensor(z) self._loss_fn = self._categorical_loss
def __init__(self, action_size, layer_fn, activation_fn=nn.ReLU, norm_fn=None, bias=True, parity="odd"): """ Conditional affine coupling layer used in Real NVP Bijector. Original paper: https://arxiv.org/abs/1605.08803 Adaptation to RL: https://arxiv.org/abs/1804.02808 Important notes --------------- 1. State embeddings are supposed to have size (action_size * 2). 2. Scale and translation networks used in the Real NVP Bijector both have one hidden layer of (action_size) (activation_fn) units. 3. Parity ("odd" or "even") determines which part of the input is being copied and which is being transformed. """ super().__init__() # hack to prevent cycle imports from catalyst.contrib.registry import Registry layer_fn = Registry.name2nn(layer_fn) activation_fn = Registry.name2nn(activation_fn) norm_fn = Registry.name2nn(norm_fn) self.parity = parity if self.parity == "odd": self.copy_size = action_size // 2 else: self.copy_size = action_size - action_size // 2 self.scale_prenet = SequentialNet( hiddens=[action_size * 2 + self.copy_size, action_size], layer_fn=layer_fn, activation_fn=activation_fn, norm_fn=None, bias=bias) self.scale_net = SequentialNet( hiddens=[action_size, action_size - self.copy_size], layer_fn=layer_fn, activation_fn=None, norm_fn=None, bias=True) self.translation_prenet = SequentialNet( hiddens=[action_size * 2 + self.copy_size, action_size], layer_fn=layer_fn, activation_fn=activation_fn, norm_fn=None, bias=bias) self.translation_net = SequentialNet( hiddens=[action_size, action_size - self.copy_size], layer_fn=layer_fn, activation_fn=None, norm_fn=None, bias=True) inner_init = create_optimal_inner_init(nonlinearity=activation_fn) self.scale_prenet.apply(inner_init) self.scale_net.apply(outer_init) self.translation_prenet.apply(inner_init) self.translation_net.apply(outer_init)
def create_from_params(cls, state_shape, action_size, observation_hiddens=None, head_hiddens=None, layer_fn=nn.Linear, activation_fn=nn.ReLU, dropout=None, norm_fn=None, bias=True, layer_order=None, residual=False, out_activation=None, observation_aggregation=None, lama_poolings=None, policy_type=None, squashing_fn=nn.Tanh, **kwargs): assert len(kwargs) == 0 # hack to prevent cycle imports from catalyst.contrib.registry import Registry observation_hiddens = observation_hiddens or [] head_hiddens = head_hiddens or [] layer_fn = Registry.name2nn(layer_fn) activation_fn = Registry.name2nn(activation_fn) norm_fn = Registry.name2nn(norm_fn) out_activation = Registry.name2nn(out_activation) inner_init = create_optimal_inner_init(nonlinearity=activation_fn) if isinstance(state_shape, int): state_shape = (state_shape, ) if len(state_shape) in [1, 2]: # linear case: one observation or several one # state_shape like [history_len, obs_shape] # @TODO: handle lama/rnn correctly if not observation_aggregation: observation_size = reduce(lambda x, y: x * y, state_shape) else: observation_size = reduce(lambda x, y: x * y, state_shape[1:]) if len(observation_hiddens) > 0: observation_net = SequentialNet(hiddens=[observation_size] + observation_hiddens, layer_fn=layer_fn, dropout=dropout, activation_fn=activation_fn, norm_fn=norm_fn, bias=bias, layer_order=layer_order, residual=residual) observation_net.apply(inner_init) obs_out = observation_hiddens[-1] else: observation_net = None obs_out = observation_size elif len(state_shape) in [3, 4]: # cnn case: one image or several one @TODO raise NotImplementedError else: raise NotImplementedError assert obs_out if observation_aggregation == "lama_obs": aggregation_net = LamaPooling(features_in=obs_out, poolings=lama_poolings) aggregation_out = aggregation_net.features_out else: aggregation_net = None aggregation_out = obs_out main_net = SequentialNet(hiddens=[aggregation_out] + head_hiddens, layer_fn=layer_fn, dropout=dropout, activation_fn=activation_fn, norm_fn=norm_fn, bias=bias, layer_order=layer_order, residual=residual) main_net.apply(inner_init) # @TODO: place for memory network if policy_type == "gauss": head_size = action_size * 2 policy_net = GaussPolicy(squashing_fn) elif policy_type == "real_nvp": head_size = action_size * 2 policy_net = RealNVPPolicy(action_size=action_size, layer_fn=layer_fn, activation_fn=activation_fn, squashing_fn=squashing_fn, norm_fn=None, bias=bias) else: head_size = action_size policy_net = None head_net = SequentialNet(hiddens=[head_hiddens[-1], head_size], layer_fn=nn.Linear, activation_fn=out_activation, norm_fn=None, bias=True) head_net.apply(outer_init) actor_net = cls(observation_net=observation_net, aggregation_net=aggregation_net, main_net=main_net, head_net=head_net, policy_net=policy_net) return actor_net
# flake8: noqa from catalyst.contrib.registry import Registry from catalyst.dl.experiments.runner import SupervisedRunner as Runner from .experiment import Experiment from .model import SimpleNet Registry.model(SimpleNet)
def get_criterion(self, stage: str) -> _Criterion: criterion_params = (self.stages_config[stage].get( "criterion_params", {})) criterion = Registry.get_criterion(**criterion_params) return criterion
def get_scheduler(self, stage: str, optimizer) -> _Scheduler: scheduler_params = (self.stages_config[stage].get( "scheduler_params", {})) scheduler = Registry.get_scheduler(optimizer, **scheduler_params) return scheduler