diff --git a/sample/layer_sampler/encoder_only_sampler.py b/sample/layer_sampler/encoder_only_sampler.py index 0894b03a64f7cde5c11cb9253f5fa3aca2645518..ca56c5f48073bd864ee493b1465a30e7a3f51d51 100644 --- a/sample/layer_sampler/encoder_only_sampler.py +++ b/sample/layer_sampler/encoder_only_sampler.py @@ -77,4 +77,7 @@ class EncoderOnlySampler: dep += [next_dep] pos += [next_pos] + if torch.sum(next_val == 2) == 0: + break # early-out, no mixed tokens sampled + return postprocess(val, target_resolution, self.spatial_dim) diff --git a/sample/layer_sampler/recurrent_sampler.py b/sample/layer_sampler/recurrent_sampler.py index 45aa9bd5efb0ffff7898928aeab612ff86fab69c..0d2e1be09986b17ee3ebe4e5537b9aa8cd9b00bd 100644 --- a/sample/layer_sampler/recurrent_sampler.py +++ b/sample/layer_sampler/recurrent_sampler.py @@ -89,5 +89,8 @@ class RecurrentSampler: dep += [next_dep] pos += [next_pos] + if torch.sum(next_val == 2) == 0: + break # early-out, no mixed tokens sampled + # transform the sampled octree sequence back into a regular-grid voxel array and return return postprocess(val, target_resolution, self.spatial_dim)