Фон
Цель представления сегментирования — указать, как сегментируется тензор относительно набора доступных устройств.
Представление шардинга может быть:
- Вручную указывается пользователем как ограничения сегментирования на входах, выходах или промежуточных звеньях.
- Трансформируется за операцию в процессе распространения шардинга.
Обзор
Базовая структура
Логическая сетка — это многомерное представление устройств, определяемое списком имен и размеров осей.
Предлагаемое представление сегментирования привязано к определенной логической сетке по ее имени и может ссылаться только на имена осей из этой сетки. Шардинг тензора определяет, по каким осям (конкретной логической сетки) сегментируется каждое измерение тензора, в порядке от большего к меньшему. Тензор копируется вдоль всех остальных осей сетки.
Давайте рассмотрим представление сегментирования с помощью простого тензора ранга 2 и устройств 4.
Сначала мы преобразуем 4 устройства [0, 1, 2, 3]
в двумерный массив [[0, 1], [2, 3]]
чтобы создать сетку с двумя осями:
@mesh_xy = <["x"=2, "y"=2]>
Затем мы можем сегментировать следующий тензор ранга 2 [[a, b], [c, d]]
следующим образом:
Другие ключевые компоненты
- Открытые/закрытые измерения — размеры могут быть открытыми — их можно дополнительно сегментировать по доступным осям; или закрытые – фиксированы и не могут быть изменены.
- Явно реплицированные оси — все оси, которые не используются для сегментирования измерения, реплицируются неявно, но при сегментировании могут быть указаны оси, которые реплицируются явно и поэтому не могут быть использованы для сегментирования измерения в дальнейшем.
- Разделение осей и подоси — (полную) ось сетки можно разделить на несколько подосей, которые можно индивидуально использовать для сегментирования измерения или явно реплицировать.
- Несколько логических сеток — разные сегменты могут быть привязаны к разным логическим сеткам, которые могут иметь разные оси или даже разный порядок идентификаторов логических устройств.
- Приоритеты — для поэтапного разделения программы приоритеты можно прикрепить к сегментам измерений, которые определяют, в каком порядке ограничения сегментирования каждого измерения будут распространяться по всему модулю.
- Делимость измерения : измерение может быть сегментировано по осям, произведение размеров которых не делит размер измерения.
Детальный проект
В этом разделе мы раскрываем базовую структуру и каждый ключевой компонент.
Базовая структура
Шардинги измерений сообщают нам для каждого измерения тензора, по каким осям (или подосям ) он сегментируется от главного к второстепенному. Все остальные оси, которые не сегментируют измерение, реплицируются неявно (или реплицируются явно ).
Мы начнем с простого примера и будем расширять его по мере описания дополнительных функций.
@mesh_xy = <["x"=2, "y"=4, "z"=2]>
// The 1st tensor dimension is sharded along axis "x" and the 2nd tensor dimension is
// sharded along axis "z" then further along axis "y". The local shape of this tensor (i.e. the shape on a single device), would be tensor<2x1xf32>.
sharding<@mesh_xy, [{"x"}, {"z", "y"}]> : tensor<4x8xf32>
Инварианты
- Количество сегментов измерения должно соответствовать рангу тензора.
- Все имена осей должны существовать в сетке, на которую ссылаются.
- Оси или подоси могут появляться в представлении сегментирования только один раз (каждая из них либо сегментирует измерение, либо явно реплицируется ).
Открытые/закрытые размеры
Каждое измерение тензора может быть открытым или закрытым.
Открыть
Открытое измерение открыто для распространения с целью дальнейшего сегментирования его по дополнительным осям, т. е. указанное сегментирование измерения не обязательно должно быть окончательным сегментированием этого измерения. Это похоже (но не совсем то же самое) на unspecified_dims
GSPMD.
Если измерение открыто, мы добавляем ?
по осям, по которым уже сегментировано измерение (см. пример ниже).
Закрыто
Закрытое измерение — это измерение, которое недоступно для распространения с целью добавления дальнейшего сегментирования, т. е. указанное сегментирование измерения является окончательным сегментированием этого измерения, и его нельзя изменить. Распространенным примером использования этого является то, что GSPMD (обычно) не изменяет аргументы ввода/вывода модуля или как с помощью jax.jit
пользователь, указанный in_shardings
, является статичным - они не могут измениться.
Мы можем расширить приведенный выше пример, добавив в него открытое и закрытое измерение.
@mesh_xy = <["x"=2, "y"=4, "z"=2]>
// The 1st dimension is closed, therefore it can't be further sharded and {"x"}
// will remain its sharding. The 2nd dimension is open, and can therefore be
// further sharded during propagation, e.g. by "y".
sharding<@mesh_xy, [{"x"}, {"z", ?}]> : tensor<4x8xf32>
Явно реплицированные оси
Явный набор осей, на которых реплицируется тензор. Хотя можно определить, что тензор, не сегментированный по оси, неявно реплицируется на нем (как сегодня jax.sharding.PartitionSpec
), его явное использование гарантирует, что распространение не сможет использовать эти оси для дальнейшего сегментирования открытого измерения с помощью этих осей. При неявной репликации тензор можно дополнительно разбить. Но при явной репликации ничто не может разделить тензор по этой оси.
Порядок реплицируемых осей не влияет на то, как хранятся данные тензора. Но исключительно для обеспечения единообразия оси будут храниться в том порядке, в котором они указаны в сетке верхнего уровня. Например, если сетка:
@mesh_xy = <["c"=2, "a"=2, "b"=2]>
И мы хотим, чтобы оси "a"
и "c"
были явно реплицированы, порядок должен быть таким:
replicated={"c", "a"}
Мы можем расширить приведенный выше пример, чтобы иметь явно реплицированную ось.
@mesh_xyz = <["x"=2, "y"=4, "z"=2]>
// Since "y" is explicitly replicated, it can't be used to shard the 2nd
// dimension that is open. However, "z" is implicitly replicated so it can be
// used to shard that dimension. The local shape of this tensor (i.e. the
// shape on a single device), would be tensor<2x8xf32>.
sharding<@mesh_xyz, [{"x"}, {?}], replicated={"y"}> : tensor<4x8xf32>
Разделение осей и подоси
Логическая сетка из n
осей создается путем преобразования одномерного массива устройств в n-мерный массив, где каждое измерение образует ось с определяемым пользователем именем.
Тот же процесс можно выполнить в компиляторе, чтобы разбить ось размера k
на m
подосей, изменив форму сетки из [...,k,...]
в [...,k1,...,km,...]
.
Мотивация
Чтобы понять мотивацию разделения осей, мы рассмотрим следующий пример:
@mesh_x = <["x"=4]>
%arg0 : tensor<8xf32> {sdy.sharding=<@mesh_x, [{"x"}]>}
%0 = reshape %arg0 : (tensor<8xf32>) -> tensor<2x4xf32>
Мы хотим сегментировать результат изменения формы таким образом, чтобы избежать обмена данными (т. е. сохранить данные там, где они есть). Поскольку размер "x"
больше, чем 1-е измерение результата, нам нужно разделить ось на две подоси "x.0"
и "x.1"
размером 2 каждая и разделить 1-е измерение на "x.0"
и 2-е измерение на "x.1"
.
Функция Шардинг ввода/вывода
Вполне возможно, что во время распространения вход или выход основной функции будет сегментирован вдоль подоси. Это может быть проблемой для некоторых фреймворков, где мы не можем выразить такие сегменты для возврата пользователю (например, в JAX мы не можем выразить подоси с помощью jax.sharding.NamedSharding
).
У нас есть несколько вариантов решения таких случаев:
- Разрешите и верните сегментирование в другом формате (например,
jax.sharding.PositionalSharding
вместоjax.sharding.NamedSharding
в JAX). - Запретить и собрать все подоси, которые сегментируют ввод/вывод.
В настоящее время мы разрешаем подоси на входах/выходах в конвейере распространения. Дайте нам знать, если вам нужен способ отключить это.
Представительство
Точно так же, как мы можем ссылаться на определенные полные оси сетки по их имени, мы можем ссылаться на определенные подоси по их размеру и произведению всех размеров подосей (с тем же именем оси) слева от них (которые являются основными для них).
Чтобы извлечь конкретную подось размера k
из полной оси "x"
размера n
, мы фактически изменяем размер n
(в сетке) на [m, k, n/(m*k)]
и используем второе измерение в качестве подоси. Таким образом, подось может быть задана двумя числами, m
и k
, и для обозначения подосей мы используем следующее краткое обозначение: "x":(m)k
.
m>=1
— это предварительный размер этой подоси (m
должен быть делителемn
). Предварительный размер — это произведение всех размеров подоси слева от этой подоси (которые являются основными для нее) (если он равен 1, это означает, что их нет, если больше 1, это соответствует одной или нескольким подосям).k>1
— фактический размер этой подоси (k
должен быть делителемn
).n/(m*k)
— размер сообщения . Это произведение всех размеров подоси справа от этой подоси (которые являются второстепенными по отношению к ней) (если равно 1, это означает, что их нет, если больше 1, это соответствует одной или нескольким подосям).
Однако количество других подосей не имеет значения при использовании конкретной подоси "x":(m)k
, и нет необходимости ссылаться на любую другую подось при сегментировании тензора, если она не сегментирует измерение или явно реплицируется.
Возвращаясь к примеру из раздела «Мотивация» , мы можем сегментировать результат следующим образом:
@mesh_x = <["x"=4]>
%arg0 : tensor<8xf32> {sdy.sharding=<@mesh_x, [{"x"}]>}
%0 = reshape %arg0 {sdy.sharding_per_value=<[<@mesh_x, [{"x":(1)2}, {"x":(2)2}]>]>}
: (tensor<8xf32>) -> tensor<2x4xf32>
Вот еще один пример разделенной оси, в которой используются только некоторые из ее подосей.
@mesh_xyz = <["x"=2, "y"=8, "z"=2]>
// Axis "y" is effectively split into 3 sub-axes denoted as
// "y":(1)2, "y":(2)2, "y":(4)2
// in order, but only "y":(2)2 is used, to shard the 2nd dimension. The local
// shape of this tensor (i.e. the shape on a single device), would be
// tensor<2x4xf32>.
sharding<@mesh_xyz, [{"x"}, {"y":(2)2}]> : tensor<4x8xf32>
Аналогично, следующие два шардинга семантически эквивалентны. Мы можем думать о mesh_xy
как о разделении mesh_full
.
@mesh_full = <"devices"=8>
@mesh_xy = <"x"=4, "y"=2>
sharding<@mesh_xy, [{"x"},{ "y"}]> : tensor<4x4xf32>
sharding<@mesh_full, [{"devices":(1)4}, {"devices":(4)2}]> : tensor<4x4xf32>
Явно реплицированные подоси
Помимо подосей, используемых для сегментирования измерений, они также могут быть помечены как явно реплицированные. Мы допускаем это в представлении, поскольку подоси ведут себя так же, как полные оси, т. е. когда вы сегментируете измерение вдоль подоси оси "x"
, другие подоси "x"
реплицируются неявно и, следовательно, могут быть реплицированы явно, чтобы указать, что подось должна оставаться реплицированной и не может использоваться для сегментирования измерения.
Например:
@mesh_xyz = <["x"=2, "y"=8, "z"=2]>
// Sub-axis "y":(1)2 is explicitly replicated and "y":(4)2 is implicitly replicated.
sharding<@mesh_xyz, [{"x"}, {"y":(2)2}], replicated={"y":(1)2}> : tensor<4x8xf32>
Реплицированные подоси одной и той же полной оси следует располагать в порядке возрастания их предварительного размера, например:
replicated={"y":(4)2, "x", "y":(1)2} ~> replicated={"x", "y":(1)2, "y":(4)2}
Инварианты
Подоси, на которые ссылаются в тензорном сегментировании, не должны перекрываться, например
"x":(1)4
и"x":(2)4
перекрываются.Подоси, на которые ссылаются в тензорном сегментировании, должны быть как можно больше, т. е. если сегментирование измерения имеет две соседние по порядку подоси A и B или подоси A и B явно реплицируются, они не должны быть последовательными, например
"x":(1)2
и"x":(2)4
поскольку их можно заменить одним"x":(1)8
.
Несколько логических сеток
Одна логическая сетка представляет собой многомерное представление устройств. Нам может потребоваться несколько представлений устройств для представления наших сегментов, особенно для произвольных назначений устройств.
Например, jax.sharding.PositionalSharding
не имеет одной общей логической сетки . GSPMD в настоящее время поддерживает это с помощью HloSharding, где представление может представлять собой упорядоченный список устройств и размеров измерений, но это невозможно представить с помощью разделения осей, описанного выше.
Мы преодолеваем это ограничение и обрабатываем существующие крайние случаи, определяя несколько логических сеток на верхнем уровне программы. Каждая сетка может иметь разное количество осей с разными именами, а также свое произвольное назначение для одного и того же набора устройств, т.е. каждая сетка относится к одному и тому же набору устройств (по их уникальному логическому идентификатору), но с произвольным порядком, аналогично представлению GSPMD.
Каждое представление сегментирования связано с определенной логической сеткой, поэтому оно будет ссылаться только на оси этой сетки.
Тензор, назначенный одной логической сетке, может использоваться операцией, назначенной другой сетке, путем наивного изменения тензора в соответствии с целевой сеткой. В GSPMD это обычно делается для разрешения конфликтующих сеток.
Пользователи могут указать несколько сеток с разными именованными осями (например, через jax.sharding.NamedSharding
), которые имеют одинаковый порядок устройств. Рассмотрим этот пример: <@mesh_0, "b">
идентичен <@mesh_1, "z">
:
@mesh_0 = {<["a"=4, "b"=2]>, device_ids=[0, 1, 2, 3, 4, 5, 6, 7]}
@mesh_1 = {<["x"=2, "y"=2, "z"=2]>, device_ids=[0, 1, 2, 3, 4, 5, 6, 7]}
Приоритеты
Приоритет — это способ определения приоритета определенных решений по секционированию и распространению над другими, а также позволяет выполнять постепенное секционирование программы.
Приоритеты — это значения, прикрепленные к некоторым или всем измерениям представления сегментирования (реплицируемые оси не имеют приоритетов).
Например:
@mesh_xy = <["w"=6, "x"=2, "y"=4, "z"=2]>
// |-> y is implicitly p0
%arg4 : sharding<@mesh_xy, [{"x"}p1, {"y"}, {"z",?}p2], replicated={} }>
Приоритеты дают пользователям более детальный контроль над распространением, например, сначала пакетный параллелизм, затем мегатрон и, наконец, сегментирование ZeRO . Это дает надежные гарантии относительно того, что секционировано, и обеспечивает лучшую отладку за счет более детальных стратегий сегментирования (можно увидеть, как программа работает только с мегатроном в изоляции).
Мы разрешаем присвоить приоритет каждому сегментированию измерения (по умолчанию 0), что означает, что все сегменты с приоритетом <i
будут распространяться на всю программу раньше сегментов с приоритетом i
.
Даже если сегментирование имеет открытое измерение с более низким приоритетом, например, {"z",?}p2
, оно не будет переопределено другим тензорным сегментированием с более высоким приоритетом во время распространения. Однако такое открытое измерение может быть дополнительно сегментировано после распространения всех сегментов с более высоким приоритетом.
Другими словами, приоритеты НЕ связаны с тем, какое сегментирование измерений важнее другого - это порядок, в котором отдельные группы сегментирования измерений должны распространяться на всю программу, и то, как следует разрешать конфликты на промежуточных, неаннотированных тензорах.
Инварианты
Приоритеты начинаются с 0 (наивысший приоритет) и увеличиваются (чтобы пользователи могли легко добавлять и удалять приоритеты, мы допускаем промежутки между приоритетами, например, используются p0 и p2, а p1 — нет).
Пустой сегмент закрытого измерения (т. е.
{}
) не должен иметь приоритета, поскольку это не будет иметь никакого эффекта.
Делимость сегментирования измерений
Размерность размера d
может быть сегментирована по осям, произведение размеров которых равно n
, так что d
не делится на n
(что на практике потребовало бы заполнения измерения).
Например:
@mesh_xy = <["x"=8, "y"=2, "z"=3]>
sharding<@mesh_xy, [{"x"}, {"y"}, {"z"}]> : tensor<7x3x8xf32>
Грамматика
Каждая логическая сетка определяется следующим образом:
@mesh_name = <mesh_axis_1,...,mesh_axis_n>
mesh_axis ::= axis_name=axis_size
axis_name ::= str
axis_size ::= int
Представление шардинга будет иметь следующую структуру для тензора ранга r:
sharding<@mesh_name, dim_shardings, replicated=replicated_axes}
mesh_name ::= str
dim_shardings ::= [dim_sharding_1,...,dim_sharding_r]
replicated_axes ::= {axis_1,...,axis_m}
dim_sharding ::=
{axis_1,...,axis_k} | // closed dimension
{axis_1,...,axis_k,?} // open dimension
axis ::=
axis_name | // a full axis
sub_axis // a sub axis
axis_name ::= str
sub_axis ::= axis_name:(pre_size)size
pre_size ::= int
size ::= int
Фон
Цель представления сегментирования — указать, как сегментируется тензор относительно набора доступных устройств.
Представление шардинга может быть:
- Вручную указывается пользователем как ограничения сегментирования на входах, выходах или промежуточных звеньях.
- Трансформируется за операцию в процессе распространения шардинга.
Обзор
Базовая структура
Логическая сетка — это многомерное представление устройств, определяемое списком имен и размеров осей.
Предлагаемое представление сегментирования привязано к определенной логической сетке по ее имени и может ссылаться только на имена осей из этой сетки. Шардинг тензора определяет, по каким осям (конкретной логической сетки) сегментируется каждое измерение тензора, в порядке от большего к меньшему. Тензор копируется вдоль всех остальных осей сетки.
Давайте рассмотрим представление сегментирования с помощью простого тензора ранга 2 и устройств 4.
Сначала мы преобразуем 4 устройства [0, 1, 2, 3]
в двумерный массив [[0, 1], [2, 3]]
чтобы создать сетку с двумя осями:
@mesh_xy = <["x"=2, "y"=2]>
Затем мы можем сегментировать следующий тензор ранга 2 [[a, b], [c, d]]
следующим образом:
Другие ключевые компоненты
- Открытые/закрытые измерения — размеры могут быть открытыми — их можно дополнительно сегментировать по доступным осям; или закрытые – фиксированы и не могут быть изменены.
- Явно реплицированные оси — все оси, которые не используются для сегментирования измерения, реплицируются неявно, но при сегментировании могут быть указаны оси, которые реплицируются явно и поэтому не могут быть использованы для сегментирования измерения в дальнейшем.
- Разделение осей и подоси — (полную) ось сетки можно разделить на несколько подосей, которые можно индивидуально использовать для сегментирования измерения или явно реплицировать.
- Несколько логических сеток — разные сегменты могут быть привязаны к разным логическим сеткам, которые могут иметь разные оси или даже разный порядок идентификаторов логических устройств.
- Приоритеты — для поэтапного разделения программы приоритеты можно прикрепить к сегментам измерений, которые определяют, в каком порядке ограничения сегментирования каждого измерения будут распространяться по всему модулю.
- Делимость измерения : измерение может быть сегментировано по осям, произведение размеров которых не делит размер измерения.
Детальный проект
В этом разделе мы раскрываем базовую структуру и каждый ключевой компонент.
Базовая структура
Шардинги измерений сообщают нам для каждого измерения тензора, по каким осям (или подосям ) он сегментируется от главного к второстепенному. Все остальные оси, которые не сегментируют измерение, реплицируются неявно (или реплицируются явно ).
Мы начнем с простого примера и будем расширять его по мере описания дополнительных функций.
@mesh_xy = <["x"=2, "y"=4, "z"=2]>
// The 1st tensor dimension is sharded along axis "x" and the 2nd tensor dimension is
// sharded along axis "z" then further along axis "y". The local shape of this tensor (i.e. the shape on a single device), would be tensor<2x1xf32>.
sharding<@mesh_xy, [{"x"}, {"z", "y"}]> : tensor<4x8xf32>
Инварианты
- Количество сегментов измерения должно соответствовать рангу тензора.
- Все имена осей должны существовать в сетке, на которую ссылаются.
- Оси или подоси могут появляться в представлении сегментирования только один раз (каждая из них либо сегментирует измерение, либо явно реплицируется ).
Открытые/закрытые размеры
Каждое измерение тензора может быть открытым или закрытым.
Открыть
Открытое измерение открыто для распространения с целью дальнейшего сегментирования его по дополнительным осям, т. е. указанное сегментирование измерения не обязательно должно быть окончательным сегментированием этого измерения. Это похоже (но не совсем то же самое) на unspecified_dims
GSPMD.
Если измерение открыто, мы добавляем ?
по осям, по которым уже сегментировано измерение (см. пример ниже).
Закрыто
Закрытое измерение — это измерение, которое недоступно для распространения с целью добавления дальнейшего сегментирования, т. е. указанное сегментирование измерения является окончательным сегментированием этого измерения, и его нельзя изменить. Распространенным примером использования этого является то, что GSPMD (обычно) не изменяет аргументы ввода/вывода модуля или как с помощью jax.jit
пользователь, указанный in_shardings
, является статичным - они не могут измениться.
Мы можем расширить приведенный выше пример, добавив в него открытое и закрытое измерение.
@mesh_xy = <["x"=2, "y"=4, "z"=2]>
// The 1st dimension is closed, therefore it can't be further sharded and {"x"}
// will remain its sharding. The 2nd dimension is open, and can therefore be
// further sharded during propagation, e.g. by "y".
sharding<@mesh_xy, [{"x"}, {"z", ?}]> : tensor<4x8xf32>
Явно реплицированные оси
Явный набор осей, на которых реплицируется тензор. Хотя можно определить, что тензор, не сегментированный по оси, неявно реплицируется на нем (как сегодня jax.sharding.PartitionSpec
), его явное использование гарантирует, что распространение не сможет использовать эти оси для дальнейшего сегментирования открытого измерения с помощью этих осей. При неявной репликации тензор можно дополнительно разбить. Но при явной репликации ничто не может разделить тензор по этой оси.
Порядок реплицируемых осей не влияет на то, как хранятся данные тензора. Но исключительно для обеспечения единообразия оси будут храниться в том порядке, в котором они указаны в сетке верхнего уровня. Например, если сетка:
@mesh_xy = <["c"=2, "a"=2, "b"=2]>
И мы хотим, чтобы оси "a"
и "c"
были явно реплицированы, порядок должен быть таким:
replicated={"c", "a"}
Мы можем расширить приведенный выше пример, чтобы иметь явно реплицированную ось.
@mesh_xyz = <["x"=2, "y"=4, "z"=2]>
// Since "y" is explicitly replicated, it can't be used to shard the 2nd
// dimension that is open. However, "z" is implicitly replicated so it can be
// used to shard that dimension. The local shape of this tensor (i.e. the
// shape on a single device), would be tensor<2x8xf32>.
sharding<@mesh_xyz, [{"x"}, {?}], replicated={"y"}> : tensor<4x8xf32>
Разделение осей и подоси
Логическая сетка из n
осей создается путем преобразования одномерного массива устройств в n-мерный массив, где каждое измерение образует ось с определяемым пользователем именем.
Тот же процесс можно выполнить в компиляторе, чтобы разбить ось размера k
на m
подосей, изменив форму сетки из [...,k,...]
в [...,k1,...,km,...]
.
Мотивация
Чтобы понять мотивацию разделения осей, мы рассмотрим следующий пример:
@mesh_x = <["x"=4]>
%arg0 : tensor<8xf32> {sdy.sharding=<@mesh_x, [{"x"}]>}
%0 = reshape %arg0 : (tensor<8xf32>) -> tensor<2x4xf32>
Мы хотим сегментировать результат изменения формы таким образом, чтобы избежать обмена данными (т. е. сохранить данные там, где они есть). Поскольку размер "x"
больше, чем 1-е измерение результата, нам нужно разделить ось на две подоси "x.0"
и "x.1"
размером 2 каждая и разделить 1-е измерение на "x.0"
и 2-е измерение на "x.1"
.
Функция Шардинг ввода/вывода
Вполне возможно, что во время распространения вход или выход основной функции будет сегментирован вдоль подоси. Это может быть проблемой для некоторых фреймворков, где мы не можем выразить такие сегменты для возврата пользователю (например, в JAX мы не можем выразить подоси с помощью jax.sharding.NamedSharding
).
У нас есть несколько вариантов решения таких случаев:
- Разрешите и верните сегментирование в другом формате (например,
jax.sharding.PositionalSharding
вместоjax.sharding.NamedSharding
в JAX). - Запретить и собрать все подоси, которые сегментируют ввод/вывод.
В настоящее время мы разрешаем подоси на входах/выходах в конвейере распространения. Дайте нам знать, если вам нужен способ отключить это.
Представительство
Точно так же, как мы можем ссылаться на определенные полные оси сетки по их имени, мы можем ссылаться на определенные подоси по их размеру и произведению всех размеров подосей (с тем же именем оси) слева от них (которые являются основными для них).
Чтобы извлечь конкретную подось размера k
из полной оси "x"
размера n
, мы фактически изменяем размер n
(в сетке) на [m, k, n/(m*k)]
и используем второе измерение в качестве подоси. Таким образом, подось может быть задана двумя числами, m
и k
, и для обозначения подосей мы используем следующее краткое обозначение: "x":(m)k
.
m>=1
— это предварительный размер этой подоси (m
должен быть делителемn
). Предварительный размер — это произведение всех размеров подоси слева от этой подоси (которые являются основными для нее) (если он равен 1, это означает, что их нет, если больше 1, это соответствует одной или нескольким подосям).k>1
— фактический размер этой подоси (k
должен быть делителемn
).n/(m*k)
— размер сообщения . Это произведение всех размеров подоси справа от этой подоси (которые являются второстепенными по отношению к ней) (если равно 1, это означает, что их нет, если больше 1, это соответствует одной или нескольким подосям).
Однако количество других подосей не имеет значения при использовании конкретной подоси "x":(m)k
, и нет необходимости ссылаться на любую другую подось при сегментировании тензора, если она не сегментирует измерение или явно реплицируется.
Возвращаясь к примеру из раздела «Мотивация» , мы можем сегментировать результат следующим образом:
@mesh_x = <["x"=4]>
%arg0 : tensor<8xf32> {sdy.sharding=<@mesh_x, [{"x"}]>}
%0 = reshape %arg0 {sdy.sharding_per_value=<[<@mesh_x, [{"x":(1)2}, {"x":(2)2}]>]>}
: (tensor<8xf32>) -> tensor<2x4xf32>
Вот еще один пример разделенной оси, где используются только некоторые из ее подосей.
@mesh_xyz = <["x"=2, "y"=8, "z"=2]>
// Axis "y" is effectively split into 3 sub-axes denoted as
// "y":(1)2, "y":(2)2, "y":(4)2
// in order, but only "y":(2)2 is used, to shard the 2nd dimension. The local
// shape of this tensor (i.e. the shape on a single device), would be
// tensor<2x4xf32>.
sharding<@mesh_xyz, [{"x"}, {"y":(2)2}]> : tensor<4x8xf32>
Аналогично, следующие два шардинга семантически эквивалентны. Мы можем думать о mesh_xy
как о разделении mesh_full
.
@mesh_full = <"devices"=8>
@mesh_xy = <"x"=4, "y"=2>
sharding<@mesh_xy, [{"x"},{ "y"}]> : tensor<4x4xf32>
sharding<@mesh_full, [{"devices":(1)4}, {"devices":(4)2}]> : tensor<4x4xf32>
Явно реплицированные подоси
Помимо подосей, используемых для сегментирования измерений, они также могут быть помечены как явно реплицированные. Мы допускаем это в представлении, поскольку подоси ведут себя так же, как полные оси, т. е. когда вы сегментируете измерение вдоль подоси оси "x"
, другие подоси "x"
реплицируются неявно и, следовательно, могут быть реплицированы явно, чтобы указать, что подось должна оставаться реплицированной и не может использоваться для сегментирования измерения.
Например:
@mesh_xyz = <["x"=2, "y"=8, "z"=2]>
// Sub-axis "y":(1)2 is explicitly replicated and "y":(4)2 is implicitly replicated.
sharding<@mesh_xyz, [{"x"}, {"y":(2)2}], replicated={"y":(1)2}> : tensor<4x8xf32>
Реплицированные подоси одной и той же полной оси следует располагать в порядке возрастания их предварительного размера, например:
replicated={"y":(4)2, "x", "y":(1)2} ~> replicated={"x", "y":(1)2, "y":(4)2}
Инварианты
Подоси, на которые ссылаются в тензорном сегментировании, не должны перекрываться, например
"x":(1)4
и"x":(2)4
перекрываются.Подоси, на которые ссылаются в тензорном сегментировании, должны быть как можно больше, т. е. если сегментирование измерения имеет две соседние по порядку подоси A и B или подоси A и B явно реплицируются, они не должны быть последовательными, например
"x":(1)2
и"x":(2)4
поскольку их можно заменить одним"x":(1)8
.
Несколько логических сеток
Одна логическая сетка представляет собой многомерное представление устройств. Нам может потребоваться несколько представлений устройств для представления наших сегментов, особенно для произвольных назначений устройств.
Например, jax.sharding.PositionalSharding
не имеет одной общей логической сетки . GSPMD в настоящее время поддерживает это с помощью HloSharding, где представление может представлять собой упорядоченный список устройств и размеров измерений, но это невозможно представить с помощью разделения осей, описанного выше.
Мы преодолеваем это ограничение и обрабатываем существующие крайние случаи, определяя несколько логических сеток на верхнем уровне программы. Каждая сетка может иметь разное количество осей с разными именами, а также свое произвольное назначение для одного и того же набора устройств, т.е. каждая сетка относится к одному и тому же набору устройств (по их уникальному логическому идентификатору), но с произвольным порядком, аналогично представлению GSPMD.
Каждое представление сегментирования связано с определенной логической сеткой, поэтому оно будет ссылаться только на оси этой сетки.
Тензор, назначенный одной логической сетке, может использоваться операцией, назначенной другой сетке, путем наивного изменения тензора в соответствии с целевой сеткой. В GSPMD это обычно делается для разрешения конфликтующих сеток.
Пользователи могут указать несколько сеток с разными именованными осями (например, через jax.sharding.NamedSharding
), которые имеют одинаковый порядок устройств. Рассмотрим этот пример: <@mesh_0, "b">
идентичен <@mesh_1, "z">
:
@mesh_0 = {<["a"=4, "b"=2]>, device_ids=[0, 1, 2, 3, 4, 5, 6, 7]}
@mesh_1 = {<["x"=2, "y"=2, "z"=2]>, device_ids=[0, 1, 2, 3, 4, 5, 6, 7]}
Приоритеты
Приоритет — это способ определения приоритета определенных решений по секционированию и распространению над другими, а также позволяет выполнять постепенное секционирование программы.
Приоритеты — это значения, прикрепленные к некоторым или всем измерениям представления сегментирования (реплицируемые оси не имеют приоритетов).
Например:
@mesh_xy = <["w"=6, "x"=2, "y"=4, "z"=2]>
// |-> y is implicitly p0
%arg4 : sharding<@mesh_xy, [{"x"}p1, {"y"}, {"z",?}p2], replicated={} }>
Приоритеты дают пользователям более детальный контроль над распространением, например, сначала пакетный параллелизм, затем мегатрон и, наконец, сегментирование ZeRO . Это дает надежные гарантии относительно того, что секционировано, и обеспечивает лучшую отладку за счет более детальных стратегий сегментирования (можно увидеть, как программа работает только с мегатроном в изоляции).
Мы разрешаем присвоить приоритет каждому сегментированию измерения (по умолчанию 0), что означает, что все сегменты с приоритетом <i
будут распространяться на всю программу раньше сегментов с приоритетом i
.
Даже если сегментирование имеет открытое измерение с более низким приоритетом, например, {"z",?}p2
, оно не будет переопределено другим тензорным сегментированием с более высоким приоритетом во время распространения. Однако такое открытое измерение может быть дополнительно сегментировано после распространения всех сегментов с более высоким приоритетом.
Другими словами, приоритеты НЕ связаны с тем, какое сегментирование измерений важнее другого - это порядок, в котором отдельные группы сегментирования измерений должны распространяться на всю программу, и то, как следует разрешать конфликты на промежуточных, неаннотированных тензорах.
Инварианты
Приоритеты начинаются с 0 (наивысший приоритет) и увеличиваются (чтобы пользователи могли легко добавлять и удалять приоритеты, мы допускаем промежутки между приоритетами, например, используются p0 и p2, а p1 — нет).
Пустой сегмент закрытого измерения (т. е.
{}
) не должен иметь приоритета, поскольку это не будет иметь никакого эффекта.
Делимость сегментирования измерений
Размерность размера d
может быть сегментирована по осям, произведение размеров которых равно n
, так что d
не делится на n
(что на практике потребовало бы заполнения измерения).
Например:
@mesh_xy = <["x"=8, "y"=2, "z"=3]>
sharding<@mesh_xy, [{"x"}, {"y"}, {"z"}]> : tensor<7x3x8xf32>
Грамматика
Каждая логическая сетка определяется следующим образом:
@mesh_name = <mesh_axis_1,...,mesh_axis_n>
mesh_axis ::= axis_name=axis_size
axis_name ::= str
axis_size ::= int
Представление шардинга будет иметь следующую структуру для тензора ранга r:
sharding<@mesh_name, dim_shardings, replicated=replicated_axes}
mesh_name ::= str
dim_shardings ::= [dim_sharding_1,...,dim_sharding_r]
replicated_axes ::= {axis_1,...,axis_m}
dim_sharding ::=
{axis_1,...,axis_k} | // closed dimension
{axis_1,...,axis_k,?} // open dimension
axis ::=
axis_name | // a full axis
sub_axis // a sub axis
axis_name ::= str
sub_axis ::= axis_name:(pre_size)size
pre_size ::= int
size ::= int
Фон
Цель представления сегментирования — указать, как сегментируется тензор относительно набора доступных устройств.
Представление шардинга может быть:
- Вручную указывается пользователем как ограничения сегментирования на входах, выходах или промежуточных звеньях.
- Трансформируется за операцию в процессе распространения шардинга.
Обзор
Базовая структура
Логическая сетка — это многомерное представление устройств, определяемое списком имен и размеров осей.
Предлагаемое представление сегментирования привязано к определенной логической сетке по ее имени и может ссылаться только на имена осей из этой сетки. Шардинг тензора определяет, по каким осям (конкретной логической сетки) сегментируется каждое измерение тензора, в порядке от большего к меньшему. Тензор копируется вдоль всех остальных осей сетки.
Давайте рассмотрим представление сегментирования с помощью простого тензора ранга 2 и устройств 4.
Сначала мы преобразуем 4 устройства [0, 1, 2, 3]
в двумерный массив [[0, 1], [2, 3]]
чтобы создать сетку с двумя осями:
@mesh_xy = <["x"=2, "y"=2]>
Затем мы можем сегментировать следующий тензор ранга 2 [[a, b], [c, d]]
следующим образом:
Другие ключевые компоненты
- Открытые/закрытые измерения — размеры могут быть открытыми — их можно дополнительно сегментировать по доступным осям; или закрытые – фиксированы и не могут быть изменены.
- Явно реплицированные оси — все оси, которые не используются для сегментирования измерения, реплицируются неявно, но при сегментировании могут быть указаны оси, которые реплицируются явно и поэтому не могут быть использованы для сегментирования измерения в дальнейшем.
- Разделение осей и подоси — (полную) ось сетки можно разделить на несколько подосей, которые можно индивидуально использовать для сегментирования измерения или явно реплицировать.
- Несколько логических сеток — разные сегменты могут быть привязаны к разным логическим сеткам, которые могут иметь разные оси или даже разный порядок идентификаторов логических устройств.
- Приоритеты — для поэтапного разделения программы приоритеты можно прикрепить к сегментам измерений, которые определяют, в каком порядке ограничения сегментирования каждого измерения будут распространяться по всему модулю.
- Делимость измерения : измерение может быть сегментировано по осям, произведение размеров которых не делит размер измерения.
Детальный проект
В этом разделе мы раскрываем базовую структуру и каждый ключевой компонент.
Базовая структура
Шардинги измерений сообщают нам для каждого измерения тензора, по каким осям (или подосям ) он сегментируется от главного к второстепенному. Все остальные оси, которые не сегментируют измерение, реплицируются неявно (или реплицируются явно ).
Мы начнем с простого примера и будем расширять его по мере описания дополнительных функций.
@mesh_xy = <["x"=2, "y"=4, "z"=2]>
// The 1st tensor dimension is sharded along axis "x" and the 2nd tensor dimension is
// sharded along axis "z" then further along axis "y". The local shape of this tensor (i.e. the shape on a single device), would be tensor<2x1xf32>.
sharding<@mesh_xy, [{"x"}, {"z", "y"}]> : tensor<4x8xf32>
Инварианты
- Количество сегментов измерения должно соответствовать рангу тензора.
- Все имена осей должны существовать в сетке, на которую ссылаются.
- Оси или подоси могут появляться в представлении сегментирования только один раз (каждая из них либо сегментирует измерение, либо явно реплицируется ).
Открытые/закрытые размеры
Каждое измерение тензора может быть открытым или закрытым.
Открыть
Открытое измерение открыто для распространения с целью дальнейшего сегментирования его по дополнительным осям, т. е. указанное сегментирование измерения не обязательно должно быть окончательным сегментированием этого измерения. Это похоже (но не совсем то же самое) на unspecified_dims
GSPMD.
Если измерение открыто, мы добавляем ?
по осям, по которым уже сегментировано измерение (см. пример ниже).
Закрыто
Закрытое измерение — это измерение, которое недоступно для распространения с целью добавления дальнейшего сегментирования, т. е. указанное сегментирование измерения является окончательным сегментированием этого измерения, и его нельзя изменить. Распространенным примером использования этого является то, что GSPMD (обычно) не изменяет аргументы ввода/вывода модуля или как с помощью jax.jit
пользователь, указанный in_shardings
, является статичным - они не могут измениться.
Мы можем расширить приведенный выше пример, добавив в него открытое и закрытое измерение.
@mesh_xy = <["x"=2, "y"=4, "z"=2]>
// The 1st dimension is closed, therefore it can't be further sharded and {"x"}
// will remain its sharding. The 2nd dimension is open, and can therefore be
// further sharded during propagation, e.g. by "y".
sharding<@mesh_xy, [{"x"}, {"z", ?}]> : tensor<4x8xf32>
Явно реплицированные оси
Явный набор осей, на которых реплицируется тензор. Хотя можно определить, что тензор, не сегментированный по оси, неявно реплицируется на нем (как сегодня jax.sharding.PartitionSpec
), его явное использование гарантирует, что распространение не сможет использовать эти оси для дальнейшего сегментирования открытого измерения с помощью этих осей. При неявной репликации тензор можно дополнительно разбить. Но при явной репликации ничто не может разделить тензор по этой оси.
Порядок реплицируемых осей не влияет на то, как хранятся данные тензора. Но исключительно для обеспечения единообразия оси будут храниться в том порядке, в котором они указаны в сетке верхнего уровня. Например, если сетка:
@mesh_xy = <["c"=2, "a"=2, "b"=2]>
И мы хотим, чтобы оси "a"
и "c"
были явно реплицированы, порядок должен быть таким:
replicated={"c", "a"}
Мы можем расширить приведенный выше пример, чтобы иметь явно реплицированную ось.
@mesh_xyz = <["x"=2, "y"=4, "z"=2]>
// Since "y" is explicitly replicated, it can't be used to shard the 2nd
// dimension that is open. However, "z" is implicitly replicated so it can be
// used to shard that dimension. The local shape of this tensor (i.e. the
// shape on a single device), would be tensor<2x8xf32>.
sharding<@mesh_xyz, [{"x"}, {?}], replicated={"y"}> : tensor<4x8xf32>
Разделение осей и подоси
Логическая сетка из n
осей создается путем преобразования одномерного массива устройств в n-мерный массив, где каждое измерение образует ось с определяемым пользователем именем.
Тот же процесс можно выполнить в компиляторе, чтобы разбить ось размера k
на m
подосей, изменив форму сетки из [...,k,...]
в [...,k1,...,km,...]
.
Мотивация
Чтобы понять мотивацию разделения осей, мы рассмотрим следующий пример:
@mesh_x = <["x"=4]>
%arg0 : tensor<8xf32> {sdy.sharding=<@mesh_x, [{"x"}]>}
%0 = reshape %arg0 : (tensor<8xf32>) -> tensor<2x4xf32>
Мы хотим сегментировать результат изменения формы таким образом, чтобы избежать обмена данными (т. е. сохранить данные там, где они есть). Поскольку размер "x"
больше, чем 1-е измерение результата, нам нужно разделить ось на две подоси "x.0"
и "x.1"
размером 2 каждая и разделить 1-е измерение на "x.0"
и 2-е измерение на "x.1"
.
Функция Шардинг ввода/вывода
Вполне возможно, что во время распространения вход или выход основной функции будет сегментирован вдоль подоси. Это может быть проблемой для некоторых фреймворков, где мы не можем выразить такие наборы, чтобы вернуть пользователю (например, в JAX мы не можем выразить субсины с jax.sharding.NamedSharding
).
У нас есть несколько вариантов борьбы с такими случаями:
- Разрешить и вернуть шардинг в другом формате (например,
jax.sharding.PositionalSharding
вместоjax.sharding.NamedSharding
in jax). - Отбросьте, и все-грамотные подзажают, которые нарушают вход/вывод.
В настоящее время мы разрешаем подоги на входах/выходах в трубопроводе распространения. Дайте нам знать, если вы хотите отключить это.
Представительство
Точно так же, как мы можем ссылаться на конкретные полные оси от сетки по их названию, мы можем ссылаться на конкретные подоши по их размеру и продукту всех подраздел (от одного и того же названия оси) слева (которые являются для них основными).
Чтобы извлечь определенную подосную оси размера k
из полной оси "x"
размера n
, мы эффективно изменяем размер n
(в сетке) в [m, k, n/(m*k)]
и используем 2-е измерение в качестве подзасоси. Таким образом, субсида может быть указана двумя числами: m
и k
, и мы используем следующие краткие нотации для обозначения подосов: "x":(m)k
.
m>=1
является предварительным размером этой подоси (m
должен быть делителемn
). Предварительный размер-это продукт всех субсисных размеров слева от (которые являются основными для) этой подосной оси (если равна 1, это означает, что нет, если больше, чем 1, он соответствует одним или множественным подосам).k>1
является фактическим размером этой подоси (k
должен быть делителемn
).n/(m*k)
-это пост-размер . Это продукт всех субсисных размеров справа от (которые незначительны) этой подосной оси (если равный 1, это означает, что нет, если нет, если больше 1, он соответствует одним или множественным подосам).
Тем не менее, количество других подосов не имеет значения при использовании конкретной подоси "x":(m)k
, и любая другая подосная ось не нужно ссылаться в тензоре, если оно не будет нарушено или явно воспроизводится.
Возвращаясь к примеру в разделе «Мотивация» , мы можем отложить результат следующим образом:
@mesh_x = <["x"=4]>
%arg0 : tensor<8xf32> {sdy.sharding=<@mesh_x, [{"x"}]>}
%0 = reshape %arg0 {sdy.sharding_per_value=<[<@mesh_x, [{"x":(1)2}, {"x":(2)2}]>]>}
: (tensor<8xf32>) -> tensor<2x4xf32>
Вот еще один пример разделенной оси, где используются только некоторые из ее подосов.
@mesh_xyz = <["x"=2, "y"=8, "z"=2]>
// Axis "y" is effectively split into 3 sub-axes denoted as
// "y":(1)2, "y":(2)2, "y":(4)2
// in order, but only "y":(2)2 is used, to shard the 2nd dimension. The local
// shape of this tensor (i.e. the shape on a single device), would be
// tensor<2x4xf32>.
sharding<@mesh_xyz, [{"x"}, {"y":(2)2}]> : tensor<4x8xf32>
Точно так же следующие два посадки семантически эквивалентны. Мы можем думать о mesh_xy
как расщепление mesh_full
.
@mesh_full = <"devices"=8>
@mesh_xy = <"x"=4, "y"=2>
sharding<@mesh_xy, [{"x"},{ "y"}]> : tensor<4x4xf32>
sharding<@mesh_full, [{"devices":(1)4}, {"devices":(4)2}]> : tensor<4x4xf32>
Явно воспроизведенные подоги
В дополнение к тем, что подзы, используемые для измерения, они также могут быть отмечены как явные воспроизведения. Мы разрешаем это в представлении, потому что под оси ведут себя так же, как и полные оси, то есть, когда вы нарушаете измерение вдоль оси оси "x"
, другие подосные оси "x"
неявно воспроизведены и, следовательно, могут быть явно воспроизведены, чтобы указать, что субсида должна оставаться воспроизведенной и не может использоваться для измерения измерения.
Например:
@mesh_xyz = <["x"=2, "y"=8, "z"=2]>
// Sub-axis "y":(1)2 is explicitly replicated and "y":(4)2 is implicitly replicated.
sharding<@mesh_xyz, [{"x"}, {"y":(2)2}], replicated={"y":(1)2}> : tensor<4x8xf32>
Реплицированная подосная ось той же полной оси должна быть заказана в порядке увеличения по их предварительному размеру, например:
replicated={"y":(4)2, "x", "y":(1)2} ~> replicated={"x", "y":(1)2, "y":(4)2}
Инварианты
Субсосы, на которые ссылается на тензор, не должны перекрываться, например
"x":(1)4
и"x":(2)4
перекрытие.Подрачные оси, на которые ссылаются в тензоре, должны быть максимально большими, как можно более большими, то есть, если размерный шардинг имеет две соседние подзасы a и b в порядке, или субсины A и B явно воспроизведены, они не должны быть последовательными, например,
"x":(1)2
и"x":(2)4
поскольку они могут быть заменены на один"x":(1)8
.
Несколько логических сетей
Одна логическая сетка-это многомерный вид устройств. Нам может понадобиться несколько представлений устройств, чтобы представлять наши наборы, особенно для работоспособных заданий.
Например, jax.sharding.PositionalSharding
не имеет ни одной общей логической сетки . В настоящее время GSPMD подтверждает, что с помощью Hlosharding, где представление может быть упорядоченным списком устройств и размеров измерений, но это не может быть представлено с расщеплением оси выше.
Мы преодолеваем это ограничение и обрабатываем существующие угловые случаи, определяя несколько логических сетей на верхнем уровне программы. Каждая сетка может иметь различное количество осей с разными именами, а также свое собственное произвольное назначение для одного и того же набора устройств, то есть каждая сетка относится к одному и тому же набору устройств (по их уникальному логическому идентификатору), но с произвольным порядком, аналогичным представлению GSPMD.
Каждое представление о шардинге связано с определенной логической сеткой, поэтому оно будет ссылаться только на осги из этой сетки.
Тензор, который назначен одной логической сетке, может быть использован операционным операцией, которая назначена другой сетке, наивно изменив тензор в соответствии с сеткой назначения. В GSPMD это то, что обычно делается, чтобы разрешить конфликтующие сетки.
Пользователи могут указать несколько сетей с различными именованными осями (например, через jax.sharding.NamedSharding
), которые имеют одинаковый порядок устройств. Рассмотрим этот пример, <@mesh_0, "b">
идентичен <@mesh_1, "z">
:
@mesh_0 = {<["a"=4, "b"=2]>, device_ids=[0, 1, 2, 3, 4, 5, 6, 7]}
@mesh_1 = {<["x"=2, "y"=2, "z"=2]>, device_ids=[0, 1, 2, 3, 4, 5, 6, 7]}
Приоритеты
Приоритет - это способ расставить приоритеты в определенных решениях по разделению и распространению по сравнению с другими, и позволяет постепенно разделять программу.
Приоритеты - это значения, прикрепленные к некоторым или всеми размерами представления шардинга (реплицированные оси не имеют приоритетов).
Например:
@mesh_xy = <["w"=6, "x"=2, "y"=4, "z"=2]>
// |-> y is implicitly p0
%arg4 : sharding<@mesh_xy, [{"x"}p1, {"y"}, {"z",?}p2], replicated={} }>
Приоритеты дают пользователям более тонкий контроль над распространением, например, пакетный параллелизм сначала, затем Megatron и, наконец, нулевой шарнинг. Это допускает сильные гарантии относительно того, что разделено, и обеспечивает лучшую отладку, имея более мелкозернистые стратегии шардинга (может увидеть, как программа заботится о только мегатроне в изоляции).
Мы допускаем приоритет приоритета к каждому размерному шарду (0 по умолчанию), что указывает на то, что все посадки с приоритетом <i
буду распространяться на всю программу перед посадками с приоритетом i
.
Даже если шарнинг имеет открытый измерение с более низким приоритетом, например, {"z",?}p2
, оно не будет переопределено еще одним тензором с более высоким приоритетом во время распространения. Тем не менее, такое открытое измерение может быть дополнительно нарушено после того, как все более высокие приоритетные нарушения были распространены.
Другими словами, приоритеты не связаны с тем, какие измерения оскорбления являются более важным, чем другое - это порядок, в котором различные группы измерений должны распространяться на всю программу, и как следует разрешать конфликты в промежуточных, необеспеченных тензорах.
Инварианты
Приоритеты начинаются с 0 (самый высокий приоритет) и увеличиваются (чтобы пользователи могли легко добавлять и удалять приоритеты, мы разрешаем разрывы между приоритетами, например, P0 и P2 используются, но P1 не).
Пустое замкнутое измерение (то есть
{}
), не должно иметь приоритета, так как это не будет иметь никакого эффекта.
Размерное распределение раздвоения
Возможно, чтобы размер размера d
было охвачено вдоль оси, продукт n
, такой, что d
не делится на n
(что на практике потребует, чтобы измерение было мягким).
Например:
@mesh_xy = <["x"=8, "y"=2, "z"=3]>
sharding<@mesh_xy, [{"x"}, {"y"}, {"z"}]> : tensor<7x3x8xf32>
Грамматика
Каждая логическая сетка определяется следующим образом:
@mesh_name = <mesh_axis_1,...,mesh_axis_n>
mesh_axis ::= axis_name=axis_size
axis_name ::= str
axis_size ::= int
Представление о шардинге будет иметь следующую структуру для тензора ранга R:
sharding<@mesh_name, dim_shardings, replicated=replicated_axes}
mesh_name ::= str
dim_shardings ::= [dim_sharding_1,...,dim_sharding_r]
replicated_axes ::= {axis_1,...,axis_m}
dim_sharding ::=
{axis_1,...,axis_k} | // closed dimension
{axis_1,...,axis_k,?} // open dimension
axis ::=
axis_name | // a full axis
sub_axis // a sub axis
axis_name ::= str
sub_axis ::= axis_name:(pre_size)size
pre_size ::= int
size ::= int