Exemplo n.º 1
0
def per_host_sum_pmap(in_tree):
    """Execute psum on in_tree"s leaves over one device per host."""
    host2devices = collections.defaultdict(list)
    for d in jax.devices():
        host2devices[d.process_index].append(d)
    devices = [host2devices[k][0] for k in host2devices]
    host_psum = jax.pmap(lambda x: jax.lax.psum(x, "i"), "i", devices=devices)

    def pre_pmap(xs):
        return jax.tree_map(lambda x: jnp.broadcast_to(x, (1, ) + x.shape), xs)

    def post_pmap(xs):
        return jax.tree_map(lambda x: x[0], xs)

    return post_pmap(host_psum(pre_pmap(in_tree)))
Exemplo n.º 2
0
    def test_with_kwargs(self, fake_pmap, fake_jit):
        with fake.fake_pmap_and_jit(fake_pmap, fake_jit):
            num_devices = len(jax.devices())

            @functools.partial(jax.pmap, axis_size=num_devices)
            @jax.jit
            def foo(x, y):
                return (x * 2) + y

            # pmap over all available devices
            inputs = jnp.array([1, 2])
            inputs = jnp.broadcast_to(inputs, (num_devices, ) + inputs.shape)
            expected = jnp.broadcast_to(jnp.array([3, 6]), (num_devices, 2))

            asserts.assert_tree_all_close(foo(x=inputs, y=inputs), expected)
Exemplo n.º 3
0
 def testCollectivesWithTreesOfDifferentDtypes(self):
   n = len(jax.devices())
   x = {'a': onp.arange(1 * n * n, 2 * n * n, dtype=onp.float32).reshape([n, n]),
        'b': onp.arange(2 * n * n, 3 * n * n, dtype=onp.int32).reshape([n, n]),
        'c': onp.arange(4 * n * n, 5 * n * n, dtype=onp.float32).reshape([n, n]),
        'd': onp.arange(6 * n * n, 7 * n * n, dtype=onp.int32).reshape([n, n])}
   tree_f = lambda f: partial(tree_util.tree_map, f)
   jax_f = lambda p: pmap(lambda x: p(x, 'i'), 'i')
   onp_f = lambda p: tree_f(lambda x: onp.broadcast_to(p(x, 0), x.shape))
   assert_allclose = partial(tree_util.tree_multimap,
                             partial(self.assertAllClose, check_dtypes=False))
   assert_allclose(jax_f(lax.pmax)(x), onp_f(onp.max)(x))
   assert_allclose(jax_f(lax.pmin)(x), onp_f(onp.min)(x))
   assert_allclose(jax_f(lax.psum)(x), onp_f(onp.sum)(x))
   assert_allclose(jax_f(lax.pmean)(x), onp_f(onp.mean)(x))
Exemplo n.º 4
0
def multi_device_put(x, devices=None, reuse=True):
  """Memory efficient multi-device replication in JAX.

  Args:
    x: jax DeviceArray or numpy ndarray to be replicated.
    devices: a jax.devices() list or subset thereof of devices to
      replicate onto.  Should match the list passed to any pmaps
      ingesting the replicated array.
    reuse: bool. If x is a DeviceArray whether to reuse its backing
      device_buffer in the resulting ShardedDeviceArray.

  Returns:
    A ShardedDeviceArray with dtype = x.dtype and shape =
    (n_devices,) + x.shape that's backed by replica
    device_buffers on each device.
  """
  # Convert _FilledConstants that don't have device_buffer, etc.
  if type(x) != jax.xla.DeviceArray:  # pylint: disable=unidiomatic-typecheck
    x = np.array(x)
  if not devices:
    devices = jax.devices()
  n_devices = len(devices)
  x_aval = jax.xla.abstractify(x)
  broadcast_x_aval = jax.abstract_arrays.ShapedArray(
      (n_devices,) + x_aval.shape,
      x_aval.dtype)
  if reuse:
    other_device_ordinals = [dv.id for dv in jax.devices()
                             if dv != x.device_buffer.device()]
    broadcast_buffers = ([x.device_buffer,] +
                         [jax.xla.xc.Buffer.from_pyval(x, device=i)
                          for i in other_device_ordinals])
  else:
    broadcast_buffers = [jax.xla.xc.Buffer.from_pyval(x, device=i)
                         for i in range(n_devices)]
  return jax.pxla.ShardedDeviceArray(broadcast_x_aval, broadcast_buffers)
Exemplo n.º 5
0
    def __post_init__(self):
        super().__post_init__()

        n_chains_per_device = self.n_chains_per_rank // len(jax.devices())

        _sampler = MetropolisSampler(
            self.hilbert,
            n_chains_per_rank=n_chains_per_device,
            rule=self.rule,
            n_sweeps=self.n_sweeps,
            reset_chains=self.reset_chains,
            machine_pow=self.machine_pow,
        )

        object.__setattr__(self, "_sampler_device", _sampler)
Exemplo n.º 6
0
def _to_dev(x, dev):
    if dev is not None:
        if 'cpu' in dev or 'gpu' in dev:
            dev_split = dev.split(':')
            dev_str = dev_split[0]
            if len(dev_split) > 1:
                idx = int(dev_split[1])
            else:
                idx = 0
            _jax.device_put(x, _jax.devices(dev_str)[idx])
        else:
            raise Exception(
                'Invalid device specified, must be in the form [ "cpu:idx" | "gpu:idx" ]'
            )
    return x
Exemplo n.º 7
0
    def test_unordered_print_with_xmap(self):
        def f(x):
            debug_print("{}", x, ordered=False)

        f = maps.xmap(f,
                      in_axes=['a'],
                      out_axes=None,
                      backend='cpu',
                      axis_resources={'a': 'dev'})
        with maps.Mesh(np.array(jax.devices(backend='cpu')), ['dev']):
            with capture_stdout() as output:
                f(jnp.arange(40))
                jax.effects_barrier()
            lines = [f"{i}\n" for i in range(40)]
            self._assertLinesEqual(output(), "".join(lines))
Exemplo n.º 8
0
    def __pre_init__(self,
                     *args,
                     n_chains=None,
                     n_chains_per_device=None,
                     **kwargs):
        """
        Constructs a Metropolis Sampler.

        Args:
            hilbert: The hilbert space to sample
            rule: A `MetropolisRule` to generate random transitions from a given state as
                    well as uniform random states.
            n_sweeps: The number of exchanges that compose a single sweep.
                    If None, sweep_size is equal to the number of degrees of freedom being sampled
                    (the size of the input vector s to the machine).
            reset_chains: If False the state configuration is not resetted when reset() is called.
            n_chains: The total number of Markov Chain to be run in parallel on a the available devices.
                This will be rounded to the nearest multiple of `len(jax.devices())`
            n_chains_per_device: The number of chains to be run in parallel on one device.
                Cannot be specified if n_chains is also specified.
            machine_pow: The power to which the machine should be exponentiated to generate the pdf (default = 2).
            dtype: The dtype of the statees sampled (default = np.float32).
        """

        if n_chains is not None and n_chains_per_device is not None:
            raise ValueError(
                "Cannot specify both n_chains and n_chains_per_device")
        elif n_chains is None and n_chains_per_device is None:
            n_chains = 16

        n_devices = len(jax.devices())

        # If chains is specified, round it
        if n_chains is not None:
            n_chains_per_device = int(max(np.ceil(n_chains / n_devices), 1))

        if n_chains != n_chains_per_device * n_devices:
            import warnings

            warnings.warn(
                f"Using {n_chains_per_device*n_devices} chains "
                f"({n_chains_per_device} chains on each of {n_devices} devices).",
                category=UserWarning,
            )

        kwargs["n_chains"] = n_chains_per_device * n_devices
        print("he")
        return super().__pre_init__(*args, **kwargs)
Exemplo n.º 9
0
def local_replica_groups(inner_group_size: int) -> List[List[int]]:
  """Constructs local nearest-neighbor rings given the JAX device assignment.

  For inner_group_size=8, each inner group is a tray with replica order:

  0/1 2/3
  7/6 5/4

  Args:
    inner_group_size: Number of replica in each group.

  Returns:
    A list of replica id groups.
  """
  world_size = jax.device_count()
  outer_group_size, ragged = divmod(world_size, inner_group_size)
  assert not ragged, 'inner group size must evenly divide global device count'
  # the last device should have maximal x and y coordinate
  def bounds_from_last_device(device):
    x, y, z = device.coords
    return (x + 1) * (device.core_on_chip + 1), (y + 1) * (z + 1)
  global_x, _ = bounds_from_last_device(jax.devices()[-1])
  per_host_x, per_host_y = bounds_from_last_device(jax.local_devices(0)[-1])
  assert inner_group_size in [2 ** i for i in range(1, 15)], \
      'inner group size must be a power of two'
  if inner_group_size <= 4:
    # inner group is Nx1 (core, chip, 2x1)
    inner_x, inner_y = inner_group_size, 1
    inner_perm = range(inner_group_size)
  else:
    if inner_group_size <= global_x * 2:
      # inner group is Nx2 (2x2 tray, 4x2 DF pod host, row of hosts)
      inner_x, inner_y = inner_group_size // 2, 2
    else:
      # inner group covers the full x dimension and must be >2 in y
      inner_x, inner_y = global_x, inner_group_size // global_x
    p = np.arange(inner_group_size)
    per_group_hosts_x = 1 if inner_x < per_host_x else inner_x // per_host_x
    p = p.reshape(inner_y // per_host_y, per_group_hosts_x,
                  per_host_y, inner_x // per_group_hosts_x)
    p = p.transpose(0, 2, 1, 3)
    p = p.reshape(inner_y // 2, 2, inner_x)
    p[:, 1, :] = p[:, 1, ::-1]
    inner_perm = p.reshape(-1)

  inner_replica_groups = [[o * inner_group_size + i for i in inner_perm]
                          for o in range(outer_group_size)]
  return inner_replica_groups
Exemplo n.º 10
0
def create_hybrid_device_mesh(mesh_shape: Sequence[int],
                              dcn_mesh_shape: Sequence[int],
                              devices: Optional[Sequence[Any]] = None,
                              *,
                              process_is_granule: bool = False) -> np.ndarray:
    """Creates a device mesh for hybrid (e.g., ICI and DCN) parallelism.

  Args:
    mesh_shape: shape of the logical mesh for the faster/inner network, ordered
      by increasing network intensity, e.g. [replica, data, mdl] where mdl has
      the most network communication requirements.
    dcn_mesh_shape: shape of the logical mesh for the slower/outer network,
      in the same order as mesh_shape.
    devices: optionally, the devices to construct a mesh for. Defaults to
      jax.devices().
    process_is_granule: if True, this function will treat processes as the units
      of the slower/outer network. Otherwise it will look for slice_index
      attributes on devices and use slices as the units. Enabling this is meant
      as a fallback for platforms (e.g., GPU) that don't set slice_index.

  Returns:
    A np.ndarray of JAX devices with mesh_shape * dcn_mesh_shape as its shape
    that can be fed into jax.experimental.maps.Mesh for hybrid parallelism.
  """
    if devices is None:
        devices = jax.devices()
    attr = 'process_index' if process_is_granule else 'slice_index'
    assert hasattr(devices[0], attr)
    granule_id, granules = 0, []
    while True:
        granule = [dev for dev in devices if getattr(dev, attr) == granule_id]
        if granule:
            granules.append(granule)
            granule_id += 1
        else:
            break
    if np.prod(dcn_mesh_shape) != len(granules):
        raise ValueError(
            'Number of slices must equal the product of dcn_mesh_shape')
    per_granule_meshes = [
        create_device_mesh(mesh_shape, granule) for granule in granules
    ]
    # TODO(jekbradbury): handle non-uniform DCN topologies
    granule_mesh = np.arange(len(granules)).reshape(dcn_mesh_shape)
    blocks = np.vectorize(lambda i: per_granule_meshes[i],
                          otypes=[object])(granule_mesh)
    device_mesh = np.block(blocks.tolist())
    return device_mesh
Exemplo n.º 11
0
def _flatten_HeteroGraphIndex(gidx):
    adj, _ = gidx.adjacency_matrix(
        etype=0,
        transpose=False,
        ctx=jax.devices('cpu')[0],
    )

    srctype, dsttype = gidx.metagraph.find_edge(0)
    num_src = gidx.number_of_nodes(srctype)
    num_dst = gidx.number_of_nodes(dsttype)

    idx = adj.index

    u = idx[0, :]
    v = idx[1, :]
    return ((u, v, num_src, num_dst), None)
Exemplo n.º 12
0
    def test_pjit_inherits_effects(self):
        if jax.default_backend() not in {'gpu', 'tpu'}:
            raise unittest.SkipTest("pjit only supports GPU and TPU backends")

        def f(x):
            effect_p.bind(effect='foo')
            effect_p.bind(effect='bar')
            return x

        f = pjit.pjit(f,
                      in_axis_resources=pjit.PartitionSpec('x'),
                      out_axis_resources=pjit.PartitionSpec('x'))
        with self.assertRaisesRegex(NotImplementedError,
                                    'Effects not supported'):
            with maps.Mesh(np.array(jax.devices()), ['x']):
                jax.make_jaxpr(f)(jnp.arange(jax.local_device_count()))
Exemplo n.º 13
0
  def testJaxRoundTrip(self, shape, dtype, take_ownership, gpu):
    rng = jtu.rand_default(self.rng())
    np = rng(shape, dtype)
    if gpu and jax.default_backend() == "cpu":
      raise unittest.SkipTest("Skipping GPU test case on CPU")
    device = jax.devices("gpu" if gpu else "cpu")[0]
    x = jax.device_put(np, device)
    dlpack = jax.dlpack.to_dlpack(x, take_ownership=take_ownership)
    self.assertEqual(take_ownership, x.device_buffer.is_deleted())
    y = jax.dlpack.from_dlpack(dlpack)
    self.assertEqual(y.device(), device)
    self.assertAllClose(np.astype(x.dtype), y)

    self.assertRaisesRegex(RuntimeError,
                           "DLPack tensor may be consumed at most once",
                           lambda: jax.dlpack.from_dlpack(dlpack))
Exemplo n.º 14
0
  def test_fake_pmap(self, context, patch, expected_type):
    # We test whether the function has been pmapped by inspecting the type of
    # the function output, if it is a sharded array type then the function has
    # been pmapped
    with context(patch):
      num_devices = len(jax.devices())

      @functools.partial(jax.pmap, axis_size=num_devices)
      def foo(x):
        return x * 2

      # pmap over all available devices
      x = jnp.array([1, 2])
      x = jnp.broadcast_to(x, (num_devices,) + x.shape)
      output = foo(x)
      self.assertEqual(type(output), expected_type)
def add_rng_to_examples(
    example_iter,
    base_rng):
  """Add an RNG to each example.

  Args:
    example_iter: Iterator over examples.
    base_rng: RNG to seed with.

  Yields:
    Examples that are tuples (orig_example, rng)
  """
  base_rng = jax.device_put(base_rng, jax.devices("cpu")[0])
  for i, item in enumerate(example_iter):
    rng = jax.random.fold_in(base_rng, i)
    yield dataclasses.replace(item, example=(item.example, rng))
Exemplo n.º 16
0
def _jax_gsddmm(gidx, op, X, Y, lhs_target, rhs_target):
    # out = _gsddmm(gidx, op, X, Y, lhs_target, rhs_target)
    # return out
    if gidx.number_of_etypes() != 1:
        from .base import DGLError
        raise DGLError("We only support gsddmm on graph with one edge type")
    use_lhs = op != 'copy_rhs'
    use_rhs = op != 'copy_lhs'
    expand_lhs, expand_rhs = False, False
    # deal with scalar features.
    if use_lhs:
        if X.ndim == 1:
            X = jnp.expand_dims(X, -1)
            expand_lhs = True
    if use_rhs:
        if Y.ndim == 1:
            Y = jnp.expand_dims(Y, -1)
            expand_rhs = True

    a, _ = gidx.adjacency_matrix(0, False, jax.devices('cpu')[0])
    dst_idxs, src_idxs = a.index
    edge_idxs = jnp.arange(dst_idxs.shape[0])
    idxs_mapping = {
        'u': src_idxs,
        'e': edge_idxs,
        'v': dst_idxs,
        'src': src_idxs,
        'edge': edge_idxs,
        'dst': dst_idxs,
    }

    if X is not None:
        _X = jnp.take(X, idxs_mapping[lhs_target], axis=0)
    else:
        _X = None

    if Y is not None:
        _Y = jnp.take(Y, idxs_mapping[rhs_target], axis=0)
    else:
        _Y = None

    Z = OPS[op](_X, _Y)

    if (expand_lhs or not use_lhs) and (expand_rhs or not use_rhs):
        Z = jnp.expand_dims(Z, -1)

    return Z
Exemplo n.º 17
0
    def test_pmap_basic(self):
        if len(jax.devices()) < 2:
            raise unittest.SkipTest("requires at least 2 devices")

        @jax.pmap
        def f(x1, x2):
            y1 = jnp.sin(x1)
            y2 = jnp.sin(x2)
            return y1 + y2

        xs = jnp.array([0., 2.])
        err, _ = checkify.checkify(f)(xs, xs)
        self.assertIs(err.get(), None)

        ys = jnp.array([3., jnp.inf])
        err, _ = checkify.checkify(f)(xs, ys)
        self.assertStartsWith(err.get(), 'nan generated by primitive sin')
Exemplo n.º 18
0
def main(argv):
    if len(argv) > 1:
        raise app.UsageError('Too many command-line arguments.')

    logging.info('JAX devices: %s', jax.devices())

    grids = np.arange(-256, 257) * 0.08
    external_potential = utils.get_atomic_chain_potential(
        grids=grids,
        locations=np.array([-0.8, 0.8]),
        nuclear_charges=np.array([1., 1.]),
        interaction_fn=utils.exponential_coulomb)

    density, total_eigen_energies, _ = scf.solve_noninteracting_system(
        external_potential, num_electrons=FLAGS.num_electrons, grids=grids)
    logging.info('density: %s', density)
    logging.info('total energy: %f', total_eigen_energies)
Exemplo n.º 19
0
  def wrapper(*args: pytypes.ArrayTree, **kwargs: pytypes.ArrayTree):
    if kwargs and (in_axes != 0 or static_argnums):
      raise ValueError("Do not use kwargs with `in_axes` or `static_argnums` "
                       "in pmapped function.")
    devices_ = list(devices or jax.devices(backend))
    n_devices_ = n_devices or len(devices_)
    devices_ = devices_[:n_devices_]
    if len(devices_) != n_devices_:
      raise ValueError("Number of available devices is less than required for "
                       f"test ({len(devices_)} < {n_devices_})")

    bcast_fn = lambda x: jnp.broadcast_to(x, (n_devices_,) + jnp.array(x).shape)
    if broadcast_args_to_devices:
      args = [
          tree_map(bcast_fn, arg) if idx not in static_argnums else arg
          for idx, arg in enumerate(args)
      ]
      kwargs = tree_map(bcast_fn, kwargs)
    else:
      # Pmappable axes size must be equal to number of devices.
      in_axes_ = in_axes if isinstance(in_axes,
                                       (tuple, list)) else [in_axes] * len(args)
      is_pmappable_arg = [
          idx not in static_argnums and in_axes_[idx] is not None
          for idx in range(len(args))
      ]
      for is_pmappable_arg, arg in zip(is_pmappable_arg, args):
        if not is_pmappable_arg:
          continue
        if not all(x.shape[0] == n_devices_ for x in jax.tree_leaves(arg)):
          shapes = tree_map(jnp.shape, arg)
          raise ValueError(
              f"Pmappable arg axes size must be equal to number of devices, "
              f"got: {shapes} (expected the first dim to be {n_devices_}). "
              "Consider setting `broadcast_args_to_devices=True`.")

    res = jax.pmap(
        fn,
        axis_name=axis_name,
        devices=devices_,
        in_axes=in_axes,
        static_broadcasted_argnums=static_argnums,
        backend=backend)(*args, **kwargs)

    return reduce_fn(res)
Exemplo n.º 20
0
def localfoe(signal, frame_size=16384, frame_step=5000, sps=1, fitkind=None, degree=2,
             method=lambda x: foe_mpowfftmax(x)[0]):
    '''
    resolution = samplerate / N / 4 / sps (linear interp.)
    '''
    cpus = jax.devices('cpu')

    # [BUG]: polyfit is buggy in GPU
    y = device_put(signal, cpus[0])
    dims = y.shape[-1]
    fo_local = xop.framescaninterp(y, method, frame_size, frame_step, sps)
    if fitkind is not None:
        if fitkind.lower() == 'poly':
            fo_T = jnp.tile(jnp.arange(fo_local.shape[0])[:, None], (1, dims))
            fo_local = exp.polyfitval(fo_T, fo_local, degree)
        else:
            raise ValueError('invlaid fitting method')
    return fo_local
Exemplo n.º 21
0
    def load_pickle(inname, metric=None, device=jax.devices()[0]):
        loaded_som = pickle.load(open(inname, 'rb'))
        if metric is None:
            print(
                'WARNING : loading a JAX SOM object without its metric results in the cdist metric being used '
                'because jitted functions can not be pickled. If you want to use another metric, please '
                'specify it in the SOM.load_pickle function')
            metric = jax_cdist
        loaded_som.device = device
        loaded_som.metric = metric

        # We need to manually choose which arrays are to be put, otherwise every numpy item ends up a jax object.
        loaded_som.centroids = jax.device_put(loaded_som.centroids,
                                              device=device)
        loaded_som.locations = jax.device_put(loaded_som.locations,
                                              device=device)
        loaded_som.distance_mat = jax.device_put(loaded_som.distance_mat,
                                                 device=device)
        return loaded_som
Exemplo n.º 22
0
def main(argv):
    del argv

    # Hide any GPUs from TensorFlow. Otherwise TF might reserve memory and make
    # it unavailable to JAX.
    tf.config.experimental.set_visible_devices([], "GPU")

    logging.info("JAX host: %d / %d", jax.host_id(), jax.host_count())
    logging.info("JAX devices: %r", jax.devices())

    # Add a note so that we can tell which task is which JAX host.
    # (Borg task 0 is not guaranteed to be host 0)
    platform.work_unit().set_task_status(
        f"host_id: {jax.host_id()}, host_count: {jax.host_count()}")
    platform.work_unit().create_artifact(platform.ArtifactType.DIRECTORY,
                                         FLAGS.workdir, "workdir")

    state = train_and_evaluate(FLAGS.config, FLAGS.workdir)
    del state
Exemplo n.º 23
0
    def learner(
        self,
        random_key: networks_lib.PRNGKey,
        replay: reverb.Client,
        counter: counting.Counter,
    ):
        """The Learning part of the agent."""

        iterator = self._builder.make_dataset_iterator(replay)

        dummy_seed = 1
        environment_spec = (self._environment_spec
                            or specs.make_environment_spec(
                                self._environment_factory(dummy_seed)))

        # Creates the networks to optimize (online) and target networks.
        networks = self._network_factory(environment_spec)

        if self._prefetch_size > 1:
            # When working with single GPU we should prefetch to device for
            # efficiency. If running on TPU this isn't necessary as the computation
            # and input placement can be done automatically. For multi-gpu currently
            # the best solution is to pre-fetch to host although this may change in
            # the future.
            device = jax.devices()[0] if self._device_prefetch else None
            iterator = utils.prefetch(iterator,
                                      buffer_size=self._prefetch_size,
                                      device=device)
        else:
            logging.info('Not prefetching the iterator.')

        counter = counting.Counter(counter, 'learner')
        learner = self._builder.make_learner(random_key, networks, iterator,
                                             replay, counter)

        return savers.CheckpointingRunner(
            learner,
            key='learner',
            subdirectory='learner',
            time_delta_minutes=5,
            directory=self._checkpointing_config.directory,
            add_uid=self._checkpointing_config.add_uid,
            max_to_keep=self._checkpointing_config.max_to_keep)
Exemplo n.º 24
0
def main():
    args = parser.parse_args()
    print('JAX host: %d / %d' % (jax.host_id(), jax.host_count()))
    print('JAX devices:\n%s' % '\n'.join(str(d) for d in jax.devices()), flush=True)

    def _try_validate(args):
        res = None
        batch_size = args.batch_size
        while res is None:
            try:
                print(f'Setting validation batch size to {batch_size}')
                args.batch_size = batch_size
                res = validate(args)
            except RuntimeError as e:
                if batch_size <= 1:
                    print("Validation failed with no ability to reduce batch size. Exiting.")
                    raise e
                batch_size = max(batch_size // 2, 1)
                print("Validation failed, reducing batch size by 50%")
        return res

    if get_model_cfg(args.model) is not None:
        _try_validate(args)
    else:
        models = list_models(pretrained=True)
        if args.model != 'all':
            models = fnmatch.filter(models, args.model)
        if not models:
            print(f'ERROR: No models found to validate with pattern {args.model}.')
            exit(1)

        print('Validating:', ', '.join(models))
        results = []
        start_batch_size = args.batch_size
        for m in models:
            args.batch_size = start_batch_size  # reset in case reduced for retry
            args.model = m
            res = _try_validate(args)
            res.update(dict(model=m))
            results.append(res)
        print('Results:')
        for r in results:
            print(f"Model: {r['model']}, Top1: {r['top1']}, Top5: {r['top5']}")
def train(model, g, feats, y_true, train_idx, optimizer):
    g = g.to(jax.devices()[0])

    @jax.jit
    def loss_fn(param, y_true=y_true):
        out = model.apply(param, g, feats)[train_idx]
        y_true = y_true[train_idx].flatten()

        y_true = jax.nn.one_hot(y_true, 40)
        loss = jnp.mean(-out * y_true)
        return loss

    # grad = jax.jacfwd(loss_fn)(optimizer.target)
    # loss = loss_fn(optimizer.target)

    loss, grad = jax.value_and_grad(loss_fn)(optimizer.target)

    optimizer = optimizer.apply_gradient(grad)
    return optimizer, loss
Exemplo n.º 26
0
  def _pjit(inp):
    if isinstance(inp, GlobalDeviceArray):
      if inp.is_fully_replicated:
        return inp.local_data(0).to_py()
      global_mesh = inp.mesh
      in_axis_resources = FROM_GDA
    else:
      # DA/SDA/np.array will be sharded based on global_mesh.local_mesh.
      # Shape of local_mesh will always be (1, local_device_count())
      devices = np.array(jax.devices()).reshape(jax.process_count(),
                                                jax.local_device_count())
      global_mesh = maps.Mesh(devices, ('processes', 'local_devices'))
      in_axis_resources = P('processes')
      if inp.ndim == 0 or not tiled:
        inp = np.expand_dims(inp, axis=0)

    with maps.Mesh(global_mesh.devices, global_mesh.axis_names):
      out = pjit(lambda x: x, in_axis_resources=in_axis_resources,
                 out_axis_resources=None)(inp)
    return out.local_data(0).to_py()
Exemplo n.º 27
0
def init_fn(master_rng, data, init_fn, optimizer):
    out_rng, init_rng = jax.random.split(master_rng)

    # copy the same initial params to each accelerator
    init_rng = jnp.broadcast_to(init_rng, (jax.local_device_count(),) + init_rng.shape)
    params = jax.pmap(init_fn)(init_rng, data)

    cpu_device = jax.devices("cpu")[0]

    # place optimizer state on CPU
    cpu_params = jax.tree_map(lambda x: jax.device_put(x[0], device=cpu_device), params)
    opt_state = optimizer.init(cpu_params)

    return dict(
        step=np.array(0),
        rng=out_rng,
        opt_state=opt_state,
        grad_acc=jax.tree_map(jnp.zeros_like, params),
        grad_count=np.array(0),
        params=params)
Exemplo n.º 28
0
def generate_graph(idtype, grad=False):
    '''
    s, d, eid
    0, 1, 0
    1, 9, 1
    0, 2, 2
    2, 9, 3
    0, 3, 4
    3, 9, 5
    0, 4, 6
    4, 9, 7
    0, 5, 8
    5, 9, 9
    0, 6, 10
    6, 9, 11
    0, 7, 12
    7, 9, 13
    0, 8, 14
    8, 9, 15
    9, 0, 16
    '''
    u = F.tensor([0, 1, 0, 2, 0, 3, 0, 4, 0, 5, 0, 6, 0, 7, 0, 8, 9])
    v = F.tensor([1, 9, 2, 9, 3, 9, 4, 9, 5, 9, 6, 9, 7, 9, 8, 9, 0])
    g = dgl.graph((u, v), idtype=idtype)
    assert g.device == F.ctx()
    ncol = F.randn((10, D))
    ecol = F.randn((17, D))
    if grad:
        ncol = F.attach_grad(ncol)
        ecol = F.attach_grad(ecol)

    g.ndata['h'] = ncol
    g.edata['w'] = ecol
    g.set_n_initializer(dgl.init.zero_initializer)
    g.set_e_initializer(dgl.init.zero_initializer)

    if dgl.backend.backend_name == "jax":
        import jax
        g = g.to(jax.devices()[0], )

    return g
Exemplo n.º 29
0
    def test_with_partial(self, fake_pmap, fake_jit):
        with fake.fake_pmap_and_jit(fake_pmap, fake_jit):
            num_devices = len(jax.devices())

            # Testing a common use-case where non-parallel arguments are partially
            # applied before pmapping
            def foo(x, y, flag):
                return (x * 2) + y if flag else (x + y)

            foo = functools.partial(foo, flag=True)

            foo = jax.pmap(foo, axis_size=num_devices)
            foo = jax.jit(foo)

            # pmap over all available devices
            inputs = jnp.array([1, 2])
            inputs = jnp.broadcast_to(inputs, (num_devices, ) + inputs.shape)
            expected = jnp.broadcast_to(jnp.array([3, 6]), (num_devices, 2))

            asserts.assert_tree_all_close(foo(inputs, inputs), expected)
            asserts.assert_tree_all_close(foo(x=inputs, y=inputs), expected)
Exemplo n.º 30
0
  def test_pmap_and_jit(self, context, fake_pmap, fake_jit, expected_type,
                        expected_execution_count):
    python_execution_count = 0
    with context(fake_pmap, fake_jit):
      num_devices = len(jax.devices())
      @functools.partial(jax.pmap, axis_size=num_devices)
      @jax.jit
      def foo(x):
        nonlocal python_execution_count
        python_execution_count += 1
        return x * 2

      # pmap over all available devices
      inputs = jnp.array([1, 2])
      inputs = jnp.broadcast_to(inputs, (num_devices,) + inputs.shape)
      output = foo(inputs)
      self.assertEqual(type(output), expected_type)
      self.assertEqual(python_execution_count, 1)

      foo(inputs)
      self.assertEqual(python_execution_count, expected_execution_count)