Skip to content
Snippets Groups Projects
Commit fa8bbbe1 authored by Moritz Ibing's avatar Moritz Ibing
Browse files

additional changes for latent filtering

parent 4db61e52
No related branches found
No related tags found
No related merge requests found
import torch
class CheckSequenceLenghtTransform():
# TODO: make this maps actually properties of the embedding class or decouple them from this module
......@@ -66,7 +69,24 @@ class CheckSequenceLenghtTransform():
conv_fac = self.convolution_factor[i]
dep_level = i + 1 - sub_diff
sum_sequence_length += len(dep[dep == dep_level]) // conv_fac
if sub_diff == 0:
num_vectors = torch.sum(torch.from_numpy(dep) == dep_level) // conv_fac
elif sub_diff == 1:
val_1 = torch.from_numpy(val)[torch.from_numpy(dep) == (dep_level - 1)]
num_vectors = (val_1.view(-1, conv_fac) == 2).max(dim=-1)[0].sum()
elif sub_diff == 2:
val_1 = torch.from_numpy(val)[torch.from_numpy(dep) == (dep_level - 1)]
val_2 = torch.from_numpy(val)[torch.from_numpy(dep) == (dep_level - 2)]
mask_1 = (val_1.view(-1, 8) == 2).max(dim=-1)[0]
mask_2 = torch.zeros_like(val_2, dtype=torch.bool)
mask_2[val_2 == 2] = mask_1
mask_2 = mask_2.view(-1, conv_fac).max(dim=-1)[0]
num_vectors = mask_2.sum()
else:
print("ERROR: substitution factors bigger than 2 are not implemented")
return None
sum_sequence_length += num_vectors
if sum_sequence_length > self.num_positions:
return None
......
......@@ -97,9 +97,6 @@ class DoubleSubstitutionEmbedding(nn.Module):
dep_0[i, :len_0[i]] = depth[i, len_2[i] + len_1[i]:len_2[i] + len_1[i] + len_0[i]]
pos_0[i, :len_0[i]] = position[i, len_2[i] + len_1[i]:len_2[i] + len_1[i] + len_0[i]]
# precompute padding mask
self.mask = padding_mask(val_2[:, ::self.conv_size], device=value.device) # [N, S'_2, E]
# convolute embedded tokens of last layer
y_0 = self.convolution_0(x_0) # [N, S'_0, E // 4]
# substitite all mixed token embeddings of second-last layer, with token embeddings of last layer
......@@ -121,8 +118,13 @@ class DoubleSubstitutionEmbedding(nn.Module):
len_out = torch.max(torch.sum(mask_2, dim=-1)).item()
x_masked = torch.zeros(batch_size, len_out, embedding.shape[2], dtype=torch.float, device=value.device)
val_masked = torch.zeros((batch_size, len_out), dtype=torch.long, device=value.device)
for i in range(batch_size):
x_masked[i] = x_out[i, mask_2[i].nonzero().squeeze(-1)]
val_masked[i] = val_2[:, ::self.conv_size][i, mask_2[i].nonzero().squeeze(-1)]
# precompute padding mask
self.mask = padding_mask(val_masked, device=value.device) # [N, S'_2, E]
return x_masked
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment