Exemplo n.º 1
0
def setup_model(checkpoint_file, fatal=True):
    global knob_names, knob_ranges, num_knobs, sr, chunk_size, out_chunk_size
    state_dict, rv = st.misc.load_checkpoint(checkpoint_file,
                                             fatal=fatal,
                                             device="cpu")
    if {} == state_dict:
        return None
    scale_factor = rv['scale_factor']
    shrink_factor = rv['shrink_factor']
    knob_names = rv['knob_names']
    knob_ranges = rv['knob_ranges']
    num_knobs = len(knob_names)
    sr = rv['sr']
    # set up model
    model = nn_proc.st_model(scale_factor=scale_factor,
                             shrink_factor=shrink_factor,
                             num_knobs=num_knobs,
                             sr=sr)
    model.load_state_dict(
        state_dict)  # overwrite weights using checkpoint info
    chunk_size, out_chunk_size = model.in_chunk_size, model.out_chunk_size
    return model
Exemplo n.º 2
0
    print("args =",args)

    # load from checkpoint
    print("Looking for checkpoint at",args.checkpoint)
    state_dict, rv = st.misc.load_checkpoint(args.checkpoint, fatal=True)
    scale_factor, shrink_factor = rv['scale_factor'], rv['shrink_factor']
    knob_names, knob_ranges = rv['knob_names'], rv['knob_ranges']
    num_knobs = len(knob_names)
    sr = rv['sr']
    chunk_size, out_chunk_size = rv['in_chunk_size'], rv['out_chunk_size']
    print(f"Effect name = {rv['effect_name']}")
    print(f"knob_names = {knob_names}")
    print(f"knob_ranges = {knob_ranges}")

    # Setup model
    model = nn_proc.st_model(scale_factor=scale_factor, shrink_factor=shrink_factor, num_knobs=num_knobs, sr=sr, model_type=args.model)
    model.load_state_dict(state_dict)   # overwrite the weights using the checkpoint
    chunk_size = model.in_chunk_size
    out_chunk_size = model.out_chunk_size
    print("out_chunk_size = ",out_chunk_size)


    if have_apex:
        optimizer = torch.optim.Adam(list(model.parameters()))
        model, optimizer = amp.initialize(model, optimizer, opt_level="O2")


    # Input Data
    #infile="/home/shawley/datasets/signaltrain/music/Test/WindyPlaces.ITB.Mix10-2488-1644.wav"
    infile = args.audiofile
    print("reading input file ",infile)
Exemplo n.º 3
0
    # load from checkpoint
    print("Looking for checkpoint at", args.checkpoint)
    state_dict, rv = st.misc.load_checkpoint(args.checkpoint, fatal=True)
    scale_factor, shrink_factor = rv['scale_factor'], rv['shrink_factor']
    knob_names, knob_ranges = rv['knob_names'], rv['knob_ranges']
    num_knobs = len(knob_names)
    sr = rv['sr']
    chunk_size, out_chunk_size = rv['in_chunk_size'], rv['out_chunk_size']
    print(f"Effect name = {rv['effect_name']}")
    print(f"knob_names = {knob_names}")
    print(f"knob_ranges = {knob_ranges}")

    # Setup model
    model = nn_proc.st_model(scale_factor=scale_factor,
                             shrink_factor=shrink_factor,
                             num_knobs=num_knobs,
                             sr=sr)
    model.load_state_dict(
        state_dict)  # overwrite the weights using the checkpoint
    chunk_size = model.in_chunk_size
    out_chunk_size = model.out_chunk_size
    print("out_chunk_size = ", out_chunk_size)

    if have_apex:
        optimizer = torch.optim.Adam(list(model.parameters()))
        model, optimizer = amp.initialize(model, optimizer, opt_level="O2")

    # Input Data
    #infile="/home/shawley/datasets/signaltrain/music/Test/WindyPlaces.ITB.Mix10-2488-1644.wav"
    infile = args.audiofile
    print("reading input file ", infile)
Exemplo n.º 4
0
def train(effect=audio.Compressor_4c(),
          epochs=100,
          n_data_points=200000,
          batch_size=20,
          device=torch.device("cuda:0"),
          plot_every=10,
          cp_every=25,
          sr=44100,
          datapath=None,
          scale_factor=1,
          shrink_factor=4,
          apex_opt="O0",
          target_type="stream",
          lr_max=1e-4,
          in_checkpointname='modelcheckpoint.tar'):
    """
    Main training routine for signaltrain

    Parameters:
        effect:           class for the audio effect to learn (see audio.py)
        epochs:           how many epochs to run over
        n_data_points:    data instances per epoch (or iterations per epoch)
        batch_size:       batch size
        device:           pytorch device to run on, either cpu or cuda (GPU)
        plot_every:       how often to generate plots of sample outputs
        cp_every:         save checkpoint every this many iterations
        scale_factor:     change overal dimensionality of i/o chunks by this factor
        shrink_factor:    output shrink factor, i.e. fraction of output actually trained on
        apex_opt:         option for apex multi-precision training. default is "O0" which means none
                          For Turing cards (e.g. RTX 2080 Ti), set this to "O2"
        target_type:      "chunk" (re-run the effect for each chunk) or "stream" (apply effect to whole audio stream)
        lr_max:           maximum learning rate
        in_checkpointname:   filename of previous checkpoint to load from (if it exists)
    """

    # print info about this training run
    print(f'SignalTrain training execution began at {time.ctime()}. Options:')
    print(
        f'    epochs = {epochs}, n_data_points = {n_data_points}, batch_size = {batch_size}'
    )
    print(
        f'    scale_factor = {scale_factor}, shrink_factor = {shrink_factor}, apex_opt = {apex_opt}'
    )
    num_knobs = len(effect.knob_names)
    print(f'    num_knobs = {num_knobs}')
    effect.info()  # Print effect settings

    # Setup the Model
    # check to see if there's a checkpoint
    state_dict, rv = misc.load_checkpoint(in_checkpointname, fatal=False)
    if state_dict != {}:  # load metadata from a checkpoint if it exists
        #model.load_state_dict(state_dict)
        scale_factor, shrink_factor = rv['scale_factor'], rv['shrink_factor']
        knob_names, knob_ranges = rv['knob_names'], rv['knob_ranges']
        model_num_knobs = len(knob_names)
        sr = rv['sr']
        chunk_size, out_chunk_size = rv['in_chunk_size'], rv['out_chunk_size']

    # initialize from scratch
    model = nn_proc.st_model(scale_factor=scale_factor,
                             shrink_factor=shrink_factor,
                             num_knobs=num_knobs,
                             sr=sr)
    if state_dict != {}:
        model.load_state_dict(
            state_dict
        )  # overwrite the weights using the input checkpoint if it exists
    chunk_size, out_chunk_size = model.in_chunk_size, model.out_chunk_size
    y_size = out_chunk_size

    print("Model defined.  Number of trainable parameters:",
          sum(p.numel() for p in model.parameters() if p.requires_grad))
    print("      model.in_chunk_size, model.out_chunk_size = ",
          model.in_chunk_size, model.out_chunk_size)
    model.to(device)

    # Specify learning rate schedule...although we don't bother stepping the momentum
    # Note: lr_max should be obtained by running lr_finder in learningrate.py
    lr_sched, mom_sched = learningrate.get_1cycle_schedule(
        lr_max=lr_max,
        n_data_points=n_data_points,
        epochs=epochs,
        batch_size=batch_size)
    maxiter = len(lr_sched)

    # Initialize optimizer. given our "random" training data, weight decay seem to doesn't help but rather slows training way down
    optimizer = torch.optim.Adam(list(model.parameters()),
                                 lr=lr_sched[0],
                                 weight_decay=0)
    # TODO: should also set optimizer state from checkpoint

    # Setup/load data
    synth_data = datapath is None  # Are we synthesizing data or do we expect it to come from files
    if synth_data:  # synthesize input & target data
        dataset = datasets.SynthAudioDataSet(chunk_size,
                                             effect,
                                             sr=sr,
                                             datapoints=n_data_points,
                                             y_size=out_chunk_size,
                                             augment=True)
        dataset_val = datasets.SynthAudioDataSet(chunk_size,
                                                 effect,
                                                 sr=sr,
                                                 datapoints=n_data_points // 4,
                                                 recycle=True,
                                                 y_size=out_chunk_size,
                                                 augment=False)
    else:  # use prerecoded files for input & target data
        dataset = datasets.AudioFileDataSet(chunk_size,
                                            effect,
                                            sr=sr,
                                            datapoints=n_data_points,
                                            path=datapath + "/Train/",
                                            y_size=out_chunk_size,
                                            rerun=(target_type != "stream"),
                                            augment=True,
                                            preload=True)
        dataset_val = datasets.AudioFileDataSet(
            chunk_size,
            effect,
            sr=sr,
            datapoints=n_data_points // 4,
            path=datapath + "/Val/",
            y_size=out_chunk_size,
            rerun=(target_type != "stream"),
            augment=False)

    dataloader = DataLoader(dataset,
                            batch_size=batch_size,
                            num_workers=10,
                            shuffle=True,
                            worker_init_fn=datasets.worker_init
                            )  # need worker_init for more variance
    dataloader_val = DataLoader(dataset_val,
                                batch_size=batch_size,
                                num_workers=10,
                                shuffle=False)

    # Mixed precision:  Initialize NVIDIA Apex/Amp.  Amp accepts either values or strings for
    #     the optional override arguments, for convenient interoperation with argparse.
    if have_apex:
        model, optimizer = amp.initialize(model, optimizer, opt_level=apex_opt)

    # Copy model to (other) GPU if possbible
    parallel = False  # torch.cuda.device_count() > 1
    if parallel:  # For Hawley's 2 Titan X GPUs this typically cuts execution time down by ~30% (not 50%)
        print("Replicating NN model for data-parallel execution across",
              torch.cuda.device_count(), "GPUs")
        model = nn.DataParallel(model)

    # Set up for using additional L1 regularization - scale by frequency bin
    scale_by_freq = None  # this gets changed later

    # Setup log file
    logfilename = "vl_avg_out.dat"  # Val Loss, average, output
    with open(logfilename, "a") as myfile:  # save progress of val loss, append
        myfile.close()

    # Now that everything's set up, call the training loop
    out_checkpointname = "modelcheckpoint.tar"
    train_loop(model,
               effect,
               device,
               optimizer,
               epochs,
               batch_size,
               lr_sched,
               mom_sched,
               dataloader,
               dataloader_val,
               y_size,
               parallel,
               logfilename,
               out_checkpointname,
               sr=sr,
               lr_max=lr_max)

    return model