NHacker Next
login
▲How to tile matrix multiplication (2023)alvinwan.com
59 points by pbd 4 days ago | 4 comments
Loading comments...
epistasis 37 minutes ago [-]
When thinking about block matrix multiplication, it's always a fun time to revisit Strassen's algorithm, which is less than O(n^3).

Normal block multiplication works like:

    [ A11  A12 ] [ B11  B12 ] = [ A11*B11 + A12*B21  A11*B12 + A12*B22 ] = [ C11  C12 ]
    [ A21  A22 ] [ B21  B22 ]   [ A21*B11 + A22*B21  A21*B12 + A22*B22 ] = [ C21  C22 ] 
Which takes 8 matrix multiplications on the sub blocks. But by cleverly defining only 7 different matrix multiplications on top of block additions and subtractions, like:

    M3 = A11 * (B12 - B22)
You can make the C blocks out of just additions and subtractions of the 7 different matrix multiplications.

https://en.wikipedia.org/wiki/Strassen_algorithm

As far as I know this is not useful in the major GPU libraries for saving bandwidth, but I have never bothered to spend the time to figure out why. It must have something to do with the ratio of bandwidth to FLOPs, which is way past my knowledge of GPUs.

adgjlsfhk1 25 minutes ago [-]
The tricky parts with Strassen are that it requires some fairly large changes to your looping strategy, and that it decreases accuracy, It also only helps once you are compute rather than bandwidth bound, and GPUs have lots of compute.
GolDDranks 3 hours ago [-]
There is something off with the explanation.

At first, there is 16 fetches per row x column, 1024 in total. Then, it is observed that an input row needs to be fetched only once per output row, reducing the amount to 8 fetches per row, plus 8 per row x column, 8 * 8 + 8 * 64 = 576 in total. This requires the same amount of 16 numbers to be kept in registers.

But then it is claimed that by doing one quadrant at a time, all that is needed is 64 fetches per quadrant or 256 fetches in total. But that assumes we can keep 4 rows and 4 columns, 8 numbers per row or column = 64 numbers in registers! If we can only keep 16 numbers like above, each row of the quadrant is going to take 40 fetches, and we get 160 fetches per quadrant or 640 fetches in total, a pessimization from 576 fetches!

slwvx 7 hours ago [-]
See https://en.wikipedia.org/wiki/Block_matrix#Multiplication