def test_lengths_arg_exists(): from speechbrain.utils.callchains import lengths_arg_exists def non_len_func(x): return x + 1 def len_func(x, lengths): return x + lengths assert not lengths_arg_exists(non_len_func) assert lengths_arg_exists(len_func)
def append(self, *args, **kwargs): # Add lengths arg inference here. super().append(*args, **kwargs) latest_forward_method = list(self.values())[-1].forward self.takes_lengths.append(lengths_arg_exists(latest_forward_method))