std::extents: unrolling loops#
Code: bit-bcast/extents
Let’s look at std::extents
from <mdspan>
again. In particular its
operator==
. Any two extents objects are comparable. If they have the
different ranks, then they’re always considered different. Otherwise, they are
considered equal if and only if E1[i] == E2[i]
for all i
.
In pseudo-code extents can be written as:
template<size_t... Extents>
struct extents {
constexpr static const array S{Extents...};
// This is std::extents::extent
constexpr int
operator[](int r) {
// Can't use constepxr if because of `r`.
return (S[r] == -1) ? D[k[r]] : S[r];
}
constexpr static const array<size_t, rank> k;
array<int, dyn_rank> D;
};
Here S
are the static extents and we’ll use E[i]
as short-hand for the
extents::extent(i)
. Note that k
is pre-computed (somehow) at compile-time.
… and we’re ready to begin asking questions. This one’s interesting to
answer:
When computing
all(E1[i] == E2[i] for in range(rank))
can the optimizer skip comparisons1 if it knows that bothE1[i]
andE2[i]
are static extents?2
A first unpublished implementation in libstdc++ looked similar to:
template<size_t... Ints>
friend constexpr bool
operator==(const extents& lhs, const extents<Ints...>& rhs)
{
if constexpr (rank != rhs.rank) {
return false;
} else {
for (size_t i = 0; i < rank; ++i) {
if (lhs[i] != rhs[i]) {
return false;
}
}
return true;
}
}
Note that the trip count is known at compile time and the order in which the extents are checked doesn’t matter. This loop could be unrolled.
Let’s try and get a feeling for what the optimizer will do:
#include <mdspan>
extern "C" {
bool same1(const extents<1, 2, 3>& e1,
const extents<1, 2, 3>& e2)
{ return e1 == e2; }
bool same2(const extents<0, 2, 3>& e1,
const extents<1, 2, 3>& e2)
{ return e1 == e2; }
bool same3(const extents<0, dyn, 3>& e1,
const extents<1, 2, 3>& e2)
{ return e1 == e2; }
bool same4(const extents<1, dyn, 3>& e1,
const extents<dyn, 2, 0>& e2)
{ return e1 == e2; }
The generated code with -O2
, after eliminating filler code for
alignment, is:
0000000000000000 <same1>:
0: mov eax,0x1
5: ret
0000000000000010 <same2>:
10: xor eax,eax
12: ret
0000000000000020 <same3>:
20: xor eax,eax
22: ret
Good! That’s very nice. Let’s glance at the next one:
0000000000000030 <same4>:
30: mov r10,rdi
33: mov r9,rsi
36: lea r8,[rip+0x0] # 3d <same4+0xd>
3d: xor eax,eax
3f: lea rdi,[rip+0x0] # 46 <same4+0x16>
46: mov rdx,QWORD PTR [r8+rax*8]
4a: mov ecx,edx
4c: cmp rdx,0xffffffffffffffff
50: jne 61 <same4+0x31>
52: lea rdx,[rip+0x0] # 59 <same4+0x29>
59: mov rdx,QWORD PTR [rdx+rax*8]
5d: mov ecx,DWORD PTR [r10+rdx*4]
61: mov rdx,QWORD PTR [rdi+rax*8]
65: mov esi,edx
67: cmp rdx,0xffffffffffffffff
6b: jne 7c <same4+0x4c>
6d: lea rdx,[rip+0x0] # 74 <same4+0x44>
74: mov rdx,QWORD PTR [rdx+rax*8]
78: mov esi,DWORD PTR [r9+rdx*4]
7c: cmp esi,ecx
7e: jne 90 <same4+0x60>
80: add rax,0x1
84: cmp rax,0x3
88: jne 46 <same4+0x16>
8a: mov eax,0x1
8f: ret
90: xor eax,eax
92: ret
This is exactly what we’re worried about, because the last static extents is a
mismatch. Hence, it’s supposed to just return false
. Naturally, while -O2
is a common optimization level, it’s not the highest one, maybe -O3
is
better? Let’s try:
0000000000000030 <same4>:
30: xor eax,eax
32: ret
Wow, nice! Sadly, or fortunately, this will be the theme throughout the post,
on -O2
one can find easy improvements, on -O3
the optimizer sees right
through our little house of cards an produces optimal code3.
Okay, let’s try a non-trivial case:
bool same5(const extents<1, 2, 3>& e1,
const extents<1, dyn, 3>& e2)
{ return e1 == e2; }
This time we get (-O2
):
00000000000000a0 <same5>:
a0: mov r9,rsi
a3: xor eax,eax
a5: lea r8,[rip+0x0] # ac <same5+0xc>
ac: lea rdi,[rip+0x0] # b3 <same5+0x13>
b3: mov rdx,QWORD PTR [r8+rax*8]
b7: cmp rdx,0xffffffffffffffff
bb: je c1 <same5+0x21>
c1: mov rcx,QWORD PTR [rdi+rax*8]
c5: mov esi,ecx
c7: cmp rcx,0xffffffffffffffff
cb: jne dc <same5+0x3c>
cd: lea rcx,[rip+0x0] # d4 <same5+0x34>
d4: mov rcx,QWORD PTR [rcx+rax*8]
d8: mov esi,DWORD PTR [r9+rcx*4]
dc: cmp edx,esi
de: jne f0 <same5+0x50>
e0: add rax,0x1
e4: cmp rax,0x3
e8: jne b3 <same5+0x13>
ea: mov eax,0x1
ef: ret
f0: xor eax,eax
f2: ret
Since, this case is non-trivial, we expect there to see some assembly; but is it reasonable? Back to the topic at hand. Let’s transcribe it to pseudo code:
for(i = 0; i != 3; ++i)
if S2[i] == -1
e2 = D2[k2[i]]
if S1[i] != e2:
return false;
return true;
How does one guess? First, there’s the sequence with a backwards jump:
e0: add rax,0x1
e4: cmp rax,0x3
e8: jne b3 <same5+0x13>
this smells like a loop. Next, we should track the loads:
and before we blindly trust that rdi
and rsi
are the first and second
pointer arguments of same5
we check… and see:
a0: mov r9,rsi
ac: lea rdi,[rip+0x0] # b3 <same5+0x13>
c5: mov esi,ecx
Annoying, but fine. Let’s find something it’s loading:
c1: mov rcx,QWORD PTR [rdi+rax*8]
loads 8 bytes from unknown_offset + 0*8
, so it’s certainly not D
because
int
s are 4 bytes each. The offset is unknown because it’s reading a static
variable (and those are only given a location in executables/shared libraries,
but not object files), so that’s likely one of the static arrays
S1
or S2
.
There’s also
d4: mov rcx,QWORD PTR [rcx+rax*8]
d8: mov esi,DWORD PTR [r9+rcx*4]
we know that r9
is the second argument passed to the function, i.e. the
reference/pointer e2
. Therefore, this is the indirect load D2[k2[i]]
.
Then, there’s
c7: cmp rcx,0xffffffffffffffff
cb: jne dc <same5+0x3c>
this checks if rcx
, i.e. S?[i]
, is equal to -1
. If not it jumps forwards.
That’s likely an if
-condition. Now, slowly one can see the rest.
What’s interesting is that:
it’s optimized the indirection for
E1
, because all its extents are static, there’s no need to emit code that can handleD1[k1[i]]
.it’s not eliminated the loop,
it’s not eliminated the trivial iterations at the beginning and end of the loop.
Considering all we want to do is:
all(E1[i] == E2[i] for i in range(n))
the amount of code seems excessive (usually rank <= 3; almost always <= 8,
because m**k
just grows too fast for k >= 4
). This is easily confirmed by
recompiling with -O3
:
0000000000000040 <same5>:
40: cmp DWORD PTR [rsi],0x2
43: sete al
46: ret
which checks if D[0] == 2
and then copies the flag (from cmp
) to the return
register (with sete
).
Purely out of curiosity, what happens if we write loop-less code. How? Probably, some variant of pack expansion. Maybe something like this:
template<size_t... OtherExtents>
requires (V == Version::v3)
friend constexpr bool
operator==(const extents& lhs, const extents<V, OtherExtents...>& rhs)
{
auto impl = [&]<size_t... Is>(std::index_sequence<Is...>) {
return ((lhs[Is] == rhs[Is]) && ...);
};
return impl(std::make_index_sequence<rank>());
}
It’s a bit clumsy, because everything is stuffed into a lambda for the sole purpose of deducing the loop indices, but otherwise it’s a reasonably flexible pattern to create a compile-time for-loop.
Time to compile all examples again (with -O2
):
0000000000000000 <same1>:
0: mov eax,0x1
5: ret
0000000000000010 <same2>:
10: xor eax,eax
12: ret
0000000000000020 <same3>:
20: xor eax,eax
22: ret
0000000000000030 <same4>:
30: xor eax,eax
32: ret
0000000000000040 <same5>:
40: cmp DWORD PTR [rsi],0x2
43: sete al
46: ret
Nice, that’s the same as with -O3
. That’s nice enough to warrant one more
example:
bool same6(const extents<dyn, 2, 3, dyn>& e1,
const extents<dyn, dyn, 3, 4>& e2)
{ return e1 == e2; }
It’s longer, has no mismatching static extents and the dynamic extents don’t line up nicely. The generated code is:
0000000000000050 <same6>:
50: mov edx,DWORD PTR [rsi]
52: xor eax,eax
54: cmp DWORD PTR [rdi],edx
56: je 60 <same6+0x10>
58: ret
59: nop DWORD PTR [rax+0x0]
60: cmp DWORD PTR [rsi+0x4],0x2
64: jne 58 <same6+0x8>
66: cmp DWORD PTR [rdi+0x4],0x4
6a: sete al
6d: ret
Recall rdi
is e1
and rsi
is e2
. So we see: not only has it eliminated
the trivial comparison 3 == 3
, it’s also removed the indirection D[k[i]]
because i
is a compile time constant. In pseudo code:
if (D1[0] != D2[0]) return false
if ( 2 != D2[1]) return false
return D1[1] == 4
… and that’s it :-)
(The code is compiled with GCC on Linux x86_64
.)
If two static extents are equal, there’s no need to waste cycles at runtime to check that they are equal. Just as interesting, if two static extents are different, can just return
false
. ↩︎The loop is reasonably generic in that it’s very short and knowing the value of the iteration variable unlocks new optimizations (e.g. if one knows the value of
i
at compile time thenE[i]
is much simpler to compute and might be cheaper too). Now it’s interesting to see if the compiler is willing to unroll this loop for us; or if we’re expected to “help”. If it doesn’t unroll, can it at least see when there’s a mismatch in static extents? ↩︎Therefore, if you don’t care about optimizing on anything other than
-O3
, this might be the time to stop reading or keep reading but ignore the “optimizations”. It’s still impressive to see how radically the optimizer eliminates all the generic code and only leaves exactly what’s needed =) ↩︎