Exemplo n.º 1
0
def main():
    # Setup the log
    log_filename = config.LOG_FILENAME

    log = setup_log(log_filename)

    # Setup datasets
    instruments = config.INSTRUMENTS
    musdbwav_path = config.MUSDB18_WAV_PATH
    # sample_time_frames = config.SAMPLE_TIME_FRAMES
    audio_samples_per_chunk = config.AUDIO_SAMPLES_PER_CHUNK
    batch_size = config.BATCH_SIZE
    prefetch_factor = config.PREFETCH_FACTOR

    train_dataset = TrainMUSDB18Dataset(musdbwav_path,
                                        instruments,
                                        sample_length=audio_samples_per_chunk)
    validation_dataset = TestMUSDB18Dataset(
        musdbwav_path,
        instruments,
        sample_length=audio_samples_per_chunk,
        subset_split='valid')

    # Setup model
    model_config_path = config.MODEL_CONFIG_YAML_PATH
    model_config_name = config.MODEL_CONFIGURATION

    with open(model_config_path, 'r') as file:
        model_configurations = yaml.load(file, Loader=yaml.FullLoader)
    train_model_config = model_configurations[model_config_name]
    train_model_class = train_model_config.pop('class')
    model = eval(train_model_class)(**train_model_config)

    # Setup trainer
    checkpoint_folder_path = config.CHECKPOINT_FOLDER_PATH
    epochs = config.EPOCHS
    checkpoint_frequency = config.CHECKPOINT_FREQUENCY
    logging_frequency = config.LOGGING_FREQUENCY
    optimizer_class = get_optimizer_class(config.OPTIMIZER)
    optimizer_params = config.OPTIMIZER_PARAMS
    optimizer = optimizer_class(model.parameters(), **optimizer_params)
    lr_scheduler_class = get_lr_scheduler_class(config.LR_SCHEDULER)
    lr_scheduler_params = config.LR_SCHEDULER_PARAMS
    lr_scheduler = lr_scheduler_class(optimizer, **lr_scheduler_params)
    loss_function = get_loss_function(config.LOSS_FUNCTION)
    gpu = config.GPU
    gpu_device = config.GPU_DEVICE
    device = get_device(gpu, gpu_device)

    if not os.path.isdir(checkpoint_folder_path):
        os.mkdir(checkpoint_folder_path)

    # Setup the spectrogram
    spectrogram_type = config.SPECTROGRAM_TYPE
    n_fft = config.N_FFT
    hop_length = config.HOP_LENGTH
    window = config.WINDOW
    window_length = config.WINDOW_LENGTH

    spectrogramer = Spectrogramer(spectrogram_type, n_fft, hop_length, window,
                                  window_length, device)

    # Initialize traininer
    trainer = Trainer(model, spectrogramer, optimizer, loss_function,
                      lr_scheduler, train_dataset, validation_dataset, log,
                      checkpoint_folder_path, epochs, logging_frequency,
                      checkpoint_frequency, batch_size, prefetch_factor,
                      instruments, train_model_class, device)
    # Start trainer/evaluation
    trainer.train()
Exemplo n.º 2
0
    parser.add_argument('--load_order', type=str, default='-')
    parser.add_argument('--maml_lr', default=0.1, type=float)
    parser.add_argument('--maml_epoch', default=50, type=int)
    parser.add_argument('--mnemonics_images_per_class_per_step', default=1, type=int)    
    parser.add_argument('--mnemonics_steps', default=20, type=int)    
    parser.add_argument('--mnemonics_epochs', default=5, type=int)    
    parser.add_argument('--mnemonics_lr', type=float, default=0.01)
    parser.add_argument('--mnemonics_decay_factor', type=float, default=0.5)
    parser.add_argument('--mnemonics_outer_lr', type=float, default=1e-6)
    parser.add_argument('--mnemonics_total_epochs', type=int, default=10)
    parser.add_argument('--mnemonics_decay_epochs', type=int, default=40)

    the_args = parser.parse_args()

    assert(the_args.nb_cl_fg % the_args.nb_cl == 0)
    assert(the_args.nb_cl_fg >= the_args.nb_cl)

    print(the_args)

    np.random.seed(the_args.random_seed)

    os.environ['CUDA_VISIBLE_DEVICES'] = the_args.gpu
    print('Using gpu:', the_args.gpu)

    occupy_memory(the_args.gpu)
    print('Occupy GPU memory in advance')

    trainer = Trainer(the_args)
    trainer.train()