Skip to content
Snippets Groups Projects
Commit 1ab9c9c9 authored by Moritz Ibing's avatar Moritz Ibing
Browse files

Added autoregressive heads and sampling

parent 8f14f930
Branches
No related tags found
No related merge requests found
Showing with 631 additions and 24 deletions
......@@ -2,10 +2,10 @@ import torch
import torch.nn as nn
from torch.nn.utils.rnn import pad_sequence
from .convolution_head_A import ConvolutionHeadA, ConvolutionHeadAutoregressive
from .double_substitution_head import DoubleSubstitutionHead, DoubleSubstitutionHeadAutoRegressive
from .linear_head import LinearHead
from .convolution_head_A import ConvolutionHeadA
from .substitution_head import SubstitutionHead
from .double_substitution_head import DoubleSubstitutionHead
from .substitution_head import SubstitutionHead, SubstitutionHeadAutoregressive
class CompositeHeadA(nn.Module):
......@@ -151,3 +151,58 @@ class CompositeHeadA(nn.Module):
# pad embedding sequence
return pad_sequence(out, batch_first=True, padding_value=0.0)
class CompositeHeadAutoregressiveA(CompositeHeadA):
def __init__(self, spatial_encoding, num_vocab, embed_dim, resolution, spatial_dim, **_):
""" Performs a transformation from transformer latent space into target value logits.
Uses a different heads for each depth layer, possibly increasing the overall sequence lenght.
Note: The token value '0' is reserved as a padding value, which does not propagate gradients.
Args:
num_vocab: Number of different target token values (exclusive padding token '0').
embded_dim: Dimension of the latent embedding space of the transformer.
resolution: Spatial resolution of sequence encoding.
spatial_dim: Spatial dimension (2D/3D) of the sequence data.
"""
super(CompositeHeadAutoregressiveA, self).__init__()
kwargs = {
"spatial_encoding": spatial_encoding,
"num_vocab": num_vocab,
"embed_dim": embed_dim,
"spatial_dim": spatial_dim,
}
modules = []
if resolution >= 2:
modules += [LinearHead(**kwargs)]
if resolution >= 4:
modules += [LinearHead(**kwargs)]
if resolution >= 8:
modules += [LinearHead(**kwargs)]
if resolution >= 16:
modules += [ConvolutionHeadAutoregressive(**kwargs, conv_size=2 ** (spatial_dim - 1))]
if resolution >= 32:
modules += [ConvolutionHeadAutoregressive(**kwargs, conv_size=2 ** spatial_dim)]
if resolution >= 64:
modules += [SubstitutionHeadAutoregressive(**kwargs, conv_size=2 ** spatial_dim)]
if resolution >= 128:
modules += [DoubleSubstitutionHeadAutoRegressive(**kwargs, conv_size=2 ** spatial_dim)]
if resolution >= 256:
modules += [DoubleSubstitutionHeadAutoRegressive(**kwargs, conv_size=2 ** spatial_dim)]
# embeddings
self.heads = nn.ModuleList(modules)
self.reduction_factor = {
1: 1,
2: 1,
3: 1,
4: 2 ** (spatial_dim - 1),
5: 2 ** spatial_dim,
6: 2 ** spatial_dim, # Note: 'substitution'
7: 2 ** spatial_dim, # Note: 'double_substitution'
8: 2 ** spatial_dim, # Note: 'double_substitution'
}
import torch.nn as nn
from ..utils import Deconvolution, Linear
from ..utils import Deconvolution, BlockConvolution, Linear
class ConvolutionHeadA(nn.Module):
......@@ -42,3 +42,51 @@ class ConvolutionHeadA(nn.Module):
# compute logits for each token
return self.linear(x) # [N, T, V]
class ConvolutionHeadAutoregressive(nn.Module):
def __init__(self, spatial_encoding, num_vocab, embed_dim, spatial_dim, conv_size, **_):
""" Performs a convolutional transformation from transformer latent space into target value logits.
Note: The token value '0' is reserved as a padding value, which does not propagate gradients.
Args:
num_vocab: Number of different target token values (exclusive padding token '0').
embded_dim: Dimension of the latent embedding space of the transformer.
spatial_dim: Spatial dimension (2D/3D) of the sequence data.
conv_size: Convolution kernel size and stride.
"""
super(ConvolutionHeadAutoregressive, self).__init__()
self.conv_size = conv_size
self.deconvolution = Deconvolution(embed_dim, embed_dim, conv_size)
self.convolution = BlockConvolution(embed_dim, embed_dim, conv_size)
self.linear = Linear(embed_dim, num_vocab)
self.spatial_encoding = spatial_encoding
self.value_embedding = nn.Embedding(num_vocab + 1, embed_dim, padding_idx=0)
def forward(self, x, value, depth, pos):
""" Transforms the output of the transformer target value logits.
Args:
x: Output of the transformer, the latent vector [N, T', E].
value: Target value token sequence [N, T].
depth: Target depth token sequence [N, T].
pos: Target position token sequence [N, T, A].
Return
Logits of target value sequence.
"""
# deconvolute the latent space - create new tokens
x = self.deconvolution(x) # [N, T, E]
emb = self.value_embedding(value)
# add spatial decoding if available
if self.spatial_encoding is not None:
emb = emb + self.spatial_encoding(pos)
emb = self.convolution(emb)
x = x + emb
# compute logits for each token
return self.linear(x) # [N, T, V]
import torch
import torch.nn as nn
from ..utils import Deconvolution, Linear
from ..utils import Deconvolution, BlockConvolution, Linear
class DoubleSubstitutionHead(nn.Module):
......@@ -92,3 +92,120 @@ class DoubleSubstitutionHead(nn.Module):
# compute logits of generated tokens
return self.linear(y_0) # [N, T, V]
class DoubleSubstitutionHeadAutoRegressive(nn.Module):
def __init__(self, spatial_encoding, num_vocab, embed_dim, spatial_dim, conv_size, **_):
""" Performs a twice a substitution transformation from transformer latent space into target value logits.
Note: The token value '0' is reserved as a padding value, which does not propagate gradients.
Args:
num_vocab: Number of different target token values (exclusive padding token '0').
embded_dim: Dimension of the latent embedding space of the transformer.
spatial_dim: Spatial dimension (2D/3D) of the sequence data.
conv_size: Convolution kernel size and stride.
"""
super(DoubleSubstitutionHeadAutoRegressive, self).__init__()
self.embed_dim = embed_dim
# deconvolutions
self.deconvolution_2 = Deconvolution(embed_dim, embed_dim, conv_size)
self.deconvolution_1 = Deconvolution(embed_dim, embed_dim, conv_size)
self.deconvolution_0 = Deconvolution(embed_dim, embed_dim, conv_size)
self.convolution_2 = BlockConvolution(embed_dim, embed_dim, conv_size)
self.convolution_1 = BlockConvolution(embed_dim, embed_dim, conv_size)
self.convolution_0 = BlockConvolution(embed_dim, embed_dim, conv_size)
self.spatial_encoding = spatial_encoding
# head
self.linear = Linear(embed_dim, num_vocab)
def forward(self, x, value, depth, pos):
""" Transforms the output of the transformer target value logits.
Transforms one token of the latent vector into multiple tokens of the target vector through de-convolutional
operations. In the case of a quadtree one token is responsible for up to 16 target tokens. In the case of a
octree one token is responsible for up to 64 target tokens. Only tokens, which correspond to a mixed target
value token in the penultimate layer are transformed into target sequence tokens.
Args:
x: Output of the transformer, the latent vector [N, T'', E].
value: Value token sequence, with penultimate and last layer.
depth: Depth token sequence, with penultimate and last layer.
pos: Position token sequence, with penultimate and last layer.
Return
Logits of target value sequence.
"""
batch_size = value.shape[0]
max_depth = torch.max(depth)
len_0 = torch.sum(depth == (max_depth), dim=1)
len_1 = torch.sum(depth == (max_depth - 1), dim=1)
len_2 = torch.sum(depth == (max_depth - 2), dim=1)
# create intermediate list to hold values
val_1 = torch.zeros((batch_size, torch.max(len_1)), device=value.device)
val_2 = torch.zeros((batch_size, torch.max(len_2)), device=value.device)
# split input in second-last (1) layer
for i in range(batch_size):
val_2[i, :len_2[i]] = value[i, :len_2[i]]
val_1[i, :len_1[i]] = value[i, len_2[i]:len_2[i] + len_1[i]]
# compute the number of mixed tokens in mask
mix_1 = torch.sum(val_1 == 2, dim=1)
mix_2 = torch.sum(val_2 == 2, dim=1)
assert ((depth[:, -len_0:] == max_depth).all())
emb_0 = self.value_embedding(value[:, -len_0:])
# add spatial decoding if available
if self.spatial_encoding is not None:
emb_0 = emb_0 + self.spatial_encoding(pos[:, -len_0:])
emb_0 = self.convolution_0(emb_0)
emb_1 = torch.zeros((batch_size, torch.max(len_1), self.embed_dim), dtype=torch.float, device=value.device)
# substitute all mixed token embeddings of penultimate layer, with token embeddings of last layer
emb_1[val_1 == 2] = emb_0[:, (self.conv_size - 1)::self.conv_size] # [N, T1, C]
emb_1 = self.convolution_1(emb_1)
emb_2 = torch.zeros((batch_size, torch.max(len_2), self.embed_dim), dtype=torch.float, device=value.device)
# substitute all mixed token embeddings of third to last layer, with token embeddings of penultimate layer
emb_2[val_2 == 2] = emb_1[:, (self.conv_size - 1)::self.conv_size] # [N, T1, C]
emb_2 = self.convolution_2(emb_2)
# create intermediate list to hold vectors
x_0 = torch.zeros((batch_size, torch.max(mix_1), self.embed_dim), device=value.device)
x_1 = torch.zeros((batch_size, torch.max(mix_2), self.embed_dim), device=value.device)
# deconvolute the latent space - sequence length equals number of tokens in the penultimate layer
y_2 = self.deconvolution_2(x)
y_2 = y_2 + emb_2[:, :y_2.shape[1]]
# select only latent vectors, which correspond to mixed tokens in third-last layer
for i in range(batch_size):
mix_2_mask_i = (val_2[i] == 2)[:len(y_2[i])] # handle overflow/clipped values in the embedding
x_1[i, :torch.sum(mix_2_mask_i)] = y_2[i, mix_2_mask_i] # [N, T', C]
# deconvolute the latent space - sequence length equals number of tokens in the penultimate layer
y_1 = self.deconvolution_1(x_1)
y_1 = y_1 + emb_1[:, :y_1.shape[1]]
# select only latent vectors, which correspond to mixed tokens in third-last layer
for i in range(batch_size):
mix_1_mask_i = (val_1[i] == 2)[:len(y_1[i])] # handle overflow/clipped values in the embedding
x_0[i, :torch.sum(mix_1_mask_i)] = y_1[i, mix_1_mask_i] # [N, T', C]
# deconvolute the intermediate latent space - create new tokens in latent space for each mixed token
y_0 = self.deconvolution_0(x_0) # [N, T, C]
y_0 = y_0 + emb_0[:, :y_0.shape[1]] # [N, T, C]
# add spatial decoding if available
if self.spatial_encoding is not None:
len_last = torch.sum(depth == max_depth, dim=1)
assert((depth[:, -len_last:] == max_depth).all())
y_0 = y_0 + self.spatial_encoding(pos[:, -len_last:])
# compute logits of generated tokens
return self.linear(y_0) # [N, T, V]
......@@ -2,7 +2,7 @@ import torch.nn as nn
from modules.utils import PositionalEncodingLearned, PositionalEncodingLearnedLookAhead, \
PositionalEncodingLearnedLookAheadSplit
from .composite_head_A import CompositeHeadA
from .composite_head_A import CompositeHeadA, CompositeHeadAutoregressiveA
from .composite_head_B import CompositeHeadB
from .composite_head_C import CompositeHeadC
from .convolution_head_A import ConvolutionHeadA
......@@ -67,6 +67,8 @@ def _create_head(name, positional_encoding, num_vocab, embed_dim, resolution, sp
return DoubleSubstitutionHead(**kwargs)
elif name in ('composite', 'composite_A'):
return CompositeHeadA(**kwargs)
elif name in ('composite_autoregressive_A'):
return CompositeHeadAutoregressiveA(**kwargs)
elif name in ('composite_B'):
return CompositeHeadB(**kwargs)
elif name in ('composite_C'):
......
import torch
import torch.nn as nn
from ..utils import Deconvolution, Linear
from ..utils import BlockConvolution, Deconvolution, Linear
class SubstitutionHead(nn.Module):
......@@ -76,3 +76,97 @@ class SubstitutionHead(nn.Module):
# compute logits of generated tokens
return self.linear(y_0) # [N, T, V]
class SubstitutionHeadAutoregressive(nn.Module):
def __init__(self, spatial_encoding, num_vocab, embed_dim, spatial_dim, conv_size, **_):
""" Performs a substitution transformation from transformer latent space into target value logits.
Note: The token value '0' is reserved as a padding value, which does not propagate gradients.
Args:
num_vocab: Number of different target token values (exclusive padding token '0').
embded_dim: Dimension of the latent embedding space of the transformer.
spatial_dim: Spatial dimension (2D/3D) of the sequence data.
conv_size: Convolution kernel size and stride.
"""
super(SubstitutionHeadAutoregressive, self).__init__()
self.embed_dim = embed_dim
self.conv_size = conv_size
self.deconvolution_1 = Deconvolution(embed_dim, embed_dim, conv_size)
self.deconvolution_0 = Deconvolution(embed_dim, embed_dim, conv_size)
self.convolution_1 = BlockConvolution(embed_dim, embed_dim, conv_size)
self.convolution_0 = BlockConvolution(embed_dim, embed_dim, conv_size)
self.linear = Linear(embed_dim, num_vocab)
self.spatial_encoding = spatial_encoding
self.value_embedding = nn.Embedding(num_vocab + 1, embed_dim, padding_idx=0)
def forward(self, x, value, depth, pos):
""" Transforms the output of the transformer target value logits.
Transforms one token of the latent vector into multiple tokens of the target vector through de-convolutional
operations. In the case of a quadtree one token is responsible for up to 16 target tokens. In the case of a
octree one token is responsible for up to 64 target tokens. Only tokens, which correspond to a mixed target
value token in the penultimate layer are transformed into target sequence tokens.
Args:
x: Output of the transformer, the latent vector [N, T'', E].
value: Value token sequence, with penultimate and last layer.
depth: Depth token sequence, with penultimate and last layer.
pos: Position token sequence, with penultimate and last layer.
Return
Logits of target value sequence.
"""
batch_size = value.shape[0]
max_depth = torch.max(depth)
len_0 = torch.sum(depth == max_depth, dim=1)
len_1 = torch.sum(depth == (max_depth - 1), dim=1)
# create intermediate list to hold values
val_1 = torch.zeros((batch_size, torch.max(len_1)), device=value.device)
# split input in second-last (1) layer
for i in range(batch_size):
val_1[i, :len_1[i]] = value[i, :len_1[i]]
# compute the number of mixed tokens in mask
mix_1 = torch.sum(val_1 == 2, dim=1)
# create intermediate list to hold vectors
assert ((depth[:, -len_0:] == max_depth).all())
emb_0 = self.value_embedding(value[:, -len_0:])
# add spatial decoding if available
if self.spatial_encoding is not None:
emb_0 = emb_0 + self.spatial_encoding(pos[:, -len_0:])
emb_0 = self.convolution_0(emb_0)
emb_1 = torch.zeros((batch_size, torch.max(len_1), self.embed_dim), dtype=torch.float, device=value.device)
# substitite all mixed token embeddings of penultimate layer, with token embeddings of last layer
emb_1[val_1 == 2] = emb_0[:, (self.conv_size - 1)::self.conv_size] # [N, T1, C]
emb_1 = self.convolution_1(emb_1)
x_0 = torch.zeros((batch_size, torch.max(mix_1), self.embed_dim), device=value.device)
# assert(y_1.shape == emb_1.shape)
# deconvolute the latent space - sequence length equals number of tokens in the penultimate layer
y_1 = self.deconvolution_1(x)
y_1 = y_1 + emb_1[:, :y_1.shape[1]]
# select only latent vectors, which correspond to mixed tokens in the penultimate layer
for i in range(batch_size):
mix_1_mask_i = (val_1[i] == 2)[:len(y_1[i])] # handle overflow/clipped values in the embedding
x_0[i, :torch.sum(mix_1_mask_i)] = y_1[i, mix_1_mask_i] # [N, T', C]
# deconvolute the intermediate latent space - create new tokens in latent space for each mixed token
# assert(y_0.shape == emb_0.shape)
y_0 = self.deconvolution_0(x_0)
y_0 = y_0 + emb_0[:, :y_0.shape[1]] # [N, T, C]
# compute logits of generated tokens
return self.linear(y_0) # [N, T, V]
from .convolution import Convolution
from .block_convolution import BlockConvolution
from .deconvolution import Deconvolution
from .embedding import Embedding, PositionalEncodingLearned, PositionalEncodingLearnedLookAhead, \
PositionalEncodingLearnedLookAheadSplit
......@@ -11,5 +12,6 @@ __all__ = [
"PositionalEncodingLearnedLookAheadSplit",
"Linear",
"Convolution",
"BlockConvolution",
"Deconvolution",
]
import torch
import torch.nn as nn
class BlockConvolution(nn.Module):
def __init__(self, source_dim, target_dim, block_size):
""" Performs masked blockwise convolution on an input sequence.
The mask is always an upper right triangle matrix with zeros on the diagonal.
Args:
source_dim: Defines the embedding dimension of the input sequence.
target_dim: Defines the embedding dimension of the output sequence.
block_size: Defines the size of the block over which we convolute.
"""
super(BlockConvolution, self).__init__()
self.block_size = block_size
self.convolutions = nn.ModuleList([
nn.Conv1d(source_dim, target_dim, (i + 1,), block_size, bias=True) for i in range(block_size-1)
])
def forward(self, seq_vector):
""" Convolute tokens to reduce sequence length
Args:
seq_vector: Sequence vector with elements of the shape [N, S, E].
Return:
Sequence vector with the same length and target embedding dimension [N, S, E']
"""
out = torch.zeros_like(seq_vector)
for i, conv in enumerate(self.convolutions):
out[:, 1 + i::self.block_size] = conv(seq_vector.transpose(1, 2)).transpose(1, 2)
return out
......@@ -63,3 +63,66 @@ class BasicGenerator:
val[-1][token_idx + i] = torch.multinomial(probs[i], num_samples=1)[0]
return val[-1]
class BasicGeneratorAutoRegressive:
def __init__(self, compute_logits_fn, num_tokens=1, **_):
""" Create token generator instance which samples 'num_tokens' in one pass.
Args:
compute_logits_fn: Pointer to function, which computes logits of given sequence.
num_tokens: Defines the number of sampled tokens in each step.
"""
self.compute_logits = compute_logits_fn
self.num_tokens = num_tokens
def __call__(self, val, dep, pos, memory=None, idx=0, temperature=1.0, slice_sequence=True, cls=None, **_):
""" Sample autoregressively current value token sequence and return updated value sequence.
Args:
val: Value token sequence of current layer.
dep: Depth token sequence of current layer.
pos: Position token sequence of current layer.
memory: Latent sequence vector of the previous layer.
idx: Currently sampled transformer layer index.
temperature: Defines the randomness of the samples.
cls: class label for conditional generation.
Return:
Sampled token sequence with values of the current layer.
"""
# compute indices
start_idx = 0
stop_idx = len(val[-1])
sampled_idx = len(torch.cat(val[:-1])) if len(val) > 1 else 0
# sample tokens autoregressively
for token_idx in trange(start_idx, stop_idx, self.num_tokens, leave=False, desc="Tokens"):
for block_idx in range(self.num_tokens):
# concat layers and slice sequence for speed_up
seq = (
torch.cat(val)[:sampled_idx + token_idx + self.num_tokens].unsqueeze(0),
torch.cat(dep)[:sampled_idx + token_idx + self.num_tokens].unsqueeze(0),
torch.cat(pos)[:sampled_idx + token_idx + self.num_tokens].unsqueeze(0),
)
logits = self.compute_logits(seq, memory, idx, cls)[0]
# retrieve only logits for for current index
sampled_token_logits = logits[
sampled_idx + token_idx + block_idx:sampled_idx + token_idx + block_idx + 1]
# check transformer token capacity
if len(sampled_token_logits) == 0:
return val[-1][:token_idx] # reached maximum number of tokens
# compute token probabilities from logits
probs = torch.nn.functional.softmax(sampled_token_logits / temperature, dim=-1) # [t, V]
probs[:, 0] = 0 # 'padding' token
assert(len(probs) == 1)
# sample next sequence token
val[-1][token_idx + block_idx] = torch.multinomial(probs[0], num_samples=1)[0]
return val[-1]
import torch
from .basic_generator import BasicGenerator
from .substitution_generator import SubstitutionGenerator
from .double_substitution_generator import DoubleSubstitutionGenerator
from .basic_generator import BasicGenerator, BasicGeneratorAutoRegressive
from .double_substitution_generator import DoubleSubstitutionGenerator, DoubleSubstitutionGeneratorAutoregressive
from .substitution_generator import SubstitutionGenerator, SubstitutionGeneratorAutoregressive
class CompositeGenerator():
class CompositeGenerator:
def __init__(self, compute_logits_fn, num_tokens=[1], **_):
""" Create token generator instance for a 'basic' head.
......@@ -44,3 +44,44 @@ class CompositeGenerator():
generator = DoubleSubstitutionGenerator(self.compute_logits_fn, num_tokens)
# sample a single layer
return generator(val, dep, pos, memory, layer_idx, temperature, cls=cls)
class CompositeGeneratorAutoregressive:
def __init__(self, compute_logits_fn, num_tokens=[1], **_):
""" Create token generator instance for a 'basic' head.
Args:
compute_logits_fn: Pointer to function, which computes logits of given sequence.
num_tokens: Defines the number of sampled tokens in each step.
"""
self.compute_logits_fn = compute_logits_fn
self.num_tokens_list = num_tokens
def __call__(self, val, dep, pos, memory=None, layer_idx=0, temperature=1.0, cls=None, **_):
""" Sample autoregressively current value token sequence and return sampled value sequence.
Args:
val: Value token sequence of previous and current layers as a list.
dep: Depth token sequence of previous and current layers as a list.
pos: Position token sequence of previous and current layers as a list.
memory: Latent sequence vector of the previous layer.
layer_idx: Currently sampled layer index.
temperature: Defines the randomness of the samples.
cls: class label for conditional generation.
Return:
Sampled token sequence with values of the current layer.
"""
# get the currently sampled depth
cur_depth = torch.max(dep[-1])
# get number of sampled tokens accordingly to depth
num_tokens = self.num_tokens_list[cur_depth - 1]
# create a generator according to layer depth
if cur_depth < 6:
generator = BasicGeneratorAutoRegressive(self.compute_logits_fn, num_tokens)
elif cur_depth == 6: # 'substitution'
generator = SubstitutionGeneratorAutoregressive(self.compute_logits_fn, num_tokens)
else: # 'double_substitution'
generator = DoubleSubstitutionGeneratorAutoregressive(self.compute_logits_fn, num_tokens)
# sample a single layer
return generator(val, dep, pos, memory, layer_idx, temperature, cls=cls)
......@@ -3,7 +3,7 @@ import torch
from tqdm.auto import trange
class DoubleSubstitutionGenerator():
class DoubleSubstitutionGenerator:
def __init__(self, compute_logits_fn, num_tokens=8, **_):
""" Create token generator instance which samples 'num_tokens' in one pass.
......@@ -79,3 +79,83 @@ class DoubleSubstitutionGenerator():
token_idx += num_sampled
return val[-1]
class DoubleSubstitutionGeneratorAutoregressive:
def __init__(self, compute_logits_fn, num_tokens=8, **_):
""" Create token generator instance which samples 'num_tokens' in one pass.
Args:
compute_logits_fn: Pointer to function, which computes logits of given sequence.
num_tokens: Defines the number of sampled tokens in each step.
"""
self.compute_logits = compute_logits_fn
self.num_tokens = num_tokens
self.kernel_size = num_tokens
def __call__(self, val, dep, pos, memory=None, idx=0, temperature=1.0, cls=None, **_):
""" Sample autoregressively current value token sequence and return updated value sequence.
Note: Needs at least, the third-, second- and last layer sequence.
Args:
val: Array of value token sequence layers in ascending order.
dep: Array of depth token sequence layers in ascending order.
pos: Array of position token sequence layers in ascending order.
memory: Latent sequence vector of the previous layer.
idx: Currently sampled transformer layer index.
temperature: Defines the randomness of the samples.
cls: class label for conditional generation.
Return:
Sampled token sequence with values of the current layer.
"""
# compute indices
token_idx = 0
start_idx = 0
second_last_idx = 0
stop_idx = len(val[-3])
# hack to distinguish between 'encoder_only' and 'encoder_multi_decoder'
sampled_idx = len(torch.cat(val[:-1])) if len(val) > 3 else 0
# sample tokens autoregressively
for third_last_idx in trange(start_idx, stop_idx, self.kernel_size, leave=False, desc="Tokens"):
# compute number of mixed tokens in third last layer
num_third_last = torch.sum(val[-3][third_last_idx:third_last_idx + self.kernel_size] == 2)
if num_third_last == 0:
continue # skip, if no tokens will be sampled - speed up
# compute number of mixed token in second last layer
num_second_last = torch.sum(
val[-2][second_last_idx:second_last_idx + self.kernel_size * num_third_last] == 2
)
if num_second_last == 0:
continue # skip, if no tokens will be sampled - speed up
# compute number of tokens, which will be sampled
second_last_idx += num_second_last
num_sampled = num_second_last * self.kernel_size
for block_idx in range(num_sampled.item()):
# concat and pack token sequences to compute logits
seq = (torch.cat(val).unsqueeze(0), torch.cat(dep).unsqueeze(0), torch.cat(pos).unsqueeze(0))
logits = self.compute_logits(seq, memory, idx, cls)[0]
# retrive only logits for tokens which were actually sampled
sampled_token_logits = logits[
sampled_idx + token_idx + block_idx:sampled_idx + token_idx + block_idx + 1]
# check transformer token capacity
if len(sampled_token_logits) == 0:
return val[-1][:token_idx] # reached maximum number of tokens
# compute token probabilities from logits
probs = torch.nn.functional.softmax(sampled_token_logits / temperature, dim=-1) # [t, V]
probs[:, 0] = 0 # 'padding' token
assert (len(probs) == 1)
# sample next sequence token
val[-1][token_idx + block_idx] = torch.multinomial(probs[0], num_samples=1)[0]
token_idx += num_sampled
return val[-1]
......@@ -63,3 +63,69 @@ class SubstitutionGenerator():
token_idx += num_sampled
return val[-1]
class SubstitutionGeneratorAutoregressive:
def __init__(self, compute_logits_fn, num_tokens=8, **_):
""" Create token generator instance which samples 'num_tokens' in one pass.
Args:
compute_logits_fn: Pointer to function, which computes logits of given sequence.
num_tokens: Defines the number of sampled tokens in each step.
"""
self.compute_logits = compute_logits_fn
self.num_tokens = num_tokens
self.kernel_size = num_tokens
def __call__(self, val, dep, pos, memory=None, idx=0, temperature=1.0, cls=None, **_):
""" Sample autoregressively current value token sequence and return updated value sequence.
Args:
val: Value token sequence of currently sampled layer.
dep: Depth token sequence of currently sampled layer.
pos: Position token sequence of currently sampled layer.
memory: Latent sequence vector of the previous layer.
idx: Currently sampled transformer layer index.
temperature: Defines the randomness of the samples.
Return:
Sampled token sequence with values of the current layer.
"""
# compute indices
token_idx = 0
start_idx = 0
stop_idx = len(val[-2])
sampled_idx = len(torch.cat(val[:-1])) if len(val) > 2 else 0
# sample tokens autoregressively
for prev_idx in trange(start_idx, stop_idx, self.kernel_size, leave=False, desc="Tokens"):
# compute number of tokens which can be sampled
num_sampled = torch.sum(val[-2][prev_idx:prev_idx + self.kernel_size] == 2) * self.num_tokens
if num_sampled == 0:
continue # 'skip' if no tokens will be sampled - speed up
for block_idx in range(num_sampled.item()):
# concat and pack token sequences to compute logits
seq = (torch.cat(val).unsqueeze(0), torch.cat(dep).unsqueeze(0), torch.cat(pos).unsqueeze(0))
logits = self.compute_logits(seq, memory, idx, cls)[0]
# retrieve only logits for for current index
sampled_token_logits = logits[
sampled_idx + token_idx + block_idx:sampled_idx + token_idx + block_idx + 1]
# check transformer token capacity
if len(sampled_token_logits) == 0:
return val[-1][:token_idx] # reached maximum number of tokens
# compute token probabilities from logits
probs = torch.nn.functional.softmax(sampled_token_logits / temperature, dim=-1) # [t, V]
probs[:, 0] = 0 # 'padding' token
assert(len(probs) == 1)
# sample next sequence token
val[-1][token_idx + block_idx] = torch.multinomial(probs[0], num_samples=1)[0]
token_idx += num_sampled
return val[-1]
from .basic_generator import BasicGenerator
from .substitution_generator import SubstitutionGenerator
from .double_substitution_generator import DoubleSubstitutionGenerator
from .composite_generator import CompositeGenerator
from .composite_generator import CompositeGenerator, CompositeGeneratorAutoregressive
def _create_token_generator(head, model, spatial_dim):
......@@ -30,6 +30,9 @@ def _create_token_generator(head, model, spatial_dim):
if head in ('composite', 'composite_A', 'composite_B'):
size = 2**spatial_dim
return CompositeGenerator(model.compute_logits, [1, 1, 1, size // 2, size, size, size, size])
if head in ('composite_autoregressive_A'):
size = 2**spatial_dim
return CompositeGeneratorAutoregressive(model.compute_logits, [1, 1, 1, size // 2, size, size, size, size])
if head in ('composite_B'):
size = 2**spatial_dim
return CompositeGenerator(model.compute_logits, [1, 1, 1, size // 4, size, size, size, size])
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment