Guide d'écriture de kernels Metal
Cette skill vous guide dans l'implémentation de kernels Metal pour les opérateurs PyTorch sur Apple Silicon.
Important : L'objectif de cette skill est d'utiliser les capacités natives de Metal via l'infrastructure c10/metal/, PAS MPSGraph. Les kernels Metal natifs offrent un meilleur contrôle, une meilleure performance et une meilleure maintenabilité.
Aperçu
Il y a deux workflows couverts par cette skill :
- Ajouter un nouveau support MPS - Implémenter un nouvel opérateur à partir de zéro
- Migrer depuis MPSGraph - Convertir les opérateurs existants basés sur MPSGraph en Metal natif
Les deux workflows impliquent :
- Mettre à jour la dispatch dans
aten/src/ATen/native/native_functions.yaml - Écrire le kernel Metal dans
aten/src/ATen/native/mps/kernels/ - Implémenter le stub côté host dans
aten/src/ATen/native/mps/operations/
Étape 1 : Mettre à jour native_functions.yaml
Emplacement : aten/src/ATen/native/native_functions.yaml
Pour les nouveaux opérateurs
Trouvez l'entrée de l'opérateur et ajoutez la dispatch MPS :
# Implémentation simple spécifique à MPS
- func: my_op(Tensor self) -> Tensor
dispatch:
CPU: my_op_cpu
CUDA: my_op_cuda
MPS: my_op_mps
# Implémentation partagée entre les appareils (préféré pour les kernels structurés)
- func: my_op.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
dispatch:
CPU, CUDA, MPS: my_op_out
# Kernel structuré (préféré pour les nouveaux opérateurs)
- func: my_op.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
structured: True
structured_inherits: TensorIteratorBase
dispatch:
CPU, CUDA, MPS: my_op_out
Pour migrer depuis MPSGraph
Lors de la migration d'un opérateur existant de MPSGraph vers Metal natif, consolidez l'entrée de dispatch :
# AVANT (basé sur MPSGraph, dispatch séparé)
- func: atan2.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)
structured: True
structured_inherits: TensorIteratorBase
dispatch:
CPU, CUDA: atan2_out
MPS: atan2_out_mps # Implémentation MPS séparée
# APRÈS (Metal natif, dispatch partagée via stub)
- func: atan2.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)
structured: True
structured_inherits: TensorIteratorBase
dispatch:
CPU, CUDA, MPS: atan2_out # MPS utilise maintenant le même mécanisme de stub
Changement clé : Remplacez MPS: my_op_out_mps en ajoutant MPS à la ligne de dispatch partagée (ex. CPU, CUDA, MPS: my_op_out).
Conventions de nommage pour la dispatch :
MPS: function_name_mps- Implémentation spécifique à MPS (ancien pattern MPSGraph)CPU, CUDA, MPS: function_name- Implémentation de stub partagée (pattern Metal natif)
Étape 2 : Implémenter le kernel Metal
Emplacement : aten/src/ATen/native/mps/kernels/
Pattern de kernel unaire
// MyKernel.metal
#include <c10/metal/indexing.h>
#include <c10/metal/utils.h>
#include <metal_stdlib>
using namespace metal;
using namespace c10::metal;
// Définir le functor d'opération
struct my_op_functor {
template <typename T>
inline T operator()(const T x) {
return /* votre opération */;
}
};
// Enregistrer pour les types supportés
REGISTER_UNARY_OP(my_op, float, float);
REGISTER_UNARY_OP(my_op, half, half);
REGISTER_UNARY_OP(my_op, bfloat, bfloat);
Pattern de kernel binaire
struct my_binary_functor {
template <typename T>
inline T operator()(const T a, const T b) {
return /* votre opération */;
}
};
REGISTER_BINARY_OP(my_binary, float, float);
REGISTER_BINARY_OP(my_binary, half, half);
Macros d'enregistrement de type pour les kernels binaires
Pour les opérations binaires, utilisez les macros de commodité définies dans BinaryKernel.metal :
// Types à virgule flottante uniquement (float, half, bfloat)
REGISTER_FLOAT_BINARY_OP(my_op);
// Types intégraux avec sortie flottante (pour des opérateurs mathématiques comme atan2, copysign)
// Enregistre : long->float, int->float, short->float, uchar->float, char->float, bool->float
REGISTER_INT2FLOAT_BINARY_OP(my_op);
// Types intégraux avec sortie du même type (pour les opérations bit à bit/logiques)
// Enregistre : long, int, short, uchar, char, bool
REGISTER_INTEGER_BINARY_OP(my_op);
// À virgule flottante avec précision opmath (pour les opérations nécessitant une plus haute précision)
REGISTER_OPMATH_FLOAT_BINARY_OP(my_op);
Patterns courants :
- Fonctions mathématiques (atan2, copysign, logaddexp) : Utilisez à la fois
REGISTER_FLOAT_BINARY_OPetREGISTER_INT2FLOAT_BINARY_OP - Opérateurs de comparaison/logiques (maximum, minimum) : Utilisez à la fois
REGISTER_FLOAT_BINARY_OPetREGISTER_INTEGER_BINARY_OP - Opérations arithmétiques (add, sub, mul) : Utilisez à la fois
REGISTER_FLOAT_BINARY_OPetREGISTER_INTEGER_BINARY_OP
Exemple pour atan2 (supporte à la fois les entrées flottantes et entières) :
struct atan2_functor {
template <typename T, enable_if_t<is_floating_point_v<T>, bool> = true>
inline T operator()(const T a, const T b) {
return static_cast<T>(precise::atan2(float(a), float(b)));
}
template <typename T, enable_if_t<is_integral_v<T>, bool> = true>
inline float operator()(const T a, const T b) {
return precise::atan2(float(a), float(b));
}
};
REGISTER_FLOAT_BINARY_OP(atan2);
REGISTER_INT2FLOAT_BINARY_OP(atan2);
Avec paramètre scalaire
struct my_alpha_functor {
template <typename T>
inline T operator()(const T a, const T b, const T alpha) {
return a + c10::metal::mul(alpha, b);
}
};
REGISTER_UNARY_ALPHA_OP(my_alpha, float, float, float);
REGISTER_UNARY_ALPHA_OP(my_alpha, half, half, half);
Functor spécialisé par type
struct special_functor {
// Types à virgule flottante
template <typename T, enable_if_t<is_scalar_floating_point_v<T>, bool> = true>
inline T operator()(const T x) {
return precise::exp(x); // Utiliser les mathématiques précises
}
// Types intégraux
template <typename T, enable_if_t<is_scalar_integral_v<T>, bool> = true>
inline float operator()(const T x) {
return precise::exp(float(x));
}
// Types complexes (float2 pour cfloat, half2 pour chalf)
template <typename T, enable_if_t<is_complex_v<T>, bool> = true>
inline T operator()(const T x) {
// x.x = réel, x.y = imaginaire
return T(/* réel */, /* imag */);
}
};
Remarque sur les types complexes : Les nombres complexes en Metal sont représentés comme des types vecteur :
c10::complex<float>correspond àfloat2(x = réel, y = imaginaire)c10::complex<half>correspond àhalf2
Utilisez is_complex_v<T> pour se spécialiser sur les types complexes dans les functors.
Utilitaires c10/metal disponibles
utils.h :
opmath_t<T>- Type mathématique opérationnel (half->float)accum_t<T>- Type d'accumulation pour les réductionsmax(),min()avec propagation NaN
special_math.h :
precise::exp(),precise::log(),precise::sqrt()precise::sin(),precise::cos(),precise::tan()erf(),erfc(),erfinv()
indexing.h :
REGISTER_UNARY_OP(name, in_type, out_type)REGISTER_BINARY_OP(name, in_type, out_type)REGISTER_UNARY_ALPHA_OP(name, in_type, alpha_type, out_type)
Étape 3 : Implémenter le stub côté host
Emplacement : aten/src/ATen/native/mps/operations/
Choisissez ou créez un fichier approprié selon le type d'opération :
UnaryKernel.mm- Opérations à entrée unique via dispatch de stubBinaryKernel.mm- Opérations à deux entrées via dispatch de stubUnaryOps.mm/BinaryOps.mm- Implémentations MPSGraph héritées (à titre de référence)ReduceOps.mm- Réductions (sum, mean, max, etc.)- Créer un nouveau fichier pour des catégories d'opérations distinctes
Pattern d'enregistrement de stub (Préféré pour Metal natif)
Pour les kernels structurés qui utilisent le pattern TensorIterator :
// Dans BinaryKernel.mm (ou fichier approprié)
static void my_op_mps_kernel(TensorIteratorBase& iter) {
lib.exec_binary_kernel(iter, "my_op"); // "my_op" correspond au nom du functor dans .metal
}
// Enregistrer le stub MPS - ceci se connecte au système de dispatch
REGISTER_DISPATCH(my_op_stub, &my_op_mps_kernel)
Pour les opérations unaires :
static void my_unary_mps_kernel(TensorIteratorBase& iter) {
lib.exec_unary_kernel(iter, "my_unary");
}
REGISTER_DISPATCH(my_unary_stub, &my_unary_mps_kernel)
Migration : Supprimer l'ancienne implémentation MPSGraph
Lors de la migration depuis MPSGraph, supprimez également l'ancienne implémentation :
-
Supprimer de BinaryOps.mm (ou UnaryOps.mm) :
- Supprimer l'implémentation
TORCH_IMPL_FUNC(my_op_out_mps) - Supprimer l'en-tête correspondant
#include <ATen/ops/my_op_native.h>
- Supprimer l'implémentation
-
Ajouter à BinaryKernel.mm (ou UnaryKernel.mm) :
- Ajouter la fonction de kernel statique
- Ajouter l'appel
REGISTER_DISPATCH
Étape 4 : Compiler
Après avoir effectué les modifications, compilez pour vérifier que tout se construit correctement :
cd build && ninja torch_cpu
Tests
Le support basique des opérateurs est déjà testé par test_output_match dans test/test_mps.py. Après avoir implémenté un opérateur, activez les tests en supprimant les défaillances attendues :
1. Supprimer de common_mps.py
Emplacement : torch/testing/_internal/common_mps.py
Trouvez et supprimez l'opérateur des listes de skip/xfail :
# Supprimer les entrées comme :
MPS_XFAILLIST = {
"my_op": ..., # Supprimer cette ligne
}
MPS_SKIPLIST = {
"my_op": ..., # Supprimer cette ligne
}
2. Supprimer des décorateurs OpInfo
Emplacement : torch/testing/_internal/common_methods_invocations.py (ou fichiers associés)
Supprimez les décorateurs spécifiques à MPS de l'OpInfo :
OpInfo(
"my_op",
# Supprimer les décorateurs comme :
# decorators=[skipMPS, expectedFailureMPS("reason")],
...
)
3. Exécuter les tests pour vérifier
# Exécuter le test de l'opérateur spécifique
python test/test_mps.py -k test_output_match_my_op
# Ou exécuter la suite de tests MPS complète
python test/test_mps.py
Déboguer les kernels Metal avec torch.mps.compile_shader
Utilisez torch.mps.compile_shader pour compiler en JIT et tester les kernels Metal individuels isolément. C'est très utile pour déboguer les pipelines multi-kernels où vous devez vérifier chaque étape indépendamment.
Utilisation basique
import torch
source = '''
#include <metal_stdlib>
using namespace metal;
kernel void my_kernel(
const device float* input [[buffer(0)]],
device float* output [[buffer(1)]],
uint tid [[thread_position_in_grid]]) {
output[tid] = input[tid] * 2.0;
}
'''
lib = torch.mps.compile_shader(source)
inp = torch.tensor([1.0, 2.0, 3.0], device='mps')
out = torch.zeros(3, device='mps')
lib.my_kernel(inp, out, threads=[3, 1, 1], group_size=[3, 1, 1])
torch.mps.synchronize()
print(out) # tensor([2., 4., 6.], device='mps:0')
Sémantique de dispatch
compile_shader utilise la sémantique dispatchThreads (identique à mtl_dispatch1DJob dans PyTorch) :
threads=[N, 1, 1]— nombre total de threads (PAS de threadgroups)group_size=[G, 1, 1]— threads par threadgroup
Ceci diffère de l'API dispatchThreadgroups utilisée par certains codes côté host. Pour correspondre à dispatchThreadgroups:MTLSizeMake(num_tgs, num_slices, 1) threadsPerThreadgroup:MTLSizeMake(TG_SIZE, 1, 1) :
# Appel compile_shader équivalent :
lib.kernel(args...,
threads=[num_tgs * TG_SIZE, num_slices, 1],
group_size=[TG_SIZE, 1, 1])
Paramètres de buffer constant
Passez les constantes scalaires sous forme de tenseurs à un seul élément :
slice_size = torch.tensor([1024], dtype=torch.int32, device='mps')
lib.my_kernel(data, output, slice_size, threads=[1024, 1, 1], group_size=[256, 1, 1])
Stratégie de débogage pour les pipelines multi-kernels
Lorsqu'un pipeline de kernels (ex. histogram → prefix_sum → scatter) produit de mauvais résultats, testez chaque kernel individuellement et vérifiez sa sortie par rapport à une référence Python/NumPy :
# 1. Exécuter le kernel GPU
lib.histogram(keys, hist, ..., threads=[N, 1, 1], group_size=[256, 1, 1])
torch.mps.synchronize()
# 2. Calculer la référence en Python
ref_hist = compute_histogram_cpu(keys.cpu().numpy(), ...)
# 3. Comparer
assert np.array_equal(hist.cpu().numpy(), ref_hist), "Histogram mismatch!"
Cela isole quel kernel du pipeline est cassé, plutôt que de déboguer l'ensemble du pipeline à la fois.
Pièges courants
- Mauvais nombre de
threads—threadsest le nombre total de threads, pas de threadgroups. Pour 5 threadgroups de 256, utilisezthreads=[1280, 1, 1]. - Mémoire de threadgroup —
compile_shaderne supporte pas directement les paramètres[[threadgroup(N)]]. Si votre kernel a besoin de mémoire de threadgroup, restructurez pour utiliser des tableauxthreadgroupdéclarés à l'intérieur du corps du kernel.
Liste de vérification
- [ ] Dispatch MPS ajoutée à
native_functions.yaml - [ ] Kernel Metal implémenté dans
kernels/ - [ ] Opérateur côté host implémenté dans
operations/ - [ ] Gère les tenseurs vides
- [ ] Gère les tenseurs non contigus
- [ ] Supporte les types d'données requis (float32, float16, bfloat16, et souvent des types complexes via float2/half2)
- [ ] Défaillances attendues supprimées de
torch/testing/_internal/common_mps.py - [ ] Décorateurs skip/xfail supprimés de l'OpInfo (le cas échéant)