Skip to content
Snippets Groups Projects
Commit e85530c1 authored by Gregor Kobsik's avatar Gregor Kobsik
Browse files

renamed token generators

removed all non-autoregressive generators
reduced number of possible heads
adjusted naming of variables with reccurent versions
parent c1847c76
No related branches found
No related tags found
No related merge requests found
......@@ -77,7 +77,4 @@ class EncoderOnlySampler:
dep += [next_dep]
pos += [next_pos]
if len(next_val) != len(next_dep):
break # reached maximum number of tokens which can be generated
return postprocess(val, target_resolution, self.spatial_dim)
......@@ -12,10 +12,10 @@ class BasicGenerator:
num_tokens: Defines the number of sampled tokens in each step.
"""
self.compute_logits = compute_logits_fn
self.num_tokens = num_tokens
self.kernel_size = num_tokens
def __call__(self, val, dep, pos, memory=None, idx=0, temperature=1.0, slice_sequence=True, cls=None, **_):
""" Sample autoregressively current value token sequence and return updated value sequence.
""" Sample autoregressive current value token sequence and return updated value sequence.
Args:
val: Value token sequence of current layer.
......@@ -30,77 +30,18 @@ class BasicGenerator:
Sampled token sequence with values of the current layer.
"""
# compute indices
start_idx = 0
stop_idx = len(val[-1])
token_idx = 0
sampled_idx = len(torch.cat(val[:-1])) if len(val) > 1 else 0
# sample tokens autoregressively
for token_idx in trange(start_idx, stop_idx, self.num_tokens, leave=False, desc="Tokens"):
# sample tokens autoregressive
for _ in trange(len(val[-1]) // self.kernel_size, leave=False, desc="Tokens"):
for block_idx in range(self.kernel_size):
# concat layers and slice sequence for speed_up
seq = (
torch.cat(val)[:sampled_idx + token_idx + self.num_tokens].unsqueeze(0),
torch.cat(dep)[:sampled_idx + token_idx + self.num_tokens].unsqueeze(0),
torch.cat(pos)[:sampled_idx + token_idx + self.num_tokens].unsqueeze(0),
)
# compute logits
logits = self.compute_logits(seq, memory, idx, cls)[0]
# retrieve only logits for for current index
sampled_token_logits = logits[sampled_idx + token_idx:sampled_idx + token_idx + self.num_tokens]
# compute token probabilities from logits
sampled_token_logits[:, 0] = -float("Inf") # 'padding' token
probs = torch.nn.functional.softmax(sampled_token_logits / temperature, dim=-1) # [t, V]
# sample next sequence token
for i in range(len(probs)):
val[-1][token_idx + i] = torch.multinomial(probs[i], num_samples=1)[0]
return val[-1]
class BasicGeneratorAutoRegressive:
def __init__(self, compute_logits_fn, num_tokens=1, **_):
""" Create token generator instance which samples 'num_tokens' in one pass.
Args:
compute_logits_fn: Pointer to function, which computes logits of given sequence.
num_tokens: Defines the number of sampled tokens in each step.
"""
self.compute_logits = compute_logits_fn
self.num_tokens = num_tokens
def __call__(self, val, dep, pos, memory=None, idx=0, temperature=1.0, slice_sequence=True, cls=None, **_):
""" Sample autoregressively current value token sequence and return updated value sequence.
Args:
val: Value token sequence of current layer.
dep: Depth token sequence of current layer.
pos: Position token sequence of current layer.
memory: Latent sequence vector of the previous layer.
idx: Currently sampled transformer layer index.
temperature: Defines the randomness of the samples.
cls: class label for conditional generation.
Return:
Sampled token sequence with values of the current layer.
"""
# compute indices
start_idx = 0
stop_idx = len(val[-1])
sampled_idx = len(torch.cat(val[:-1])) if len(val) > 1 else 0
# sample tokens autoregressively
for token_idx in trange(start_idx, stop_idx, self.num_tokens, leave=False, desc="Tokens"):
for block_idx in range(self.num_tokens):
# concat layers and slice sequence for speed_up
seq = (
torch.cat(val)[:sampled_idx + token_idx + self.num_tokens].unsqueeze(0),
torch.cat(dep)[:sampled_idx + token_idx + self.num_tokens].unsqueeze(0),
torch.cat(pos)[:sampled_idx + token_idx + self.num_tokens].unsqueeze(0),
torch.cat(val)[:sampled_idx + token_idx + self.kernel_size].unsqueeze(0),
torch.cat(dep)[:sampled_idx + token_idx + self.kernel_size].unsqueeze(0),
torch.cat(pos)[:sampled_idx + token_idx + self.kernel_size].unsqueeze(0),
)
logits = self.compute_logits(seq, memory, idx, cls)[0]
......@@ -115,4 +56,7 @@ class BasicGeneratorAutoRegressive:
# sample next sequence token
val[-1][token_idx + block_idx] = torch.multinomial(probs, num_samples=1)[0]
# update indices
token_idx += self.kernel_size
return val[-1]
import torch
from .basic_generator import BasicGenerator, BasicGeneratorAutoRegressive
from .double_substitution_generator import DoubleSubstitutionGenerator, DoubleSubstitutionGeneratorAutoregressive
from .substitution_generator import SubstitutionGenerator, SubstitutionGeneratorAutoregressive
from .basic_generator import BasicGenerator
from .double_substitution_generator import DoubleSubstitutionGenerator
from .substitution_generator import SubstitutionGenerator
class CompositeGenerator:
......@@ -44,44 +44,3 @@ class CompositeGenerator:
generator = DoubleSubstitutionGenerator(self.compute_logits_fn, num_tokens)
# sample a single layer
return generator(val, dep, pos, memory, layer_idx, temperature, cls=cls)
class CompositeGeneratorAutoregressive:
def __init__(self, compute_logits_fn, num_tokens=[1], **_):
""" Create token generator instance for a 'basic' head.
Args:
compute_logits_fn: Pointer to function, which computes logits of given sequence.
num_tokens: Defines the number of sampled tokens in each step.
"""
self.compute_logits_fn = compute_logits_fn
self.num_tokens_list = num_tokens
def __call__(self, val, dep, pos, memory=None, layer_idx=0, temperature=1.0, cls=None, **_):
""" Sample autoregressively current value token sequence and return sampled value sequence.
Args:
val: Value token sequence of previous and current layers as a list.
dep: Depth token sequence of previous and current layers as a list.
pos: Position token sequence of previous and current layers as a list.
memory: Latent sequence vector of the previous layer.
layer_idx: Currently sampled layer index.
temperature: Defines the randomness of the samples.
cls: class label for conditional generation.
Return:
Sampled token sequence with values of the current layer.
"""
# get the currently sampled depth
cur_depth = torch.max(dep[-1])
# get number of sampled tokens accordingly to depth
num_tokens = self.num_tokens_list[cur_depth - 1]
# create a generator according to layer depth
if cur_depth < 6:
generator = BasicGeneratorAutoRegressive(self.compute_logits_fn, num_tokens)
elif cur_depth == 6: # 'substitution'
generator = SubstitutionGeneratorAutoregressive(self.compute_logits_fn, num_tokens)
else: # 'double_substitution'
generator = DoubleSubstitutionGeneratorAutoregressive(self.compute_logits_fn, num_tokens)
# sample a single layer
return generator(val, dep, pos, memory, layer_idx, temperature, cls=cls)
......@@ -16,7 +16,7 @@ class DoubleSubstitutionGenerator:
self.kernel_size = num_tokens
def __call__(self, val, dep, pos, memory=None, idx=0, temperature=1.0, cls=None, **_):
""" Sample autoregressively current value token sequence and return updated value sequence.
""" Sample autoregressive current value token sequence and return updated value sequence.
Note: Needs at least, the third-, second- and last layer sequence.
......@@ -34,103 +34,17 @@ class DoubleSubstitutionGenerator:
"""
# compute indices
token_idx = 0
start_idx = 0
second_last_idx = 0
stop_idx = len(val[-3])
third_last_idx = 0
# hack to distinguish between 'encoder_only' and 'encoder_multi_decoder'
sampled_idx = len(torch.cat(val[:-1])) if len(val) > 3 else 0
# sample tokens autoregressively
for third_last_idx in trange(start_idx, stop_idx, self.kernel_size, leave=False, desc="Tokens"):
# compute number of mixed tokens in third last layer
num_third_last = torch.sum(val[-3][third_last_idx:third_last_idx + self.kernel_size] == 2)
if num_third_last == 0:
continue # skip, if no tokens will be sampled - speed up
# compute number of mixed token in second last layer
num_second_last = torch.sum(
val[-2][second_last_idx:second_last_idx + self.kernel_size * num_third_last] == 2
)
if num_second_last == 0:
continue # skip, if no tokens will be sampled - speed up
# compute number of tokens, which will be sampled
second_last_idx += num_second_last
num_sampled = num_second_last * self.kernel_size
# concat and pack token sequences to compute logits
seq = (torch.cat(val).unsqueeze(0), torch.cat(dep).unsqueeze(0), torch.cat(pos).unsqueeze(0))
logits = self.compute_logits(seq, memory, idx, cls)[0]
# retrive only logits for tokens which were actually sampled
sampled_token_logits = logits[sampled_idx + token_idx:sampled_idx + token_idx + num_sampled]
# compute token probabilities from logits
sampled_token_logits[:, 0] = -float("Inf") # 'padding' token
probs = torch.nn.functional.softmax(sampled_token_logits / temperature, dim=-1) # [t, V]
# sample next sequence token
for i in range(len(probs)):
val[-1][token_idx + i] = torch.multinomial(probs[i], num_samples=1)[0]
token_idx += num_sampled
return val[-1]
class DoubleSubstitutionGeneratorAutoregressive:
def __init__(self, compute_logits_fn, num_tokens=8, **_):
""" Create token generator instance which samples 'num_tokens' in one pass.
Args:
compute_logits_fn: Pointer to function, which computes logits of given sequence.
num_tokens: Defines the number of sampled tokens in each step.
"""
self.compute_logits = compute_logits_fn
self.num_tokens = num_tokens
self.kernel_size = num_tokens
def __call__(self, val, dep, pos, memory=None, idx=0, temperature=1.0, cls=None, **_):
""" Sample autoregressively current value token sequence and return updated value sequence.
Note: Needs at least, the third-, second- and last layer sequence.
Args:
val: Array of value token sequence layers in ascending order.
dep: Array of depth token sequence layers in ascending order.
pos: Array of position token sequence layers in ascending order.
memory: Latent sequence vector of the previous layer.
idx: Currently sampled transformer layer index.
temperature: Defines the randomness of the samples.
cls: class label for conditional generation.
Return:
Sampled token sequence with values of the current layer.
"""
# compute indices
token_idx = 0
start_idx = 0
second_last_idx = 0
stop_idx = len(val[-3])
# hack to distinguish between 'encoder_only' and 'encoder_multi_decoder'
sampled_idx = len(torch.cat(val[:-1])) if len(val) > 3 else 0
# sample tokens autoregressively
for third_last_idx in trange(start_idx, stop_idx, self.kernel_size, leave=False, desc="Tokens"):
# compute number of mixed tokens in third last layer
num_third_last = torch.sum(val[-3][third_last_idx:third_last_idx + self.kernel_size] == 2)
if num_third_last == 0:
continue # skip, if no tokens will be sampled - speed up
# compute number of mixed token in second last layer
num_second_last = torch.sum(
val[-2][second_last_idx:second_last_idx + self.kernel_size * num_third_last] == 2
)
if num_second_last == 0:
continue # skip, if no tokens will be sampled - speed up
# compute number of tokens, which will be sampled
second_last_idx += num_second_last
num_sampled = num_second_last * self.kernel_size
# sample tokens autoregressive
for _ in trange(len(val[-3]) // self.kernel_size, leave=False, desc="Tokens"):
# compute number of mixed tokens in third and second last layer and number of tokens, which will be sampled
mix_third_last = torch.sum(val[-3][third_last_idx:third_last_idx + self.kernel_size] == 2)
mix_second_last = torch.sum(val[-2][second_last_idx:second_last_idx + mix_third_last * 8] == 2)
num_sampled = mix_second_last * 8
for block_idx in range(num_sampled.item()):
# concat and pack token sequences to compute logits
......@@ -146,6 +60,10 @@ class DoubleSubstitutionGeneratorAutoregressive:
# sample next sequence token
val[-1][token_idx + block_idx] = torch.multinomial(probs, num_samples=1)[0]
# update indices
third_last_idx += self.kernel_size
second_last_idx += mix_third_last * 8
token_idx += num_sampled
return val[-1]
......@@ -3,7 +3,7 @@ import torch
from tqdm.auto import trange
class SubstitutionGenerator():
class SubstitutionGenerator:
def __init__(self, compute_logits_fn, num_tokens=8, **_):
""" Create token generator instance which samples 'num_tokens' in one pass.
......@@ -12,11 +12,10 @@ class SubstitutionGenerator():
num_tokens: Defines the number of sampled tokens in each step.
"""
self.compute_logits = compute_logits_fn
self.num_tokens = num_tokens
self.kernel_size = num_tokens
def __call__(self, val, dep, pos, memory=None, idx=0, temperature=1.0, cls=None, **_):
""" Sample autoregressively current value token sequence and return updated value sequence.
""" Sample autoregressive current value token sequence and return updated value sequence.
Args:
val: Value token sequence of currently sampled layer.
......@@ -31,74 +30,14 @@ class SubstitutionGenerator():
"""
# compute indices
token_idx = 0
start_idx = 0
stop_idx = len(val[-2])
second_last_idx = 0
sampled_idx = len(torch.cat(val[:-1])) if len(val) > 2 else 0
# sample tokens autoregressively
for prev_idx in trange(start_idx, stop_idx, self.kernel_size, leave=False, desc="Tokens"):
# sample tokens autoregressive
for _ in trange(len(val[-2]) // self.kernel_size, leave=False, desc="Tokens"):
# compute number of tokens which can be sampled
num_sampled = torch.sum(val[-2][prev_idx:prev_idx + self.kernel_size] == 2) * self.num_tokens
if num_sampled == 0:
continue # 'skip' if no tokens will be sampled - speed up
# concat and pack token sequences to compute logits
seq = (torch.cat(val).unsqueeze(0), torch.cat(dep).unsqueeze(0), torch.cat(pos).unsqueeze(0))
logits = self.compute_logits(seq, memory, idx, cls)[0]
# retrive only logits for tokens which were actually sampled
sampled_token_logits = logits[sampled_idx + token_idx:sampled_idx + token_idx + num_sampled]
# compute token probabilities from logits
sampled_token_logits[:, 0] = -float("Inf") # 'padding' token
probs = torch.nn.functional.softmax(sampled_token_logits / temperature, dim=-1) # [t, V]
# sample next sequence token
for i in range(len(probs)):
val[-1][token_idx + i] = torch.multinomial(probs[i], num_samples=1)[0]
token_idx += num_sampled
return val[-1]
class SubstitutionGeneratorAutoregressive:
def __init__(self, compute_logits_fn, num_tokens=8, **_):
""" Create token generator instance which samples 'num_tokens' in one pass.
Args:
compute_logits_fn: Pointer to function, which computes logits of given sequence.
num_tokens: Defines the number of sampled tokens in each step.
"""
self.compute_logits = compute_logits_fn
self.num_tokens = num_tokens
self.kernel_size = num_tokens
def __call__(self, val, dep, pos, memory=None, idx=0, temperature=1.0, cls=None, **_):
""" Sample autoregressively current value token sequence and return updated value sequence.
Args:
val: Value token sequence of currently sampled layer.
dep: Depth token sequence of currently sampled layer.
pos: Position token sequence of currently sampled layer.
memory: Latent sequence vector of the previous layer.
idx: Currently sampled transformer layer index.
temperature: Defines the randomness of the samples.
Return:
Sampled token sequence with values of the current layer.
"""
# compute indices
token_idx = 0
start_idx = 0
stop_idx = len(val[-2])
sampled_idx = len(torch.cat(val[:-1])) if len(val) > 2 else 0
# sample tokens autoregressively
for prev_idx in trange(start_idx, stop_idx, self.kernel_size, leave=False, desc="Tokens"):
# compute number of tokens which can be sampled
num_sampled = torch.sum(val[-2][prev_idx:prev_idx + self.kernel_size] == 2) * self.num_tokens
if num_sampled == 0:
continue # 'skip' if no tokens will be sampled - speed up
mix_second_last = torch.sum(val[-2][second_last_idx:second_last_idx + self.kernel_size] == 2)
num_sampled = mix_second_last * 8
for block_idx in range(num_sampled.item()):
# concat and pack token sequences to compute logits
......@@ -116,6 +55,8 @@ class SubstitutionGeneratorAutoregressive:
# sample next sequence token
val[-1][token_idx + block_idx] = torch.multinomial(probs, num_samples=1)[0]
# update indices
second_last_idx += self.kernel_size
token_idx += num_sampled
return val[-1]
from .basic_generator import BasicGenerator
from .substitution_generator import SubstitutionGenerator
from .double_substitution_generator import DoubleSubstitutionGenerator
from .composite_generator import CompositeGenerator, CompositeGeneratorAutoregressive
from .composite_generator import CompositeGenerator
def _create_token_generator(head, model, spatial_dim):
......@@ -17,25 +14,18 @@ def _create_token_generator(head, model, spatial_dim):
Return:
Token generator initialised with specified parameters.
"""
if head in ('generative_basic', 'basic', 'linear'):
return BasicGenerator(model.compute_logits)
if head in ('half_conv', 'half_conv_A'):
return BasicGenerator(model.compute_logits, 2**(spatial_dim - 1))
if head in ('single_conv', 'single_conv_A'):
return BasicGenerator(model.compute_logits, 2**spatial_dim)
if head in ('substitution'):
return SubstitutionGenerator(model.compute_logits, 2**spatial_dim)
if head in ('double_substitution'):
return DoubleSubstitutionGenerator(model.compute_logits, 2**spatial_dim)
if head in ('composite', 'composite_A', 'composite_B'):
size = 2**spatial_dim
return CompositeGenerator(model.compute_logits, [1, 1, 1, size // 2, size, size, size, size])
if head in ('composite_autoregressive_A'):
size = 2**spatial_dim
return CompositeGeneratorAutoregressive(model.compute_logits, [1, 1, 1, size // 2, size, size, size, size])
kwargs = {
'compute_logits_fn': model.compute_logits,
}
if head in ('composite_A'):
return CompositeGenerator(num_tokens=[1, 1, 1, 4, 8, 8, 8, 8], **kwargs)
if head in ('composite_B'):
size = 2**spatial_dim
return CompositeGenerator(model.compute_logits, [1, 1, 1, size // 4, size, size, size, size])
return CompositeGenerator(num_tokens=[1, 1, 1, 1, 8, 8], **kwargs)
if head in ('composite_C'):
return CompositeGenerator(num_tokens=[1, 1, 2, 4, 8, 4], **kwargs)
if head in ('composite_D'):
return CompositeGenerator(num_tokens=[1, 1, 4, 8, 4, 8, 4, 8], **kwargs)
raise ValueError(f"ERROR: {head} token generator not implemented.")
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment