Skip to content
Snippets Groups Projects
Commit 9398fdc0 authored by Gregor Kobsik's avatar Gregor Kobsik
Browse files

remove token clipping handling in sampler

The limit is now not handled by the Transformer, thus it needs to be enforced by the sampler itself.
parent 8275ad9f
Branches
No related tags found
No related merge requests found
......@@ -102,9 +102,6 @@ class EncoderDecoderSampler():
**kwargs,
)
if len(layer_val) != len(layer_dep):
break # reached maximum number of tokens which can be generated
# append sampled tokens to sequence
val += [layer_val]
dep += [layer_dep]
......
......@@ -120,12 +120,9 @@ class EncoderMultiDecoderSampler():
nxt_val = self.generators[idx]([nxt_val], [nxt_dep], [nxt_pos], **kwargs)
# append sampled tokens to sequence
val += [nxt_val[:len(nxt_val)]]
dep += [nxt_dep[:len(nxt_val)]]
pos += [nxt_pos[:len(nxt_val)]]
if len(nxt_val) != len(nxt_dep):
break # reached maximum number of tokens which can be generated
val += [nxt_val]
dep += [nxt_dep]
pos += [nxt_pos]
# prepare sequence to update memory
if self.head[idx] == 'substitution':
......
......@@ -73,9 +73,9 @@ class EncoderOnlySampler:
)
# append sampled tokens to current sequence
val += [next_val[:len(next_val)]]
dep += [next_dep[:len(next_val)]]
pos += [next_pos[:len(next_val)]]
val += [next_val]
dep += [next_dep]
pos += [next_pos]
if len(next_val) != len(next_dep):
break # reached maximum number of tokens which can be generated
......
......@@ -50,10 +50,6 @@ class BasicGenerator:
# retrieve only logits for for current index
sampled_token_logits = logits[sampled_idx + token_idx:sampled_idx + token_idx + self.num_tokens]
# check transformer token capacity
if len(sampled_token_logits) == 0:
return val[-1][:token_idx] # reached maximum number of tokens
# compute token probabilities from logits
probs = torch.nn.functional.softmax(sampled_token_logits / temperature, dim=-1) # [t, V]
probs[:, 0] = 0 # 'padding' token
......
......@@ -65,10 +65,6 @@ class DoubleSubstitutionGenerator():
# retrive only logits for tokens which were actually sampled
sampled_token_logits = logits[sampled_idx + token_idx:sampled_idx + token_idx + num_sampled]
# check transformer token capacity
if len(sampled_token_logits) != num_sampled:
return val[-1][:token_idx] # reached maximum number of tokens
# compute token probabilities from logits
probs = torch.nn.functional.softmax(sampled_token_logits / temperature, dim=-1) # [t, V]
probs[:, 0] = 0 # 'padding' token
......
......@@ -49,10 +49,6 @@ class SubstitutionGenerator():
# retrive only logits for tokens which were actually sampled
sampled_token_logits = logits[sampled_idx + token_idx:sampled_idx + token_idx + num_sampled]
# check transformer token capacity
if len(sampled_token_logits) != num_sampled:
return val[-1][:token_idx] # reached maximum number of tokens
# compute token probabilities from logits
probs = torch.nn.functional.softmax(sampled_token_logits / temperature, dim=-1) # [t, V]
probs[:, 0] = 0 # 'padding' token
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment