Skip to content
Snippets Groups Projects
Commit 15a2ff10 authored by Moritz Ibing's avatar Moritz Ibing
Browse files

quick fix for positional encoding in linear head, as it requires a different dimension

parent dc1565de
No related branches found
No related tags found
No related merge requests found
......@@ -26,7 +26,11 @@ class LinearHead(nn.Module):
else:
self.linear = Linear(embed_dim, num_vocab)
self.spatial_encoding = spatial_encoding
# quick fix as this is the only place, where the output should be embed_dim instead of head_dim
self.spatial_encoding = nn.Sequential(
spatial_encoding,
nn.Linear(head_dim, embed_dim)
)
def forward(self, x, value, depth, pos):
""" Transforms the output of the transformer target value logits.
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment