Exemple #1
0
 def get_mfcc(self, waveform: torch.Tensor):
     melkwargs = {
         "win_length": self.win_length,
         "hop_length": self.hop_length,
         "n_fft": self.win_length,
         "n_mels": self.n_mels,
     }
     coefs = MFCC(self.sample_rate, n_mfcc=30, melkwargs=melkwargs)(waveform)
     means = coefs.mean(dim=2)
     stds = coefs.std(dim=2)
     return torch.cat((means, stds), dim=1)
Exemple #2
0
 def create_spectro(self, item:AudioItem):
     if self.config.mfcc: 
         mel = MFCC(sample_rate=item.sr, n_mfcc=self.config.sg_cfg.n_mfcc, melkwargs=self.config.sg_cfg.mel_args())(item.sig)
     else:
         mel = MelSpectrogram(**(self.config.sg_cfg.mel_args()))(item.sig)
         if self.config.sg_cfg.to_db_scale: 
             mel = AmplitudeToDB(top_db=self.config.sg_cfg.top_db)(mel)
     mel = mel.detach()
     if self.config.standardize: 
         mel = standardize(mel)
     if self.config.delta: 
         mel = torch.cat([torch.stack([m,torchdelta(m),torchdelta(m, order=2)]) for m in mel]) 
     return mel
Exemple #3
0
    def __init__(self, sample_rate=16000, win_ms=25, hop_ms=10, n_freq=201, n_mels=40, n_mfcc=13, feat_list=None, eps=1e-10, **kwargs):
        super(OnlinePreprocessor, self).__init__()
        # save preprocessing arguments
        self._sample_rate = sample_rate
        self._win_ms = win_ms
        self._hop_ms = hop_ms
        self._n_freq = n_freq
        self._n_mels = n_mels
        self._n_mfcc = n_mfcc

        win = round(win_ms * sample_rate / 1000)
        hop = round(hop_ms * sample_rate / 1000)
        n_fft = (n_freq - 1) * 2
        self._win_args = {'n_fft': n_fft, 'hop_length': hop, 'win_length': win}
        self.register_buffer('_window', torch.hann_window(win))
        
        self._stft_args = {'center': True, 'pad_mode': 'reflect', 'normalized': False, 'onesided': True}
        # stft_args: same default values as torchaudio.transforms.Spectrogram & librosa.core.spectrum._spectrogram
        self._stft = partial(torch.stft, **self._win_args, **self._stft_args)
        self._magphase = partial(torchaudio.functional.magphase, power=2)
        self._melscale = MelScale(sample_rate=sample_rate, n_mels=n_mels)
        self._mfcc_trans = MFCC(sample_rate=sample_rate, n_mfcc=n_mfcc, log_mels=True, melkwargs=self._win_args)
        self._istft = partial(torch.istft, **self._win_args, **self._stft_args)
        
        self.feat_list = feat_list
        self.register_buffer('_pseudo_wavs', torch.randn(N_SAMPLED_PSEUDO_WAV, sample_rate))
        self.eps = eps
def build(
    pre_process_step_cfg: pre_process_step_pb2.PreProcessStep,
) -> Tuple[Union[MFCC, Standardize], Stage]:
    """Returns tuple of ``(preprocessing callable, stage)``.

    Args:
        pre_process_step_cfg: A ``PreProcessStep`` protobuf object containing
            the config for the desired preprocessing step.

    Returns:
        A tuple of ``(preprocessing callable, stage)``.

    Raises:
        :py:class:`ValueError`: On invalid configuration.
    """
    step_type = pre_process_step_cfg.WhichOneof("pre_process_step")
    if step_type == "mfcc":
        step = MFCC(
            n_mfcc=pre_process_step_cfg.mfcc.n_mfcc,
            melkwargs={
                "win_length": pre_process_step_cfg.mfcc.win_length,
                "hop_length": pre_process_step_cfg.mfcc.hop_length,
            },
        )
    elif step_type == "standardize":
        step = Standardize()
    elif step_type == "context_frames":
        step = AddContextFrames(
            n_context=pre_process_step_cfg.context_frames.n_context)
    else:
        raise ValueError(f"unknown pre_process_step '{step_type}'")

    return step, Stage(pre_process_step_cfg.stage)
Exemple #5
0
 def __init__(self):
     sample_rate = 16000
     self.mfcc = MFCC(sample_rate=sample_rate,
                      n_mfcc=40,
                      melkwargs={
                          'win_length': int(0.025 * sample_rate),
                          'hop_length': int(0.010 * sample_rate),
                          'n_fft': int(0.025 * sample_rate)
                      })
    def __init__(self, root: str, training: bool=True, max_length: int=2500):
        self.data = []
        self.labels = []
        self.transform = MFCC(sample_rate=8000)

        for filename in os.listdir(root):
            info = re.split(r'[_.]', filename)
            if (training and int(info[2]) > 4) or (not training and int(info[2]) < 5):
                filepath = root + filename
                input_audio = self.transform(load_wav(filepath)[0])[0, :, :max_length]
                if input_audio.shape[1] < max_length:
                    input_audio = torch.cat([input_audio, torch.zeros((40, max_length - input_audio.shape[1]))], dim=1)
                self.data.append(input_audio)
                self.labels.append(int(info[0]))
Exemple #7
0
def test_dataset_transforms_multiple():
    """Note: This test may fail if not on Linux."""
    n_mfcc = 13
    files = glob.glob('lib/test/data/v1.0.10/*.wav')
    x_original, y = TorchFSDD(files)[0]
    x_trans, _ = TorchFSDD(files,
                           transforms=Compose([
                               TrimSilence(threshold=0.05),
                               MFCC(sample_rate=8e3, n_mfcc=n_mfcc)
                           ]))[0]
    assert isinstance(x_trans, torch.Tensor)
    assert x_trans.ndim == 2
    assert x_trans.shape == (n_mfcc, 9)
    assert y == 7
Exemple #8
0
    def __init__(
            self,
            in_size_in_seconds=SAMPLE_LENGTH_IN_SECONDS,
            sr=SAMPLE_RATE,
            n_mfcc=30,
            n_fft=1024,
            hop_length=256,
            n_mels=128,
            rnn_dim=512,
            z_size=16):
        """
        Construct an instance of ZEncoder

        Args:
            in_size_in_seconds (float, optional): The length of the input in
                seconds. Defaults to SAMPLE_LENGTH_IN_SECONDS.
            sr (int, optional): Sample rate. Defaults to SAMPLE_RATE.
            n_mfcc (int, optional): Number of MFCCs. Defaults to 30.
            n_fft (int, optional): FFT size. Defaults to 1024.
            hop_length (int, optional): FFT hop length. Defaults to 256.
            n_mels (int, optional): Number of mel bands. Defaults to 128.
            rnn_dim (int, optional): Number of RNN states. Defaults to 512.
            z_size (int, optional): Size of latent dimension. Defaults to 16.
        """
        super().__init__()
        self.sr = sr
        self.in_size = sr * in_size_in_seconds
        self.mfcc = MFCC(
            sample_rate=sr,
            n_mfcc=n_mfcc,
            log_mels=True,
            melkwargs={
                "n_fft": n_fft,
                "n_mels": n_mels,
                "hop_length": hop_length,
                "f_min": 20.0,
                "f_max": 8000.0
            })

        self.time_dim = int(self.in_size // hop_length)
        self.norm = nn.LayerNorm((n_mfcc, self.time_dim))

        self.gru = nn.GRU(input_size=n_mfcc, hidden_size=rnn_dim)

        self.linear = nn.Linear(rnn_dim, z_size)
Exemple #9
0
    def __init__(self):
        if torch.cuda.is_available():
            self.device = torch.device("cuda")
        else:
            self.device = torch.device("cpu")

        self.seq_size = 20
        self.batch_size = 32
        sample_rate = 16000
        model_path = "model/model_lstm.pt"

        self.model = torch.load(model_path).to(self.device)
        self.model.eval()

        self.mfcc_ft = MFCC(sample_rate=sample_rate, n_mfcc=40,
                            melkwargs={'win_length': int(0.025 * sample_rate),
                                       'hop_length': int(0.010 * sample_rate),
                                       'n_fft': int(0.025 * sample_rate)}).to(self.device)
Exemple #10
0
    def open(self, item) -> AudioItem:
        p = Path(item)
        if self.path is not None and str(self.path) not in str(item): p = self.path/item
        if not p.exists(): 
            raise FileNotFoundError(f"Neither '{item}' nor '{p}' could be found")
        if not str(p).lower().endswith(AUDIO_EXTENSIONS): raise Exception("Invalid audio file")

        cfg = self.config
        if cfg.use_spectro:
            cache_dir = self.path / cfg.cache_dir
            folder = md5(str(asdict(cfg))+str(asdict(cfg.sg_cfg)))
            fname = f"{md5(str(p))}-{p.name}.pt"
            image_path = cache_dir/(f"{folder}/{fname}")
            if cfg.cache and not cfg.force_cache and image_path.exists():
                mel = torch.load(image_path).squeeze()
                start, end = None, None
                if cfg.duration and cfg._processed:
                    mel, start, end = tfm_crop_time(mel, cfg._sr, cfg.duration, cfg.sg_cfg.hop)
                return AudioItem(spectro=mel, path=item, max_to_pad=cfg.max_to_pad, start=start, end=end)

        signal, samplerate = torchaudio.load(str(p))
        if(cfg._sr is not None and samplerate != cfg._sr):
            raise ValueError(f'''Multiple sample rates detected. Sample rate {samplerate} of file {str(p)} 
                                does not match config sample rate {cfg._sr} 
                                this means your dataset has multiple different sample rates, 
                                please choose one and set resample_to to that value''')

        if cfg.max_to_pad or cfg.segment_size:
            pad_len = cfg.max_to_pad if cfg.max_to_pad is not None else cfg.segment_size
            signal = PadTrim(max_len=int(pad_len/1000*samplerate))(signal)

        mel = None
        if cfg.use_spectro:
            if cfg.mfcc: mel = MFCC(sr=samplerate, n_mfcc=cfg.sg_cfg.n_mfcc, melkwargs=asdict(cfg.sg_cfg))(signal.reshape(1,-1))
            else:
                mel = MelSpectrogram(**(cfg.sg_cfg.mel_args()))(signal.reshape(1, -1))
                if cfg.sg_cfg.to_db_scale: mel = SpectrogramToDB(top_db=cfg.sg_cfg.top_db)(mel)
            mel = mel.squeeze().permute(1, 0)
            if cfg.standardize: mel = standardize(mel)
            if cfg.delta: mel = torch.stack([mel, torchdelta(mel), torchdelta(mel, order=2)]) 
            else: mel = mel.expand(3,-1,-1)
            if cfg.cache:
                os.makedirs(image_path.parent, exist_ok=True)
                torch.save(mel, image_path)
            start, end = None, None
            if cfg.duration and cfg._processed: 
                mel, start, end = tfm_crop_time(mel, cfg._sr, cfg.duration, cfg.sg_cfg.hop)
        return AudioItem(sig=signal.squeeze(), sr=samplerate, spectro=mel, path=item, start=start, end=end)
Exemple #11
0
    def __init__(self,
                 args,
                 pad_idx,
                 cls_idx,
                 sep_idx,
                 bert_args,
                 device='cpu'):
        self.device = device
        self.only_audio = args.only_audio
        self.only_text = args.only_text
        self.use_both = not (self.only_audio or self.only_text)

        # audio properties
        self.max_len_audio = args.max_len_audio
        self.n_mfcc = args.n_mfcc
        self.n_fft_size = args.n_fft_size
        self.sample_rate = args.sample_rate
        self.resample_rate = args.resample_rate

        # text properties
        self.max_len_bert = bert_args.max_len
        self.pad_idx = pad_idx
        self.cls_idx = cls_idx
        self.sep_idx = sep_idx

        # audio feature extractor
        if not self.only_text:
            self.audio2mfcc = MFCC(sample_rate=self.resample_rate,
                                   n_mfcc=self.n_mfcc,
                                   log_mels=False,
                                   melkwargs={
                                       'n_fft': self.n_fft_size
                                   }).to(self.device)

        # text feature extractor
        if not self.only_audio:
            self.bert = load_bert(args.bert_path, self.device)
            self.bert.eval()
            self.bert.zero_grad()
Exemple #12
0
def build_transform(feature_type, feature_size, n_fft=512, win_length=400,
                    hop_length=200, delta=False, cmvn=False, downsample=1,
                    T_mask=0, T_num_mask=0, F_mask=0, F_num_mask=0,
                    pad_to_divisible=True):
    feature_args = {
        'n_fft': n_fft,
        'win_length': win_length,
        'hop_length': hop_length,
        # 'f_min': 20,
        # 'f_max': 5800,
    }
    transform = []
    input_size = feature_size
    if feature_type == 'mfcc':
        transform.append(MFCC(
            n_mfcc=feature_size, log_mels=True, melkwargs=feature_args))
    if feature_type == 'melspec':
        transform.append(MelSpectrogram(
            n_mels=feature_size, **feature_args))
    if feature_type == 'logfbank':
        transform.append(FilterbankFeatures(
            n_filt=feature_size, **feature_args))
    if delta:
        transform.append(CatDeltas())
        input_size = input_size * 3
    # if cmvn:
    #     transform.append(CMVN())
    if downsample > 1:
        transform.append(Downsample(downsample, pad_to_divisible))
        input_size = input_size * downsample
    transform_test = torch.nn.Sequential(*transform)

    if T_mask > 0 and T_num_mask > 0:
        transform.append(TimeMasking(T_mask, T_num_mask))
    if F_mask > 0 and F_num_mask > 0:
        transform.append(FrequencyMasking(F_mask, F_num_mask))
    transform_train = torch.nn.Sequential(*transform)

    return transform_train, transform_test, input_size
Exemple #13
0
    def __init__(self,
                 root: str,
                 training: bool = True,
                 frequency: int = 16000,
                 max_length: int = 280,
                 transform=None,
                 return_length: bool = False):
        self.data = []
        self.return_length = return_length
        if transform is None:
            self.transform = MFCC(frequency)
        else:
            self.transform = transform

        self.training = training
        self.filenames = []
        self.max_length = max_length
        if frequency != 16000:
            self.resampler = Resample(orig_freq=16000, new_freq=frequency)

        if training:
            df_labels = pd.read_csv(root + "train_label.csv")
            root = root + "Train/"
            self.labels = []
        else:
            root = root + "Public_Test/"

        for filename in os.listdir(root):
            if filename.endswith(".wav"):
                self.filenames.append(filename)
                input_audio, sample_rate = load_wav(root + filename)
                if frequency != 16000:
                    input_audio = self.resampler(input_audio)

                self.data.append(input_audio)
                if training:
                    self.labels.append(
                        df_labels.loc[df_labels["File"] == filename,
                                      "Label"].values.item())
Exemple #14
0
 def create_spectro(self, item:AudioItem):
     if self.config.mfcc: 
         mel = MFCC(sample_rate=item.sr, n_mfcc=self.config.sg_cfg.n_mfcc, melkwargs=self.config.sg_cfg.mel_args())(item.sig)
     else:
         if self.config.sg_cfg.custom_spectro != None:
             mel = self.config.sg_cfg.custom_spectro(item.sig)
         else:
             if self.config.sg_cfg.n_mels > 0:
               c = self.config.sg_cfg
               mel = librosa.feature.melspectrogram(y=np.array(item.sig[0,:]), sr=item.sr, fmax=c.f_max, fmin=c.f_min, **(self.config.sg_cfg.mel_args()))
             
               mel = torch.from_numpy(mel)
               mel.unsqueeze_(0)  
             else:
               mel = Spectrogram(**(self.config.sg_cfg.spectro_args()))(item.sig)
         if self.config.sg_cfg.to_db_scale: 
             mel = AmplitudeToDB(top_db=self.config.sg_cfg.top_db)(mel)
     mel = mel.detach()
     if self.config.standardize: 
         mel = standardize(mel)
     if self.config.delta: 
         mel = torch.cat([torch.stack([m,torchdelta(m),torchdelta(m, order=2)]) for m in mel]) 
     return mel
Exemple #15
0
    def __init__(
        self,
        base_channels: int,
        out_channels: int = 64,
        input_ulaw: bool = True,
        input_rate: int = 16000,
        mfcc_rate: int = 100,
        version: int = 1,
    ):
        super().__init__()
        self.base_channels = base_channels
        self.out_channels = out_channels
        self.input_ulaw = input_ulaw
        self.input_rate = input_rate
        self.mfcc_rate = mfcc_rate
        self.mid_channels = base_channels * 12
        self.version = version

        assert mfcc_rate % 2 == 0, "must be able to downsample MFCCs once"
        assert input_rate % mfcc_rate == 0, "must evenly downsample input sequences"

        from torchaudio.transforms import MFCC

        if version == 2:
            n_fft = round(400 * input_rate / 16000)
        else:
            n_fft = (input_rate // self.mfcc_rate) * 2
        self.mfcc = MFCC(
            sample_rate=input_rate,
            n_mfcc=13,
            log_mels=version == 1,
            melkwargs=dict(
                n_fft=n_fft,
                hop_length=input_rate // self.mfcc_rate,
                n_mels=40 if version == 1 else 80,
                normalized=version == 2,
            ),
        )

        self.blocks = nn.ModuleList(
            [
                nn.Sequential(
                    nn.Conv1d(13 * 3, self.mid_channels, 3, padding=1),
                    nn.GELU(),
                ),
                ResConv(self.mid_channels, self.mid_channels, 3, padding=1),
                nn.Sequential(
                    nn.Conv1d(
                        self.mid_channels, self.mid_channels, 4, stride=2, padding=1
                    ),
                    nn.GELU(),
                ),
                *[
                    ResConv(self.mid_channels, self.mid_channels, 3, padding=1)
                    for _ in range(2)
                ],
                *[ResConv(self.mid_channels, self.mid_channels, 1) for _ in range(4)],
                nn.Conv1d(self.mid_channels, self.out_channels, 1),
            ]
        )
        # Zero output so that by default we don't affect the
        # behavior of downstream models.
        for p in self.blocks[-1].parameters():
            with torch.no_grad():
                p.zero_()
    def _get_mfcc(self, arr, sample_rate=22000):

        mfcc_tensor = MFCC(sample_rate, n_mfcc=window)
        return mfcc_tensor.forward(arr)
Exemple #17
0
from nntoolbox.metrics import *
from torch.optim import Adam
from src.utils import *
from src.models import *


batch_size = 128
frequency = 16000
# lr = 5e-4


transform_train = Compose(
    [
        CropCenter(40000),
        Noise(),
        MFCC(sample_rate=frequency, n_mfcc=30),
        TimePad(216)
    ]
)

transform_val = Compose(
    [
        MFCC(sample_rate=frequency, n_mfcc=30),
        TimePad(216)
    ]
)

train_val_dataset = ERCAoTData("data/", True)
train_size = int(0.8 * len(train_val_dataset))
val_size = len(train_val_dataset) - train_size
train_data, val_data = random_split_before_transform(
Exemple #18
0
 def __init__(self, sr, n_mfcc=40):
     self.sr = sr
     self.n_mfcc = n_mfcc
     self._mfcc = MFCC(sr, n_mfcc=40, log_mels=True)
Exemple #19
0
    def open(self, item) -> AudioItem:
        p = Path(item)
        if self.path is not None and str(self.path) not in str(item):
            p = self.path / item
        if not p.exists():
            raise FileNotFoundError(
                f"Neither '{item}' nor '{p}' could be found")
        if not str(p).lower().endswith(AUDIO_EXTENSIONS):
            raise Exception("Invalid audio file")

        cfg = self.config
        if cfg.use_spectro:
            folder = md5(str(asdict(cfg)) + str(asdict(cfg.sg_cfg)))
            fname = f"{md5(str(p))}-{p.name}.pt"
            image_path = cfg.cache_dir / (f"{folder}/{fname}")
            if cfg.cache and not cfg.force_cache and image_path.exists():
                mel = torch.load(image_path).squeeze()
                start, end = None, None
                if cfg.duration and cfg._processed:
                    mel, start, end = tfm_crop_time(mel, cfg._sr, cfg.duration,
                                                    cfg.sg_cfg.hop,
                                                    cfg.pad_mode)
                return AudioItem(spectro=mel,
                                 path=item,
                                 max_to_pad=cfg.max_to_pad,
                                 start=start,
                                 end=end)

        sig, sr = torchaudio.load(str(p))
        if (cfg._sr is not None and sr != cfg._sr):
            raise ValueError(
                f'''Multiple sample rates detected. Sample rate {sr} of file {str(p)} 
                                does not match config sample rate {cfg._sr} 
                                this means your dataset has multiple different sample rates, 
                                please choose one and set resample_to to that value'''
            )
        if (sig.shape[0] > 1):
            if not cfg.downmix:
                warnings.warn(
                    f'''Audio file {p} has {sig.shape[0]} channels, automatically downmixing to mono, 
                                set AudioConfig.downmix=True to remove warnings'''
                )
            sig = DownmixMono(channels_first=True)(sig)
        if cfg.max_to_pad or cfg.segment_size:
            pad_len = cfg.max_to_pad if cfg.max_to_pad is not None else cfg.segment_size
            sig = tfm_padtrim_signal(sig,
                                     int(pad_len / 1000 * sr),
                                     pad_mode="zeros")

        mel = None
        if cfg.use_spectro:
            if cfg.mfcc:
                mel = MFCC(sr=sr,
                           n_mfcc=cfg.sg_cfg.n_mfcc,
                           melkwargs=cfg.sg_cfg.mel_args())(sig)
            else:
                mel = MelSpectrogram(**(cfg.sg_cfg.mel_args()))(sig)
                if cfg.sg_cfg.to_db_scale:
                    mel = SpectrogramToDB(top_db=cfg.sg_cfg.top_db)(mel)
            mel = mel.squeeze().permute(1, 0).flip(0)
            if cfg.standardize: mel = standardize(mel)
            if cfg.delta:
                mel = torch.stack(
                    [mel, torchdelta(mel),
                     torchdelta(mel, order=2)])
            else:
                mel = mel.expand(3, -1, -1)
            if cfg.cache:
                os.makedirs(image_path.parent, exist_ok=True)
                torch.save(mel, image_path)
                _record_cache_contents(cfg, [image_path])
            start, end = None, None
            if cfg.duration and cfg._processed:
                mel, start, end = tfm_crop_time(mel, cfg._sr, cfg.duration,
                                                cfg.sg_cfg.hop, cfg.pad_mode)
        return AudioItem(sig=sig.squeeze(),
                         sr=sr,
                         spectro=mel,
                         path=item,
                         start=start,
                         end=end)
Exemple #20
0
 def __init__(self, sr: int, sg_cfg: SpectrogramConfig):
     self.sg_cfg = sg_cfg
     self.spec = Spectrogram(**sg_cfg.spec_args)
     self.to_mel = MelScale(sample_rate=sr, **sg_cfg.mel_args)
     self.mfcc = MFCC(sample_rate=sr, **sg_cfg.mfcc_args)
     self.to_db = AmplitudeToDB(top_db=sg_cfg.top_db)
Exemple #21
0
from torchaudio.transforms import MFCC
from torchvision.transforms import Compose
from nntoolbox.learner import SupervisedLearner
from nntoolbox.callbacks import *
from nntoolbox.metrics import *
from torch.optim import Adam
from src.utils import *
from src.models import *

batch_size = 128
frequency = 16000
lr = 0.001

transform_train = Compose(
    [RandomCropCenter(30000),
     MFCC(sample_rate=frequency),
     TimePad(280)])

transform_val = Compose([MFCC(sample_rate=frequency), TimePad(280)])

train_val_dataset = ERCDataRaw("data/", True)
train_size = int(0.8 * len(train_val_dataset))
val_size = len(train_val_dataset) - train_size
train_data, val_data = random_split_before_transform(
    train_val_dataset,
    lengths=[train_size, val_size],
    transforms=[transform_train, transform_val])

train_loader = DataLoader(train_data, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_data, batch_size=batch_size)
Exemple #22
0
from nntoolbox.learner import SupervisedLearner
from nntoolbox.callbacks import *
from nntoolbox.metrics import *
from torch.optim import Adam
from src.utils import *
from src.models import *

print("Running Nhat's script")

batch_size = 128
frequency = 16000
lr = 0.001

transform_train = Compose([
    # RandomCropCenter(45000),
    MFCC(sample_rate=frequency, n_mfcc=30),
    TimePad(216)
])

transform_val = Compose([MFCC(sample_rate=frequency, n_mfcc=30), TimePad(216)])

for i in range(2):
    print('===== Run {} ===='.format(i))

    model = ICModel()
    optimizer = Adam(model.parameters(), lr=lr)

    train_val_dataset = ERCDataRaw("data/", True)
    train_size = int(0.8 * len(train_val_dataset))
    val_size = len(train_val_dataset) - train_size
    train_data, val_data = stratified_random_split(