Shardy Guide for JAX Users

Open in Colab

Shardy is a new propagation system being introduced into the XLA stack, and below we want to introduce any JAX users to:

  1. What has changed in JAX
  2. Why Shardy?
  3. Future plans

This is meant for JAX users who use jax.jit for running training/inference models across more than 1 GPU or TPU (batch parallelism, megatron, ZeRO, etc). They would be using things like PartitionSpecs and NamedShardings.

1. What has changed in JAX?

State of JAX before: GSPMD

Prior to Shardy, JAX users who partitioned their models across models across multiple devices used GSPMD behind the scenes.

GSPMD is the propagation+partitioning system that lives in the middle of the XLA pipeline. It operates on HLO - the IR that comes after StableHLO (the program you get after running jax.jit.lower).

JAX doesn't run GSPMD directly, but encodes instructions into the StableHLO IR for GSPMD to read later on.

But before we go any further, let's introduce our working example.

Make sure you have installed jax>=0.4.35.

pip install jax==0.4.35

Imports and utilities

First, let's create our mesh.

mesh = Mesh(
    np.reshape(np.array(jax.devices()), (4, 2)),
    ('data', 'model'))

print(mesh.shape)
OrderedDict([('data', 4), ('model', 2)])

In/Out shardings

Let's look at what changed the most: how sharding attributes are encoded in the JAX program for the compiler to read.

Let's look at it through an example. It's going to be an MLP-like model consisting of no bias tensors, and 2 layers (two matmuls).

def predict(x, w1, w2):
  x = jnp.tanh(x)
  z1 = jnp.einsum('ij,jk->ik', x, w1)
  z2 = jnp.einsum('ij,jk->ik', z1, w2)
  return jnp.sin(z2)

What we will want to do here sharding wise is:

  1. data parallelism on x
  2. tensor parallelism on w1 and w2 through the megatron sharding strategy.

Now let's prepare the model for GSPMD sharding. Note that we will explicitly shard w1, but let GSPMD propagation shard w2.

def run_in_out_shardings():
  samples = jax.ShapeDtypeStruct((16, 128), jnp.float32, sharding=NamedSharding(mesh, PartitionSpec('data', None)))
  samples_sharding = NamedSharding(mesh, PartitionSpec('data', None))
  w1 = jax.ShapeDtypeStruct((128, 256), jnp.float32, sharding=NamedSharding(mesh, PartitionSpec(None, 'model')))
  w1_sharding = NamedSharding(mesh, PartitionSpec(None, 'model'))
  w2 = jax.ShapeDtypeStruct((256, 10), jnp.float32)
  w2_sharding = None

  print(jax.jit(predict, in_shardings=(samples_sharding, w1_sharding, w2_sharding)).lower(samples, w1, w2).as_text())

run_in_out_shardings()
module @jit_predict attributes {mhlo.num_partitions = 8 : i32, mhlo.num_replicas = 1 : i32} {
  func.func public @main(%arg0: tensor<16x128xf32> {mhlo.sharding = "{devices=[4,1,2]<=[8] last_tile_dim_replicate}"}, %arg1: tensor<128x256xf32> {mhlo.sharding = "{devices=[1,2,4]<=[4,2]T(1,0) last_tile_dim_replicate}"}, %arg2: tensor<256x10xf32>) -> (tensor<16x10xf32> {jax.result_info = ""}) {
    %0 = stablehlo.tanh %arg0 : tensor<16x128xf32>
    %1 = stablehlo.dot_general %0, %arg1, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<16x128xf32>, tensor<128x256xf32>) -> tensor<16x256xf32>
    %2 = stablehlo.dot_general %1, %arg2, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<16x256xf32>, tensor<256x10xf32>) -> tensor<16x10xf32>
    %3 = stablehlo.sine %2 : tensor<16x10xf32>
    return %3 : tensor<16x10xf32>
  }
}

GSPMD's sharding annotations look like the following:

JAX sharding GSPMD sharding
NamedSharding(mesh, PartitionSpec('data', None)) {devices=[4,1,2]<=[8] last_tile_dim_replicate}
NamedSharding(mesh, PartitionSpec(None, 'model')) {devices=[1,2,4]<=[4,2]T(1,0) last_tile_dim_replicate}
None nothing

None is no sharding as expected since GSPMD will populate this during sharding propagation.

Notice how all the axis names go away? While there is a 1:1 correspondence between NamedSharding and GSPMD sharding, as a reader, it can be difficult to read. It is only more difficult once you introduce various axis names.

Let's look at Shardy for comparison. To enable Shardy in JAX, simply enable the flag:

jax.config.update("jax_use_shardy_partitioner", True)
run_in_out_shardings()
jax.config.update("jax_use_shardy_partitioner", False)
module @jit_predict attributes {mhlo.num_partitions = 8 : i32, mhlo.num_replicas = 1 : i32} {
  sdy.mesh @mesh = <["data"=4, "model"=2]>
  func.func public @main(%arg0: tensor<16x128xf32> {sdy.sharding = #sdy.sharding<@mesh, [{"data"}, {}]>}, %arg1: tensor<128x256xf32> {sdy.sharding = #sdy.sharding<@mesh, [{}, {"model"}]>}, %arg2: tensor<256x10xf32>) -> (tensor<16x10xf32> {jax.result_info = ""}) {
    %0 = stablehlo.tanh %arg0 : tensor<16x128xf32>
    %1 = stablehlo.dot_general %0, %arg1, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<16x128xf32>, tensor<128x256xf32>) -> tensor<16x256xf32>
    %2 = stablehlo.dot_general %1, %arg2, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<16x256xf32>, tensor<256x10xf32>) -> tensor<16x10xf32>
    %3 = stablehlo.sine %2 : tensor<16x10xf32>
    return %3 : tensor<16x10xf32>
  }
}

Now we have

JAX sharding Shardy sharding
NamedSharding(mesh, PartitionSpec('data', None)) #sdy.sharding<@mesh, [{"data"}, {}]>
NamedSharding(mesh, PartitionSpec(None, 'model')) #sdy.sharding<@mesh, [{}, {"model"}]>
None nothing

Shardy's representation is a lot closer to what JAX NamedShardings are like. So when looking at a file dump of your program after propagation, it will be a lot easier to understand what is going on since the correspondence is a lot closer to JAX.

Note that instead of the total devices/axes living on the sharding, they live on a top level @mesh value.

jax.lax.with_sharding_constraint

GSPMD currently lowers it to a custom call:

def run_with_sharding_constraint():
  x = jax.ShapeDtypeStruct((32, 64), jnp.float32)

  def f(x):
    return jax.lax.with_sharding_constraint(x, NamedSharding(mesh, PartitionSpec('data', PartitionSpec.UNCONSTRAINED)))

  print(jax.jit(f).lower(x).as_text())

run_with_sharding_constraint()
module @jit_f attributes {mhlo.num_partitions = 8 : i32, mhlo.num_replicas = 1 : i32} {
  func.func public @main(%arg0: tensor<32x64xf32>) -> (tensor<32x64xf32> {jax.result_info = ""}) {
    %0 = stablehlo.custom_call @Sharding(%arg0) {backend_config = "unspecified_dims=[1]", mhlo.sharding = "{devices=[4,1,2]<=[8] last_tile_dim_replicate}"} : (tensor<32x64xf32>) -> tensor<32x64xf32>
    return %0 : tensor<32x64xf32>
  }
}

But under Shardy it's an explicit op:

jax.config.update("jax_use_shardy_partitioner", True)
run_with_sharding_constraint()
jax.config.update("jax_use_shardy_partitioner", False)
module @jit_f attributes {mhlo.num_partitions = 8 : i32, mhlo.num_replicas = 1 : i32} {
  sdy.mesh @mesh = <["data"=4, "model"=2]>
  func.func public @main(%arg0: tensor<32x64xf32>) -> (tensor<32x64xf32> {jax.result_info = ""}) {
    %0 = sdy.sharding_constraint %arg0 <@mesh, [{"data"}, {?}]> : tensor<32x64xf32>
    return %0 : tensor<32x64xf32>
  }
}

Note that UNCONSTRAINED under GSPMD has the custom call have an op attribute backend_config = "unspecified_dims=[1]". But under Shardy, it makes dim 1 be {?}. In Shardy, dimension shardings without a ? are closed, meaning that dimension can't be further sharded, but when it has a trailing ?, it can be further sharded. Refer to Sharding representation for more info on the sharding representation.

jax.experimental.shard_map

Under GSPMD this is a few different custom calls with various shard_map specific attributes on the GSPMD sharding. Let's look where the model axis is auto, meaning it's free to be used inside the body of the shard_map by sharding constraints.

def run_shard_map():
  x = jax.ShapeDtypeStruct((32, 64), jnp.float32)

  def body(x):
    return jax.lax.all_gather(x, 'data', tiled=True)

  shmaped_f = shard_map(
        body,
        mesh=mesh,
        in_specs=(jax.sharding.PartitionSpec('data',),),
        out_specs=jax.sharding.PartitionSpec(),
        check_rep=False)

  print(jax.jit(shmaped_f).lower(x).as_text())

print(run_shard_map())
module @jit_body attributes {mhlo.num_partitions = 8 : i32, mhlo.num_replicas = 1 : i32} {
  func.func public @main(%arg0: tensor<32x64xf32>) -> (tensor<32x64xf32> {jax.result_info = ""}) {
    %0 = stablehlo.custom_call @Sharding(%arg0) {backend_config = "", mhlo.sharding = "{devices=[4,1,2]<=[8] last_tile_dim_replicate}"} : (tensor<32x64xf32>) -> tensor<32x64xf32>
    %1 = stablehlo.custom_call @SPMDFullToShardShape(%0) {backend_config = "", mhlo.sharding = "{manual}"} : (tensor<32x64xf32>) -> tensor<8x64xf32>
    %2 = call @shmap_body(%1) : (tensor<8x64xf32>) -> tensor<32x64xf32>
    %3 = stablehlo.custom_call @Sharding(%2) {backend_config = "", mhlo.sharding = "{manual}"} : (tensor<32x64xf32>) -> tensor<32x64xf32>
    %4 = stablehlo.custom_call @SPMDShardToFullShape(%3) {backend_config = "", mhlo.sharding = "{replicated}"} : (tensor<32x64xf32>) -> tensor<32x64xf32>
    return %4 : tensor<32x64xf32>
  }
  func.func private @shmap_body(%arg0: tensor<8x64xf32>) -> (tensor<32x64xf32> {jax.result_info = "[None, None]"}) {
    %0 = "stablehlo.all_gather"(%arg0) <{all_gather_dim = 0 : i64, channel_handle = #stablehlo.channel_handle<handle = 1, type = 1>, replica_groups = dense<[[0, 2, 4, 6], [1, 3, 5, 7]]> : tensor<2x4xi64>, use_global_device_ids}> : (tensor<8x64xf32>) -> tensor<32x64xf32>
    return %0 : tensor<32x64xf32>
  }
}

None

With the custom calls and GSPMD sharding, it's getting pretty confusing. Let's look at what Shardy gives:

jax.config.update("jax_use_shardy_partitioner", True)
run_shard_map()
jax.config.update("jax_use_shardy_partitioner", False)
module @jit_body attributes {mhlo.num_partitions = 8 : i32, mhlo.num_replicas = 1 : i32} {
  sdy.mesh @mesh = <["data"=4, "model"=2]>
  func.func public @main(%arg0: tensor<32x64xf32>) -> (tensor<32x64xf32> {jax.result_info = ""}) {
    %0 = sdy.manual_computation(%arg0) in_shardings=[<@mesh, [{"data"}, {}]>] out_shardings=[<@mesh, [{}, {}]>] manual_axes={"data", "model"} (%arg1: tensor<8x64xf32>) {
      %1 = "stablehlo.all_gather"(%arg1) <{all_gather_dim = 0 : i64, channel_handle = #stablehlo.channel_handle<handle = 1, type = 1>, replica_groups = dense<[[0, 2, 4, 6], [1, 3, 5, 7]]> : tensor<2x4xi64>, use_global_device_ids}> : (tensor<8x64xf32>) -> tensor<32x64xf32>
      sdy.return %1 : tensor<32x64xf32>
    } : (tensor<32x64xf32>) -> tensor<32x64xf32>
    return %0 : tensor<32x64xf32>
  }
}

We now:

  • Have a single op called sdy.manual_computation which holds:
    • the in_specs
    • the out_specs
    • the body of the shard_map
    • the inverse of the auto axes which we call manual_axes

A lot easier to read!

jax.experimental.custom_partitioning

With GSPMD, we define two routines, propagate_user_sharding and infer_sharding_from_operands, that may traverse jaxpr to return the sharding for the operands and results in order to use custom_partitioning. With Shardy, we provide sharding_rule corresponding to an Einsum like notation string to specify a sharding rule. Here is an example, where the routine that we use custom partition for implements a batch matrix multiplication.

We use a device array of (2M, M) to compute a matmul with the form of (...4N, 2N) x (...2N, 4N). Notice that instead of hard-coding the device array and the matrix shapes, we introduce two parameters, M and N, for specifying the shapes of the matrixes and the shapes of the device array so that we can change with these values to fit your purpose.

We first perform the needed setup and define the partition routine as we would do with GSPMD.

from functools import partial
from jax.experimental.custom_partitioning import custom_partitioning, SdyShardingRule, BATCHING

jax.config.update("jax_use_shardy_partitioner", True)

def partition(mesh, arg_shapes, result_shape):
  arg_shardings = jax.tree.map(lambda s: s.sharding, arg_shapes)
  result_sharding = result_shape.sharding
  rank=len(arg_shapes[0].shape)

  def lower_fn(x, y):
    axis_name = arg_shardings[1].spec[rank-2][0]
    i = jax.lax.axis_index(axis_name)
    z = jax.lax.psum(jax.lax.dynamic_slice_in_dim(
        jax.lax.dynamic_slice_in_dim(x, i * 0, N, axis=rank-2), i * N, N, axis=rank-1) @ y,
        (axis_name))
    return z

  return mesh, lower_fn, (result_sharding), arg_shardings

@partial(custom_partitioning)
def f(x, y):
  return jnp.matmul(x, y)

Then, we invoke the def_partition API. Note that instead of providing two callbacks for parameters infer_sharding_from_operands and propagate_user_sharding as we would do with GSPMD, we provide a sharding_rule parameter, which is an Einsum-like notation string similar to the subscripts in jnp.einsum("...ij, ...jk->...ik", x, y), if we would extend jnp.einsum to support the use of ... for representing leading batching dimensions. We borrow the idea from einops.rearrange string, to use a space separator between factors (to allow multiple letters factor names) and to not specify the value for a factor that ever represents a whole tensor dimension. We also support rank polymorphism by allowing leading ... in each tensor dimension representation to represent any number leading dimensions.

f.def_partition(
  infer_sharding_from_operands=None,
  propagate_user_sharding=None,
  partition=partition,
  sharding_rule="... i j, ... j k -> ... i k")

Alternatively, we can also create an equivalent SdyShardingRule object for the sharding_rule parameter. See Shardy document on sharding rule for more details.

f.def_partition(
  infer_sharding_from_operands=None,
  propagate_user_sharding=None,
  partition=partition,
  sharding_rule=SdyShardingRule(operand_mappings=((BATCHING, 'i', 'j'), (BATCHING, 'j', 'k')), result_mappings=((BATCHING, 'i', 'k'),)))

The sharding_rule parameter can also take a callback function generating either a string or an SdyShardingRule object.

def sharding_rule_producer(mesh, arg_shapes, result_shape):
  rank = len(arg_shapes[0].shape)
  leading_axes = ""
  for i in range(rank - 2):
    leading_axes += f" b{i}"
  return f"{leading_axes} i j, {leading_axes} j k -> {leading_axes} i k"


f.def_partition(
  partition=partition,
  sharding_rule=sharding_rule_producer)

Lastly, we create a mesh, define the input matrixes x and y, run the jitted f, and compare the results producted by the unjitted and the jitted f.

N = 4
M = 2
num_devices = 2 * M * M

devices = np.array(list(jax.devices())[:num_devices])
if devices.size < num_devices:
  raise ValueError(f'Requires {num_devices} devices')
device_mesh = Mesh(devices.reshape((2 * M, M)), ('x', 'y'))

sharding_x = NamedSharding(device_mesh, PartitionSpec(None, None, 'x'))
sharding_y = NamedSharding(device_mesh, PartitionSpec(None, None, 'y'))
jitted_f = jax.jit(f, in_shardings=(sharding_x, sharding_y), out_shardings=sharding_x)

x = np.asarray(np.random.randint(0, 20, (2, 3, 4*N, 2*N)), dtype=np.float32)
y = np.asarray(np.random.randint(0, 20, (2, 3, 2*N, 4*N)), dtype=np.float32)

result = f(x, y)

with device_mesh:
  jitted_result = jitted_f(x, y)

for i in range(num_devices):
  j = (i // M) * N
  assert((np.asarray(jitted_result.addressable_shards[i].data) == result[:,:,j:j+N,:]).all())

Auto partitioners

In progress.

XLA_DUMP_TO

When specifying the XLA_DUMP_TO, you will see an additional shardy/ directory containing various dumps of the StableHLO program. A lot of them are currently only relevant to the Shardy team to debug issues. The one you should focus on when debugging is sdy_module_after_sdy_export.mlir which is the module after propagation finished on the StableHLO program.

2. Why Shardy?

Readability

As seen above, it's much easier to read the shardings and shard_maps and understand how they match what is happening in the JAX code. Similarly GSPMD propagation will give back HLO code - not MLIR which both Shardy and jax.jit.lower return.

Interpretability

We are planning on exposing a feature we call "user priorities" (not in JAX yet!). It allows you to attach a value telling Shardy how important a tensor's dimension sharding is over other constraints in the program.

Higher prioritied are defines as lower values (lowest being 0, think of it as a p0 priority tasks).

PartitionSpec(None, 'x', 'y', priorities=(None, 0, 1))

Here the sharding of dim 1 on x has a higher priority than dim 2 on y, meaning dim 1 will be propagated through the program first and then dim 2, meaning any potential sharding conflicts will be explicitly avoided by having x propagated first.

This can be helpful for debugging models as well by having you break down your sharding strategies to separate rounds of propagation in Shardy. For example:

  • Priority 0: data parallelism
  • Priority 1: megatron
  • Priority 2: ZeRO sharding

FAQS

Below is a list of questions you may have on various JAX features and capabilities.

JAX Sharding types

What about GSPMDSharding?

GSPMDSharding is closely tied to the C++/protobuf representation inside the XLA compiler. As such the type itself won't be supported.

What about PositionalSharding?

This won't be supported. Instead use a NamedSharding with device_ids.

PmapSharding

This won't be supported. Shardy is meant for jax.jit, not jax.pmap.

Propagation Questions

Section for questions about what you may see during propagation.

What are split Axes in Shardy, aka "x":(2)2?

Refer to Axis splitting and sub-axes.