metal-kernel

Par pytorch · pytorch

Rédige des kernels Metal/MPS pour les opérateurs PyTorch. À utiliser lors de l'ajout du support du device MPS aux opérateurs, de l'implémentation de shaders Metal, ou du portage de kernels CUDA vers Apple Silicon. Couvre le dispatch dans `native_functions.yaml`, les opérateurs côté hôte et l'implémentation des kernels Metal.

npx skills add https://github.com/pytorch/pytorch --skill metal-kernel

Guide d'écriture de kernels Metal

Cette skill vous guide dans l'implémentation de kernels Metal pour les opérateurs PyTorch sur Apple Silicon.

Important : L'objectif de cette skill est d'utiliser les capacités natives de Metal via l'infrastructure c10/metal/, PAS MPSGraph. Les kernels Metal natifs offrent un meilleur contrôle, une meilleure performance et une meilleure maintenabilité.

Aperçu

Il y a deux workflows couverts par cette skill :

  1. Ajouter un nouveau support MPS - Implémenter un nouvel opérateur à partir de zéro
  2. Migrer depuis MPSGraph - Convertir les opérateurs existants basés sur MPSGraph en Metal natif

Les deux workflows impliquent :

  1. Mettre à jour la dispatch dans aten/src/ATen/native/native_functions.yaml
  2. Écrire le kernel Metal dans aten/src/ATen/native/mps/kernels/
  3. Implémenter le stub côté host dans aten/src/ATen/native/mps/operations/

Étape 1 : Mettre à jour native_functions.yaml

Emplacement : aten/src/ATen/native/native_functions.yaml

Pour les nouveaux opérateurs

Trouvez l'entrée de l'opérateur et ajoutez la dispatch MPS :

# Implémentation simple spécifique à MPS
- func: my_op(Tensor self) -> Tensor
  dispatch:
    CPU: my_op_cpu
    CUDA: my_op_cuda
    MPS: my_op_mps

# Implémentation partagée entre les appareils (préféré pour les kernels structurés)
- func: my_op.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
  dispatch:
    CPU, CUDA, MPS: my_op_out

# Kernel structuré (préféré pour les nouveaux opérateurs)
- func: my_op.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
  structured: True
  structured_inherits: TensorIteratorBase
  dispatch:
    CPU, CUDA, MPS: my_op_out

Pour migrer depuis MPSGraph

Lors de la migration d'un opérateur existant de MPSGraph vers Metal natif, consolidez l'entrée de dispatch :

# AVANT (basé sur MPSGraph, dispatch séparé)
- func: atan2.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)
  structured: True
  structured_inherits: TensorIteratorBase
  dispatch:
    CPU, CUDA: atan2_out
    MPS: atan2_out_mps  # Implémentation MPS séparée

# APRÈS (Metal natif, dispatch partagée via stub)
- func: atan2.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)
  structured: True
  structured_inherits: TensorIteratorBase
  dispatch:
    CPU, CUDA, MPS: atan2_out  # MPS utilise maintenant le même mécanisme de stub

Changement clé : Remplacez MPS: my_op_out_mps en ajoutant MPS à la ligne de dispatch partagée (ex. CPU, CUDA, MPS: my_op_out).

Conventions de nommage pour la dispatch :

  • MPS: function_name_mps - Implémentation spécifique à MPS (ancien pattern MPSGraph)
  • CPU, CUDA, MPS: function_name - Implémentation de stub partagée (pattern Metal natif)

Étape 2 : Implémenter le kernel Metal

Emplacement : aten/src/ATen/native/mps/kernels/

Pattern de kernel unaire

// MyKernel.metal
#include <c10/metal/indexing.h>
#include <c10/metal/utils.h>
#include <metal_stdlib>

using namespace metal;
using namespace c10::metal;

// Définir le functor d'opération
struct my_op_functor {
  template <typename T>
  inline T operator()(const T x) {
    return /* votre opération */;
  }
};

// Enregistrer pour les types supportés
REGISTER_UNARY_OP(my_op, float, float);
REGISTER_UNARY_OP(my_op, half, half);
REGISTER_UNARY_OP(my_op, bfloat, bfloat);

Pattern de kernel binaire

struct my_binary_functor {
  template <typename T>
  inline T operator()(const T a, const T b) {
    return /* votre opération */;
  }
};

REGISTER_BINARY_OP(my_binary, float, float);
REGISTER_BINARY_OP(my_binary, half, half);

Macros d'enregistrement de type pour les kernels binaires

Pour les opérations binaires, utilisez les macros de commodité définies dans BinaryKernel.metal :

// Types à virgule flottante uniquement (float, half, bfloat)
REGISTER_FLOAT_BINARY_OP(my_op);

// Types intégraux avec sortie flottante (pour des opérateurs mathématiques comme atan2, copysign)
// Enregistre : long->float, int->float, short->float, uchar->float, char->float, bool->float
REGISTER_INT2FLOAT_BINARY_OP(my_op);

// Types intégraux avec sortie du même type (pour les opérations bit à bit/logiques)
// Enregistre : long, int, short, uchar, char, bool
REGISTER_INTEGER_BINARY_OP(my_op);

// À virgule flottante avec précision opmath (pour les opérations nécessitant une plus haute précision)
REGISTER_OPMATH_FLOAT_BINARY_OP(my_op);

Patterns courants :

  • Fonctions mathématiques (atan2, copysign, logaddexp) : Utilisez à la fois REGISTER_FLOAT_BINARY_OP et REGISTER_INT2FLOAT_BINARY_OP
  • Opérateurs de comparaison/logiques (maximum, minimum) : Utilisez à la fois REGISTER_FLOAT_BINARY_OP et REGISTER_INTEGER_BINARY_OP
  • Opérations arithmétiques (add, sub, mul) : Utilisez à la fois REGISTER_FLOAT_BINARY_OP et REGISTER_INTEGER_BINARY_OP

Exemple pour atan2 (supporte à la fois les entrées flottantes et entières) :

struct atan2_functor {
  template <typename T, enable_if_t<is_floating_point_v<T>, bool> = true>
  inline T operator()(const T a, const T b) {
    return static_cast<T>(precise::atan2(float(a), float(b)));
  }
  template <typename T, enable_if_t<is_integral_v<T>, bool> = true>
  inline float operator()(const T a, const T b) {
    return precise::atan2(float(a), float(b));
  }
};

REGISTER_FLOAT_BINARY_OP(atan2);
REGISTER_INT2FLOAT_BINARY_OP(atan2);

Avec paramètre scalaire

struct my_alpha_functor {
  template <typename T>
  inline T operator()(const T a, const T b, const T alpha) {
    return a + c10::metal::mul(alpha, b);
  }
};

REGISTER_UNARY_ALPHA_OP(my_alpha, float, float, float);
REGISTER_UNARY_ALPHA_OP(my_alpha, half, half, half);

Functor spécialisé par type

struct special_functor {
  // Types à virgule flottante
  template <typename T, enable_if_t<is_scalar_floating_point_v<T>, bool> = true>
  inline T operator()(const T x) {
    return precise::exp(x);  // Utiliser les mathématiques précises
  }

  // Types intégraux
  template <typename T, enable_if_t<is_scalar_integral_v<T>, bool> = true>
  inline float operator()(const T x) {
    return precise::exp(float(x));
  }

  // Types complexes (float2 pour cfloat, half2 pour chalf)
  template <typename T, enable_if_t<is_complex_v<T>, bool> = true>
  inline T operator()(const T x) {
    // x.x = réel, x.y = imaginaire
    return T(/* réel */, /* imag */);
  }
};

Remarque sur les types complexes : Les nombres complexes en Metal sont représentés comme des types vecteur :

  • c10::complex<float> correspond à float2 (x = réel, y = imaginaire)
  • c10::complex<half> correspond à half2

Utilisez is_complex_v<T> pour se spécialiser sur les types complexes dans les functors.

Utilitaires c10/metal disponibles

utils.h :

  • opmath_t<T> - Type mathématique opérationnel (half->float)
  • accum_t<T> - Type d'accumulation pour les réductions
  • max(), min() avec propagation NaN

special_math.h :

  • precise::exp(), precise::log(), precise::sqrt()
  • precise::sin(), precise::cos(), precise::tan()
  • erf(), erfc(), erfinv()

indexing.h :

  • REGISTER_UNARY_OP(name, in_type, out_type)
  • REGISTER_BINARY_OP(name, in_type, out_type)
  • REGISTER_UNARY_ALPHA_OP(name, in_type, alpha_type, out_type)

Étape 3 : Implémenter le stub côté host

Emplacement : aten/src/ATen/native/mps/operations/

Choisissez ou créez un fichier approprié selon le type d'opération :

  • UnaryKernel.mm - Opérations à entrée unique via dispatch de stub
  • BinaryKernel.mm - Opérations à deux entrées via dispatch de stub
  • UnaryOps.mm / BinaryOps.mm - Implémentations MPSGraph héritées (à titre de référence)
  • ReduceOps.mm - Réductions (sum, mean, max, etc.)
  • Créer un nouveau fichier pour des catégories d'opérations distinctes

Pattern d'enregistrement de stub (Préféré pour Metal natif)

Pour les kernels structurés qui utilisent le pattern TensorIterator :

// Dans BinaryKernel.mm (ou fichier approprié)

static void my_op_mps_kernel(TensorIteratorBase& iter) {
  lib.exec_binary_kernel(iter, "my_op");  // "my_op" correspond au nom du functor dans .metal
}

// Enregistrer le stub MPS - ceci se connecte au système de dispatch
REGISTER_DISPATCH(my_op_stub, &my_op_mps_kernel)

Pour les opérations unaires :

static void my_unary_mps_kernel(TensorIteratorBase& iter) {
  lib.exec_unary_kernel(iter, "my_unary");
}

REGISTER_DISPATCH(my_unary_stub, &my_unary_mps_kernel)

Migration : Supprimer l'ancienne implémentation MPSGraph

Lors de la migration depuis MPSGraph, supprimez également l'ancienne implémentation :

  1. Supprimer de BinaryOps.mm (ou UnaryOps.mm) :

    • Supprimer l'implémentation TORCH_IMPL_FUNC(my_op_out_mps)
    • Supprimer l'en-tête correspondant #include <ATen/ops/my_op_native.h>
  2. Ajouter à BinaryKernel.mm (ou UnaryKernel.mm) :

    • Ajouter la fonction de kernel statique
    • Ajouter l'appel REGISTER_DISPATCH

Étape 4 : Compiler

Après avoir effectué les modifications, compilez pour vérifier que tout se construit correctement :

cd build && ninja torch_cpu

Tests

Le support basique des opérateurs est déjà testé par test_output_match dans test/test_mps.py. Après avoir implémenté un opérateur, activez les tests en supprimant les défaillances attendues :

1. Supprimer de common_mps.py

Emplacement : torch/testing/_internal/common_mps.py

Trouvez et supprimez l'opérateur des listes de skip/xfail :

# Supprimer les entrées comme :
MPS_XFAILLIST = {
    "my_op": ...,  # Supprimer cette ligne
}

MPS_SKIPLIST = {
    "my_op": ...,  # Supprimer cette ligne
}

2. Supprimer des décorateurs OpInfo

Emplacement : torch/testing/_internal/common_methods_invocations.py (ou fichiers associés)

Supprimez les décorateurs spécifiques à MPS de l'OpInfo :

OpInfo(
    "my_op",
    # Supprimer les décorateurs comme :
    # decorators=[skipMPS, expectedFailureMPS("reason")],
    ...
)

3. Exécuter les tests pour vérifier

# Exécuter le test de l'opérateur spécifique
python test/test_mps.py -k test_output_match_my_op

# Ou exécuter la suite de tests MPS complète
python test/test_mps.py

Déboguer les kernels Metal avec torch.mps.compile_shader

Utilisez torch.mps.compile_shader pour compiler en JIT et tester les kernels Metal individuels isolément. C'est très utile pour déboguer les pipelines multi-kernels où vous devez vérifier chaque étape indépendamment.

Utilisation basique

import torch

source = '''
#include <metal_stdlib>
using namespace metal;

kernel void my_kernel(
    const device float* input [[buffer(0)]],
    device float* output [[buffer(1)]],
    uint tid [[thread_position_in_grid]]) {
  output[tid] = input[tid] * 2.0;
}
'''

lib = torch.mps.compile_shader(source)

inp = torch.tensor([1.0, 2.0, 3.0], device='mps')
out = torch.zeros(3, device='mps')
lib.my_kernel(inp, out, threads=[3, 1, 1], group_size=[3, 1, 1])
torch.mps.synchronize()
print(out)  # tensor([2., 4., 6.], device='mps:0')

Sémantique de dispatch

compile_shader utilise la sémantique dispatchThreads (identique à mtl_dispatch1DJob dans PyTorch) :

  • threads=[N, 1, 1] — nombre total de threads (PAS de threadgroups)
  • group_size=[G, 1, 1] — threads par threadgroup

Ceci diffère de l'API dispatchThreadgroups utilisée par certains codes côté host. Pour correspondre à dispatchThreadgroups:MTLSizeMake(num_tgs, num_slices, 1) threadsPerThreadgroup:MTLSizeMake(TG_SIZE, 1, 1) :

# Appel compile_shader équivalent :
lib.kernel(args...,
    threads=[num_tgs * TG_SIZE, num_slices, 1],
    group_size=[TG_SIZE, 1, 1])

Paramètres de buffer constant

Passez les constantes scalaires sous forme de tenseurs à un seul élément :

slice_size = torch.tensor([1024], dtype=torch.int32, device='mps')
lib.my_kernel(data, output, slice_size, threads=[1024, 1, 1], group_size=[256, 1, 1])

Stratégie de débogage pour les pipelines multi-kernels

Lorsqu'un pipeline de kernels (ex. histogram → prefix_sum → scatter) produit de mauvais résultats, testez chaque kernel individuellement et vérifiez sa sortie par rapport à une référence Python/NumPy :

# 1. Exécuter le kernel GPU
lib.histogram(keys, hist, ..., threads=[N, 1, 1], group_size=[256, 1, 1])
torch.mps.synchronize()

# 2. Calculer la référence en Python
ref_hist = compute_histogram_cpu(keys.cpu().numpy(), ...)

# 3. Comparer
assert np.array_equal(hist.cpu().numpy(), ref_hist), "Histogram mismatch!"

Cela isole quel kernel du pipeline est cassé, plutôt que de déboguer l'ensemble du pipeline à la fois.

Pièges courants

  • Mauvais nombre de threadsthreads est le nombre total de threads, pas de threadgroups. Pour 5 threadgroups de 256, utilisez threads=[1280, 1, 1].
  • Mémoire de threadgroupcompile_shader ne supporte pas directement les paramètres [[threadgroup(N)]]. Si votre kernel a besoin de mémoire de threadgroup, restructurez pour utiliser des tableaux threadgroup déclarés à l'intérieur du corps du kernel.

Liste de vérification

  • [ ] Dispatch MPS ajoutée à native_functions.yaml
  • [ ] Kernel Metal implémenté dans kernels/
  • [ ] Opérateur côté host implémenté dans operations/
  • [ ] Gère les tenseurs vides
  • [ ] Gère les tenseurs non contigus
  • [ ] Supporte les types d'données requis (float32, float16, bfloat16, et souvent des types complexes via float2/half2)
  • [ ] Défaillances attendues supprimées de torch/testing/_internal/common_mps.py
  • [ ] Décorateurs skip/xfail supprimés de l'OpInfo (le cas échéant)

Skills similaires