Blogg

Här finns tekniska artiklar, presentationer och nyheter om arkitektur och systemutveckling. Håll dig uppdaterad, följ oss på LinkedIn

Callista medarbetare Erik Lupander

Go SIMD part 5: Bitwise operation performance optimizations

// Erik Lupander

Recently, a proposal for adding low-level SIMD support to Go was “Accepted” and was added to Go 1.26 as a GOEXPERIMENT. In the last part I applied a more genuine and fully SIMD-ish approach to my dear ray-sphere intersections, producing results almost 4x faster than the corresponding scalar code. In this part, I’ll try to take advantage of the addition of ToBits Mask32x8 methods to the archsimd package included in Go 1.26.

1. Introduction

In previous iterations of the new simd/archsimd package, working with masks was somewhat cumbersome since every time one wanted to run an if-statement to check if bits were set, one had to do something like:

myMask := someVec.Greater(otherVector)
if myMask.ToInt32x8().IsZero()  {
	// Do something if no elements of someVec were greater than their corresponding
	// element in otherVector
}

Not only was the above a bit counter-intuitive, where checking for “all zeroes” often necessitated calling the “opposite” method of how one would naturally express the intent. It also entailed calling ToInt32x8 on the mask and then IsZero() to go from the MaskNNxN type to a boolean value.

Assembly-wise, ToInt32x8 internals is opaque, while IsZero relies on an underlying VPTEST instruction. With ToBits() uint8 having been added to the mask types in Go 1.26, we can now cross over to scalar code with fewer instructions and less code.

For 8-element SIMD types such as Int32x8 a call to ToBits() returns a single uint8 which we can do all sorts of standard bitwise comparisons and operations on.

2. Using ToBits()

Note: type archsimd.Mask32x8 did have a ToBits() method defined even in the very first iterations of archsimd, but it was implemented using AVX512 instructions which means it wasn’t usable on my Mac, i.e. not usable on many common commodity CPUs.

With Go 1.26 ToBits() is available using plain AVX/AVX2 for compliant mask types, implemented using the VPMOVMSKB instruction (or variants, depending on mask size).

What this means is that we won’t need to go from Mask32x8 to Int32x8, then call IsZero() anymore. Instead, the Go snippet from the previous section can now be written as:

myMask := someVec.Greater(otherVec)
if myMask.ToBits() == 0x00  { // no bits were set
	// Do something...
}

This also allows much better expressivity using old-fashioned bit masks. Checking if all elems in someVec were greater than otherVec is now just a standard:

myMask := someVec.Greater(otherVec)
if myMask.ToBits() == 0xFF  { // All bits were set
	// Do something...
}

If we want to check if any of the 4 least significant bits are set (for example), we can now write:

if myMask.ToBits()&0b1111 != 0 { // at least one bit was set
    // Do something...
}

The point I’m trying to make is that instead of either using IsZero() or storing the mask back to scalar code as Int32x8 => [8]int32, we can now express ourselves simpler.

What about performance? Can we use the new ToBits() to speed up our ray-sphere intersection testing code?

3. Faster ray-sphere intersection testing

I’ll break this section into a few different parts.

3.1 Exit-early if statements

Throughout the ray/sphere intersection code, we have a number of “exit early” if statements, where we check if all (or no) elements fulfill some criteria - if criteria is fulfilled no intersection can take place for the 8 spheres in the current batch, and we can continue the for-loop with the next batch of spheres.

// If no elems in tcas is greater than zero...
if tcas.GreaterEqual(zeroes).ToInt32x8().IsZero() {
    continue
}

With ToBits(), we can express this using both less code and more naturally:

// If all tcas are less than zero...
if tcas.Less(zeroes).ToBits() == 0xFF {
    continue
}

In each iteration (processing a batch of 8 spheres, one per SIMD lane) there are four “exit early” if-statements with continue.

Regarding performance, Replacing .ToInt32x8().IsZero() with ToBits() does not yield any significant improvement (or regression) as far as performance is concerned - the big win here is the simplification of the code.

3.2 Finding the index of the lowest t

In the last part, a rather big portion of the time spent in the SIMD intersection code was not related to actual intersection testing - it was to figure out which index (in current batch of spheres) the lowest t (distance from camera to intersection point in 3D space) had. E.g. given [4,3,7,1,3,7,5,6] how can the computer figure out that the 1 in index 3 is the lowest.

There’s actually two optimizations.

  1. Before, if the t in the current batch was the lowest we’d seen this far (compared to the variable defined outside of the for-loop), we resolved its index in the current batch and stored that.
  2. More importantly, the function resolveCurrentIndex which was an abomination of SIMD code including masks, XOR and permutations, can now be written using simple scalar code given ToBits().

The abomination code:

func resolveCurrentIndex(maskLo archsimd.Int32x4, maskHi archsimd.Int32x4, offset, currentIndex int) int {

	for i := range 4 {
		// firstElemSetMask is declared elsewhere as [-1,0,0,0]
		// (-1 in two's complement sets all bits to 1)
		
		// If XOR produces [0,0,0,0] the masks are equal.
        if maskLo.Xor(firstElemSetMask).IsZero() {
			// If the lower mask is equal, return the iteration index i + the offset (e.g. which batch of 8).
            return offset + i
        }
        if maskHi.Xor(firstElemSetMask).IsZero() {
            return offset + 4 + i
        }
        // shift every element one step to the left, e.g. [0,0,1,0] => [0,1,0,0], 
		// next time [0,1,0,0] becomes [1,0,0,0] and so on.
        maskLo = maskLo.PermuteScalars(1, 2, 3, 0) 
        maskHi = maskHi.PermuteScalars(1, 2, 3, 0)
    }
    // If no elem was set, return the index passed to this func
	return currentIndex
}

The code above used XOR’s and PermuteScalars in a loop (can also be unrolled), counting the number of iterations needed to match lo/hi masks with a [1,0,0,0] pattern.

This consumed a significant portion of time spent. The benchmark figures below show the durations before any ToBits or other optimizations, with the code related to finding the index removed - three iterations per variant:

WITH code to find index
Benchmark16IntersectSpheresSIMD-16    	43165770	        26.03 ns/op
Benchmark16IntersectSpheresSIMD-16    	44923888	        24.42 ns/op
Benchmark16IntersectSpheresSIMD-16    	42984493	        25.69 ns/op

WITHOUT code to find index
Benchmark16IntersectSpheresSIMD-16    	61557524	        18.19 ns/op
Benchmark16IntersectSpheresSIMD-16    	60667825	        17.90 ns/op
Benchmark16IntersectSpheresSIMD-16    	59046730	        18.08 ns/op

As seen, the three iterations averages out to ~25ns per intersection test against 16 spheres including finding the index, and ~18ns per iteration without. Naturally, the intersection testing code breaks when we don’t figure out which sphere that was intersected, the point being that it consumes ~28% av the total time spent in the IntersectSpheresSIMD function.

3.3 Finding index using ToBits()

So, I fixed two things.

3.3.1 Store mask per iteration, skip finding index

First off - I realized we don’t have to figure out the index on each iteration producing a new lowest t. We can just save the mask containing the lowest t for that iteration to a variable currentMask uint8 defined just outside the for-loop.

maskLo := currentMin.Equal(lo).ToInt32x4() // [0,0,1,0]
maskHi := currentMin.Equal(hi).ToInt32x4() // [0,0,0,0]

NEW:

hiBits := currentMin.Equal(hi).ToBits()
loBits := currentMin.Equal(lo).ToBits()
currentMask = (hiBits << 4) | loBits // currentMask defined outside for-loop
currentBatch = i // keeps track of which batch that has produced the currently lowest t

Then, once the for-loop finishes, it’s time to use some plain bitwise operations to find out which bit of currentBatch that is set. I implemented this in a standalone findCurrentIndex function:

func findCurrentIndex(currentMask uint8, currentBatch int) int {

	// Check if any of the low bits are set.
	if currentMask&0b1111 != 0 {

		// Low bits are set, check low bits
		switch currentMask {
		case 0b00000001:
			return currentBatch
		case 0b00000010:
			return currentBatch + 1
		case 0b00000100:
			return currentBatch + 2
		default:
			return currentBatch + 3
		}
	} else {
		// Low bits ARE NOT set, therefore, check high bits
		switch currentMask >> 4 { // Shift the high bits into the lower half and compare.
		case 0b00000001:
			return currentBatch + 4
		case 0b00000010:
			return currentBatch + 5
		case 0b00000100:
			return currentBatch + 6
		default:
			return currentBatch + 7
		}
	}
}

The code above first determines whether the set bit is in the lower or higher half, and then uses a plain switch that will find the index of the set bit in one to three attempts.

A constant-time implementation is also possible, where one first determines lo/hi nibble of byte that contains the 1, then lo/hi half of nibble and finally which of the 2 remaining bits is the 1. That variant will always use 3 if-statements, but it seems to about on par performance-wise with the (IMHO) more readable version above, so I’m sticking with that.

3.4 Using bits.TrailingZeros8

Turns out that instead of the switch/case bonanza of my findCurrentIndex function above, there’s a convenient function in the stdlib bits package that can provide the index-finding functionality of findCurrentIndex, albeit in a slightly different manner: bits.TrailingZeros8.

TrailingZeros8 accepts a single uint8 and returns the number of trailing zeroes, i.e. number of zeros to the right of the last 1. E.g. given 0b01000000, a call to TrailingZeros8 returns 6, which then makes it easy to figure out that the set bit is at index 8-6 == 2.

Instead of running currentIndex := findCurrentIndex(currentMask, currentBatchIndex) after the for-loop, the following line suffices:

currentIndex := currentBatchIndex + (8 - bits.TrailingZeros8(currentMask))

Performance-wise, there’s no consistently measurable difference between findCurrentIndex and bits.TrailingZeros8 in the context of my ray/sphere intersection function, so I’ll probably scrap `findCurrentIndex´ and rely on the stdlib instead.

4. Benchmark results!

So, back to the overall picture - does these fixes/simplifications improve performance from the ~25ns/iteration over three runs?

Benchmark16IntersectSpheresSIMD-16                  	65791402	        17.81 ns/op
Benchmark16IntersectSpheresSIMD-16                  	65042857	        17.85 ns/op
Benchmark16IntersectSpheresSIMD-16                  	56433920	        18.55 ns/op

The result averages out to about ~18ns per iteration. Somewhat surprising since ~18ns was what we got when we removed the earlier “find index of bit set in mask”-code. However, we also improved the overall algorithmic efficiency somewhat by only figuring out the index of the closest intersected sphere after all batches using scalar bitwise operations.

4.1 Full image render

Looking at a full 640x480 render of the 16-sphere image, we also see a (albeit smaller) improvement:

Plain Go:   164.483787ms
SIMD:       38.361375ms
SIMD new:   35.160145ms

As stated before, I’ve only applied the index-finding optimization to the sphere intersection code. The code that intersects planes and that casts shadow rays hasn’t been optimized.

4.2 Scaling the number of spheres

As a little bonus, I went a bit bananas and rendered the image seen below that draws 512 smaller spheres instead of just 16:

512 spheres

If tested in isolation, the IntersectSpheres functions (SIMD/non-SIMD) taking 512 spheres (in 64 batches):

Benchmark512IntersectSpheres-16        	  627394	      2596 ns/op
Benchmark512IntersectSpheresSIMD-16    	 5630026	       214.4 ns/op

Isolated, an approximate 12x speed-up. Rendering the full image produces somewhat surprising results, with the old pre-SIMD renderer needing 12.5 seconds to render the image in 1024*768, while the optimized SIMD renderer needs 250.65 milliseconds, more than 50 times faster!

Seems this particular code scales better the more spheres there are to intersect, though there may be some apples-and-oranges in play here given some SIMD optimizations being present for plane intersections and that I turned off shadow rays since the image looked quite crap with them enabled.

5. Final words

This was something of a “bonus episode” I wrote while on a bus trip, based on some previous experimentation with ToBits() I made when Go 1.26RC2 dropped a few weeks ago.

I still want to spend some time adding support for things like 3D models and SIMD-accelerated BVH trees. BVH trees having 8 children per node might be an interesting challenge to do using 8 SIMD lanes. Path tracing, soft shadows, textures, live rendering with GUI, transparency and reflections. We’ll see.

Thanks for reading,

Until next time!

Tack för att du läser Callistas blogg.
Hjälp oss att nå ut med information genom att dela nyheter och artiklar i ditt nätverk.

Kommentarer