Skip to content
Snippets Groups Projects
Commit ffcab467 authored by Gregor Kobsik's avatar Gregor Kobsik
Browse files

updated sliding window transformer

- compatible with pytorch 1.13
parent c9a3c74e
No related branches found
No related tags found
No related merge requests found
......@@ -101,36 +101,85 @@ class SlidingWindowEncoderLayer(nn.TransformerEncoderLayer):
activation=activation,
)
self.window_size = window_size
self.self_attn = LocalAttention(
self.self_attn = _LocalAttention(
dim=d_model,
window_size=window_size,
causal=True,
look_backward=1,
look_forward=0,
dropout=dropout,
autopad=True,
exact_windowsize=True,
batch_first=False,
)
def forward(self, src, src_mask, src_key_padding_mask=None):
r"""Pass the input through the encoder layer.
# def forward(self, src, src_mask, src_key_padding_mask=None):
# r"""Pass the input through the encoder layer.
Args:
src: the sequence to the encoder layer (required).
src_mask: the mask for the src sequence (unused).
src_key_padding_mask: the mask for the src keys per batch (unused).
# Args:
# src: the sequence to the encoder layer (required).
# src_mask: the mask for the src sequence (unused).
# src_key_padding_mask: the mask for the src keys per batch (unused).
Shape:
see the docs in Transformer class.
"""
src = src.transpose(0, 1)
b, t, e = src.shape
reminder = (self.window_size - t % self.window_size) % self.window_size
src = torch.cat([src, torch.zeros(b, reminder, e, device=src.device)], dim=1)
src2 = self.self_attn(src, src, src, input_mask=None)
src = src + self.dropout1(src2)
src = self.norm1(src)
src2 = self.linear2(self.dropout(self.activation(self.linear1(src))))
src = src + self.dropout2(src2)
src = self.norm2(src)
return src.transpose(0, 1)[:t]
# Shape:
# see the docs in Transformer class.
# """
# src = src.transpose(0, 1)
# b, t, e = src.shape
# reminder = (self.window_size - t % self.window_size) % self.window_size
# src = torch.cat([src, torch.zeros(b, reminder, e, device=src.device)], dim=1)
# src2 = self.self_attn(src, src, src, input_mask=None)
# src = src + self.dropout1(src2)
# src = self.norm1(src)
# src2 = self.linear2(self.dropout(self.activation(self.linear1(src))))
# src = src + self.dropout2(src2)
# src = self.norm2(src)
# return src.transpose(0, 1)[:t]
class _LocalAttention(LocalAttention):
def __init__(
self,
window_size,
causal=False,
look_backward=1,
look_forward=None,
dropout=0.,
shared_qk=False,
rel_pos_emb_config=None,
dim=None,
autopad=False,
exact_windowsize=False,
batch_first=False,
):
super().__init__(
window_size,
causal=causal,
look_backward=look_backward,
look_forward=look_forward,
dropout=dropout,
shared_qk=shared_qk,
rel_pos_emb_config=rel_pos_emb_config,
dim=dim,
autopad=autopad,
exact_windowsize=exact_windowsize,
)
self.batch_first = batch_first
self._qkv_same_embed_dim = True
self.num_heads = 1
def forward(
self, query, key, value, key_padding_mask=None, need_weights=True, attn_mask=None, average_attn_weights=True
):
if not self.batch_first:
query = query.transpose(0, 1)
key = key.transpose(0, 1)
value = value.transpose(0, 1)
out = super().forward(query, key, value)
if not self.batch_first:
out = out.transpose(0, 1)
return out
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment