641 lines
23 KiB
Python
641 lines
23 KiB
Python
|
|
# mypy: allow-untyped-defs
|
||
|
|
# pyre-strict
|
||
|
|
from __future__ import annotations
|
||
|
|
|
||
|
|
import heapq
|
||
|
|
import operator
|
||
|
|
import sys
|
||
|
|
from collections import defaultdict
|
||
|
|
from typing import Dict, List, Set, TYPE_CHECKING
|
||
|
|
|
||
|
|
import torch
|
||
|
|
|
||
|
|
from . import config, ir
|
||
|
|
from .dependencies import WeakDep
|
||
|
|
from .utils import (
|
||
|
|
contains_collective,
|
||
|
|
contains_wait,
|
||
|
|
find_recursive_deps_of_node,
|
||
|
|
find_recursive_users_of_node,
|
||
|
|
is_collective,
|
||
|
|
is_fallback_op,
|
||
|
|
is_wait,
|
||
|
|
)
|
||
|
|
|
||
|
|
|
||
|
|
overlap_log = torch._logging.getArtifactLogger(__name__, "overlap")
|
||
|
|
|
||
|
|
if TYPE_CHECKING:
|
||
|
|
from .scheduler import BaseSchedulerNode
|
||
|
|
|
||
|
|
|
||
|
|
def sink_waits(snodes: List[BaseSchedulerNode]) -> List[BaseSchedulerNode]:
|
||
|
|
"""
|
||
|
|
Greedily schedules waits as late as possible.
|
||
|
|
"""
|
||
|
|
return _schedule_for_comm(
|
||
|
|
snodes, raise_comms=False, sink_waits=True, reorder_for_overlap=False
|
||
|
|
)
|
||
|
|
|
||
|
|
|
||
|
|
def raise_comms(snodes: List[BaseSchedulerNode]) -> List[BaseSchedulerNode]:
|
||
|
|
"""
|
||
|
|
Greedily schedules comms as early as possible.
|
||
|
|
"""
|
||
|
|
return _schedule_for_comm(
|
||
|
|
snodes, raise_comms=True, sink_waits=False, reorder_for_overlap=False
|
||
|
|
)
|
||
|
|
|
||
|
|
|
||
|
|
def reorder_compute_for_overlap(
|
||
|
|
snodes: List[BaseSchedulerNode],
|
||
|
|
) -> List[BaseSchedulerNode]:
|
||
|
|
"""
|
||
|
|
This achieves the following overall scheduling procedure:
|
||
|
|
Step 1: Given that we've currently scheduled comm N, we now schedule all compute nodes
|
||
|
|
that are required for comm N + 1 but do not depend on comm N, to run at the same time with comm N.
|
||
|
|
Step 2: If all those compute nodes are sufficient to overlap comm N, we're done.
|
||
|
|
Otherwise, we now need to look elsewhere to find compute that overlaps with comm N.
|
||
|
|
We prioritize compute nodes that are needed sooner.
|
||
|
|
Step 3: We schedule the compute nodes dependent on comm N and required for comm N + 1.
|
||
|
|
Step 4: We schedule comm N + 1.
|
||
|
|
Repeat this for subsequent comm nodes.
|
||
|
|
"""
|
||
|
|
return _schedule_for_comm(
|
||
|
|
snodes, raise_comms=True, sink_waits=True, reorder_for_overlap=True
|
||
|
|
)
|
||
|
|
|
||
|
|
|
||
|
|
def _schedule_for_comm(
|
||
|
|
snodes: List[BaseSchedulerNode],
|
||
|
|
raise_comms: bool,
|
||
|
|
sink_waits: bool,
|
||
|
|
reorder_for_overlap: bool,
|
||
|
|
) -> List[BaseSchedulerNode]:
|
||
|
|
"""
|
||
|
|
Schedule `snodes` for various comm optimization objectives.
|
||
|
|
|
||
|
|
Args:
|
||
|
|
snodes: the nodes to be scheduled.
|
||
|
|
raise_comms: whether to greedily schedule collectives as early as possible
|
||
|
|
sink_wait: whether to greedily schedule waits as late as possible
|
||
|
|
reorder_compute_for_overlap: whether to reorder compute nodes to
|
||
|
|
optimize for compute/communication overlapping.
|
||
|
|
|
||
|
|
Returns:
|
||
|
|
The new schedule order.
|
||
|
|
|
||
|
|
Some notes on the synergy between different options:
|
||
|
|
- `raise_comms` provides more overlapping oppurtunies for `reorder_compute_for_overlap`.
|
||
|
|
- When both `raise_comms` and `sink_waits` is `True`, `raise_comms` is prioritized.
|
||
|
|
"""
|
||
|
|
# We assign each node a tuple of scores (score_0, score_1, score_2),
|
||
|
|
# decreasing in importance, with a lower value indicating a higher ranking:
|
||
|
|
#
|
||
|
|
# - score_0: the lowest comm_idx among the comm nodes that the node blocks.
|
||
|
|
# If a node doesn't block any comm nodes, its score_0 is set to
|
||
|
|
# sys.maxsize. This score ensures that comm nodes get scheduled as early as
|
||
|
|
# possible.
|
||
|
|
# - score_1: 1 if the node is a wait node, 0 otherwise. This score ensures
|
||
|
|
# that wait nodes are deferred as late as possible.
|
||
|
|
# - score_2: the index of the node in the original topological order. This
|
||
|
|
# score provides stability in case of ties.
|
||
|
|
#
|
||
|
|
# When only raise_comms is True, only score_0 and score_2 are considered.
|
||
|
|
# When only sink_waits is True, only score_1 and score_2 are considered.
|
||
|
|
# When neither is True, the original order is yielded.
|
||
|
|
buf_name_to_snode = {}
|
||
|
|
name_to_fused_node = {}
|
||
|
|
scores_0, scores_1, scores_2 = {}, {}, {}
|
||
|
|
for idx, snode in enumerate(snodes):
|
||
|
|
for buf_name in snode.get_buffer_names():
|
||
|
|
buf_name_to_snode[buf_name] = snode
|
||
|
|
|
||
|
|
for op_name in snode.get_operation_names():
|
||
|
|
name_to_fused_node[op_name] = snode
|
||
|
|
name_to_fused_node[snode.get_name()] = snode
|
||
|
|
|
||
|
|
node_name = snode.get_name()
|
||
|
|
scores_0[node_name] = sys.maxsize
|
||
|
|
scores_1[node_name] = 0
|
||
|
|
scores_2[node_name] = idx
|
||
|
|
|
||
|
|
comm_idx = 0
|
||
|
|
for snode in snodes:
|
||
|
|
if raise_comms and contains_collective(snode):
|
||
|
|
scores_0[snode.get_name()] = comm_idx
|
||
|
|
for anc in snode.ancestors:
|
||
|
|
anc_fused_name = name_to_fused_node[anc].get_name()
|
||
|
|
scores_0[anc_fused_name] = min(scores_0[anc_fused_name], comm_idx)
|
||
|
|
comm_idx += 1
|
||
|
|
elif sink_waits and contains_wait(snode):
|
||
|
|
scores_1[snode.get_name()] = 1
|
||
|
|
|
||
|
|
class Runnable:
|
||
|
|
def __init__(self, snode) -> None:
|
||
|
|
self.snode = snode
|
||
|
|
name = next(iter(snode.get_operation_names()))
|
||
|
|
fused_name = name_to_fused_node[name].get_name()
|
||
|
|
self.score = (
|
||
|
|
scores_0[fused_name],
|
||
|
|
scores_1[fused_name],
|
||
|
|
scores_2[fused_name],
|
||
|
|
)
|
||
|
|
|
||
|
|
def __lt__(self, other):
|
||
|
|
return self.score < other.score
|
||
|
|
|
||
|
|
unmet_deps: Dict[BaseSchedulerNode, Set[str]] = {
|
||
|
|
snode: {dep.name for dep in snode.unmet_dependencies} for snode in snodes
|
||
|
|
}
|
||
|
|
|
||
|
|
ready: List[Runnable] = []
|
||
|
|
buffer_users: Dict[str, Set[BaseSchedulerNode]] = defaultdict(set)
|
||
|
|
snode_to_cost = {snode: estimate_op_runtime(snode) for snode in snodes}
|
||
|
|
|
||
|
|
for snode, deps in unmet_deps.items():
|
||
|
|
if len(deps) == 0:
|
||
|
|
heapq.heappush(ready, Runnable(snode))
|
||
|
|
for dep in deps:
|
||
|
|
buffer_users[dep].add(snode)
|
||
|
|
|
||
|
|
scheduled = []
|
||
|
|
|
||
|
|
def schedule(snode):
|
||
|
|
"""
|
||
|
|
Schedules `snode` and put all unblocked nodes onto the ready queue.
|
||
|
|
"""
|
||
|
|
scheduled.append(snode)
|
||
|
|
for buf_name in snode.get_buffer_names():
|
||
|
|
for snode in buffer_users[buf_name]:
|
||
|
|
unmet_deps[snode].remove(buf_name)
|
||
|
|
if len(unmet_deps[snode]) == 0:
|
||
|
|
heapq.heappush(ready, Runnable(snode))
|
||
|
|
|
||
|
|
def get_overlapping_candidate():
|
||
|
|
"""
|
||
|
|
Return the next node in the ready queue that's neither a collective or
|
||
|
|
a wait.
|
||
|
|
"""
|
||
|
|
candidates = [
|
||
|
|
x
|
||
|
|
for x in ready
|
||
|
|
if not contains_collective(x.snode) and not contains_wait(x.snode)
|
||
|
|
]
|
||
|
|
if len(candidates) == 0:
|
||
|
|
return None
|
||
|
|
return min(candidates, key=lambda x: x.score)
|
||
|
|
|
||
|
|
def schedule_collective_for_overlap(snode):
|
||
|
|
"""
|
||
|
|
Schedules collective node `snode`, along with one or more compute nodes
|
||
|
|
to overlap with it. The strategy is described in the comment of
|
||
|
|
`reorder_compute_for_overlap`.
|
||
|
|
"""
|
||
|
|
assert contains_collective(snode)
|
||
|
|
schedule(snode)
|
||
|
|
|
||
|
|
collective_cost = snode_to_cost[snode]
|
||
|
|
while (
|
||
|
|
collective_cost > 0
|
||
|
|
and (candidate := get_overlapping_candidate()) is not None
|
||
|
|
):
|
||
|
|
ready.remove(candidate)
|
||
|
|
schedule(candidate.snode)
|
||
|
|
collective_cost -= snode_to_cost[candidate.snode]
|
||
|
|
heapq.heapify(ready)
|
||
|
|
|
||
|
|
while len(ready):
|
||
|
|
snode = heapq.heappop(ready).snode
|
||
|
|
if reorder_for_overlap and contains_collective(snode):
|
||
|
|
schedule_collective_for_overlap(snode)
|
||
|
|
else:
|
||
|
|
schedule(snode)
|
||
|
|
|
||
|
|
for snode, deps in unmet_deps.items():
|
||
|
|
assert len(deps) == 0, (
|
||
|
|
"Detected unscheduled nodes. "
|
||
|
|
f"Nodes with unmet dependencies: {unmet_deps}"
|
||
|
|
)
|
||
|
|
return scheduled
|
||
|
|
|
||
|
|
|
||
|
|
def decide_global_ordering_of_comms(
|
||
|
|
nodes: List[BaseSchedulerNode], name_to_buf, name_to_fused_node
|
||
|
|
) -> List[BaseSchedulerNode]:
|
||
|
|
"""
|
||
|
|
Decide global ordering of comms, by just enforcing the ordering that's in the input graph
|
||
|
|
(might not be the same ordering as the eager mode program).
|
||
|
|
TODO: Come up with a better approach
|
||
|
|
"""
|
||
|
|
# If FSDP2 is used, we apply FSDP-specific passes.
|
||
|
|
if any(
|
||
|
|
is_fallback_op(
|
||
|
|
x.node,
|
||
|
|
{
|
||
|
|
torch.ops.fsdp.all_gather_copy_in.default,
|
||
|
|
torch.ops.fsdp.chunk_cat.default,
|
||
|
|
},
|
||
|
|
)
|
||
|
|
for x in nodes
|
||
|
|
):
|
||
|
|
nodes = enforce_comm_ordering_for_fsdp(nodes, name_to_buf, name_to_fused_node)
|
||
|
|
|
||
|
|
comm_nodes = [n for n in nodes if contains_collective(n)]
|
||
|
|
|
||
|
|
for i in range(1, len(comm_nodes)):
|
||
|
|
# Enforce ordering by making previous comm a `WeakDep` dependency of the next comm
|
||
|
|
mutating_buf = next(iter(comm_nodes[i].get_buffer_names()))
|
||
|
|
for buf in comm_nodes[i - 1].get_buffer_names():
|
||
|
|
comm_nodes[i].add_fake_dep(WeakDep(buf, mutating_buf=mutating_buf))
|
||
|
|
|
||
|
|
return nodes
|
||
|
|
|
||
|
|
|
||
|
|
def estimate_op_runtime(snode: BaseSchedulerNode) -> float:
|
||
|
|
"""
|
||
|
|
Returns estimated op runtime in nanoseconds (ns)
|
||
|
|
"""
|
||
|
|
if config.estimate_op_runtime == "default":
|
||
|
|
runtime = snode.get_estimated_runtime()
|
||
|
|
else:
|
||
|
|
assert callable(config.estimate_op_runtime)
|
||
|
|
runtime = config.estimate_op_runtime(snode)
|
||
|
|
return runtime
|
||
|
|
|
||
|
|
|
||
|
|
def node_summary(snode):
|
||
|
|
detail = ""
|
||
|
|
if isinstance(snode.node, ir.ExternKernelOut):
|
||
|
|
detail = f" ({snode.node.python_kernel_name})"
|
||
|
|
out_tensor_info = ""
|
||
|
|
if (
|
||
|
|
hasattr(snode.node, "layout")
|
||
|
|
and hasattr(snode.node.layout, "size")
|
||
|
|
and hasattr(snode.node.layout, "stride")
|
||
|
|
):
|
||
|
|
out_tensor_info = (
|
||
|
|
f" (size={snode.node.layout.size}, stride={snode.node.layout.stride})"
|
||
|
|
)
|
||
|
|
node_name = ""
|
||
|
|
if hasattr(snode.node, "name"):
|
||
|
|
node_name = snode.node.name
|
||
|
|
return f"{snode.node.__class__.__name__}{detail}{out_tensor_info} ({node_name})"
|
||
|
|
|
||
|
|
|
||
|
|
def visualize_overlap(order):
|
||
|
|
total_est_runtime: float = 0.0
|
||
|
|
cur_comm_node = None
|
||
|
|
for snode in order:
|
||
|
|
if cur_comm_node is None:
|
||
|
|
if contains_collective(snode):
|
||
|
|
total_est_runtime += estimate_op_runtime(snode)
|
||
|
|
cur_comm_node = snode.node
|
||
|
|
elif is_wait(snode.node):
|
||
|
|
raise AssertionError(
|
||
|
|
"Wait is not expected when there is no collective running"
|
||
|
|
)
|
||
|
|
else: # exposed compute op
|
||
|
|
total_est_runtime += estimate_op_runtime(snode)
|
||
|
|
overlap_log.debug(f"{node_summary(snode)}") # noqa: G004
|
||
|
|
else: # cur_comm_node is not None
|
||
|
|
if contains_collective(snode):
|
||
|
|
raise AssertionError(
|
||
|
|
"Found two collectives running at the same time. "
|
||
|
|
"`visualize_overlap` needs to be updated to handle this case"
|
||
|
|
)
|
||
|
|
elif is_wait(snode.node): # end of this comm op
|
||
|
|
overlap_log.debug(f"{node_summary(snode)}") # noqa: G004
|
||
|
|
cur_comm_node = None
|
||
|
|
else: # overlapped compute op
|
||
|
|
overlap_log.debug(f"| {node_summary(snode)}") # noqa: G004
|
||
|
|
overlap_log.debug(
|
||
|
|
f"Est. runtime (ms): {total_est_runtime / 1000 / 1000}" # noqa: G004
|
||
|
|
)
|
||
|
|
|
||
|
|
|
||
|
|
def reorder_compute_and_comm_for_overlap(
|
||
|
|
snodes: List[BaseSchedulerNode],
|
||
|
|
) -> List[BaseSchedulerNode]:
|
||
|
|
order = snodes
|
||
|
|
|
||
|
|
for p in config.reorder_for_compute_comm_overlap_passes:
|
||
|
|
if isinstance(p, str) and p in globals():
|
||
|
|
p = globals()[p] # it is a builtin pass
|
||
|
|
if torch.distributed.get_rank() == 0:
|
||
|
|
overlap_log.debug(
|
||
|
|
f"==== Visualize overlap before reordering pass {p} ====" # noqa: G004
|
||
|
|
)
|
||
|
|
try:
|
||
|
|
visualize_overlap(order)
|
||
|
|
except Exception as e:
|
||
|
|
overlap_log.debug(str(e))
|
||
|
|
order = p(order) # type: ignore[operator]
|
||
|
|
if torch.distributed.get_rank() == 0:
|
||
|
|
overlap_log.debug(
|
||
|
|
f"==== Visualize overlap after reordering pass {p} ====" # noqa: G004
|
||
|
|
)
|
||
|
|
try:
|
||
|
|
visualize_overlap(order)
|
||
|
|
except Exception as e:
|
||
|
|
overlap_log.debug(str(e))
|
||
|
|
return order
|
||
|
|
|
||
|
|
|
||
|
|
def reinplace_fsdp_all_gather(graph: torch.fx.Graph) -> None:
|
||
|
|
try:
|
||
|
|
import torch.distributed._composable.fsdp._fsdp_collectives
|
||
|
|
|
||
|
|
assert torch.distributed.is_available()
|
||
|
|
# Assert existence of these ops
|
||
|
|
assert (
|
||
|
|
torch.ops._c10d_functional.all_gather_into_tensor
|
||
|
|
and torch.ops._c10d_functional.all_gather_into_tensor_out
|
||
|
|
)
|
||
|
|
except (ImportError, AttributeError, AssertionError):
|
||
|
|
return
|
||
|
|
|
||
|
|
from .pattern_matcher import (
|
||
|
|
CallFunction,
|
||
|
|
KeywordArg,
|
||
|
|
Match,
|
||
|
|
PatternMatcherPass,
|
||
|
|
register_graph_pattern,
|
||
|
|
)
|
||
|
|
|
||
|
|
"""
|
||
|
|
all_gather_copy_in = torch.ops.fsdp.all_gather_copy_in.default(...);
|
||
|
|
getitem = all_gather_copy_in[0];
|
||
|
|
(getitem_1 = all_gather_copy_in[1];) # optional
|
||
|
|
|
||
|
|
all_gather_into_tensor = torch.ops._c10d_functional.all_gather_into_tensor.default(getitem, ...);
|
||
|
|
|
||
|
|
->
|
||
|
|
|
||
|
|
all_gather_copy_in = torch.ops.fsdp.all_gather_copy_in.default(...);
|
||
|
|
getitem = all_gather_copy_in[0];
|
||
|
|
getitem_1 = all_gather_copy_in[1];
|
||
|
|
|
||
|
|
all_gather_into_tensor = torch.ops._c10d_functional.all_gather_into_tensor_out.default(getitem, ..., out=getitem_1);
|
||
|
|
"""
|
||
|
|
|
||
|
|
def remove_unused_getitem(g):
|
||
|
|
# Remove `getitem_X = all_gather_copy_in[1]` which is never used.
|
||
|
|
node_list = list(g.nodes)
|
||
|
|
for n in node_list:
|
||
|
|
if (
|
||
|
|
n.target == operator.getitem
|
||
|
|
and n.args[0].target is torch.ops.fsdp.all_gather_copy_in.default
|
||
|
|
and n.args[1] == 1
|
||
|
|
):
|
||
|
|
g.erase_node(n)
|
||
|
|
|
||
|
|
graph_pass = PatternMatcherPass()
|
||
|
|
|
||
|
|
@register_graph_pattern(
|
||
|
|
CallFunction(
|
||
|
|
torch.ops._c10d_functional.all_gather_into_tensor.default,
|
||
|
|
CallFunction(
|
||
|
|
operator.getitem,
|
||
|
|
CallFunction(
|
||
|
|
torch.ops.fsdp.all_gather_copy_in.default,
|
||
|
|
KeywordArg("all_gather_inputs"),
|
||
|
|
KeywordArg("inp_split_sizes"),
|
||
|
|
KeywordArg("all_gather_input_numel"),
|
||
|
|
KeywordArg("world_size"),
|
||
|
|
KeywordArg("rank"),
|
||
|
|
KeywordArg("dtype"),
|
||
|
|
KeywordArg("device"),
|
||
|
|
),
|
||
|
|
KeywordArg("item_idx"),
|
||
|
|
),
|
||
|
|
KeywordArg("group_size"),
|
||
|
|
KeywordArg("group_name"),
|
||
|
|
),
|
||
|
|
pass_dict=graph_pass,
|
||
|
|
extra_check=lambda match: match.kwargs["item_idx"] == 0,
|
||
|
|
)
|
||
|
|
def reinplace_all_gather(match: Match, *args, **kwargs):
|
||
|
|
def repl(
|
||
|
|
*args,
|
||
|
|
):
|
||
|
|
copy_in_args = args[:-2]
|
||
|
|
group_size = args[-2]
|
||
|
|
group_name = args[-1]
|
||
|
|
all_gather_copy_in = torch.ops.fsdp.all_gather_copy_in.default(
|
||
|
|
*copy_in_args
|
||
|
|
)
|
||
|
|
getitem = all_gather_copy_in[0]
|
||
|
|
getitem_1 = all_gather_copy_in[1]
|
||
|
|
all_gather_into_tensor = (
|
||
|
|
torch.ops._c10d_functional.all_gather_into_tensor_out.default(
|
||
|
|
getitem, group_size, group_name, out=getitem_1
|
||
|
|
)
|
||
|
|
)
|
||
|
|
return all_gather_into_tensor
|
||
|
|
|
||
|
|
match.replace_by_example(
|
||
|
|
repl,
|
||
|
|
[
|
||
|
|
kwargs["all_gather_inputs"],
|
||
|
|
kwargs["inp_split_sizes"],
|
||
|
|
kwargs["all_gather_input_numel"],
|
||
|
|
kwargs["world_size"],
|
||
|
|
kwargs["rank"],
|
||
|
|
kwargs["dtype"],
|
||
|
|
kwargs["device"],
|
||
|
|
kwargs["group_size"],
|
||
|
|
kwargs["group_name"],
|
||
|
|
],
|
||
|
|
)
|
||
|
|
|
||
|
|
remove_unused_getitem(graph)
|
||
|
|
graph_pass.apply(graph) # type: ignore[arg-type]
|
||
|
|
|
||
|
|
|
||
|
|
def get_op_idx(snode):
|
||
|
|
assert not isinstance(
|
||
|
|
snode,
|
||
|
|
(
|
||
|
|
torch._inductor.scheduler.FusedSchedulerNode,
|
||
|
|
torch._inductor.scheduler.GroupedSchedulerNode,
|
||
|
|
),
|
||
|
|
)
|
||
|
|
return int(snode.get_name()[2:])
|
||
|
|
|
||
|
|
|
||
|
|
def enforce_comm_ordering_for_fsdp(
|
||
|
|
snodes: List[torch._inductor.scheduler.BaseSchedulerNode],
|
||
|
|
name_to_buf: Dict[str, torch._inductor.scheduler.SchedulerBuffer],
|
||
|
|
name_to_fused_node: Dict[str, BaseSchedulerNode],
|
||
|
|
) -> List[torch._inductor.scheduler.BaseSchedulerNode]:
|
||
|
|
from . import scheduler
|
||
|
|
|
||
|
|
new_order: list[BaseSchedulerNode] = []
|
||
|
|
scheduled = set()
|
||
|
|
ag_exists = False
|
||
|
|
rs_exists = False
|
||
|
|
ag_grouped_node_to_wait_grouped_node = {}
|
||
|
|
rs_grouped_node_to_wait_grouped_node = {}
|
||
|
|
snode_name_to_final_snode = {}
|
||
|
|
|
||
|
|
def _create_group_node(snodes_to_group):
|
||
|
|
group_node = scheduler.GroupedSchedulerNode.create(snodes_to_group)
|
||
|
|
for snode in snodes_to_group:
|
||
|
|
snode_name_to_final_snode[snode.get_name()] = group_node
|
||
|
|
snode_name_to_final_snode[group_node.get_name()] = group_node
|
||
|
|
return group_node
|
||
|
|
|
||
|
|
# Create grouped nodes for specific sets of ops
|
||
|
|
for snode in snodes:
|
||
|
|
# Case 1: Handle AllGather
|
||
|
|
if is_collective(
|
||
|
|
snode.node, op=torch.ops._c10d_functional.all_gather_into_tensor_out.default
|
||
|
|
) and any(
|
||
|
|
is_fallback_op(
|
||
|
|
name_to_fused_node[x].node, torch.ops.fsdp.all_gather_copy_in.default
|
||
|
|
)
|
||
|
|
for x in snode.ancestors
|
||
|
|
):
|
||
|
|
ag_exists = True
|
||
|
|
ag_snode = snode
|
||
|
|
ag_related_snode_set: set[scheduler.BaseSchedulerNode] = set()
|
||
|
|
|
||
|
|
# Find the "cast + copy_in + getitem + all_gather" code block
|
||
|
|
find_recursive_deps_of_node(
|
||
|
|
ag_snode,
|
||
|
|
ag_related_snode_set,
|
||
|
|
name_to_buf,
|
||
|
|
name_to_fused_node,
|
||
|
|
)
|
||
|
|
|
||
|
|
# Find the "all_gather + all_gather_wait_tensor + copy_out + set_" code block
|
||
|
|
allowed_ops = {
|
||
|
|
torch.ops._c10d_functional.all_gather_into_tensor_out.default,
|
||
|
|
torch.ops._c10d_functional.wait_tensor.default,
|
||
|
|
torch.ops.fsdp.split_with_sizes_copy.default,
|
||
|
|
torch.ops.aten.set_.source_Tensor,
|
||
|
|
}
|
||
|
|
find_recursive_users_of_node(
|
||
|
|
ag_snode,
|
||
|
|
ag_related_snode_set,
|
||
|
|
name_to_buf,
|
||
|
|
name_to_fused_node,
|
||
|
|
criteria_cb=lambda x: not (
|
||
|
|
isinstance(x, scheduler.NopKernelSchedulerNode)
|
||
|
|
or (
|
||
|
|
isinstance(x, scheduler.ExternKernelSchedulerNode)
|
||
|
|
and x.node.op_overload in allowed_ops # type: ignore[union-attr]
|
||
|
|
)
|
||
|
|
),
|
||
|
|
)
|
||
|
|
|
||
|
|
# sort nodes by original operation order
|
||
|
|
ag_related_snodes = sorted(
|
||
|
|
ag_related_snode_set, key=lambda x: get_op_idx(x)
|
||
|
|
)
|
||
|
|
|
||
|
|
# In the "reuse layer" case, some ops in the 2nd all-gather code block could also
|
||
|
|
# depend on ops in the 1st all-gather code block, and we don't want to group them together.
|
||
|
|
end_idx_of_current_ag_block = len(ag_related_snodes)
|
||
|
|
copy_out_count = 0
|
||
|
|
for i in range(len(ag_related_snodes)):
|
||
|
|
cur_snode = ag_related_snodes[i]
|
||
|
|
if is_fallback_op(
|
||
|
|
cur_snode.node, torch.ops.fsdp.split_with_sizes_copy.default
|
||
|
|
):
|
||
|
|
copy_out_count += 1
|
||
|
|
if copy_out_count > 1:
|
||
|
|
end_idx_of_current_ag_block = i
|
||
|
|
break
|
||
|
|
|
||
|
|
ag_related_snodes = ag_related_snodes[:end_idx_of_current_ag_block]
|
||
|
|
|
||
|
|
# Group "cast + copy_in + getitem + all_gather" into one GroupedSchedulerNode
|
||
|
|
wait_node_idx = None
|
||
|
|
for i in range(len(ag_related_snodes) - 1):
|
||
|
|
if isinstance(ag_related_snodes[i + 1].node, ir._WaitKernel):
|
||
|
|
wait_node_idx = i + 1
|
||
|
|
break
|
||
|
|
assert wait_node_idx is not None
|
||
|
|
ag_group_node = _create_group_node(ag_related_snodes[:wait_node_idx])
|
||
|
|
|
||
|
|
# Group "all_gather_wait_tensor + copy_out + set_" into one GroupedSchedulerNode
|
||
|
|
ag_wait_group_node = _create_group_node(ag_related_snodes[wait_node_idx:])
|
||
|
|
|
||
|
|
ag_grouped_node_to_wait_grouped_node[ag_group_node] = ag_wait_group_node
|
||
|
|
|
||
|
|
# Case 2: Handle ReduceScatter
|
||
|
|
elif is_fallback_op(snode.node, torch.ops.fsdp.chunk_cat.default):
|
||
|
|
rs_exists = True
|
||
|
|
rs_snode = snode
|
||
|
|
|
||
|
|
# Find the "reduce_scatter copy-in + reduce_scatter comm + reduce_scatter wait" code block
|
||
|
|
rs_related_snode_set: set[scheduler.BaseSchedulerNode] = set()
|
||
|
|
find_recursive_users_of_node(
|
||
|
|
rs_snode,
|
||
|
|
rs_related_snode_set,
|
||
|
|
name_to_buf,
|
||
|
|
name_to_fused_node,
|
||
|
|
)
|
||
|
|
|
||
|
|
# sort nodes by original operation order
|
||
|
|
rs_related_snodes = sorted(
|
||
|
|
rs_related_snode_set, key=lambda x: get_op_idx(x)
|
||
|
|
)
|
||
|
|
|
||
|
|
# Group "reduce_scatter copy-in + reduce_scatter comm" into one GroupedSchedulerNode
|
||
|
|
wait_node_idx = None
|
||
|
|
for i in range(len(rs_related_snodes) - 1):
|
||
|
|
if isinstance(rs_related_snodes[i + 1].node, ir._WaitKernel):
|
||
|
|
wait_node_idx = i + 1
|
||
|
|
break
|
||
|
|
assert wait_node_idx is not None
|
||
|
|
rs_group_node = _create_group_node(rs_related_snodes[:wait_node_idx])
|
||
|
|
|
||
|
|
# Group "reduce_scatter wait + related output nodes" into one GroupedSchedulerNode
|
||
|
|
rs_wait_group_node = _create_group_node(rs_related_snodes[wait_node_idx:])
|
||
|
|
|
||
|
|
rs_grouped_node_to_wait_grouped_node[rs_group_node] = rs_wait_group_node
|
||
|
|
|
||
|
|
assert len(snode_name_to_final_snode) > 0
|
||
|
|
if ag_exists:
|
||
|
|
assert len(ag_grouped_node_to_wait_grouped_node) > 0
|
||
|
|
if rs_exists:
|
||
|
|
assert len(rs_grouped_node_to_wait_grouped_node) > 0
|
||
|
|
|
||
|
|
# Build the new node schedule, taking GroupedSchedulerNode into account
|
||
|
|
for snode in snodes:
|
||
|
|
if snode.get_name() in snode_name_to_final_snode:
|
||
|
|
snode = snode_name_to_final_snode[snode.get_name()]
|
||
|
|
if snode in scheduled:
|
||
|
|
continue
|
||
|
|
new_order.append(snode)
|
||
|
|
scheduled.add(snode)
|
||
|
|
|
||
|
|
# Enforce AllGather ordering: previous AllGather's "wait then copy_out" group node must run
|
||
|
|
# before next AllGather's "copy_in then AG" group node
|
||
|
|
prev_ag_wait = None
|
||
|
|
for ag_group_node, wait_group_node in ag_grouped_node_to_wait_grouped_node.items():
|
||
|
|
if prev_ag_wait is not None:
|
||
|
|
mutating_buf = next(iter(ag_group_node.get_buffer_names()))
|
||
|
|
for o in prev_ag_wait.get_outputs():
|
||
|
|
ag_group_node.add_fake_dep(
|
||
|
|
WeakDep(o.get_name(), mutating_buf=mutating_buf)
|
||
|
|
)
|
||
|
|
prev_ag_wait = wait_group_node
|
||
|
|
|
||
|
|
# Enforce ReduceScatter ordering: previous ReduceScatter's "wait" group node must run
|
||
|
|
# before next ReduceScatter's "copy_in then RS" group node
|
||
|
|
prev_rs_wait = None
|
||
|
|
for rs_group_node, wait_group_node in rs_grouped_node_to_wait_grouped_node.items():
|
||
|
|
if prev_rs_wait is not None:
|
||
|
|
mutating_buf = next(iter(rs_group_node.get_buffer_names()))
|
||
|
|
for o in prev_rs_wait.get_outputs():
|
||
|
|
rs_group_node.add_fake_dep(
|
||
|
|
WeakDep(o.get_name(), mutating_buf=mutating_buf)
|
||
|
|
)
|
||
|
|
prev_rs_wait = wait_group_node
|
||
|
|
|
||
|
|
return new_order # type: ignore[return-value]
|