TL;DR: Read this section of the CUDA C++ Programming Guide before using the __CUDA_ARCH__ macro so that you’re aware of cases where it’s problematic.


In the last year or so, I’ve become quite interested in low-level GPU programming for performance optimization of neural networks. Diving into CUDA programming has been a fun (and challenging!) way to explore the magic behind what makes neural nets run fast on GPUs.1 I tend to do a lot of Python programming, so stepping down into the depths of low-level GPU programming has been, well… illuminating. If there’s one thing I’ve noticed, it’s that learning to write high-performance CUDA code is like learning to program in hard mode. In my experience, one does not simply “learn” CUDA – one merely becomes less incompetent over time. The devil tends to be in the details when you write CUDA code, and the details are far too easy to miss if you’re not careful.

In this post, I want to highlight a specific detail of CUDA C/C++ that can easily lead to mistakes: the __CUDA_ARCH__ macro. It’s caused me a fair bit of confusion in the last few days, so this post acts as an overview of my findings on how to use it while avoiding some of its pitfalls.

What is the __CUDA_ARCH__ macro?

The __CUDA_ARCH__ macro is used to write code that behaves differently depending on your GPU architecture. NVIDIA tends to release a new GPU architecture every two years or so — which may provide functionality that previous generations did not — so it’s quite useful to have the ability to write architecture-specific code.

Inside device code (CUDA code), the __CUDA_ARCH__ macro expands to the compute capability of the GPU that you’re compiling for. For example, on an NVIDIA A100, which has compute capability sm80, __CUDA_ARCH__ will expand to 800. For a GeForce RTX 4090, which has compute capability sm89, __CUDA_ARCH__ will expand to 890. We can use the value of __CUDA_ARCH__ to conditionally include code that may only work for specific GPU architectures.

As a real-world example, here’s a snippet from Flash Attention that sets certain traits depending on whether the GPU architecture is at least Ampere (sm80):

template<...>
struct Flash_kernel_traits {

#if defined(__CUDA_ARCH__) &&  __CUDA_ARCH__ >= 800
    using Element = elem_type;
    static constexpr bool Has_cp_async = true;
#else
    using Element = cutlass::half_t;
    static constexpr bool Has_cp_async = false;
#endif

}

Note the use of Has_cp_async: The cp_async operation requires sm80 or higher, so the Has_cp_async boolean is a way to set the correct instructions at compile time by using preprocessor directives.

The problem: undefined behavior

Now, what I want to highlight is the cases where using __CUDA_ARCH__ is actually very problematic, and needs to be avoided. In particular, there are various situations where using it actually leads to undefined behavior. Luckily, the CUDA C++ programming guide has a section that indicates the precise four situations where using __CUDA_ARCH__ leads to undefined behavior2. These are3:

  1. Setting type signatures for __global__ functions and/or __device__ and __constant__ variables based on __CUDA_ARCH__.
  2. Instantiating function templates for __global__ functions based on __CUDA_ARCH__.
  3. In separate compilation, using __CUDA_ARCH__ to conditionally define a function or variable with external linkage.
  4. In separate compilation, using __CUDA_ARCH__ in headers such that different objects could contain different behavior.

For example, the following code block violates rule no. 1:

#if !defined(__CUDA_ARCH__)
typedef int mytype;
#else
typedef double mytype;
#endif

__global__ void foo(mytype in, // problem: foo's type depends on __CUDA_ARCH__
                    mytype *ptr)
{
  *ptr = in;
}

What exactly happens when you write code that violates any of these four rules? The docs provide the answer:

The compiler does not guarantee that a diagnostic will be generated for the unsupported uses of __CUDA_ARCH__ described above.

In other words, your code can now exhibit arbitrary behavior, including crashing, producing incorrect results, or even behaving as expected by coincidence. What’s worse, the compiler won’t even tell you that something is wrong.

Someone from the GPU Mode Discord channel said it best:

This is probably the scariest kind of thing you can have in C/C++: UB NDR (Undefined Behaviour, No Diagnostic Required)

Scary stuff. If you take away nothing else from the current post, let it be this: read the aforementioned section of the CUDA C++ programming guide before you find yourself using __CUDA_ARCH__ in any kind of non-trivial way.

Case study: torchao FP6 kernel

To give you an impression of how things can go wrong, let’s look at a case study. I have been contributing to torchao in recent months, a library for PyTorch native quantization and sparsity for training and inference. Specifically, I have been focusing on torchaos integration with FP6, a high-performance CUDA kernel for 6-bit quantization. Most of the details aren’t important – I just want to highlight a small part of the code. Here’s a simplified version of what the structure of the FP6 code looked like at some point:

#if !defined(__CUDA_ARCH__) || __CUDA_ARCH__ >= 800

template<int EXPONENT, int MANTISSA>
__global__ void fpx_kernel(const uint4* weights, const half* scales,
                           const half* in_feats, half* out_feats,
                           const size_t M, const size_t N, const size_t K)
{
    // CUDA code here ...
}


torch::Tensor fp6_forward_cuda(
    torch::Tensor   _in_feats,
    torch::Tensor   _weights,
    torch::Tensor   _scales)
{
    // Setup code here ...
    
    // Call the CUDA kernel
    fpx_kernel<3, 2><<<grid_dim, block_dim, stream>>>(
        weights, scales, in_feats, out_feats, M, N, K);
    
    return out_feats;
}

#endif

The idea here is that the FP6 kernel only supports sm80 and higher, so we use __CUDA_ARCH__ to only include the FP6 code when __CUDA_ARCH__ >= 800 or when __CUDA_ARCH__ is undefined (which is the case for host code). This avoids compilation problems when using GPUs with compute capability older than sm80.

The code above worked fine on sm80 GPUs and higher. However, when we called the above kernel on a sm75 GPU, we noticed that it behaved inconsistently. Sometimes, the code would error out saying that the kernel wasn’t defined (which was the expected outcome). However, other times, it would actually run without errors! (while producing garbage output.)4

If we refer to the four rules I mentioned earlier from the CUDA C++ Programming Guide, it appears that we are violating rule no. 2. To quote the docs:

If a __global__ function template is instantiated and launched from the host, then the function template must be instantiated with the same template arguments irrespective of whether __CUDA_ARCH__ is defined and regardless of the value of __CUDA_ARCH__.

Note that during compilation, the nvcc compiler makes a distinction between device code (code that runs on the GPU) and host code (code that runs on the CPU). Host code is forwarded to a standard C++ compiler (like gcc or cl), similar to any regular C++ program. The device code, on the other hand, is compiled by nvcc, the NVIDIA CUDA compiler. Afterwards, the device code is embedded into the host object file as a “fat binary”.

In the case of the FP6 code shown above, it would appear that the following is happening when we compile for an sm75 GPU:

  1. For the compilation of host code, the #if directive will evaluate to true (since !defined(__CUDA_ARCH__) is true for host code), which means the fpx_kernel<3, 2> function template is instantiated.
  2. For the compilation of device code, the #if directive will evaluate to false (since __CUDA_ARCH__ < 800), which means the fpx_kernel<3, 2> function template is not instantiated.

This means there is a problematic divergence between the host and device code: The host code includes a __global__ function symbol that the device code does not. What happens as a result is undefined behavior, i.e. anything can happen. This explained the inconsistent behavior of the kernel, which went away after addressing this issue.5

Further reading



  1. Shoutout to the GPU Mode Discord community (formerly known as CUDA Mode), which has some of the best lectures on CUDA I’ve seen, as well as being a great community of CUDA hackers from which I’ve learned a lot. 

  2. Thanks to Thien Tran for pointing this section out to me. 

  3. I’m omitting some of the details here for brevity. The CUDA C++ programming guide contains all the relevant details. 

  4. This actually happened in conjunction with another bug where the program failed silently because it did not check for errors after a kernel call. Always check for errors after calling a CUDA kernel! Here’s a good primer on proper CUDA error checking: https://leimao.github.io/blog/Proper-CUDA-Error-Checking/ 

  5. A possible fix is to use the #if directive inside the fpx_kernel function body instead of outside of it, such that the function symbols are the same for device and host code. For example, the function body can be left empty in the case of an unsupported GPU architecture.