Shardy is a new propagation system being introduced into the XLA stack, and below we want to introduce any JAX users to:
- What has changed in JAX
- Why Shardy?
- Future plans
This is meant for JAX users who use jax.jit
for running training/inference models across more than 1 GPU or TPU (batch parallelism, megatron, ZeRO, etc). They would be using things like PartitionSpec
s and NamedSharding
s.
1. What has changed in JAX?
State of JAX before: GSPMD
Prior to Shardy, JAX users who partitioned their models across models across multiple devices used GSPMD behind the scenes.
GSPMD is the propagation+partitioning system that lives in the middle of the XLA pipeline. It operates on HLO - the IR that comes after StableHLO (the program you get after running jax.jit.lower
).
JAX doesn't run GSPMD directly, but encodes instructions into the StableHLO IR for GSPMD to read later on.
But before we go any further, let's introduce our working example.
Make sure you have installed jax>=0.4.35
.
pip install jax==0.4.35
Imports and utilities
import os
# make sure our code runs on 8 devices
os.environ['XLA_FLAGS'] = '--xla_force_host_platform_device_count=8'
import jax
import numpy as np
from jax import numpy as jnp
from jax.sharding import Mesh, NamedSharding, PartitionSpec
from jax.experimental.shard_map import shard_map
First, let's create our mesh.
mesh = Mesh(
np.reshape(np.array(jax.devices()), (4, 2)),
('data', 'model'))
print(mesh.shape)
OrderedDict([('data', 4), ('model', 2)])
In/Out shardings
Let's look at what changed the most: how sharding attributes are encoded in the JAX program for the compiler to read.
Let's look at it through an example. It's going to be an MLP-like model consisting of no bias tensors, and 2 layers (two matmuls).
def predict(x, w1, w2):
x = jnp.tanh(x)
z1 = jnp.einsum('ij,jk->ik', x, w1)
z2 = jnp.einsum('ij,jk->ik', z1, w2)
return jnp.sin(z2)
What we will want to do here sharding wise is:
data
parallelism on xtensor
parallelism onw1
andw2
through the megatron sharding strategy.
Now let's prepare the model for GSPMD sharding. Note that we will explicitly shard w1
, but let GSPMD propagation shard w2
.
def run_in_out_shardings():
samples = jax.ShapeDtypeStruct((16, 128), jnp.float32, sharding=NamedSharding(mesh, PartitionSpec('data', None)))
samples_sharding = NamedSharding(mesh, PartitionSpec('data', None))
w1 = jax.ShapeDtypeStruct((128, 256), jnp.float32, sharding=NamedSharding(mesh, PartitionSpec(None, 'model')))
w1_sharding = NamedSharding(mesh, PartitionSpec(None, 'model'))
w2 = jax.ShapeDtypeStruct((256, 10), jnp.float32)
w2_sharding = None
print(jax.jit(predict, in_shardings=(samples_sharding, w1_sharding, w2_sharding)).lower(samples, w1, w2).as_text())
run_in_out_shardings()
module @jit_predict attributes {mhlo.num_partitions = 8 : i32, mhlo.num_replicas = 1 : i32} { func.func public @main(%arg0: tensor<16x128xf32> {mhlo.sharding = "{devices=[4,1,2]<=[8] last_tile_dim_replicate}"}, %arg1: tensor<128x256xf32> {mhlo.sharding = "{devices=[1,2,4]<=[4,2]T(1,0) last_tile_dim_replicate}"}, %arg2: tensor<256x10xf32>) -> (tensor<16x10xf32> {jax.result_info = ""}) { %0 = stablehlo.tanh %arg0 : tensor<16x128xf32> %1 = stablehlo.dot_general %0, %arg1, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<16x128xf32>, tensor<128x256xf32>) -> tensor<16x256xf32> %2 = stablehlo.dot_general %1, %arg2, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<16x256xf32>, tensor<256x10xf32>) -> tensor<16x10xf32> %3 = stablehlo.sine %2 : tensor<16x10xf32> return %3 : tensor<16x10xf32> } }
GSPMD's sharding annotations look like the following:
JAX sharding | GSPMD sharding |
---|---|
NamedSharding(mesh, PartitionSpec('data', None)) |
{devices=[4,1,2]<=[8] last_tile_dim_replicate} |
NamedSharding(mesh, PartitionSpec(None, 'model')) |
{devices=[1,2,4]<=[4,2]T(1,0) last_tile_dim_replicate} |
None |
nothing |
None
is no sharding as expected since GSPMD will populate this during sharding propagation.
Notice how all the axis names go away? While there is a 1:1 correspondence between NamedSharding
and GSPMD sharding, as a reader, it can be difficult to read. It is only more difficult once you introduce various axis names.
Let's look at Shardy for comparison. To enable Shardy in JAX, simply enable the flag:
jax.config.update("jax_use_shardy_partitioner", True)
run_in_out_shardings()
jax.config.update("jax_use_shardy_partitioner", False)
module @jit_predict attributes {mhlo.num_partitions = 8 : i32, mhlo.num_replicas = 1 : i32} { sdy.mesh @mesh = <["data"=4, "model"=2]> func.func public @main(%arg0: tensor<16x128xf32> {sdy.sharding = #sdy.sharding<@mesh, [{"data"}, {}]>}, %arg1: tensor<128x256xf32> {sdy.sharding = #sdy.sharding<@mesh, [{}, {"model"}]>}, %arg2: tensor<256x10xf32>) -> (tensor<16x10xf32> {jax.result_info = ""}) { %0 = stablehlo.tanh %arg0 : tensor<16x128xf32> %1 = stablehlo.dot_general %0, %arg1, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<16x128xf32>, tensor<128x256xf32>) -> tensor<16x256xf32> %2 = stablehlo.dot_general %1, %arg2, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<16x256xf32>, tensor<256x10xf32>) -> tensor<16x10xf32> %3 = stablehlo.sine %2 : tensor<16x10xf32> return %3 : tensor<16x10xf32> } }
Now we have
JAX sharding | Shardy sharding |
---|---|
NamedSharding(mesh, PartitionSpec('data', None)) |
#sdy.sharding<@mesh, [{"data"}, {}]> |
NamedSharding(mesh, PartitionSpec(None, 'model')) |
#sdy.sharding<@mesh, [{}, {"model"}]> |
None |
nothing |
Shardy's representation is a lot closer to what JAX NamedSharding
s are like. So when looking at a file dump of your program after propagation, it will be a lot easier to understand what is going on since the correspondence is a lot closer to JAX.
Note that instead of the total devices/axes living on the sharding, they live on a top level @mesh
value.
jax.lax.with_sharding_constraint
GSPMD currently lowers it to a custom call:
def run_with_sharding_constraint():
x = jax.ShapeDtypeStruct((32, 64), jnp.float32)
def f(x):
return jax.lax.with_sharding_constraint(x, NamedSharding(mesh, PartitionSpec('data', PartitionSpec.UNCONSTRAINED)))
print(jax.jit(f).lower(x).as_text())
run_with_sharding_constraint()
module @jit_f attributes {mhlo.num_partitions = 8 : i32, mhlo.num_replicas = 1 : i32} { func.func public @main(%arg0: tensor<32x64xf32>) -> (tensor<32x64xf32> {jax.result_info = ""}) { %0 = stablehlo.custom_call @Sharding(%arg0) {backend_config = "unspecified_dims=[1]", mhlo.sharding = "{devices=[4,1,2]<=[8] last_tile_dim_replicate}"} : (tensor<32x64xf32>) -> tensor<32x64xf32> return %0 : tensor<32x64xf32> } }
But under Shardy it's an explicit op:
jax.config.update("jax_use_shardy_partitioner", True)
run_with_sharding_constraint()
jax.config.update("jax_use_shardy_partitioner", False)
module @jit_f attributes {mhlo.num_partitions = 8 : i32, mhlo.num_replicas = 1 : i32} { sdy.mesh @mesh = <["data"=4, "model"=2]> func.func public @main(%arg0: tensor<32x64xf32>) -> (tensor<32x64xf32> {jax.result_info = ""}) { %0 = sdy.sharding_constraint %arg0 <@mesh, [{"data"}, {?}]> : tensor<32x64xf32> return %0 : tensor<32x64xf32> } }
Note that UNCONSTRAINED
under GSPMD has the custom call have an op attribute backend_config = "unspecified_dims=[1]"
. But under Shardy, it makes dim 1 be {?}
. In Shardy, dimension shardings without a ?
are closed, meaning that dimension can't be further sharded, but when it has a trailing ?
, it can be further sharded. Refer to Sharding representation for more info on the sharding representation.
jax.experimental.shard_map
Under GSPMD this is a few different custom calls with various shard_map
specific attributes on the GSPMD sharding. Let's look where the model
axis is auto
, meaning it's free to be used inside the body of the shard_map by sharding constraints.
def run_shard_map():
x = jax.ShapeDtypeStruct((32, 64), jnp.float32)
def body(x):
return jax.lax.all_gather(x, 'data', tiled=True)
shmaped_f = shard_map(
body,
mesh=mesh,
in_specs=(jax.sharding.PartitionSpec('data',),),
out_specs=jax.sharding.PartitionSpec(),
check_rep=False)
print(jax.jit(shmaped_f).lower(x).as_text())
print(run_shard_map())
module @jit_body attributes {mhlo.num_partitions = 8 : i32, mhlo.num_replicas = 1 : i32} { func.func public @main(%arg0: tensor<32x64xf32>) -> (tensor<32x64xf32> {jax.result_info = ""}) { %0 = stablehlo.custom_call @Sharding(%arg0) {backend_config = "", mhlo.sharding = "{devices=[4,1,2]<=[8] last_tile_dim_replicate}"} : (tensor<32x64xf32>) -> tensor<32x64xf32> %1 = stablehlo.custom_call @SPMDFullToShardShape(%0) {backend_config = "", mhlo.sharding = "{manual}"} : (tensor<32x64xf32>) -> tensor<8x64xf32> %2 = call @shmap_body(%1) : (tensor<8x64xf32>) -> tensor<32x64xf32> %3 = stablehlo.custom_call @Sharding(%2) {backend_config = "", mhlo.sharding = "{manual}"} : (tensor<32x64xf32>) -> tensor<32x64xf32> %4 = stablehlo.custom_call @SPMDShardToFullShape(%3) {backend_config = "", mhlo.sharding = "{replicated}"} : (tensor<32x64xf32>) -> tensor<32x64xf32> return %4 : tensor<32x64xf32> } func.func private @shmap_body(%arg0: tensor<8x64xf32>) -> (tensor<32x64xf32> {jax.result_info = "[None, None]"}) { %0 = "stablehlo.all_gather"(%arg0) <{all_gather_dim = 0 : i64, channel_handle = #stablehlo.channel_handle<handle = 1, type = 1>, replica_groups = dense<[[0, 2, 4, 6], [1, 3, 5, 7]]> : tensor<2x4xi64>, use_global_device_ids}> : (tensor<8x64xf32>) -> tensor<32x64xf32> return %0 : tensor<32x64xf32> } } None
With the custom calls and GSPMD sharding, it's getting pretty confusing. Let's look at what Shardy gives:
jax.config.update("jax_use_shardy_partitioner", True)
run_shard_map()
jax.config.update("jax_use_shardy_partitioner", False)
module @jit_body attributes {mhlo.num_partitions = 8 : i32, mhlo.num_replicas = 1 : i32} { sdy.mesh @mesh = <["data"=4, "model"=2]> func.func public @main(%arg0: tensor<32x64xf32>) -> (tensor<32x64xf32> {jax.result_info = ""}) { %0 = sdy.manual_computation(%arg0) in_shardings=[<@mesh, [{"data"}, {}]>] out_shardings=[<@mesh, [{}, {}]>] manual_axes={"data", "model"} (%arg1: tensor<8x64xf32>) { %1 = "stablehlo.all_gather"(%arg1) <{all_gather_dim = 0 : i64, channel_handle = #stablehlo.channel_handle<handle = 1, type = 1>, replica_groups = dense<[[0, 2, 4, 6], [1, 3, 5, 7]]> : tensor<2x4xi64>, use_global_device_ids}> : (tensor<8x64xf32>) -> tensor<32x64xf32> sdy.return %1 : tensor<32x64xf32> } : (tensor<32x64xf32>) -> tensor<32x64xf32> return %0 : tensor<32x64xf32> } }
We now:
- Have a single op called
sdy.manual_computation
which holds:- the
in_specs
- the
out_specs
- the body of the shard_map
- the inverse of the
auto
axes which we callmanual_axes
- the
A lot easier to read!
jax.experimental.custom_partitioning
With GSPMD, we define two routines, propagate_user_sharding
and infer_sharding_from_operands
, that may traverse jaxpr to return the sharding for the operands and results in order to use custom_partitioning. With Shardy, we provide sharding_rule
corresponding to an Einsum like notation string to specify a sharding rule. Here is an example, where the routine that we use custom partition for implements a batch matrix multiplication.
We use a device array of (2M, M) to compute a matmul with the form of (...4N, 2N) x (...2N, 4N). Notice that instead of hard-coding the device array and the matrix shapes, we introduce two parameters, M and N, for specifying the shapes of the matrixes and the shapes of the device array so that we can change with these values to fit your purpose.
We first perform the needed setup and define the partition
routine as we would do with GSPMD.
from functools import partial
from jax.experimental.custom_partitioning import custom_partitioning, SdyShardingRule, BATCHING
jax.config.update("jax_use_shardy_partitioner", True)
def partition(mesh, arg_shapes, result_shape):
arg_shardings = jax.tree.map(lambda s: s.sharding, arg_shapes)
result_sharding = result_shape.sharding
rank=len(arg_shapes[0].shape)
def lower_fn(x, y):
axis_name = arg_shardings[1].spec[rank-2][0]
i = jax.lax.axis_index(axis_name)
z = jax.lax.psum(jax.lax.dynamic_slice_in_dim(
jax.lax.dynamic_slice_in_dim(x, i * 0, N, axis=rank-2), i * N, N, axis=rank-1) @ y,
(axis_name))
return z
return mesh, lower_fn, (result_sharding), arg_shardings
@partial(custom_partitioning)
def f(x, y):
return jnp.matmul(x, y)
Then, we invoke the def_partition
API. Note that instead of providing two callbacks for parameters infer_sharding_from_operands
and propagate_user_sharding
as we would do with GSPMD, we provide a sharding_rule
parameter, which is an Einsum-like notation string similar to the subscripts in jnp.einsum("...ij, ...jk->...ik", x, y)
, if we would extend jnp.einsum
to support the use of ...
for representing leading batching dimensions. We borrow the idea from einops.rearrange string, to use a space separator between factors (to allow multiple letters factor names) and to not specify the value for a factor that ever represents a whole tensor dimension. We also support rank polymorphism by allowing leading ... in each tensor dimension representation to represent any number leading dimensions.
f.def_partition(
infer_sharding_from_operands=None,
propagate_user_sharding=None,
partition=partition,
sharding_rule="... i j, ... j k -> ... i k")
Alternatively, we can also create an equivalent SdyShardingRule
object for the sharding_rule
parameter. See Shardy document on sharding rule for more details.
f.def_partition(
infer_sharding_from_operands=None,
propagate_user_sharding=None,
partition=partition,
sharding_rule=SdyShardingRule(operand_mappings=((BATCHING, 'i', 'j'), (BATCHING, 'j', 'k')), result_mappings=((BATCHING, 'i', 'k'),)))
The sharding_rule
parameter can also take a callback function generating either a string or an SdyShardingRule
object.
def sharding_rule_producer(mesh, arg_shapes, result_shape):
rank = len(arg_shapes[0].shape)
leading_axes = ""
for i in range(rank - 2):
leading_axes += f" b{i}"
return f"{leading_axes} i j, {leading_axes} j k -> {leading_axes} i k"
f.def_partition(
partition=partition,
sharding_rule=sharding_rule_producer)
Lastly, we create a mesh, define the input matrixes x and y, run the jitted f, and compare the results producted by the unjitted and the jitted f.
N = 4
M = 2
num_devices = 2 * M * M
devices = np.array(list(jax.devices())[:num_devices])
if devices.size < num_devices:
raise ValueError(f'Requires {num_devices} devices')
device_mesh = Mesh(devices.reshape((2 * M, M)), ('x', 'y'))
sharding_x = NamedSharding(device_mesh, PartitionSpec(None, None, 'x'))
sharding_y = NamedSharding(device_mesh, PartitionSpec(None, None, 'y'))
jitted_f = jax.jit(f, in_shardings=(sharding_x, sharding_y), out_shardings=sharding_x)
x = np.asarray(np.random.randint(0, 20, (2, 3, 4*N, 2*N)), dtype=np.float32)
y = np.asarray(np.random.randint(0, 20, (2, 3, 2*N, 4*N)), dtype=np.float32)
result = f(x, y)
with device_mesh:
jitted_result = jitted_f(x, y)
for i in range(num_devices):
j = (i // M) * N
assert((np.asarray(jitted_result.addressable_shards[i].data) == result[:,:,j:j+N,:]).all())
Auto partitioners
In progress.
XLA_DUMP_TO
When specifying the XLA_DUMP_TO
, you will see an additional shardy/
directory containing various dumps of the StableHLO program. A lot of them are currently only relevant to the Shardy team to debug issues. The one you should focus on when debugging is sdy_module_after_sdy_export.mlir
which is the module after propagation finished on the StableHLO program.
2. Why Shardy?
Readability
As seen above, it's much easier to read the shardings and shard_maps and understand how they match what is happening in the JAX code. Similarly GSPMD propagation will give back HLO code - not MLIR which both Shardy and jax.jit.lower
return.
Interpretability
We are planning on exposing a feature we call "user priorities" (not in JAX yet!). It allows you to attach a value telling Shardy how important a tensor's dimension sharding is over other constraints in the program.
Higher prioritied are defines as lower values (lowest being 0, think of it as a p0 priority tasks).
PartitionSpec(None, 'x', 'y', priorities=(None, 0, 1))
Here the sharding of dim 1 on x
has a higher priority than dim 2 on y
, meaning dim 1 will be propagated through the program first and then dim 2, meaning any potential sharding conflicts will be explicitly avoided by having x
propagated first.
This can be helpful for debugging models as well by having you break down your sharding strategies to separate rounds of propagation in Shardy. For example:
- Priority 0: data parallelism
- Priority 1: megatron
- Priority 2: ZeRO sharding
FAQS
Below is a list of questions you may have on various JAX features and capabilities.
JAX Sharding types
What about GSPMDSharding?
GSPMDSharding
is closely tied to the C++/protobuf representation inside the XLA compiler. As such the type itself won't be supported.
What about PositionalSharding?
This won't be supported. Instead use a NamedSharding
with device_ids
.
PmapSharding
This won't be supported. Shardy is meant for jax.jit
, not jax.pmap
.
Propagation Questions
Section for questions about what you may see during propagation.
What are split Axes in Shardy, aka "x":(2)2?
Refer to Axis splitting and sub-axes.