示例#1
0
def main():
    input_lang, output_lang, pairs = prepare_data('ques',
                                                  'ans',
                                                  '../debug.json',
                                                  reverse=False)
    encoder = EncoderRNN(input_lang.n_words, hidden_size).to(device)
    attn_decoder = AttnDecoderRNN(hidden_size,
                                  output_lang.n_words,
                                  dropout_p=0.1,
                                  max_length=1000).to(device)

    rate = 0.9
    pairs_train, pairs_test = pairs[0:int(len(pairs) *
                                          rate)], pairs[int(len(pairs) *
                                                            rate):]
    encoder.load_state_dict(torch.load('model/encoder-0.model'))
    encoder.eval()
    attn_decoder.load_state_dict(torch.load('model/decoder-0.model'))
    attn_decoder.eval()
    evaluate_all(encoder,
                 attn_decoder,
                 pairs_test,
                 max_length=1000,
                 input_lang=input_lang,
                 output_lang=output_lang,
                 n=len(pairs_test))
    # show_plot(loss_history)
    print('done test')
示例#2
0
def main():
    input_lang, output_lang, pairs = prepare_data('ques',
                                                  'ans',
                                                  '../test.json',
                                                  reverse=False)
    model = Transformer(
        src_vocab_size=input_lang.n_words,
        src_max_len=MAX_LENGTH,
        tgt_vocab_size=output_lang.n_words,
        tgt_max_len=MAX_LENGTH,
    ).to(device)

    rate = 0.9
    pairs_train, pairs_test = pairs[0:int(len(pairs) *
                                          rate)], pairs[int(len(pairs) *
                                                            rate):]
    model.load_state_dict(torch.load('model/transformer-0.model'))
    model.eval()
    evaluate_all(model,
                 pairs_train,
                 max_length=100,
                 input_lang=input_lang,
                 output_lang=output_lang,
                 n=len(pairs_train))
    # show_plot(loss_history)
    print('done test')
示例#3
0
def make_data_cqt_spectrogram():
    # Note that in audio.py, the spectrogram for CQT uses default parameters. Later we may want to think over this !
    # Thus these parameters being passed has no effect.

    fs = 16000
    fft_size = 512
    win_size = 512
    hop_size = 160

    duration = 1

    inputType = 'cqt_spec'  #1second correspond to (32, 84) spectrogram in current default configurations
    augment = True
    data_window = 32
    window_shift = 30  # keep 30 frames as shift (making 32 will discard lot of frames)
    save = True

    if augment:
        spectrogramPath = '/homes/bc305/myphd/stage2/deeplearning.experiment1/spectrograms_augmented/1sec_shift/'
    else:
        spectrogramPath = '/homes/bc305/myphd/stage2/deeplearning.experiment1/spectrograms/'

    basePath = '/import/c4dm-datasets/SpeakerRecognitionDatasets/ASVSpoof2017/'
    outPath = spectrogramPath + inputType + '/' + str(fft_size) + 'FFT/' + str(
        duration) + 'sec/'

    # Prepare training data
    #print('Preparing the training data')
    #prepare_data(basePath,'train',outPath,inputType,duration,fs,fft_size,win_size,hop_size,data_window,window_shift,
    #             augment,save)

    # Prepare Validation data
    #print('Preparing the validation data')
    #prepare_data(basePath,'dev',outPath,inputType,duration,fs,fft_size,win_size,hop_size,data_window,window_shift,
    #             augment,save)

    # Prepare test data
    print('Preparing the test data')
    prepare_data(basePath, 'test', outPath, inputType, duration, fs, fft_size,
                 win_size, hop_size, data_window, window_shift, augment, save)
示例#4
0
def make_data_mag_spectrogram():

    fs = 16000
    fft_size = 512  #256  # 512
    win_size = 512  #256  #512
    hop_size = 160

    duration = 1
    inputType = 'mag_spec'

    augment = True  #True
    data_window = 100  # for FFT based and for cqt =
    window_shift = 100  #each frame is 32ms, 10 window shift corresponds to 320ms
    save = True
    #minimum_length=1  # in seconds

    if augment:
        spectrogramPath = '/homes/bc305/myphd/stage2/deeplearning.experiment1/spectrograms_augmented/1sec_shift/'
    else:
        spectrogramPath = '/homes/bc305/myphd/stage2/deeplearning.experiment1/spectrograms/'

    basePath = '/import/c4dm-datasets/SpeakerRecognitionDatasets/ASVSpoof2017/'
    outPath = spectrogramPath + inputType + '/' + str(fft_size) + 'FFT/' + str(
        duration) + 'sec/'

    # Prepare training data
    #print('Preparing the training data')
    #prepare_data(basePath,'train',outPath,inputType,duration,fs,fft_size,win_size,hop_size,data_window,window_shift,
    #             augment,save)

    # Prepare Validation data
    #print('Preparing the validation data')
    #prepare_data(basePath,'dev',outPath,inputType,duration,fs,fft_size,win_size,hop_size,data_window,window_shift,
    #             augment,save)

    # Prepare test data
    print('Preparing the test data')
    prepare_data(basePath, 'test', outPath, inputType, duration, fs, fft_size,
                 win_size, hop_size, data_window, window_shift, augment, save)
示例#5
0
def main(argv):
    config = Config()
    config.phase = FLAGS.phase
    config.beam_size = FLAGS.beam_size
    with tf.Session() as sess:
        if FLAGS.phase == 'train':
            # training phase
            train_data, vocabulary1 = prepare_data(config)
            model = Model(config, vocabulary1)
            saver = tf.train.Saver(max_to_keep=1000)
            sess.run(tf.global_variables_initializer())
            if FLAGS.load:
                model.load(sess, saver, FLAGS.model_file)
            tf.get_default_graph().finalize()
            if config.is_sc:
                config.phase = 'test'
                eval_data, _ = prepare_data(config)
                config.phase = 'train'
                model.train_sc(sess, saver, train_data, vocabulary1, eval_data)
            else:
                model.train(sess, saver, train_data, vocabulary1)
        elif FLAGS.phase == 'eval':
            # evaluation phase
            eval_data, vocabulary = prepare_data(config)
            model = Model(config, vocabulary)
            saver = tf.train.Saver(max_to_keep=1000)
            model.load(sess, saver, FLAGS.model_file)
            tf.get_default_graph().finalize()
            model.eval(sess, eval_data, vocabulary)
        else:
            # testing phase
            test_data, vocabulary = prepare_data(config)
            model = Model(config, vocabulary)
            saver = tf.train.Saver(max_to_keep=1000)
            model.load(sess, saver, FLAGS.model_file)
            tf.get_default_graph().finalize()
            model.eval(sess, test_data, vocabulary)
示例#6
0
def main(_):
    # Retrieve parameters
    config = get_parameters()

    #prepare data
    word_table, data = dataset.prepare_data(config)

    # Preprocess all images
    test_data = pre.load_image('data/laska.png', (224, 224))
    test_data = np.array([test_data])

    # Build model.
    model = ImageCaptioner(config, word_table)

    # model.train(train_data)
    model.test(test_data)
def main(config):

    # Create save directories
    utils.create_directories(config)

    # Prepare and load the data
    data = dataset.prepare_data(config.dataset_dir, config)

    # Train the ensemble models
    # if config.training_type == 'bagging':
    # 	ensemble_trainer.bagging_ensemble_training(data, config)
    # elif config.training_type == 'boosting':
    # 	ensemble_trainer.boosted_ensemble_training(data, config)

    # Evaluate the model
    test_data = dataset.prepare_test_data(config.test_dataset_dir, config)
    evaluator.evaluate(data, test_data, config)

    print(config.model_dir, config.boosting_type, config.voting_type)
示例#8
0
def main():
    input_lang, output_lang, pairs = prepare_data('ques', 'ans', '../data.json',reverse=False)
    encoder = Encoder(input_lang.n_words, MAX_LENGTH).to(device)
    attn_decoder = AttnDecoderRNN(hidden_size, output_lang.n_words, dropout_p=0.1, max_length=MAX_LENGTH).to(device)


    rate = 0.9
    epoch = 10
    pairs_train,pairs_test = pairs[0:int(len(pairs)*rate)], pairs[int(len(pairs)*rate):]
    for i in range(epoch):
        encoder.train()
        attn_decoder.train()
        train(encoder, attn_decoder, len(pairs_train), pairs=pairs_train, input_lang=input_lang,output_lang=output_lang, print_every=10)
        encoder.eval()
        attn_decoder.eval()
        evaluate_all(encoder, attn_decoder, pairs_test, max_length=MAX_LENGTH, input_lang=input_lang, output_lang=output_lang,
                         n=len(pairs_test))
        torch.save(encoder.state_dict(), 'model/encoder-' + str(i) + '.model')
        torch.save(attn_decoder.state_dict(), 'model/decoder-' + str(i) + '.model')
    #show_plot(loss_history)
    print('done training')
示例#9
0
def main(config):

    # Create save directories
    utils.create_directories(config)

    # Prepare and load the data
    if 'silences' in config.model_types:
        data = dataset.prepare_data_new(config.dataset_dir, config)
    else:
        data = dataset.prepare_data(config.dataset_dir, config)
    # print(data)
    # return
    # Train the ensemble models
    if config.training_type == 'bagging':
        ensemble_trainer.bagging_ensemble_training(data, config)
    elif config.training_type == 'boosting':
        ensemble_trainer.boosted_ensemble_training(data, config)

    # Evaluate the model
    if 'silences' not in config.model_types:
        test_data = dataset.prepare_test_data(config.test_dataset_dir, config)
        evaluator.evaluate(data, test_data, config)
                                  scale_each=True)
            Imgn = utils.make_grid(imgn_train.data,
                                   nrow=8,
                                   normalize=True,
                                   scale_each=True)
            Irecon = utils.make_grid(out_train.data,
                                     nrow=8,
                                     normalize=True,
                                     scale_each=True)
            writer.add_image('clean image', Img, epoch)
            writer.add_image('noisy image', Imgn, epoch)
            writer.add_image('reconstructed image', Irecon, epoch)
        except:
            print('[{}] Get error when log the images ...'.format(epoch))
        # save model
        if not max(psnr_list) > psnr_val:
            torch.save(model.state_dict(),
                       os.path.join(opt.outf, 'net_best.pth'))
        torch.save(model.state_dict(), os.path.join(opt.outf, 'net_final.pth'))


if __name__ == "__main__":
    # print('begin to run ...')
    if opt.preprocess:
        prepare_data(data_path='../data',
                     patch_size=40,
                     stride=10,
                     aug_times=1,
                     debug=opt.debug)
    main()
示例#11
0
        model_dir = os.path.join('saved_models', model_name)
        if not os.path.exists(model_dir):
            os.makedirs(model_dir)
        torch.save(model.state_dict(),
                   os.path.join(model_dir, 'net_%d.pth' % (epoch)))


if __name__ == "__main__":
    if opt.preprocess:
        if opt.color == 0:
            grayscale = True
            stride = 10
            if opt.net_mode == 'F':
                prepare_data(data_path='training_data',
                             patch_size=50,
                             stride=stride,
                             aug_times=1,
                             grayscale=grayscale)
            else:
                prepare_data(data_path='training_data',
                             patch_size=50,
                             stride=stride,
                             aug_times=2,
                             grayscale=grayscale,
                             scales_bool=True)
        else:
            stride = 25
            grayscale = False
            prepare_data(data_path='training_data',
                         patch_size=50,
                         stride=stride,
示例#12
0
                              normalize=True,
                              scale_each=True)
        Imgn = utils.make_grid(imgn_train.data,
                               nrow=8,
                               normalize=True,
                               scale_each=True)
        Irecon = utils.make_grid(out_train.data,
                                 nrow=8,
                                 normalize=True,
                                 scale_each=True)
        # writer.add_image('clean image', Img, epoch)
        # writer.add_image('noisy image', Imgn, epoch)
        # writer.add_image('reconstructed image', Irecon, epoch)
        # save model
        torch.save(model.state_dict(), os.path.join(opt.outf, 'net.pth'))


if __name__ == "__main__":
    if opt.preprocess:
        if opt.mode == 'S':
            prepare_data(data_path='../dataset/NWPU',
                         patch_size=40,
                         stride=10,
                         aug_times=1)
        if opt.mode == 'B':
            prepare_data(data_path='../dataset/NWPU',
                         patch_size=50,
                         stride=10,
                         aug_times=2)
    main()
示例#13
0
#        # validate
#        psnr_val = 0
#        for k in range(len(dataset_val)):
#            img_val = torch.unsqueeze(dataset_val[k], 0)
#            noise = torch.FloatTensor(img_val.size()).normal_(mean=0, std=opt.val_noiseL/255.)
#            imgn_val = img_val + noise
#            img_val, imgn_val = Variable(img_val.cuda(), volatile=True), Variable(imgn_val.cuda(), volatile=True)
#            out_val = torch.clamp(imgn_val-model(imgn_val), 0., 1.)
#            psnr_val += batch_PSNR(out_val, img_val, 1.)
#        psnr_val /= len(dataset_val)
#        print("\n[epoch %d] PSNR_val: %.4f" % (epoch+1, psnr_val))
#        writer.add_scalar('PSNR on validation data', psnr_val, epoch)
#        # log the images
#        out_train = torch.clamp(imgn_train-model(imgn_train), 0., 1.)
#        Img = utils.make_grid(img_train.data, nrow=8, normalize=True, scale_each=True)
#        Imgn = utils.make_grid(imgn_train.data, nrow=8, normalize=True, scale_each=True)
#        Irecon = utils.make_grid(out_train.data, nrow=8, normalize=True, scale_each=True)
#        writer.add_image('clean image', Img, epoch)
#        writer.add_image('noisy image', Imgn, epoch)
#        writer.add_image('reconstructed image', Irecon, epoch)
        # save model
        torch.save(model.module.state_dict(), os.path.join(opt.outf, 'netB50color_'+str(epoch)+'.pth'))

if __name__ == "__main__":
    if opt.preprocess:
        if opt.mode == 'S':
            prepare_data(data_path='data', patch_size=40, stride=10, aug_times=1)
        if opt.mode == 'B':
            prepare_data(data_path='data', patch_size=60, stride=20, aug_times=1)
    main()
示例#14
0
import numpy as np
import sys
from neural import NeuralNet
from dataset import get_all_categories_shuffled, numpy_array_from_file, prepare_data

labels_dictionary = {'Circle': 0, 'L': 1, 'RightArrow': 2}
labels_map = ['Circle', 'L', 'RightArrow']

#Replace with your root folder that holds the dataset of images
all_bytes, all_labels = prepare_data("Datasets/DatasetSample0/Images",
                                     labels_dictionary, 56, 56)

nn = NeuralNet(56, 56)
nn.build_layers()
nn.fit_data(training_data=all_bytes,
            training_labels=all_labels,
            epochs=10,
            accuracy=0.999)
nn.save_model('easter_egg_ahlabikyafraise_')
#nn.load_model('easter_egg_ahlabikyafraise_')

image_path = sys.argv[1]
test_image = numpy_array_from_file(image_path)
test_image = test_image / 255

arg_max, prediction_level = nn.predict_element(test_image)
print('arg_max : {} which is {} with prediction : {}'.format(
    arg_max, labels_map[arg_max], prediction_level))
    parser.add_argument("--trainset_dir", type=str, default=None, \
         help='path of trainset')
    parser.add_argument("--valset_dir", type=str, default=None, \
          help='path of validation set')
    args = parser.parse_args()

    if args.gray:
        if args.trainset_dir is None:
            args.trainset_dir = 'data/gray/train'
        if args.valset_dir is None:
            args.valset_dir = 'data/gray/Set12'
    else:
        if args.trainset_dir is None:
            args.trainset_dir = 'data/rgb/CImageNet_expl'
        if args.valset_dir is None:
            args.valset_dir = 'data/rgb/Kodak24'

    print("\n### Building databases ###")
    print("> Parameters:")
    for p, v in zip(args.__dict__.keys(), args.__dict__.values()):
        print('\t{}: {}'.format(p, v))
    print('\n')

    prepare_data(args.trainset_dir,\
        args.valset_dir,\
        args.patch_size,\
        args.stride,\
        args.max_number_patches,\
        aug_times=args.aug_times,\
        gray_mode=args.gray)
示例#16
0
文件: main.py 项目: zfxu/Automatting
    print("Finished!")


if __name__ == '__main__':
    # Load the data
    print("Loading the data ...")
    class_names_list, label_values = helpers.get_label_info(
        os.path.join(cfg.data_dir, 'class_dict.csv'))
    class_names_string = ""
    for class_name in class_names_list:
        if not class_name == class_names_list[-1]:
            class_names_string = class_names_string + class_name + ", "
        else:
            class_names_string = class_names_string + class_name
    num_classes = len(label_values)
    train_input_names, train_output_names, val_input_names, val_output_names, test_input_names, test_output_names = dataset.prepare_data(
    )

    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    sess = tf.Session(config=config)

    # Compute your softmax cross entropy loss
    print("Preparing the model ...")
    net_input = tf.placeholder(tf.float32, shape=[None, None, None, 3])
    net_output = tf.placeholder(tf.float32,
                                shape=[None, None, None, num_classes])

    network, init_fn = buildNetwork(cfg.model, net_input, num_classes)

    saver = tf.train.Saver(max_to_keep=cfg.num_keep)
示例#17
0
            result = torch.clamp(pred_res, 0., 1.)
            psnr_train = batch_PSNR(result, gt_img, 1.)

            ave_loss = (ave_loss*i + loss.item()) / (i+1)
            ave_psnr = (ave_psnr*i + psnr_train) / (i+1)
            ave_ssim = (ave_ssim*i + 1-loss2.item()*2) / (i+1)

            time2 = time.time()

            if i % 100 == 0:

                print("[epoch %d][%d/%d] time: %.3f t_time: %.3f loss: %.4f PSNR_train: %.4f SSIM_train: %.4f" %
                    (epoch+1, i, len(loader_train), (time2 - time1), (time2 - start_time), ave_loss, ave_psnr, ave_ssim))

            if step % 1000 == 0:
                torch.save(model.state_dict(), os.path.join(model_dir, 'latest_net.pth'))
            step += 1
        print('Time for the epoch is %f' % (time.time() - start_time))
        ## the end of each epoch

        # save model
        save_name = '%d_net.pth' % (epoch+1)
        torch.save(model.state_dict(), os.path.join(model_dir, save_name))


if __name__ == "__main__":
    if opt.preprocess:
        prepare_data(root=opt.root, data_path='data', patch_size=256, stride=200, aug_times=1)
	else:
        main()
示例#18
0
                        outest = torch.clamp(est_model(imgn_val), 0., 1.)
                        out_val_b = torch.clamp( imgn_val - model(imgn_val, outest), 0., 1.)
                        out_val_nb = torch.clamp( imgn_val - model(imgn_val, NM_tensor_val), 0., 1.)
                        psnr_val_nb += batch_PSNR(out_val_nb, img_val, 1.)
                    elif opt.mode == "B":
                        out_val_b = torch.clamp(imgn_val-model(imgn_val), 0., 1.)
                        psnr_val_nb = 0

                    #crit += evl_criterion(out_val_b, img_val).item()
                    psnr_val_b += batch_PSNR(out_val_b, img_val, 1.)

                print("\n[val at epoch %d] PSNR_val_b: %.4f, PSNR_val_nb: %.4f" % (epoch+1, psnr_val_b, psnr_val_nb))

                writer.add_scalar('PSNR on validation data (blind)', psnr_val_b, epoch*len(loader_train) + i)
                writer.add_scalar('PSNR on validation data (non_blind)', psnr_val_nb, epoch*len(loader_train) + i)
            '''
            step += 1
            
        ## the end of each epoch
        # save model
        if not os.path.exists(opt.outf):
            os.makedirs(opt.outf)
        torch.save(model.state_dict(), os.path.join(opt.outf, 'net.pth'))
        if opt.mode == "MC":
            torch.save(est_model.state_dict(), os.path.join(opt.outf, 'est_net.pth'))

if __name__ == "__main__":
    if opt.preprocess==1:
        prepare_data(data_path='data', patch_size=50, stride=10, aug_times=2, color=opt.color)
    main()
示例#19
0
from dataset import prepare_data
from params import params
from train import train
from model import MarketPredictionModel

train_dataset, validation_dataset = prepare_data()

model = MarketPredictionModel()

train(model, (train_dataset, validation_dataset))
示例#20
0
                    errG_D += criterionBCE(output, label) / 4.

            out_train = modelG(imgn_train)
            loss = criterionMSE(out_train, noise) + 0.01 * errG_D
            loss.backward()
            optimizerG.step()

            # results
            modelG.eval()
            denoise_image = torch.clamp(imgn_train - modelG(imgn_train), 0.,
                                        1.)
            psnr_train = batch_PSNR(denoise_image, img_train, 1.)

            print(
                "[epoch %d][%d/%d] Loss_G: %.4f PSNR_train: %.4f" %
                (epoch + 1, i + 1, len(loader_train), loss.item(), psnr_train))
            step += 1

        # log the images
        torch.save({
            'epoch': epoch + 1,
            'state_dict': modelG.state_dict()
        }, 'model/modelG.pth')


if __name__ == "__main__":
    if opt.preprocess:
        prepare_data(opt.train_root, opt.test_root)

    main()
示例#21
0
文件: train.py 项目: YeobKim/FBDN
        ax2 = fig.add_subplot(rows, cols, 2)
        ax2.imshow(np.transpose(Imgn.cpu(), (1,2,0)), cmap="gray")
        ax2.set_title('noisy image')

        ax3 = fig.add_subplot(rows, cols, 4)
        ax3.imshow(np.transpose(Irecon.cpu(), (1, 2, 0)), cmap="gray")
        ax3.set_title('denoising image')

        # plt.savefig('./fig_result/epoch_{:d}.png'.format(epoch + 1))
        plt.show()


        # save model
        # torch.save(model.state_dict(), os.path.join(opt.outf, 'UsingIQRnNoiseblock_Dualnet_25.pth'))
        # nl(noise level)25 => 30.6000 nl15 => 32.8000 nl50 => 27.3000
        if psnr_val >= 30.6000:
            torch.save(model.state_dict(), os.path.join(opt.outf, 'FBDNet_' + str(round(psnr_val, 4)) + '.pth'))

    end_time = datetime.now()
    print('Training Finished!!')
    print(end_time)

if __name__ == "__main__":
    if opt.preprocess:
        # prepare_data에서 data를 patch_size대로 나누어주고 다 준비해서 .h5 파일로 만들어주서 main()의 dataset에 집어넣는다.
        if opt.mode == 'S':
            prepare_data(data_path='data', patch_size=opt.patchsize, stride=96, aug_times=1)
        if opt.mode == 'B':
            prepare_data(data_path='data', patch_size=50, stride=10, aug_times=2)
    main()
示例#22
0
        print("[epoch %d][%d/%d] loss: %.4f PSNR_val: %.4f" %
              (epoch+1, i+1, len(loader_train), loss.item(), psnr_val))
        writer.add_scalar('PSNR on validation data', psnr_val, epoch)
        # log the images
        out_train = torch.clamp(model(imgn_train), 0., 1.)
        Img = utils.make_grid(img_train.data, nrow=8, normalize=True, scale_each=True)
        Imgn = utils.make_grid(imgn_train.data, nrow=8, normalize=True, scale_each=True)
        Irecon = utils.make_grid(out_train.data, nrow=8, normalize=True, scale_each=True)
        writer.add_image('clean image', Img, epoch)
        writer.add_image('noisy image', Imgn, epoch)
        writer.add_image('reconstructed image', Irecon, epoch)
        # save model
        torch.save(model.state_dict(), os.path.join(opt.outf, 'net_' + str(round(psnr_val, 4)) + '.pth'))
        '''


if __name__ == "__main__":
    if opt.preprocess:
        if opt.mode == 'S':
            # prepare_data(data_path='data', patch_size=40, stride=10, aug_times=1)
            prepare_data(data_path='data',
                         patch_size=48,
                         stride=48,
                         aug_times=1)
        if opt.mode == 'B':
            prepare_data(data_path='data',
                         patch_size=50,
                         stride=10,
                         aug_times=2)
    main()
示例#23
0
                        help='batch size for training')
    parser.add_argument('--embedding_type',
                        type=str,
                        required=True,
                        help='loss function to train with')
    parser.add_argument('--embedding_save_path',
                        type=str,
                        default=None,
                        help='path to store embeddings')
    args = parser.parse_args()

    print(args.data_root, args.preloaded, args.n_epochs)
    if not args.preloaded:
        if args.data_root is None:
            raise Exception('data_root passed is None')
        song_data = dataset.prepare_data(args.data_root)
        print('{} songs loaded...'.format(len(song_data)))
        train_split, val_split = dataset.train_test_split(song_data,
                                                          split_ratio=0.90)
    else:
        song_data = None
        train_split = val_split = None

    train_dataset = dataset.MusicDataset(train_split,
                                         mode='train',
                                         preloaded=args.preloaded == 1)
    val_dataset = dataset.MusicDataset(val_split,
                                       mode='val',
                                       preloaded=args.preloaded == 1)

    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")