def init_model_fn(self, rng: spec.RandomState) -> spec.ModelInitState: init_val = jnp.ones((1, 28, 28, 1), jnp.float32) initial_params = self._model.init({'params': rng}, init_val, train=True)['params'] self._param_shapes = jax.tree_map(lambda x: spec.ShapeTuple(x.shape), initial_params) return jax_utils.replicate(initial_params), None
def init_model_fn(self, rng: spec.RandomState) -> spec.ModelInitState: model = models.CNNLSTM() self._param_shapes = { k: spec.ShapeTuple(v.shape) for k, v in model.named_parameters() } if torch.cuda.device_count() > 1: model = torch.nn.DataParallel(model) model.to(DEVICE) return model, None
def init_model_fn(self, rng: spec.RandomState) -> spec.ModelInitState: model_cls = getattr(models, 'ResNet50') model = model_cls(num_classes=1000, dtype=jnp.float32) self._model = model params, model_state = self.initialized(rng, model) self._param_shapes = jax.tree_map(lambda x: spec.ShapeTuple(x.shape), params) model_state = jax_utils.replicate(model_state) params = jax_utils.replicate(params) return params, model_state
def init_model_fn(self, rng: spec.RandomState) -> spec.ModelInitState: torch.random.manual_seed(rng[0]) model = Transformer() self._param_shapes = { k: spec.ShapeTuple(v.shape) for k, v in model.named_parameters() } if torch.cuda.device_count() > 1: model = torch.nn.DataParallel(model) model.to(DEVICE) return model, None
def init_model_fn(self, rng: spec.RandomState) -> spec.ModelInitState: model_cls = getattr(models, 'ResNet18') model = model_cls(num_classes=10, dtype=jnp.float32) self._model = model input_shape = (1, 32, 32, 3) variables = jax.jit(model.init)({ 'params': rng }, jnp.ones(input_shape, model.dtype)) model_state, params = variables.pop('params') self._param_shapes = jax.tree_map(lambda x: spec.ShapeTuple(x.shape), params) model_state = jax_utils.replicate(model_state) params = jax_utils.replicate(params) return params, model_state
def init_model_fn(self, rng: spec.RandomState) -> spec.ModelInitState: if self._init_graphs is None: raise ValueError( 'This should not happen, workload.build_input_queue() should be ' 'called before workload.init_model_fn()!') rng, params_rng, dropout_rng = jax.random.split(rng, 3) init_fn = jax.jit(functools.partial(self._model.init, train=False)) params = init_fn({ 'params': params_rng, 'dropout': dropout_rng }, self._init_graphs) params = params['params'] self._param_shapes = jax.tree_map(lambda x: spec.ShapeTuple(x.shape), params) return jax_utils.replicate(params), None
def init_model_fn(self, rng: spec.RandomState) -> spec.ModelInitState: rng, init_rng = jax.random.split(rng) init_fake_batch_size = 2 input_shape = (init_fake_batch_size, 256) target_shape = (init_fake_batch_size, 256) initial_variables = jax.jit( models.Transformer(self._eval_config).init)( init_rng, jnp.ones(input_shape, jnp.float32), jnp.ones(target_shape, jnp.float32)) initial_params = initial_variables['params'] self._param_shapes = jax.tree_map(lambda x: spec.ShapeTuple(x.shape), initial_params) return jax_utils.replicate(initial_params), None
def init_model_fn(self, rng: spec.RandomState) -> spec.ModelInitState: rng, init_rng = jax.random.split(rng) init_fake_batch_size = 2 input_size = _NUM_DENSE_FEATURES + len(_VOCAB_SIZES) input_shape = (init_fake_batch_size, input_size) target_shape = (init_fake_batch_size, input_size) initial_variables = jax.jit(self._flax_module.init)( init_rng, jnp.ones(input_shape, jnp.float32), jnp.ones(target_shape, jnp.float32)) initial_params = initial_variables['params'] self._param_shapes = jax.tree_map(lambda x: spec.ShapeTuple(x.shape), initial_params) return jax_utils.replicate(initial_params), None
def init_model_fn(self, rng: spec.RandomState) -> spec.ModelInitState: rng, params_rng, dropout_rng = jax.random.split(rng, 3) init_fn = jax.jit(functools.partial(self._model.init, train=False)) fake_batch = jraph.GraphsTuple(n_node=jnp.asarray([1]), n_edge=jnp.asarray([1]), nodes=jnp.ones((1, 3)), edges=jnp.ones((1, 7)), globals=jnp.zeros( (1, self._num_outputs)), senders=jnp.asarray([0]), receivers=jnp.asarray([0])) params = init_fn({ 'params': params_rng, 'dropout': dropout_rng }, fake_batch) params = params['params'] self._param_shapes = jax.tree_map(lambda x: spec.ShapeTuple(x.shape), params) return jax_utils.replicate(params), None
def param_shapes(self): init_params, _ = self.init_model_fn(jax.random.PRNGKey(0)) return jax.tree_map(lambda x: spec.ShapeTuple(x.shape), init_params)