Пример #1
0
 def __init__(self, dataset_path, mode, spec_augment=False):
     self.dataset = pickle.load(open(dataset_path, 'rb'))
     self.spec_augment = spec_augment
     self.mode = mode
     if spec_augment:
         self.transform = lambda x: spec_augment_pytorch.spec_augment(
             x, frequency_masking_para=60, time_masking_para=41)
     else:
         self.transform = lambda x: x
Пример #2
0
 def transform_feature(self, feature):
     feature_shape = feature.shape
     feature = mylibrosa.db_to_power(feature)
     feature = feature.reshape(-1, *feature_shape)
     feature = torch.from_numpy(feature)
     feature = spec_augment(
         feature,
         frequency_masking_para = self.freq_mask,
         time_masking_para = self.time_mask,
     )
     feature = feature.numpy().reshape(feature_shape)
     feature = mylibrosa.power_to_db(feature)
     return feature
Пример #3
0
def specAugmented_melfeature_extraction(session, audio_dir):
    from SpecAugment import spec_augment_pytorch
    labels_df = pd.read_csv(config.DF_IEMOCAP +
                            'df_iemocap_{}.csv'.format(session))
    iemocap_dir = config.IEMOCAP_DIR
    compressed_dir = audio_dir
    sr = 44100  #sample rate
    audio_features = pd.DataFrame(columns=['feature'])
    counter = 0
    emotions = []
    for sess in [session]:
        wav_file_path = '{}Session{}/wav/'.format(compressed_dir, sess)
        orig_wav_files = os.listdir(wav_file_path)
        for orig_wav_file in tqdm(orig_wav_files):
            orig_wav_vector, sample_rate = librosa.load(wav_file_path +
                                                        orig_wav_file,
                                                        sr=sr)
            sample_rate = np.array(sample_rate)
            melspect = librosa.feature.melspectrogram(y=orig_wav_vector,
                                                      sr=sample_rate,
                                                      n_mels=256,
                                                      hop_length=128,
                                                      fmax=8000)
            warped_masked_spectrogram = np.mean(
                (spec_augment_pytorch.spec_augment(mel_spectrogram=melspect)),
                axis=0)
            audio_features.loc[counter] = [warped_masked_spectrogram]
            counter = counter + 1
        audio_features = (pd.DataFrame(
            audio_features['feature'].values.tolist())).fillna(0)
        for orig_wav_file in tqdm(orig_wav_files):
            orig_wav_file, file_format = orig_wav_file.split('.')
            for index, row in labels_df[labels_df['wav_file'].str.contains(
                    orig_wav_file)].iterrows():
                label = row['emotion']
                emotions.append(label)
        audio_features['emotions'] = pd.Series(emotions)
        audio_feature_subset = audio_features[audio_features["emotions"].isin(
            ["neu", 'ang', 'hap', 'sad', 'exc'])]
        if not os.path.exists('input/mel_features/SpecAugmented_features'):
            os.makedirs('input/mel_features/SpecAugmented_features')
        audio_feature_subset.to_csv(
            'input/mel_features/SpecAugmented_features/audio_features_{}.csv'.
            format(session),
            index=False)
Пример #4
0
def train(args):
    '''Training. Model will be saved after several iterations.

    Args:
      dataset_dir: string, directory of dataset
      workspace: string, directory of workspace
      subtask: 'a' | 'b' | 'c', corresponds to 3 subtasks in DCASE2019 Task1
      data_type: 'development' | 'evaluation'
      holdout_fold: '1' | 'none', set 1 for development and none for training
          on all data without validation
      model_type: string, e.g. 'Cnn_9layers_AvgPooling'
      batch_size: int
      cuda: bool
      mini_data: bool, set True for debugging on a small part of data
    '''

    # Arugments & parameters
    dataset_dir = args.dataset_dir
    workspace = args.workspace
    subtask = args.subtask
    data_type = args.data_type
    holdout_fold = args.holdout_fold
    model_type = args.model_type
    batch_size = args.batch_size
    cuda = args.cuda and torch.cuda.is_available()
    mini_data = args.mini_data
    filename = args.filename

    mel_bins = config.mel_bins
    frames_per_second = config.frames_per_second
    max_iteration = None  # Number of mini-batches to evaluate on training data
    reduce_lr = True

    sources_to_evaluate = get_sources(subtask)
    in_domain_classes_num = len(config.labels) - 1

    # Paths
    if mini_data:
        prefix = 'minidata_'
    else:
        prefix = ''

    sub_dir = get_subdir(subtask, data_type)

    train_csv = os.path.join(dataset_dir, sub_dir, 'evaluation_setup',
                             'fold1_train.csv')

    validate_csv = os.path.join(dataset_dir, sub_dir, 'evaluation_setup',
                                'fold1_evaluate.csv')

    feature_hdf5_path = os.path.join(
        workspace, 'features',
        '{}logmel_{}frames_{}melbins'.format(prefix, frames_per_second,
                                             mel_bins),
        '{}.h5'.format(sub_dir))
    feature_hdf5_path_left = os.path.join(
        workspace, 'features_left',
        '{}logmel_{}frames_{}melbins'.format(prefix, frames_per_second,
                                             mel_bins),
        '{}.h5'.format(sub_dir))
    feature_hdf5_path_right = os.path.join(
        workspace, 'features_right',
        '{}logmel_{}frames_{}melbins'.format(prefix, frames_per_second,
                                             mel_bins),
        '{}.h5'.format(sub_dir))
    feature_hdf5_path_side = os.path.join(
        workspace, 'features_side',
        '{}logmel_{}frames_{}melbins'.format(prefix, frames_per_second,
                                             mel_bins),
        '{}.h5'.format(sub_dir))
    feature_hdf5_path_harmonic = os.path.join(
        workspace, 'features_harmonic',
        '{}logmel_{}frames_{}melbins'.format(prefix, frames_per_second,
                                             mel_bins),
        '{}.h5'.format(sub_dir))
    feature_hdf5_path_percussive = os.path.join(
        workspace, 'features_percussive',
        '{}logmel_{}frames_{}melbins'.format(prefix, frames_per_second,
                                             mel_bins),
        '{}.h5'.format(sub_dir))
    scalar_path = os.path.join(
        workspace, 'scalars',
        '{}logmel_{}frames_{}melbins'.format(prefix, frames_per_second,
                                             mel_bins),
        '{}.h5'.format(sub_dir))
    scalar_path_left = os.path.join(
        workspace, 'scalars_left',
        '{}logmel_{}frames_{}melbins'.format(prefix, frames_per_second,
                                             mel_bins),
        '{}.h5'.format(sub_dir))
    scalar_path_right = os.path.join(
        workspace, 'scalars_right',
        '{}logmel_{}frames_{}melbins'.format(prefix, frames_per_second,
                                             mel_bins),
        '{}.h5'.format(sub_dir))
    scalar_path_side = os.path.join(
        workspace, 'scalars_side',
        '{}logmel_{}frames_{}melbins'.format(prefix, frames_per_second,
                                             mel_bins),
        '{}.h5'.format(sub_dir))
    scalar_path_harmonic = os.path.join(
        workspace, 'scalars_harmonic',
        '{}logmel_{}frames_{}melbins'.format(prefix, frames_per_second,
                                             mel_bins),
        '{}.h5'.format(sub_dir))
    scalar_path_percussive = os.path.join(
        workspace, 'scalars_percussive',
        '{}logmel_{}frames_{}melbins'.format(prefix, frames_per_second,
                                             mel_bins),
        '{}.h5'.format(sub_dir))
    checkpoints_dir = os.path.join(
        workspace, 'checkpoints', filename,
        '{}logmel_{}frames_{}melbins'.format(prefix, frames_per_second,
                                             mel_bins), '{}'.format(sub_dir),
        'holdout_fold={}'.format(holdout_fold), model_type)
    create_folder(checkpoints_dir)

    validate_statistics_path = os.path.join(
        workspace, 'statistics', filename,
        '{}logmel_{}frames_{}melbins'.format(prefix, frames_per_second,
                                             mel_bins), '{}'.format(sub_dir),
        'holdout_fold={}'.format(holdout_fold), model_type,
        'validate_statistics.pickle')

    create_folder(os.path.dirname(validate_statistics_path))

    logs_dir = os.path.join(
        workspace, 'logs', filename, args.mode,
        '{}logmel_{}frames_{}melbins'.format(prefix, frames_per_second,
                                             mel_bins), '{}'.format(sub_dir),
        'holdout_fold={}'.format(holdout_fold), model_type)
    create_logging(logs_dir, 'w')
    logging.info(args)

    # Load scalar
    scalar = load_scalar(scalar_path)
    scalar_left = load_scalar(scalar_path_left)
    scalar_right = load_scalar(scalar_path_right)
    scalar_side = load_scalar(scalar_path_side)
    scalar_harmonic = load_scalar(scalar_path_harmonic)
    scalar_percussive = load_scalar(scalar_path_percussive)
    # Model
    Model = eval(model_type)

    if subtask in ['a', 'b']:
        model = Model(in_domain_classes_num, activation='logsoftmax')
        loss_func = nll_loss

    elif subtask == 'c':
        model = Model(in_domain_classes_num, activation='sigmoid')
        loss_func = F.binary_cross_entropy

    if cuda:
        model.cuda()

    # Optimizer
    optimizer = optim.Adam(model.parameters(),
                           lr=1e-3,
                           betas=(0.9, 0.999),
                           eps=1e-08,
                           weight_decay=0.,
                           amsgrad=True)

    # Data generator
    data_generator = DataGenerator(
        feature_hdf5_path=feature_hdf5_path,
        feature_hdf5_path_left=feature_hdf5_path_left,
        feature_hdf5_path_right=feature_hdf5_path_right,
        feature_hdf5_path_side=feature_hdf5_path_side,
        feature_hdf5_path_harmonic=feature_hdf5_path_harmonic,
        feature_hdf5_path_percussive=feature_hdf5_path_percussive,
        train_csv=train_csv,
        validate_csv=validate_csv,
        scalar=scalar,
        scalar_left=scalar_left,
        scalar_right=scalar_right,
        scalar_side=scalar_side,
        scalar_harmonic=scalar_harmonic,
        scalar_percussive=scalar_percussive,
        batch_size=batch_size)

    # Evaluator
    evaluator = Evaluator(model=model,
                          data_generator=data_generator,
                          subtask=subtask,
                          cuda=cuda)

    # Statistics
    validate_statistics_container = StatisticsContainer(
        validate_statistics_path)

    train_bgn_time = time.time()
    iteration = 0

    # Train on mini batches
    for batch_data_dict, batch_data_dict_left, batch_data_dict_right, batch_data_dict_side, batch_data_dict_harmonic,\
            batch_data_dict_percussive in data_generator.generate_train():

        # Evaluates
        if iteration % 200 == 0:
            logging.info('------------------------------------')
            logging.info('Iteration: {}'.format(iteration))

            train_fin_time = time.time()

            for source in sources_to_evaluate:
                train_statistics = evaluator.evaluate(data_type='train',
                                                      source=source,
                                                      max_iteration=None,
                                                      verbose=False)

            for source in sources_to_evaluate:
                validate_statistics = evaluator.evaluate(data_type='validate',
                                                         source=source,
                                                         max_iteration=None,
                                                         verbose=False)

                validate_statistics_container.append_and_dump(
                    iteration, source, validate_statistics)

            train_time = train_fin_time - train_bgn_time
            validate_time = time.time() - train_fin_time

            logging.info('Train time: {:.3f} s, validate time: {:.3f} s'
                         ''.format(train_time, validate_time))

            train_bgn_time = time.time()

        # Save model
        if iteration % 200 == 0 and iteration > 0:
            checkpoint = {
                'iteration': iteration,
                'model': model.state_dict(),
                'optimizer': optimizer.state_dict()
            }

            checkpoint_path = os.path.join(
                checkpoints_dir, '{}_iterations.pth'.format(iteration))

            torch.save(checkpoint, checkpoint_path)
            logging.info('Model saved to {}'.format(checkpoint_path))

        # Reduce learning rate
        if reduce_lr and iteration % 200 == 0 and iteration > 0:
            for param_group in optimizer.param_groups:
                param_group['lr'] *= 0.92

        # Move data to GPU
        for key in batch_data_dict.keys():
            if key in ['feature', 'target']:
                batch_data_dict[key] = move_data_to_gpu(
                    batch_data_dict[key], cuda)
        for key in batch_data_dict_left.keys():
            if key in ['feature_left', 'target']:
                batch_data_dict_left[key] = move_data_to_gpu(
                    batch_data_dict_left[key], cuda)
        for key in batch_data_dict_right.keys():
            if key in ['feature_right', 'target']:
                batch_data_dict_right[key] = move_data_to_gpu(
                    batch_data_dict_right[key], cuda)
        for key in batch_data_dict_side.keys():
            if key in ['feature_side', 'target']:
                batch_data_dict_side[key] = move_data_to_gpu(
                    batch_data_dict_side[key], cuda)
        for key in batch_data_dict_harmonic.keys():
            if key in ['feature_harmonic', 'target']:
                batch_data_dict_harmonic[key] = move_data_to_gpu(
                    batch_data_dict_harmonic[key], cuda)
        for key in batch_data_dict_percussive.keys():
            if key in ['feature_percussive', 'target']:
                batch_data_dict_percussive[key] = move_data_to_gpu(
                    batch_data_dict_percussive[key], cuda)

        # # Train
        # model.train()
        # data, data_left, data_right, data_side,\
        # data_harmonic, data_percussive, target_a, target_b, lam = mixup_data(x1=batch_data_dict['feature'],
        #                                                                      x2=batch_data_dict_left['feature_left'],
        #                                                                      x3=batch_data_dict_right['feature_right'],
        #                                                                      x4=batch_data_dict_side['feature_side'],
        #                                                                      x5=batch_data_dict_harmonic['feature_harmonic'],
        #                                                                      x6=batch_data_dict_percussive['feature_percussive'],
        #                                                                      y=batch_data_dict['target'],
        #                                                                      alpha=0.2)
        # data = spec_augment_pytorch.spec_augment(data, alpha=0.001)
        # data_left = spec_augment_pytorch.spec_augment(data_left, alpha=0.001)
        # data_right = spec_augment_pytorch.spec_augment(data_right, alpha=0.001)
        # data_side = spec_augment_pytorch.spec_augment(data_side, alpha=0.001)
        # data_harmonic = spec_augment_pytorch.spec_augment(data_harmonic, alpha=0.001)
        # data_percussive = spec_augment_pytorch.spec_augment(data_percussive, alpha=0.001)
        # batch_output = model(data=data,
        #                      data_left=data_left,
        #                      data_right=data_right,
        #                      data_side=data_side,
        #                      data_harmonic=data_harmonic,
        #                      data_percussive=data_percussive)
        #
        # # loss
        # # loss = loss_func(batch_output, batch_data_dict['target'])
        # loss = mixup_criterion(loss_func, batch_output, target_a, target_b, lam)
        # # Backward
        # optimizer.zero_grad()
        # loss.backward()
        # optimizer.step()

        # # Train original data
        # model.train()
        # batch_output = model(data=batch_data_dict['feature'],
        #                      data_left=batch_data_dict_left['feature_left'],
        #                      data_right=batch_data_dict_right['feature_right'],
        #                      data_side=batch_data_dict_side['feature_side'],
        #                      data_harmonic=batch_data_dict_harmonic['feature_harmonic'],
        #                      data_percussive=batch_data_dict_percussive['feature_percussive'])
        # loss = loss_func(batch_output, batch_data_dict['target'])
        # optimizer.zero_grad()
        # loss.backward()
        # optimizer.step()

        # Train mixup data
        model.train()
        data, data_left, data_right, data_side, \
        data_harmonic, data_percussive, target_a, target_b, lam = mixup_data(x1=batch_data_dict['feature'],
                                                                             x2=batch_data_dict_left['feature_left'],
                                                                             x3=batch_data_dict_right['feature_right'],
                                                                             x4=batch_data_dict_side['feature_side'],
                                                                             x5=batch_data_dict_harmonic[
                                                                                 'feature_harmonic'],
                                                                             x6=batch_data_dict_percussive[
                                                                                 'feature_percussive'],
                                                                             y=batch_data_dict['target'],
                                                                             alpha=0.3)
        batch_output = model(data=data,
                             data_left=data_left,
                             data_right=data_right,
                             data_side=data_side,
                             data_harmonic=data_harmonic,
                             data_percussive=data_percussive)
        loss = mixup_criterion(loss_func, batch_output, target_a, target_b,
                               lam)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # Train SpecAugment data
        model.train()
        data = spec_augment_pytorch.spec_augment(batch_data_dict['feature'],
                                                 using_frequency_masking=True,
                                                 using_time_masking=True)
        data_left = spec_augment_pytorch.spec_augment(
            batch_data_dict_left['feature_left'],
            using_frequency_masking=True,
            using_time_masking=True)
        data_right = spec_augment_pytorch.spec_augment(
            batch_data_dict_right['feature_right'],
            using_frequency_masking=True,
            using_time_masking=True)
        data_side = spec_augment_pytorch.spec_augment(
            batch_data_dict_side['feature_side'],
            using_frequency_masking=True,
            using_time_masking=True)
        data_harmonic = spec_augment_pytorch.spec_augment(
            batch_data_dict_harmonic['feature_harmonic'],
            using_frequency_masking=True,
            using_time_masking=True)
        data_percussive = spec_augment_pytorch.spec_augment(
            batch_data_dict_percussive['feature_percussive'],
            using_frequency_masking=True,
            using_time_masking=True)
        batch_output = model(data=data,
                             data_left=data_left,
                             data_right=data_right,
                             data_side=data_side,
                             data_harmonic=data_harmonic,
                             data_percussive=data_percussive)
        loss = loss_func(batch_output, batch_data_dict['target'])
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # Stop learning
        if iteration == 25000:
            break

        iteration += 1
Пример #5
0
    audio, sampling_rate = librosa.load(audio_path)
    mel_spectrogram = librosa.feature.melspectrogram(y=audio,
                                                     sr=sampling_rate,
                                                     n_mels=256,
                                                     hop_length=128,
                                                     fmax=8000)

    # Show Raw mel-spectrogram
    spec_augment_tensorflow.visualization_spectrogram(
        mel_spectrogram=mel_spectrogram, title="Raw Mel Spectrogram")

    # Calculate SpecAugment ver.tensorflow
    warped_masked_spectrogram = spec_augment_tensorflow.spec_augment(
        mel_spectrogram=mel_spectrogram)
    # print(warped_masked_spectrogram)

    # Show time warped & masked spectrogram
    spec_augment_tensorflow.visualization_spectrogram(
        mel_spectrogram=warped_masked_spectrogram,
        title="tensorflow Warped & Masked Mel Spectrogram")

    # Calculate SpecAugment ver.pytorch
    warped_masked_spectrogram = spec_augment_pytorch.spec_augment(
        mel_spectrogram=mel_spectrogram)
    print(warped_masked_spectrogram)

    # Show time warped & masked spectrogram
    spec_augment_tensorflow.visualization_spectrogram(
        mel_spectrogram=warped_masked_spectrogram,
        title="pytorch Warped & Masked Mel Spectrogram")
Пример #6
0
def train(args, i):
    '''Training. Model will be saved after several iterations. 
    
    Args: 
      dataset_dir: string, directory of dataset
      workspace: string, directory of workspace
      holdout_fold: '1' | 'none', set 1 for development and none for training 
          on all data without validation
      model_type: string, e.g. 'Cnn_9layers_AvgPooling'
      batch_size: int
      cuda: bool
      mini_data: bool, set True for debugging on a small part of data
    '''

    # Arugments & parameters
    dataset_dir = args.dataset_dir
    workspace = args.workspace
    holdout_fold = args.holdout_fold
    model_type = args.model_type
    batch_size = args.batch_size
    cuda = args.cuda and torch.cuda.is_available()
    mini_data = args.mini_data
    filename = args.filename
    audio_num = config.audio_num
    mel_bins = config.mel_bins
    frames_per_second = config.frames_per_second
    max_iteration = None  # Number of mini-batches to evaluate on training data
    reduce_lr = True
    in_domain_classes_num = len(config.labels)

    # Paths
    if mini_data:
        prefix = 'minidata_'
    else:
        prefix = ''

    train_csv = os.path.join(sys.path[0], 'fold' + str(i) + '_train.csv')

    validate_csv = os.path.join(sys.path[0], 'fold' + str(i) + '_test.csv')

    feature_hdf5_path = os.path.join(
        workspace, 'features',
        '{}logmel_{}frames_{}melbins.h5'.format(prefix, frames_per_second,
                                                mel_bins))

    checkpoints_dir = os.path.join(
        workspace, 'checkpoints', filename,
        '{}logmel_{}frames_{}melbins.h5'.format(prefix, frames_per_second,
                                                mel_bins),
        'holdout_fold={}'.format(holdout_fold), model_type)
    create_folder(checkpoints_dir)

    validate_statistics_path = os.path.join(
        workspace, 'statistics', filename,
        '{}logmel_{}frames_{}melbins'.format(prefix, frames_per_second,
                                             mel_bins),
        'holdout_fold={}'.format(holdout_fold), model_type,
        'validate_statistics.pickle')

    create_folder(os.path.dirname(validate_statistics_path))

    logs_dir = os.path.join(
        workspace, 'logs', filename, args.mode,
        '{}logmel_{}frames_{}melbins'.format(prefix, frames_per_second,
                                             mel_bins),
        'holdout_fold={}'.format(holdout_fold), model_type)
    create_logging(logs_dir, 'w')
    logging.info(args)

    if cuda:
        logging.info('Using GPU.')
    else:
        logging.info('Using CPU. Set --cuda flag to use GPU.')

    # Model
    Model = eval(model_type)

    model = Model(in_domain_classes_num, activation='logsoftmax')
    loss_func = nll_loss

    if cuda:
        model.cuda()

    # Optimizer
    optimizer = optim.Adam(model.parameters(),
                           lr=1e-4,
                           betas=(0.9, 0.999),
                           eps=1e-08,
                           weight_decay=0.,
                           amsgrad=True)
    # Data generator
    data_generator = DataGenerator(feature_hdf5_path=feature_hdf5_path,
                                   train_csv=train_csv,
                                   validate_csv=validate_csv,
                                   holdout_fold=holdout_fold,
                                   batch_size=batch_size)

    # Evaluator
    evaluator = Evaluator(model=model,
                          data_generator=data_generator,
                          cuda=cuda)

    # Statistics
    validate_statistics_container = StatisticsContainer(
        validate_statistics_path)

    train_bgn_time = time.time()
    iteration = 0

    # Train on mini batches
    for batch_data_dict in data_generator.generate_train():

        # Evaluate
        if iteration % 100 == 0 and iteration >= 1000:
            logging.info('------------------------------------')
            logging.info('Iteration: {}'.format(iteration))

            train_fin_time = time.time()

            train_statistics = evaluator.evaluate(data_type='train',
                                                  iteration=iteration,
                                                  max_iteration=None,
                                                  verbose=False)

            if holdout_fold != 'none':
                validate_statistics = evaluator.evaluate(data_type='validate',
                                                         iteration=iteration,
                                                         max_iteration=None,
                                                         verbose=False)
                validate_statistics_container.append_and_dump(
                    iteration, validate_statistics)

            train_time = train_fin_time - train_bgn_time
            validate_time = time.time() - train_fin_time

            logging.info('Train time: {:.3f} s, validate time: {:.3f} s'
                         ''.format(train_time, validate_time))

            train_bgn_time = time.time()


#         Save model
        if iteration % 100 == 0 and iteration > 0:
            checkpoint = {
                'iteration': iteration,
                'model': model.state_dict(),
                'optimizer': optimizer.state_dict()
            }

            checkpoint_path = os.path.join(
                checkpoints_dir, '{}_iterations.pth'.format(iteration))

            torch.save(checkpoint, checkpoint_path)
            logging.info('Model saved to {}'.format(checkpoint_path))

        # Reduce learning rate
        if reduce_lr and iteration % 100 == 0 and iteration > 0:
            for param_group in optimizer.param_groups:
                param_group['lr'] *= 0.9

        # Move data to GPU
        for key in batch_data_dict.keys():
            if key in ['feature', 'target']:
                batch_data_dict[key] = move_data_to_gpu(
                    batch_data_dict[key], cuda)

        if iteration % 3 == 0:
            # Train
            for i in range(audio_num):
                model.train()
                data, target_a, target_b, lam = mixup_data(
                    x=batch_data_dict['feature'][:, i, :, :, :],
                    y=batch_data_dict['target'],
                    alpha=0.2)
                batch_output = model(data)
                #         batch_output = model(batch_data_dict['feature'])
                # loss
                #                 loss = loss_func(batch_output, batch_data_dict['target'])
                loss = mixup_criterion(loss_func, batch_output, target_a,
                                       target_b, lam)

                # Backward
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

        if iteration % 3 == 1:
            # Train
            for i in range(audio_num):
                model.train()
                batch_output = model(batch_data_dict['feature'][:, i, :, :, :])
                # loss
                loss = loss_func(batch_output, batch_data_dict['target'])

                # Backward
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

        if iteration % 3 == 2:
            # Train
            for i in range(audio_num):
                model.train()
                data = spec_augment_pytorch.spec_augment(
                    batch_data_dict['feature'][:, i, :, :, :],
                    using_frequency_masking=True,
                    using_time_masking=True)
                batch_output = model(data)
                # loss
                loss = loss_func(batch_output, batch_data_dict['target'])

                # Backward
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

        # Stop learning
        if iteration == 4500:
            break

        iteration += 1
Пример #7
0
 def augment_spec(self, melspectrogram):
     augmented_spec = spec_augment_pytorch.spec_augment(melspectrogram)
     return augmented_spec