Exemple #1
0
def test_sample_seq2seq():
    embed_size = 12
    input_size = output_size = 91
    z_size = 33
    device = torch.device("cpu")
    enc = base.Encoder(input_size=input_size,
                       embed_size=embed_size,
                       hidden_size=32,
                       z_size=z_size,
                       n_layers=2)

    params = {
        "output_size": output_size,
        "z_size": z_size,
        "device": device,
        "temperature": 3.0,
        "decoder_args": {
            "key": "another",
            "params": {
                "embed_size": embed_size,
                "hidden_size": 64,
                "z_size": z_size,
                "n_layers": 2
            }
        }
    }
    seq_decoder = seq_decoder_factory.create("sample", **params)
    assert_seq2seq(enc, seq_decoder, sos_token=90)
Exemple #2
0
def test_greedy_trained():
    embed_size = 12
    input_size = output_size = 91
    device = torch.device("cpu")
    encoder = base.Encoder(input_size=input_size,
                           embed_size=embed_size,
                           hidden_size=32,
                           z_size=Z_SIZE,
                           n_layers=2)

    params = {
        "output_size": output_size,
        "z_size": Z_SIZE,
        "device": device,
        "decoder_args": {
            "key": "another",
            "params": {
                "z_size": Z_SIZE,
                "embed_size": embed_size,
                "hidden_size": 64,
                "n_layers": 2
            }
        }
    }
    seq_decoder = seq_decoder_factory.create("greedy", **params)
    model = base.Seq2Seq(encoder, seq_decoder, 90)
    assert_trained(model)
Exemple #3
0
def test_hier_trained():
    embed_size = 12
    input_size = output_size = 91
    c_size = 134
    n_subseq = 2
    device = torch.device("cpu")
    encoder = base.Encoder(input_size=input_size,
                           embed_size=embed_size,
                           hidden_size=32,
                           z_size=Z_SIZE,
                           n_layers=2)

    params = {
        "output_size": output_size,
        "z_size": Z_SIZE,
        "n_subsequences": n_subseq,
        "device": device,
        "conductor_args": {
            "key": "conductor",
            "params": {
                "hidden_size": 23,
                "c_size": c_size,
                "n_layers": 3
            }
        },
        "seq_decoder_args": {
            "key": "sample",
            "params": {
                "z_size": c_size,
                "device": device,
                "temperature": 3.0,
                "decoder_args": {
                    "key": "another",
                    "params": {
                        "embed_size": embed_size,
                        "hidden_size": 64,
                        "n_layers": 2,
                        "dropout_prob": 0.3,
                        "z_size": c_size
                    }
                }
            }
        },
    }
    seq_decoder = seq_decoder_factory.create("hier", **params)
    model = base.Seq2Seq(encoder, seq_decoder, 90)
    assert_trained(model)
Exemple #4
0
def test_hier_seq2seq():
    embed_size = 12
    input_size = output_size = 91
    z_size = 33
    c_size = 121
    n_subseq = 2

    device = torch.device("cpu")
    enc = base.Encoder(input_size=input_size,
                       embed_size=embed_size,
                       hidden_size=32,
                       z_size=z_size,
                       n_layers=2)

    params = {
        "output_size": output_size,
        "z_size": z_size,
        "n_subsequences": n_subseq,
        "device": device,
        "conductor_args": {
            "key": "conductor",
            "params": {
                "hidden_size": 23,
                "c_size": c_size,
                "n_layers": 3
            }
        },
        "seq_decoder_args": {
            "key": "greedy",
            "params": {
                "z_size": c_size,
                "decoder_args": {
                    "key": "another",
                    "params": {
                        "embed_size": embed_size,
                        "hidden_size": 64,
                        "n_layers": 2,
                        "dropout_prob": 0.3,
                        "z_size": c_size
                    }
                }
            }
        },
    }
    seq_decoder = seq_decoder_factory.create("hier", **params)
    assert_seq2seq(enc, seq_decoder, sos_token=90)
Exemple #5
0
def test_greedy_seq2seq_ckpt():
    embed_size = 12
    input_size = output_size = 91
    z_size = 33
    device = torch.device("cpu")
    encoder = base.Encoder(input_size=input_size, embed_size=embed_size, hidden_size=32, z_size=z_size, n_layers=2)
    params = {
        "output_size": output_size,
        "z_size": z_size,
        "device": device,
        "decoder_args": {
            "key": "another",
            "params": {
                "z_size": z_size,
                "embed_size": embed_size,
                "hidden_size": 64,
                "n_layers": 2,
                "dropout_prob": 0.2
            }
        }
    }
    seq_decoder = seq_decoder_factory.create("greedy", **params)
    model = base.Seq2Seq(encoder, seq_decoder, 90)
    assert_seq2seq_ckpt(model)
Exemple #6
0
def run(_run,
        num_steps,
        batch_size,
        num_workers,
        z_size,
        beta_settings,
        sampling_settings,
        free_bits,
        encoder_params,
        seq_decoder_args,
        learning_rate,
        train_dir,
        eval_dir,
        slice_bar,
        lr_scheduler_factor,
        lr_scheduler_patience,
        evaluate_interval=1000,
        advanced_interval=200,
        print_interval=200,
        ckpt_path=None):
    global logger

    # define dataset ------------------------------------------------------------------------------
    enc_mel_to_idx = melody_dataset.MapMelodyToIndex(has_sos_token=False)
    dec_mel_to_idx = melody_dataset.MapMelodyToIndex(has_sos_token=True)

    ds_train = melody_dataset.MelodyDataset(midi_dir=train_dir,
                                            slice_bars=slice_bar,
                                            transforms=enc_mel_to_idx,
                                            train=True)
    dl_train = DataLoader(ds_train,
                          batch_size=batch_size,
                          num_workers=num_workers,
                          drop_last=True)

    ds_eval = melody_dataset.MelodyDataset(midi_dir=eval_dir,
                                           slice_bars=slice_bar,
                                           transforms=enc_mel_to_idx,
                                           train=False)
    dl_eval = DataLoader(ds_eval,
                         batch_size=batch_size,
                         num_workers=num_workers,
                         drop_last=False)
    print(
        f"Train/Eval files: {len(ds_train.midi_files)} / {len(ds_eval.midi_files)}"
    )

    # define logger -------------------------------------------------------------------------------
    run_dir = utils.get_run_dir(_run)
    writer = SummaryWriter(log_dir=run_dir)
    ckpt_dir = Path(writer.log_dir) / "ckpt"
    ckpt_dir.mkdir(parents=True, exist_ok=True)
    logger = Logger(writer, ckpt_dir=ckpt_dir, melody_dict=dec_mel_to_idx)

    # define model --------------------------------------------------------------------------------
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    enc = base.Encoder(input_size=enc_mel_to_idx.dict_size(),
                       z_size=z_size,
                       **encoder_params)
    params = {
        **seq_decoder_args["params"], "output_size":
        dec_mel_to_idx.dict_size(),
        "device": device
    }
    seq_decoder = seq_decoder_factory.create(seq_decoder_args["key"], **params)

    # ---------------------------------------------
    # if not use_hier:
    #     # flat model ------------------------------------------------------------------------------
    #     dec = base.AnotherDecoder(output_size=dec_mel_to_idx.dict_size(), z_size=z_size, **decoder_params)
    #     seq_decoder = base.SimpleSeqDecoder(dec, z_size=z_size, device=device)
    # else:
    #     # hier model ------------------------------------------------------------------------------
    #     assert conductor_params is not None and c_size is not None and n_subsequences is not None
    #
    #     conductor = hier.Conductor(c_size=c_size, **conductor_params)
    #     dec = base.AnotherDecoder(output_size=dec_mel_to_idx.dict_size(), z_size=c_size, **decoder_params)
    #     seq_decoder = hier.HierarchicalSeqDecoder(conductor, dec, n_subsequences=n_subsequences, z_size=z_size,
    #                                               device=device)
    # ----------------------------------------------
    model = base.Seq2Seq(enc,
                         seq_decoder,
                         sos_token=dec_mel_to_idx.get_sos_token()).to(device)

    print(model)

    # define optimizer ----------------------------------------------------------------------------
    opt = optim.Adam(model.parameters(), lr=learning_rate, betas=(0.9, 0.999))

    # define scheduler ----------------------------------------------------------------------------
    lr_scheduler = optim.lr_scheduler.ReduceLROnPlateau(
        opt,
        mode='min',
        factor=lr_scheduler_factor,
        patience=lr_scheduler_patience,
        verbose=True)

    # load checkpoint, if given -------------------------------------------------------------------
    step = 1
    if ckpt_path is not None:
        step = utils.load_ckpt(ckpt_path, model, opt, lr_scheduler, device) + 1
        print(
            f"Loaded checkpoint from \"{ckpt_path}\" start from step {step}.")

    # start train loop ----------------------------------------------------------------------------
    train(model,
          dl_train,
          opt,
          lr_scheduler,
          device,
          beta_settings,
          sampling_settings,
          free_bits,
          step=step,
          num_steps=num_steps,
          dl_eval=dl_eval,
          evaluate_interval=evaluate_interval,
          advanced_interval=advanced_interval,
          print_interval=print_interval)