std::extents
: reducing indirection#
Let’s look at std::extents
from <mdspan>
. In essence std::extents
is a
variation of
std::array<int, n> exts{1, ..., n};
that handles compile-time constant extents and is used to store the shape (or extents) of a multi-dimensional array. For example:
std::extents<int, 5, 7, dyn> exts(11);
defines an extents object [5, 7, 11]
, the first two are known at compile time,
these are called static extents, and the third one is known at runtime, it’s
called a dynamic extent. Unlike std::array
, the std::extents
only stores
the dynamic extents. The size of the array is called the rank of the extents.
We need to fix some notation:
S[i]
is the i-th static extent (if the i-th extent is dynamic thenS[i] == -1
),E[i]
is the i-th extent (static or dynamic),D[k]
is the array of dynamic extents,- and
k[i]
is the index of the i-th extent inD
(which is meaningful only if the i-th extent is dynamic).
Let’s look at extents::extent(size_t r)
from <mdspan>
, in our notation
exts.extent(i) == E[i]
. It’s original implementation in libstdc++ was:
template<typename _IndexType, array _Extents>
class _ExtentsStorage
{
constexpr _IndexType
_M_extent(size_t __r) const noexcept
{
auto __se = _Extents[__r];
if (__se == dynamic_extent)
return _M_dyn_exts[_S_dynamic_index[__r]];
else
return __se;
}
// ...
};
Here, _Extents
is simply the array of static extents passed to
std::extents
, e.g. if std::extents<int, 1, dyn, 3, 4>
then _Extents == std::array{1, -1, 3, 4}
, _M_dyn_exts
is the array D
and _S_dynamic_index
is k
in our notation. Since _S_dynamic_index
only depends on _Extents
it
can be computed at compile-time. (Right, dyn
is short for std::dynamic_extent
.)
The performance question is:
If
__r
is known at compile time, does the compiler eliminate the branching1 and indirection?2
Only one way to find out: try it. For example by multiplying a dynamic and static extent:
int prod(const std::extents<int, 3, dyn>& exts)
{ return exts.extent(0) * exts.extent(1); }
Compile it with -O2
, and disassemble:
0000000000000000 <prod>:
0: mov eax,DWORD PTR [rdi]
2: lea eax,[rax+rax*2]
5: ret
(mov
copies from the right to the left; lea
is kinda fuse-multiply-add for
integers. eax
and rax
refer to the same register, same for ?dx
. rdi
refers to the first pointer argument of the function prod
, i.e. exts
.)
It seems d + 2*d
with d = D[0]
is the fast way of implementing: 3*D[0]
.
So, the answer is: yes. Or “sometimes” if you want to be more cautious. Note,
that not only did it eliminate the branching, it also eliminated the
indirection D[k[i]]
.
Let’s try again:
int prod2(const std::extents<int, 3, dyn>& exts,
const std::array<int, 2>& a)
{ return exts.extent(0) * a[0] + exts.extent(1) * a[1]; }
which results in:
0000000000000010 <prod2>:
10: mov eax,DWORD PTR [rsi+0x4]
13: mov edx,DWORD PTR [rsi]
15: imul eax,DWORD PTR [rdi]
18: lea edx,[rdx+rdx*2]
1b: add eax,edx
1d: ret
(imul
and add
are integer multiplication and addition, they write back to
the left. rsi
refers to the first pointer argument, i.e. a
. )
Again, it eliminated the branching and indirection. So that seems to work, but what about a similar, but harder question:
If
S[i] != -1
happens to be true for alli
, does the compiler eliminate the branching?
Let’s adjust the test problem a little:
int prod3(const std::extents<int, 3, 5, 7>& exts,
const std::array<int, 3>& a)
{
int ret = 0;
for(size_t i = 0; i < exts.rank(); ++i)
ret += exts.extent(i) * a[i];
return ret;
}
The loop has a trip count that’s easily known at compile time. It’s less easy
to see at compile-time that S[i] != dyn
(hidden inside exts.extent(i)
) is
always true. Here, the compiler flags matter, but on -O2
the generated code
is:
0000000000000020 <prod3>:
20: xor eax,eax
22: xor ecx,ecx
24: mov rdx,QWORD PTR [rax*8+0x0]
2c: cmp rdx,0xffffffffffffffff
30: je 36 <prod3+0x16>
36: imul edx,DWORD PTR [rsi+rax*4]
3a: add rax,0x1
3e: add ecx,edx
40: cmp rax,0x3
44: jne 24 <prod3+0x4>
46: mov eax,ecx
48: ret
(cmp
sets a flag if the difference is zero. je
checks that flag and jumps to
the indicated line, if the flag is set, otherwise it “falls-through” an
continues with the next line.)
Notice the highlighted lines. Clearly, this is the check: S[i] == -1
. What’s
interesting is that there’s no code to handle the case when S[i] == -1
is
true (because it never is). If S[i] == -1
were true, it would jump to line 36,
else it’ll fall-through and also end up on line 36. Therefore, it has no way of
handling the two cases differently. However, it’s not eliminated the check nor
the jump. The picture changes when passing -O3
:
0000000000000020 <prod3>:
20: mov eax,DWORD PTR [rsi]
22: mov edx,DWORD PTR [rsi+0x4]
25: lea eax,[rax+rax*2]
28: lea edx,[rdx+rdx*4]
2b: add eax,edx
2d: mov edx,DWORD PTR [rsi+0x8]
30: lea eax,[rax+rdx*8]
33: sub eax,edx
35: ret
The generated code makes sense: no comparison or jump and no loading of static
extents. It simply computes 3*a[0] + 5*a[1] + 7*a[2]
as follows:
a0 + 2*a0 + 4*a1 + a1 + 8*a2 - a2
with a0 = a[0]
, a1 = a[1]
and a2 = a[2]
(the mov
s). Naturally, there
might be an even faster sequence of instructions, but this doesn’t contain any
superfluous instructions.
Let’s see if we can make it easier on the optimizer and get the same behaviour
on -O2
:
template<typename _IndexType, array _Extents>
class _ExtentsStorage
{
static constexpr bool
_S_is_dynamic(size_t __r) noexcept
{
if constexpr (__all_static<_Extents>())
return false;
else
return _Extents[__r] == dynamic_extent;
}
constexpr _IndexType
_M_extent(size_t __r) const noexcept
{
if (_S_is_dynamic(__r))
return _M_dyn_exts[_S_dynamic_index(__r)];
else
return _S_static_extent(__r);
}
// ...
};
The point is to see what happens if the condition is made more obviously always
true
or false
. Let’s recompile again with -O2
:
0000000000000020 <prod3>:
20: mov eax,DWORD PTR [rsi]
22: mov edx,DWORD PTR [rsi+0x4]
25: lea eax,[rax+rax*2]
28: lea edx,[rdx+rdx*4]
2b: add eax,edx
2d: mov edx,DWORD PTR [rsi+0x8]
30: lea eax,[rax+rdx*8]
33: sub eax,edx
35: ret
Yay! No more pointless comparisons and jumps. Okay, one last time. Let’s look at the following:
int prod4(const std::extents<int, 3, 5, 7, 11>& exts,
const std::array<int, 4>& a)
{
int ret = 0;
for(size_t i = 0; i < exts.rank(); ++i)
ret += exts.extent(i) * a[i];
return ret;
}
it’s different from before in that the array is exactly four elements long;
which just happens to be 128 bits. Let’s also compile this example with -O2
.
First the version without the optimization and disassemble. What we see is
essentially unchanged:
0000000000000050 <prod4>:
50: xor eax,eax
52: xor ecx,ecx
54: mov rdx,QWORD PTR [rax*8+0x0]
5c: cmp rdx,0xffffffffffffffff
60: je 66 <prod4+0x16>
66: imul edx,DWORD PTR [rsi+rax*4]
6a: add rax,0x1
6e: add ecx,edx
70: cmp rax,0x4
74: jne 54 <prod4+0x4>
76: mov eax,ecx
78: ret
Let’s compile (-O2
) against the optimized version and disassemble:
0000000000000040 <prod4>:
40: movdqu xmm1,XMMWORD PTR [rsi]
44: movdqa xmm2,XMMWORD PTR [rip+0x0]
4c: movdqa xmm0,xmm1
50: psrlq xmm1,0x20
55: pmuludq xmm0,xmm2
59: psrlq xmm2,0x20
5e: pmuludq xmm1,xmm2
62: pshufd xmm0,xmm0,0x8
67: pshufd xmm1,xmm1,0x8
6c: punpckldq xmm0,xmm1
70: movdqa xmm1,xmm0
74: psrldq xmm1,0x8
79: paddd xmm0,xmm1
7d: movdqa xmm1,xmm0
81: psrldq xmm1,0x4
86: paddd xmm0,xmm1
8a: movd eax,xmm0
8e: ret
Meaning, it unlocks SIMD vectorization on -O2
. Overall, it’s impressive to
see how well the compiler eliminates all the convoluted code we write and even
figures out rather non-trivial properties like: S[i] == -1
for all i
.
(The code is compiled with GCC on Linux x86_64
.)
Branching refers to the two branches of the if-condition. The cost of branching is measured in cycles, i.e. extremely low. However, if done in tight loops, due to how modern CPUs work, it’s distinctly not cheap. At the compiler level, the cost of branching is that it might prevent SIMD vectorization, because the single instruction in SIMD isn’t a single instruction. At the CPU level, branching means the CPU can’t know what the next instruction will be several cycles in advance. CPUs heavily utilize pipelining, i.e. whenever possible they start the next instruction before the current one is done, but if it doesn’t know what the next instruction is, it can’t be scheduled. Therefore, modern CPUs have branch prediction which guesses which branch will be taken and then speculatively executes that branch. This reduces the cost of branching back down to a few cycles. However, if it guesses wrong the cost is very high because it needs to “stop and forget” any speculative values. As a result branches in small functions that might appear in a tight loop make programmers uneasy. ↩︎
Why is the question interesting? If all extents are static then even the most naive custom implementation would not use any storage for the dynamic extents and it would not have an if-condition. Similarly, if all extents are dynamic, there’s no need for the if-condition or the indirection
k[i]
, one can simply returnD[i]
. The concern is that becauseextents
supports a mix of dynamic and static indexes, one has to pay a cost at runtime, even for trivial cases where there’s no need for the additional complexity. ↩︎