diff --git a/sample/layer_sampler/encoder_only_sampler.py b/sample/layer_sampler/encoder_only_sampler.py
index 0894b03a64f7cde5c11cb9253f5fa3aca2645518..ca56c5f48073bd864ee493b1465a30e7a3f51d51 100644
--- a/sample/layer_sampler/encoder_only_sampler.py
+++ b/sample/layer_sampler/encoder_only_sampler.py
@@ -77,4 +77,7 @@ class EncoderOnlySampler:
                 dep += [next_dep]
                 pos += [next_pos]
 
+                if torch.sum(next_val == 2) == 0:
+                    break  # early-out, no mixed tokens sampled
+
         return postprocess(val, target_resolution, self.spatial_dim)
diff --git a/sample/layer_sampler/recurrent_sampler.py b/sample/layer_sampler/recurrent_sampler.py
index 45aa9bd5efb0ffff7898928aeab612ff86fab69c..0d2e1be09986b17ee3ebe4e5537b9aa8cd9b00bd 100644
--- a/sample/layer_sampler/recurrent_sampler.py
+++ b/sample/layer_sampler/recurrent_sampler.py
@@ -89,5 +89,8 @@ class RecurrentSampler:
                 dep += [next_dep]
                 pos += [next_pos]
 
+                if torch.sum(next_val == 2) == 0:
+                    break  # early-out, no mixed tokens sampled
+
         # transform the sampled octree sequence back into a regular-grid voxel array and return
         return postprocess(val, target_resolution, self.spatial_dim)