Published on

Fixing llama 2

Authors

So llama 2 just released. It looks really good. Separately, I just got a new M2 mac to replace my intel mac. I'm quite glad to try it out, but it doesn't seem to work! There's quite a few issues summed up in this repository. The repository owner fixed a few of the bugs, but the current state leaves the chat and completion examples yielding question marks.

I want to see if I can fix this. I already have some experience dealing with crappy ML frameworks when I ported Stable Diffusion to iOS ahead of the official port1. First, the "size too large" console output for the transpose ANE operation seems like a prime target.

error: 'anec.transpose' op Height/Width dimensions must be less than 16384

However, looking at the inputs and outputs of the ColumnParallelLinear layer which triggers this operation, we find that the input was corrupted by NaN before this step. Lacking any other logs to go off of, we plunge down the rabbit hole of going back up, function by function, to find the first place where NaNs were introduced.

This brings us to the forward method in the transformer module.

# ...
mask = None
if seqlen > 1:
    mask = torch.full(
        (1, 1, seqlen, seqlen), float("-inf"), device=tokens.device
    )
    mask = torch.triu(mask, diagonal=start_pos + 1).type_as(h)

for layer in self.layers:
    h = layer(h, start_pos, freqs_cis, mask)
# ...

A quick glance at the triu documentation reveals that it returns a tensor with the upper triangular part of the input matrix, and the rest filled with zeros. This is then multiplied by negative infinity, and added to the input.

Unfortunately, a glance at our debugger shows that's not the case! The bottom triangle is, in fact, filled with NaNs. Some googling about this reveals a github issue about this behavior that's designated 'low priority' until our poor souls are forced to deal with it.

I don't really know pytorch that well, so our solution will be "don't do that" - just take all the NaNs and replace them with zeros.

# ...
mask = None
if seqlen > 1:
    mask = torch.full(
        (1, 1, seqlen, seqlen), float("-inf"), device=tokens.device
    )
    mask = torch.triu(mask, diagonal=start_pos + 1).type_as(h)
    mask = mask.masked_fill(torch.isnan(mask), 0.0)

for layer in self.layers:
    h = layer(h, start_pos, freqs_cis, mask)
# ...

And this works! The test output now shows the correct output, and the completion example works as well.

Footnotes

  1. The short project that inspired the creation of this blog, even if it didn't end up making it on. Perhaps one day it will!