std::extents: reducing indirection

std::extents: reducing indirection#

Let’s look at std::extents from <mdspan>. In essence std::extents is a variation of

std::array<int, n> exts{1, ..., n};

that handles compile-time constant extents and is used to store the shape (or extents) of a multi-dimensional array. For example:

std::extents<int, 5, 7, dyn> exts(11);

defines an extents object [5, 7, 11], the first two are known at compile time, these are called static extents, and the third one is known at runtime, it’s called a dynamic extent. Unlike std::array, the std::extents only stores the dynamic extents. The size of the array is called the rank of the extents.

We need to fix some notation:

  • S[i] is the i-th static extent (if the i-th extent is dynamic then S[i] == -1),
  • E[i] is the i-th extent (static or dynamic),
  • D[k] is the array of dynamic extents,
  • and k[i] is the index of the i-th extent in D (which is meaningful only if the i-th extent is dynamic).

Let’s look at extents::extent(size_t r) from <mdspan>, in our notation exts.extent(i) == E[i]. It’s original implementation in libstdc++ was:

template<typename _IndexType, array _Extents>
class _ExtentsStorage
{
  constexpr _IndexType
  _M_extent(size_t __r) const noexcept
  {
    auto __se = _Extents[__r];
    if (__se == dynamic_extent)
      return _M_dyn_exts[_S_dynamic_index[__r]];
    else
      return __se;
  }
  // ...
};

Here, _Extents is simply the array of static extents passed to std::extents, e.g. if std::extents<int, 1, dyn, 3, 4> then _Extents == std::array{1, -1, 3, 4}, _M_dyn_exts is the array D and _S_dynamic_index is k in our notation. Since _S_dynamic_index only depends on _Extents it can be computed at compile-time. (Right, dyn is short for std::dynamic_extent.)

The performance question is:

If __r is known at compile time, does the compiler eliminate the branching1 and indirection?2

Only one way to find out: try it. For example by multiplying a dynamic and static extent:

int prod(const std::extents<int, 3, dyn>& exts)
{ return exts.extent(0) * exts.extent(1); }

Compile it with -O2, and disassemble:

0000000000000000 <prod>:
   0:  mov    eax,DWORD PTR [rdi]
   2:  lea    eax,[rax+rax*2]
   5:  ret

(mov copies from the right to the left; lea is kinda fuse-multiply-add for integers. eax and rax refer to the same register, same for ?dx. rdi refers to the first pointer argument of the function prod, i.e. exts.)

It seems d + 2*d with d = D[0] is the fast way of implementing: 3*D[0]. So, the answer is: yes. Or “sometimes” if you want to be more cautious. Note, that not only did it eliminate the branching, it also eliminated the indirection D[k[i]].

Let’s try again:

int prod2(const std::extents<int, 3, dyn>& exts,
          const std::array<int, 2>& a)
{ return exts.extent(0) * a[0] + exts.extent(1) * a[1]; }

which results in:

0000000000000010 <prod2>:
  10:  mov    eax,DWORD PTR [rsi+0x4]
  13:  mov    edx,DWORD PTR [rsi]
  15:  imul   eax,DWORD PTR [rdi]
  18:  lea    edx,[rdx+rdx*2]
  1b:  add    eax,edx
  1d:  ret

(imul and add are integer multiplication and addition, they write back to the left. rsi refers to the first pointer argument, i.e. a. )

Again, it eliminated the branching and indirection. So that seems to work, but what about a similar, but harder question:

If S[i] != -1 happens to be true for all i, does the compiler eliminate the branching?

Let’s adjust the test problem a little:

int prod3(const std::extents<int, 3, 5, 7>& exts,
          const std::array<int, 3>& a)
{
  int ret = 0;
  for(size_t i = 0; i < exts.rank(); ++i)
    ret += exts.extent(i) * a[i];
  return ret;
}

The loop has a trip count that’s easily known at compile time. It’s less easy to see at compile-time that S[i] != dyn (hidden inside exts.extent(i)) is always true. Here, the compiler flags matter, but on -O2 the generated code is:

0000000000000020 <prod3>:
  20:  xor    eax,eax
  22:  xor    ecx,ecx
  24:  mov    rdx,QWORD PTR [rax*8+0x0]
  2c:  cmp    rdx,0xffffffffffffffff
  30:  je     36 <prod3+0x16>
  36:  imul   edx,DWORD PTR [rsi+rax*4]
  3a:  add    rax,0x1
  3e:  add    ecx,edx
  40:  cmp    rax,0x3
  44:  jne    24 <prod3+0x4>
  46:  mov    eax,ecx
  48:  ret

(cmp sets a flag if the difference is zero. je checks that flag and jumps to the indicated line, if the flag is set, otherwise it “falls-through” an continues with the next line.)

Notice the highlighted lines. Clearly, this is the check: S[i] == -1. What’s interesting is that there’s no code to handle the case when S[i] == -1 is true (because it never is). If S[i] == -1 were true, it would jump to line 36, else it’ll fall-through and also end up on line 36. Therefore, it has no way of handling the two cases differently. However, it’s not eliminated the check nor the jump. The picture changes when passing -O3:

0000000000000020 <prod3>:
  20:  mov    eax,DWORD PTR [rsi]
  22:  mov    edx,DWORD PTR [rsi+0x4]
  25:  lea    eax,[rax+rax*2]
  28:  lea    edx,[rdx+rdx*4]
  2b:  add    eax,edx
  2d:  mov    edx,DWORD PTR [rsi+0x8]
  30:  lea    eax,[rax+rdx*8]
  33:  sub    eax,edx
  35:  ret

The generated code makes sense: no comparison or jump and no loading of static extents. It simply computes 3*a[0] + 5*a[1] + 7*a[2] as follows:

a0 + 2*a0 + 4*a1 + a1 + 8*a2 - a2

with a0 = a[0], a1 = a[1] and a2 = a[2] (the movs). Naturally, there might be an even faster sequence of instructions, but this doesn’t contain any superfluous instructions.

Let’s see if we can make it easier on the optimizer and get the same behaviour on -O2:

template<typename _IndexType, array _Extents>
class _ExtentsStorage
{
  static constexpr bool
  _S_is_dynamic(size_t __r) noexcept
  {
    if constexpr (__all_static<_Extents>())
      return false;
    else
      return _Extents[__r] == dynamic_extent;
  }

  constexpr _IndexType
  _M_extent(size_t __r) const noexcept
  {
    if (_S_is_dynamic(__r))
      return _M_dyn_exts[_S_dynamic_index(__r)];
    else
      return _S_static_extent(__r);
  }
  // ...
};

The point is to see what happens if the condition is made more obviously always true or false. Let’s recompile again with -O2:

0000000000000020 <prod3>:
  20:  mov    eax,DWORD PTR [rsi]
  22:  mov    edx,DWORD PTR [rsi+0x4]
  25:  lea    eax,[rax+rax*2]
  28:  lea    edx,[rdx+rdx*4]
  2b:  add    eax,edx
  2d:  mov    edx,DWORD PTR [rsi+0x8]
  30:  lea    eax,[rax+rdx*8]
  33:  sub    eax,edx
  35:  ret

Yay! No more pointless comparisons and jumps. Okay, one last time. Let’s look at the following:

int prod4(const std::extents<int, 3, 5, 7, 11>& exts,
          const std::array<int, 4>& a)
{
  int ret = 0;
  for(size_t i = 0; i < exts.rank(); ++i)
    ret += exts.extent(i) * a[i];
  return ret;
}

it’s different from before in that the array is exactly four elements long; which just happens to be 128 bits. Let’s also compile this example with -O2. First the version without the optimization and disassemble. What we see is essentially unchanged:

0000000000000050 <prod4>:
  50:  xor    eax,eax
  52:  xor    ecx,ecx
  54:  mov    rdx,QWORD PTR [rax*8+0x0]
  5c:  cmp    rdx,0xffffffffffffffff
  60:  je     66 <prod4+0x16>
  66:  imul   edx,DWORD PTR [rsi+rax*4]
  6a:  add    rax,0x1
  6e:  add    ecx,edx
  70:  cmp    rax,0x4
  74:  jne    54 <prod4+0x4>
  76:  mov    eax,ecx
  78:  ret

Let’s compile (-O2) against the optimized version and disassemble:

0000000000000040 <prod4>:
  40:  movdqu xmm1,XMMWORD PTR [rsi]
  44:  movdqa xmm2,XMMWORD PTR [rip+0x0]
  4c:  movdqa xmm0,xmm1
  50:  psrlq  xmm1,0x20
  55:  pmuludq xmm0,xmm2
  59:  psrlq  xmm2,0x20
  5e:  pmuludq xmm1,xmm2
  62:  pshufd xmm0,xmm0,0x8
  67:  pshufd xmm1,xmm1,0x8
  6c:  punpckldq xmm0,xmm1
  70:  movdqa xmm1,xmm0
  74:  psrldq xmm1,0x8
  79:  paddd  xmm0,xmm1
  7d:  movdqa xmm1,xmm0
  81:  psrldq xmm1,0x4
  86:  paddd  xmm0,xmm1
  8a:  movd   eax,xmm0
  8e:  ret

Meaning, it unlocks SIMD vectorization on -O2. Overall, it’s impressive to see how well the compiler eliminates all the convoluted code we write and even figures out rather non-trivial properties like: S[i] == -1 for all i.

(The code is compiled with GCC on Linux x86_64.)


  1. Branching refers to the two branches of the if-condition. The cost of branching is measured in cycles, i.e. extremely low. However, if done in tight loops, due to how modern CPUs work, it’s distinctly not cheap. At the compiler level, the cost of branching is that it might prevent SIMD vectorization, because the single instruction in SIMD isn’t a single instruction. At the CPU level, branching means the CPU can’t know what the next instruction will be several cycles in advance. CPUs heavily utilize pipelining, i.e. whenever possible they start the next instruction before the current one is done, but if it doesn’t know what the next instruction is, it can’t be scheduled. Therefore, modern CPUs have branch prediction which guesses which branch will be taken and then speculatively executes that branch. This reduces the cost of branching back down to a few cycles. However, if it guesses wrong the cost is very high because it needs to “stop and forget” any speculative values. As a result branches in small functions that might appear in a tight loop make programmers uneasy. ↩︎

  2. Why is the question interesting? If all extents are static then even the most naive custom implementation would not use any storage for the dynamic extents and it would not have an if-condition. Similarly, if all extents are dynamic, there’s no need for the if-condition or the indirection k[i], one can simply return D[i]. The concern is that because extents supports a mix of dynamic and static indexes, one has to pay a cost at runtime, even for trivial cases where there’s no need for the additional complexity. ↩︎