std::extents: unrolling loops

std::extents: unrolling loops#

Code: bit-bcast/extents

Let’s look at std::extents from <mdspan> again. In particular its operator==. Any two extents objects are comparable. If they have the different ranks, then they’re always considered different. Otherwise, they are considered equal if and only if E1[i] == E2[i] for all i.

In pseudo-code extents can be written as:

template<size_t... Extents>
struct extents {
  constexpr static const array S{Extents...};

  // This is std::extents::extent
  constexpr int
  operator[](int r) {
    // Can't use constepxr if because of `r`.
    return (S[r] == -1) ? D[k[r]] : S[r];
  }

  constexpr static const array<size_t, rank> k;
  array<int, dyn_rank> D;
};

Here S are the static extents and we’ll use E[i] as short-hand for the extents::extent(i). Note that k is pre-computed (somehow) at compile-time. … and we’re ready to begin asking questions. This one’s interesting to answer:

When computing all(E1[i] == E2[i] for in range(rank)) can the optimizer skip comparisons1 if it knows that both E1[i] and E2[i] are static extents?2

A first unpublished implementation in libstdc++ looked similar to:

template<size_t... Ints>
friend constexpr bool
operator==(const extents& lhs, const extents<Ints...>& rhs)
{
    if constexpr (rank != rhs.rank) {
        return false;
    } else {
        for (size_t i = 0; i < rank; ++i) {
            if (lhs[i] != rhs[i]) {
                return false;
            }
        }
        return true;
    }
}

Note that the trip count is known at compile time and the order in which the extents are checked doesn’t matter. This loop could be unrolled.

Let’s try and get a feeling for what the optimizer will do:

#include <mdspan>
extern "C" {

bool same1(const extents<1, 2, 3>& e1,
           const extents<1, 2, 3>& e2)
{ return e1 == e2; }

bool same2(const extents<0, 2, 3>& e1,
           const extents<1, 2, 3>& e2)
{ return e1 == e2; }

bool same3(const extents<0, dyn, 3>& e1,
           const extents<1, 2, 3>& e2)
{ return e1 == e2; }

bool same4(const extents<1, dyn, 3>& e1,
           const extents<dyn, 2, 0>& e2)
{ return e1 == e2; }

The generated code with -O2, after eliminating filler code for alignment, is:

0000000000000000 <same1>:
   0:   mov    eax,0x1
   5:   ret

0000000000000010 <same2>:
  10:   xor    eax,eax
  12:   ret

0000000000000020 <same3>:
  20:   xor    eax,eax
  22:   ret

Good! That’s very nice. Let’s glance at the next one:

0000000000000030 <same4>:
  30:   mov    r10,rdi
  33:   mov    r9,rsi
  36:   lea    r8,[rip+0x0]        # 3d <same4+0xd>
  3d:   xor    eax,eax
  3f:   lea    rdi,[rip+0x0]        # 46 <same4+0x16>
  46:   mov    rdx,QWORD PTR [r8+rax*8]
  4a:   mov    ecx,edx
  4c:   cmp    rdx,0xffffffffffffffff
  50:   jne    61 <same4+0x31>
  52:   lea    rdx,[rip+0x0]        # 59 <same4+0x29>
  59:   mov    rdx,QWORD PTR [rdx+rax*8]
  5d:   mov    ecx,DWORD PTR [r10+rdx*4]
  61:   mov    rdx,QWORD PTR [rdi+rax*8]
  65:   mov    esi,edx
  67:   cmp    rdx,0xffffffffffffffff
  6b:   jne    7c <same4+0x4c>
  6d:   lea    rdx,[rip+0x0]        # 74 <same4+0x44>
  74:   mov    rdx,QWORD PTR [rdx+rax*8]
  78:   mov    esi,DWORD PTR [r9+rdx*4]
  7c:   cmp    esi,ecx
  7e:   jne    90 <same4+0x60>
  80:   add    rax,0x1
  84:   cmp    rax,0x3
  88:   jne    46 <same4+0x16>
  8a:   mov    eax,0x1
  8f:   ret
  90:   xor    eax,eax
  92:   ret

This is exactly what we’re worried about, because the last static extents is a mismatch. Hence, it’s supposed to just return false. Naturally, while -O2 is a common optimization level, it’s not the highest one, maybe -O3 is better? Let’s try:

0000000000000030 <same4>:
  30:   xor    eax,eax
  32:   ret

Wow, nice! Sadly, or fortunately, this will be the theme throughout the post, on -O2 one can find easy improvements, on -O3 the optimizer sees right through our little house of cards an produces optimal code3.

Okay, let’s try a non-trivial case:

bool same5(const extents<1, 2, 3>& e1,
           const extents<1, dyn, 3>& e2)
{ return e1 == e2; }

This time we get (-O2):

00000000000000a0 <same5>:
  a0:   mov    r9,rsi
  a3:   xor    eax,eax
  a5:   lea    r8,[rip+0x0]        # ac <same5+0xc>
  ac:   lea    rdi,[rip+0x0]        # b3 <same5+0x13>
  b3:   mov    rdx,QWORD PTR [r8+rax*8]
  b7:   cmp    rdx,0xffffffffffffffff
  bb:   je     c1 <same5+0x21>
  c1:   mov    rcx,QWORD PTR [rdi+rax*8]
  c5:   mov    esi,ecx
  c7:   cmp    rcx,0xffffffffffffffff
  cb:   jne    dc <same5+0x3c>
  cd:   lea    rcx,[rip+0x0]        # d4 <same5+0x34>
  d4:   mov    rcx,QWORD PTR [rcx+rax*8]
  d8:   mov    esi,DWORD PTR [r9+rcx*4]
  dc:   cmp    edx,esi
  de:   jne    f0 <same5+0x50>
  e0:   add    rax,0x1
  e4:   cmp    rax,0x3
  e8:   jne    b3 <same5+0x13>
  ea:   mov    eax,0x1
  ef:   ret
  f0:   xor    eax,eax
  f2:   ret

Since, this case is non-trivial, we expect there to see some assembly; but is it reasonable? Back to the topic at hand. Let’s transcribe it to pseudo code:

for(i = 0; i != 3; ++i)
    if S2[i] == -1
      e2 = D2[k2[i]]
    if S1[i] != e2:
       return false;
return true;

How does one guess? First, there’s the sequence with a backwards jump:

  e0:   add    rax,0x1
  e4:   cmp    rax,0x3
  e8:   jne    b3 <same5+0x13>

this smells like a loop. Next, we should track the loads: and before we blindly trust that rdi and rsi are the first and second pointer arguments of same5 we check… and see:

  a0:   mov    r9,rsi
  ac:   lea    rdi,[rip+0x0]        # b3 <same5+0x13>
  c5:   mov    esi,ecx

Annoying, but fine. Let’s find something it’s loading:

  c1:   mov    rcx,QWORD PTR [rdi+rax*8]

loads 8 bytes from unknown_offset + 0*8, so it’s certainly not D because ints are 4 bytes each. The offset is unknown because it’s reading a static variable (and those are only given a location in executables/shared libraries, but not object files), so that’s likely one of the static arrays S1 or S2.

There’s also

  d4:   mov    rcx,QWORD PTR [rcx+rax*8]
  d8:   mov    esi,DWORD PTR [r9+rcx*4]

we know that r9 is the second argument passed to the function, i.e. the reference/pointer e2. Therefore, this is the indirect load D2[k2[i]].

Then, there’s

  c7:   cmp    rcx,0xffffffffffffffff
  cb:   jne    dc <same5+0x3c>

this checks if rcx, i.e. S?[i], is equal to -1. If not it jumps forwards. That’s likely an if-condition. Now, slowly one can see the rest.

What’s interesting is that:

  • it’s optimized the indirection for E1, because all its extents are static, there’s no need to emit code that can handle D1[k1[i]].

  • it’s not eliminated the loop,

  • it’s not eliminated the trivial iterations at the beginning and end of the loop.

Considering all we want to do is:

    all(E1[i] == E2[i] for i in range(n))

the amount of code seems excessive (usually rank <= 3; almost always <= 8, because m**k just grows too fast for k >= 4). This is easily confirmed by recompiling with -O3:

0000000000000040 <same5>:
  40:   cmp    DWORD PTR [rsi],0x2
  43:   sete   al
  46:   ret

which checks if D[0] == 2 and then copies the flag (from cmp) to the return register (with sete).

Purely out of curiosity, what happens if we write loop-less code. How? Probably, some variant of pack expansion. Maybe something like this:

template<size_t... OtherExtents>
requires (V == Version::v3)
friend constexpr bool
operator==(const extents& lhs, const extents<V, OtherExtents...>& rhs)
{
    auto impl = [&]<size_t... Is>(std::index_sequence<Is...>) {
        return ((lhs[Is] == rhs[Is]) && ...);
    };
    return impl(std::make_index_sequence<rank>());
}

It’s a bit clumsy, because everything is stuffed into a lambda for the sole purpose of deducing the loop indices, but otherwise it’s a reasonably flexible pattern to create a compile-time for-loop.

Time to compile all examples again (with -O2):

0000000000000000 <same1>:
   0:   mov    eax,0x1
   5:   ret

0000000000000010 <same2>:
  10:   xor    eax,eax
  12:   ret

0000000000000020 <same3>:
  20:   xor    eax,eax
  22:   ret

0000000000000030 <same4>:
  30:   xor    eax,eax
  32:   ret

0000000000000040 <same5>:
  40:   cmp    DWORD PTR [rsi],0x2
  43:   sete   al
  46:   ret

Nice, that’s the same as with -O3. That’s nice enough to warrant one more example:

bool same6(const extents<dyn,   2, 3, dyn>& e1,
           const extents<dyn, dyn, 3, 4>& e2)
{ return e1 == e2; }

It’s longer, has no mismatching static extents and the dynamic extents don’t line up nicely. The generated code is:

0000000000000050 <same6>:
  50:   mov    edx,DWORD PTR [rsi]
  52:   xor    eax,eax
  54:   cmp    DWORD PTR [rdi],edx
  56:   je     60 <same6+0x10>
  58:   ret
  59:   nop    DWORD PTR [rax+0x0]
  60:   cmp    DWORD PTR [rsi+0x4],0x2
  64:   jne    58 <same6+0x8>
  66:   cmp    DWORD PTR [rdi+0x4],0x4
  6a:   sete   al
  6d:   ret

Recall rdi is e1 and rsi is e2. So we see: not only has it eliminated the trivial comparison 3 == 3, it’s also removed the indirection D[k[i]] because i is a compile time constant. In pseudo code:

    if (D1[0] != D2[0]) return false
    if (    2 != D2[1]) return false
    return D1[1] == 4

… and that’s it :-)

(The code is compiled with GCC on Linux x86_64.)


  1. If two static extents are equal, there’s no need to waste cycles at runtime to check that they are equal. Just as interesting, if two static extents are different, can just return false↩︎

  2. The loop is reasonably generic in that it’s very short and knowing the value of the iteration variable unlocks new optimizations (e.g. if one knows the value of i at compile time then E[i] is much simpler to compute and might be cheaper too). Now it’s interesting to see if the compiler is willing to unroll this loop for us; or if we’re expected to “help”. If it doesn’t unroll, can it at least see when there’s a mismatch in static extents? ↩︎

  3. Therefore, if you don’t care about optimizing on anything other than -O3, this might be the time to stop reading or keep reading but ignore the “optimizations”. It’s still impressive to see how radically the optimizer eliminates all the generic code and only leaves exactly what’s needed =) ↩︎