Exemplo n.º 1
0
 def init_model_fn(self, rng: spec.RandomState) -> spec.ModelInitState:
     init_val = jnp.ones((1, 28, 28, 1), jnp.float32)
     initial_params = self._model.init({'params': rng},
                                       init_val,
                                       train=True)['params']
     self._param_shapes = jax.tree_map(lambda x: spec.ShapeTuple(x.shape),
                                       initial_params)
     return jax_utils.replicate(initial_params), None
Exemplo n.º 2
0
 def init_model_fn(self, rng: spec.RandomState) -> spec.ModelInitState:
   model = models.CNNLSTM()
   self._param_shapes = {
       k: spec.ShapeTuple(v.shape) for k, v in model.named_parameters()
   }
   if torch.cuda.device_count() > 1:
     model = torch.nn.DataParallel(model)
   model.to(DEVICE)
   return model, None
 def init_model_fn(self, rng: spec.RandomState) -> spec.ModelInitState:
     model_cls = getattr(models, 'ResNet50')
     model = model_cls(num_classes=1000, dtype=jnp.float32)
     self._model = model
     params, model_state = self.initialized(rng, model)
     self._param_shapes = jax.tree_map(lambda x: spec.ShapeTuple(x.shape),
                                       params)
     model_state = jax_utils.replicate(model_state)
     params = jax_utils.replicate(params)
     return params, model_state
Exemplo n.º 4
0
 def init_model_fn(self, rng: spec.RandomState) -> spec.ModelInitState:
   torch.random.manual_seed(rng[0])
   model = Transformer()
   self._param_shapes = {
       k: spec.ShapeTuple(v.shape) for k, v in model.named_parameters()
   }
   if torch.cuda.device_count() > 1:
     model = torch.nn.DataParallel(model)
   model.to(DEVICE)
   return model, None
Exemplo n.º 5
0
    def init_model_fn(self, rng: spec.RandomState) -> spec.ModelInitState:
        model_cls = getattr(models, 'ResNet18')
        model = model_cls(num_classes=10, dtype=jnp.float32)
        self._model = model
        input_shape = (1, 32, 32, 3)
        variables = jax.jit(model.init)({
            'params': rng
        }, jnp.ones(input_shape, model.dtype))
        model_state, params = variables.pop('params')

        self._param_shapes = jax.tree_map(lambda x: spec.ShapeTuple(x.shape),
                                          params)
        model_state = jax_utils.replicate(model_state)
        params = jax_utils.replicate(params)
        return params, model_state
 def init_model_fn(self, rng: spec.RandomState) -> spec.ModelInitState:
     if self._init_graphs is None:
         raise ValueError(
             'This should not happen, workload.build_input_queue() should be '
             'called before workload.init_model_fn()!')
     rng, params_rng, dropout_rng = jax.random.split(rng, 3)
     init_fn = jax.jit(functools.partial(self._model.init, train=False))
     params = init_fn({
         'params': params_rng,
         'dropout': dropout_rng
     }, self._init_graphs)
     params = params['params']
     self._param_shapes = jax.tree_map(lambda x: spec.ShapeTuple(x.shape),
                                       params)
     return jax_utils.replicate(params), None
Exemplo n.º 7
0
    def init_model_fn(self, rng: spec.RandomState) -> spec.ModelInitState:
        rng, init_rng = jax.random.split(rng)
        init_fake_batch_size = 2
        input_shape = (init_fake_batch_size, 256)
        target_shape = (init_fake_batch_size, 256)

        initial_variables = jax.jit(
            models.Transformer(self._eval_config).init)(
                init_rng, jnp.ones(input_shape, jnp.float32),
                jnp.ones(target_shape, jnp.float32))

        initial_params = initial_variables['params']
        self._param_shapes = jax.tree_map(lambda x: spec.ShapeTuple(x.shape),
                                          initial_params)
        return jax_utils.replicate(initial_params), None
Exemplo n.º 8
0
  def init_model_fn(self, rng: spec.RandomState) -> spec.ModelInitState:
    rng, init_rng = jax.random.split(rng)
    init_fake_batch_size = 2
    input_size = _NUM_DENSE_FEATURES + len(_VOCAB_SIZES)
    input_shape = (init_fake_batch_size, input_size)
    target_shape = (init_fake_batch_size, input_size)

    initial_variables = jax.jit(self._flax_module.init)(
        init_rng,
        jnp.ones(input_shape, jnp.float32),
        jnp.ones(target_shape, jnp.float32))

    initial_params = initial_variables['params']
    self._param_shapes = jax.tree_map(lambda x: spec.ShapeTuple(x.shape),
                                      initial_params)
    return jax_utils.replicate(initial_params), None
Exemplo n.º 9
0
 def init_model_fn(self, rng: spec.RandomState) -> spec.ModelInitState:
     rng, params_rng, dropout_rng = jax.random.split(rng, 3)
     init_fn = jax.jit(functools.partial(self._model.init, train=False))
     fake_batch = jraph.GraphsTuple(n_node=jnp.asarray([1]),
                                    n_edge=jnp.asarray([1]),
                                    nodes=jnp.ones((1, 3)),
                                    edges=jnp.ones((1, 7)),
                                    globals=jnp.zeros(
                                        (1, self._num_outputs)),
                                    senders=jnp.asarray([0]),
                                    receivers=jnp.asarray([0]))
     params = init_fn({
         'params': params_rng,
         'dropout': dropout_rng
     }, fake_batch)
     params = params['params']
     self._param_shapes = jax.tree_map(lambda x: spec.ShapeTuple(x.shape),
                                       params)
     return jax_utils.replicate(params), None
Exemplo n.º 10
0
 def param_shapes(self):
   init_params, _ = self.init_model_fn(jax.random.PRNGKey(0))
   return jax.tree_map(lambda x: spec.ShapeTuple(x.shape), init_params)