Beispiel #1
0
def setup_classifier(load_weights_from):
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model_args = lib_rnn.set_default_args()
    model = lib_rnn.create_RNN_model(model_args, load_weights_from)
    if 0:  # random test
        label_index = model.predict(np.random.random((66, 12)))
        print("Label index of a random feature: ", label_index)
        exit("Complete test.")
    return model
    file_labels = file_labels[::GAP]
    args.num_epochs = 5
    
# Set data augmentation
if args.do_data_augment:
    Aug = lib_augment.Augmenter # rename
    aug = Aug([        
        Aug.Shift(rate=0.2, keep_size=False), 
        Aug.PadZeros(time=(0, 0.3)),
        Aug.Amplify(rate=(0.5, 1.2)),
        # Aug.PlaySpeed(rate=(0.7, 1.3), keep_size=False),
        # Aug.Noise(noise_folder="data/noises/", prob_noise=0.8, intensity=(0.1, 0.4)),
        #       There is already strong white noise in most of my data. No need to add noise.
    ], prob_to_aug=0.8)
else:
    aug = None

# Split data into train/eval/test
tr_X, tr_Y, ev_X, ev_Y, te_X, te_Y = lib_ml.split_train_eval_test(
    X=file_paths, Y=file_labels, ratios=args.train_eval_test_ratio, dtype='list')
train_dataset = lib_datasets.AudioDataset(file_paths=tr_X, file_labels=tr_Y, transform=aug)
eval_dataset = lib_datasets.AudioDataset(file_paths=ev_X, file_labels=ev_Y, transform=None)

# Data loader
train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=args.batch_size, shuffle=True)
eval_loader = torch.utils.data.DataLoader(dataset=eval_dataset, batch_size=args.batch_size, shuffle=True)

# Create model and train -------------------------------------------------
model = lib_rnn.create_RNN_model(args, load_weight_from=args.load_weight_from) # create model
lib_rnn.train_model(model, args, train_loader, eval_loader)
def setup_classifier(load_weights_from):
    model_args = lib_rnn.set_default_args()
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model = lib_rnn.create_RNN_model(model_args, load_weights_from)
    return model