Beispiel #1
0
def test_greedy_trained():
    embed_size = 12
    input_size = output_size = 91
    device = torch.device("cpu")
    encoder = base.Encoder(input_size=input_size,
                           embed_size=embed_size,
                           hidden_size=32,
                           z_size=Z_SIZE,
                           n_layers=2)

    params = {
        "output_size": output_size,
        "z_size": Z_SIZE,
        "device": device,
        "decoder_args": {
            "key": "another",
            "params": {
                "z_size": Z_SIZE,
                "embed_size": embed_size,
                "hidden_size": 64,
                "n_layers": 2
            }
        }
    }
    seq_decoder = seq_decoder_factory.create("greedy", **params)
    model = base.Seq2Seq(encoder, seq_decoder, 90)
    assert_trained(model)
Beispiel #2
0
def test_sample_seq2seq():
    embed_size = 12
    input_size = output_size = 91
    z_size = 33
    device = torch.device("cpu")
    enc = base.Encoder(input_size=input_size,
                       embed_size=embed_size,
                       hidden_size=32,
                       z_size=z_size,
                       n_layers=2)

    params = {
        "output_size": output_size,
        "z_size": z_size,
        "device": device,
        "temperature": 3.0,
        "decoder_args": {
            "key": "another",
            "params": {
                "embed_size": embed_size,
                "hidden_size": 64,
                "z_size": z_size,
                "n_layers": 2
            }
        }
    }
    seq_decoder = seq_decoder_factory.create("sample", **params)
    assert_seq2seq(enc, seq_decoder, sos_token=90)
Beispiel #3
0
def test_factory_greedy():
    z_size = 111

    device = torch.device("cpu")
    params = {
        "decoder_args": {
            "key": "another",
            "params": {
                "embed_size": 43,
                "hidden_size": 222,
                "n_layers": 1,
                "z_size": z_size
            }
        },
        "output_size": 101,
        "z_size": z_size,
        "device": device
    }
    seq_decoder = seq_decoder_factory.create("greedy", **params)
    assert isinstance(seq_decoder, seq_decoder_module.GreedySeqDecoder)
    assert seq_decoder.z_size == z_size

    assert isinstance(seq_decoder.decoder, base.AnotherDecoder)
    assert seq_decoder.decoder.output_size == 101
    assert seq_decoder.decoder.embed_size == 43
    assert seq_decoder.decoder.hidden_size == 222
    assert seq_decoder.decoder.z_size == z_size
    assert seq_decoder.decoder.n_layers == 1
    assert seq_decoder.decoder.dropout_prob == 0.
Beispiel #4
0
def test_factory_sample():
    z_size = 111

    device = torch.device("cpu")
    params = {
        "decoder_args": {
            "key": "another",
            "params": {
                "embed_size": 43,
                "hidden_size": 222,
                "n_layers": 4,
                "z_size": z_size,
                "dropout_prob": 0.3
            }
        },
        "output_size": 81,
        "temperature": 0.4,
        "z_size": z_size,
        "device": device
    }
    seq_decoder = seq_decoder_factory.create("sample", **params)
    assert isinstance(seq_decoder, seq_decoder_module.GreedySeqDecoder)
    assert seq_decoder.z_size == z_size
    assert seq_decoder.temperature == 0.4

    assert isinstance(seq_decoder.decoder, base.AnotherDecoder)
    assert seq_decoder.decoder.output_size == 81
    assert seq_decoder.decoder.embed_size == 43
    assert seq_decoder.decoder.hidden_size == 222
    assert seq_decoder.decoder.z_size == z_size
    assert seq_decoder.decoder.n_layers == 4
    assert seq_decoder.decoder.dropout_prob == .3
Beispiel #5
0
def test_factory_hier():
    z_size = 38
    c_size = 23

    device = torch.device("cpu")
    params = {
        "conductor_args": {
            "key": "conductor",
            "params": {
                "hidden_size": 13,
                "n_layers": 3,
                "c_size": c_size
            }
        },
        "seq_decoder_args": {
            "key": "greedy",
            "params": {
                "decoder_args": {
                    "key": "another",
                    "params": {
                        "embed_size": 12,
                        "hidden_size": 12,
                        "n_layers": 2,
                        "z_size": c_size
                    }
                },
                "z_size": c_size
            }
        },
        "output_size": 91,
        "n_subsequences": 8,
        "z_size": z_size,
        "device": device
    }

    seq_decoder = seq_decoder_factory.create("hier", **params)
    assert isinstance(seq_decoder, seq_decoder_module.HierarchicalSeqDecoder)
    assert seq_decoder.n_subsequences == 8
    assert seq_decoder.z_size == z_size

    assert isinstance(seq_decoder.conductor, base.Conductor)
    assert seq_decoder.conductor.hidden_size == 13
    assert seq_decoder.conductor.c_size == c_size
    assert seq_decoder.conductor.n_layers == 3

    assert isinstance(seq_decoder.seq_decoder, seq_decoder_module.GreedySeqDecoder)
    assert seq_decoder.seq_decoder.z_size == c_size

    assert isinstance(seq_decoder.seq_decoder.decoder, base.AnotherDecoder)
    assert seq_decoder.seq_decoder.decoder.output_size == 91
    assert seq_decoder.seq_decoder.decoder.embed_size == 12
    assert seq_decoder.seq_decoder.decoder.hidden_size == 12
    assert seq_decoder.seq_decoder.decoder.z_size == c_size
    assert seq_decoder.seq_decoder.decoder.n_layers == 2
    assert seq_decoder.seq_decoder.decoder.dropout_prob == 0.
Beispiel #6
0
def test_sample_seq_decoder_output_shape():
    z_size = 9
    hidden_size = 10
    embed_size = 12
    output_size = 90
    n_layers = 2

    device = torch.device("cpu")
    params = {
        "output_size": output_size,
        "z_size": z_size,
        "device": device,
        "temperature": 1.e-8,
        "decoder_args": {
            "key": "simple",
            "params": {
                "embed_size": embed_size,
                "hidden_size": hidden_size,
                "n_layers": n_layers,
                "dropout_prob": 0.2
            }
        }
    }
    seq_decoder = seq_decoder_factory.create("sample", **params)
    assert isinstance(seq_decoder, seq_decoder_module.SampleSeqDecoder)
    assert isinstance(seq_decoder.decoder, base.SimpleDecoder)
    assert_seq_decoder_output_shape(seq_decoder, z_size)

    params["decoder_args"] = {
        "key": "another",
        "params": {
            "embed_size": embed_size,
            "hidden_size": hidden_size,
            "n_layers": n_layers,
            "dropout_prob": 0.3,
            "z_size": z_size
        }
    }
    seq_decoder = seq_decoder_factory.create("sample", **params)
    assert isinstance(seq_decoder.decoder, base.AnotherDecoder)
    assert_seq_decoder_output_shape(seq_decoder, z_size)
Beispiel #7
0
def test_hier_trained():
    embed_size = 12
    input_size = output_size = 91
    c_size = 134
    n_subseq = 2
    device = torch.device("cpu")
    encoder = base.Encoder(input_size=input_size,
                           embed_size=embed_size,
                           hidden_size=32,
                           z_size=Z_SIZE,
                           n_layers=2)

    params = {
        "output_size": output_size,
        "z_size": Z_SIZE,
        "n_subsequences": n_subseq,
        "device": device,
        "conductor_args": {
            "key": "conductor",
            "params": {
                "hidden_size": 23,
                "c_size": c_size,
                "n_layers": 3
            }
        },
        "seq_decoder_args": {
            "key": "sample",
            "params": {
                "z_size": c_size,
                "device": device,
                "temperature": 3.0,
                "decoder_args": {
                    "key": "another",
                    "params": {
                        "embed_size": embed_size,
                        "hidden_size": 64,
                        "n_layers": 2,
                        "dropout_prob": 0.3,
                        "z_size": c_size
                    }
                }
            }
        },
    }
    seq_decoder = seq_decoder_factory.create("hier", **params)
    model = base.Seq2Seq(encoder, seq_decoder, 90)
    assert_trained(model)
Beispiel #8
0
def test_hier_seq2seq():
    embed_size = 12
    input_size = output_size = 91
    z_size = 33
    c_size = 121
    n_subseq = 2

    device = torch.device("cpu")
    enc = base.Encoder(input_size=input_size,
                       embed_size=embed_size,
                       hidden_size=32,
                       z_size=z_size,
                       n_layers=2)

    params = {
        "output_size": output_size,
        "z_size": z_size,
        "n_subsequences": n_subseq,
        "device": device,
        "conductor_args": {
            "key": "conductor",
            "params": {
                "hidden_size": 23,
                "c_size": c_size,
                "n_layers": 3
            }
        },
        "seq_decoder_args": {
            "key": "greedy",
            "params": {
                "z_size": c_size,
                "decoder_args": {
                    "key": "another",
                    "params": {
                        "embed_size": embed_size,
                        "hidden_size": 64,
                        "n_layers": 2,
                        "dropout_prob": 0.3,
                        "z_size": c_size
                    }
                }
            }
        },
    }
    seq_decoder = seq_decoder_factory.create("hier", **params)
    assert_seq2seq(enc, seq_decoder, sos_token=90)
Beispiel #9
0
def test_hier_seq_decoder_input():
    z_size = 2
    batch_size = 1
    seq_len = 32
    output_size = 90

    device = torch.device("cpu")
    params = {
        "output_size": output_size,
        "z_size": z_size,
        "n_subsequences": 2,
        "device": device,
        "conductor_args": {
            "key": "conductor",
            "params": {
                "hidden_size": 1,
                "c_size": 1,
                "n_layers": 1
            }
        },
        "seq_decoder_args": {
            "key": "greedy",
            "params": {
                "z_size": 1,
                "device": device,
                "decoder_args": {
                    "key": "fake",
                    "params": {}
                }
            }
        },
    }
    seq_decoder = seq_decoder_factory.create("hier", **params)
    assert isinstance(seq_decoder, seq_decoder_module.HierarchicalSeqDecoder)

    z = torch.randn((batch_size, z_size))
    for _ in range(100):
        trg = torch.randint(0, 50, size=(batch_size, seq_len))
        assert_input_order(seq_decoder, trg, z)
Beispiel #10
0
def test_greedy_seq_decoder_input():
    z_size = 2
    batch_size = 1
    seq_len = 10
    output_size = 90

    device = torch.device("cpu")
    params = {
        "output_size": output_size,
        "z_size": z_size,
        "device": device,
        "decoder_args": {
            "key": "fake",
            "params": {}
        }
    }
    seq_decoder = seq_decoder_factory.create("greedy", **params)
    assert isinstance(seq_decoder, seq_decoder_module.GreedySeqDecoder)

    z = torch.randn((batch_size, z_size))
    for _ in range(100):
        trg = torch.randint(0, 50, size=(batch_size, seq_len))
        assert_input_order(seq_decoder, trg, z)
Beispiel #11
0
def test_greedy_seq2seq_ckpt():
    embed_size = 12
    input_size = output_size = 91
    z_size = 33
    device = torch.device("cpu")
    encoder = base.Encoder(input_size=input_size, embed_size=embed_size, hidden_size=32, z_size=z_size, n_layers=2)
    params = {
        "output_size": output_size,
        "z_size": z_size,
        "device": device,
        "decoder_args": {
            "key": "another",
            "params": {
                "z_size": z_size,
                "embed_size": embed_size,
                "hidden_size": 64,
                "n_layers": 2,
                "dropout_prob": 0.2
            }
        }
    }
    seq_decoder = seq_decoder_factory.create("greedy", **params)
    model = base.Seq2Seq(encoder, seq_decoder, 90)
    assert_seq2seq_ckpt(model)
Beispiel #12
0
def run(_run,
        num_steps,
        batch_size,
        num_workers,
        z_size,
        beta_settings,
        sampling_settings,
        free_bits,
        encoder_params,
        seq_decoder_args,
        learning_rate,
        train_dir,
        eval_dir,
        slice_bar,
        lr_scheduler_factor,
        lr_scheduler_patience,
        evaluate_interval=1000,
        advanced_interval=200,
        print_interval=200,
        ckpt_path=None):
    global logger

    # define dataset ------------------------------------------------------------------------------
    enc_mel_to_idx = melody_dataset.MapMelodyToIndex(has_sos_token=False)
    dec_mel_to_idx = melody_dataset.MapMelodyToIndex(has_sos_token=True)

    ds_train = melody_dataset.MelodyDataset(midi_dir=train_dir,
                                            slice_bars=slice_bar,
                                            transforms=enc_mel_to_idx,
                                            train=True)
    dl_train = DataLoader(ds_train,
                          batch_size=batch_size,
                          num_workers=num_workers,
                          drop_last=True)

    ds_eval = melody_dataset.MelodyDataset(midi_dir=eval_dir,
                                           slice_bars=slice_bar,
                                           transforms=enc_mel_to_idx,
                                           train=False)
    dl_eval = DataLoader(ds_eval,
                         batch_size=batch_size,
                         num_workers=num_workers,
                         drop_last=False)
    print(
        f"Train/Eval files: {len(ds_train.midi_files)} / {len(ds_eval.midi_files)}"
    )

    # define logger -------------------------------------------------------------------------------
    run_dir = utils.get_run_dir(_run)
    writer = SummaryWriter(log_dir=run_dir)
    ckpt_dir = Path(writer.log_dir) / "ckpt"
    ckpt_dir.mkdir(parents=True, exist_ok=True)
    logger = Logger(writer, ckpt_dir=ckpt_dir, melody_dict=dec_mel_to_idx)

    # define model --------------------------------------------------------------------------------
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    enc = base.Encoder(input_size=enc_mel_to_idx.dict_size(),
                       z_size=z_size,
                       **encoder_params)
    params = {
        **seq_decoder_args["params"], "output_size":
        dec_mel_to_idx.dict_size(),
        "device": device
    }
    seq_decoder = seq_decoder_factory.create(seq_decoder_args["key"], **params)

    # ---------------------------------------------
    # if not use_hier:
    #     # flat model ------------------------------------------------------------------------------
    #     dec = base.AnotherDecoder(output_size=dec_mel_to_idx.dict_size(), z_size=z_size, **decoder_params)
    #     seq_decoder = base.SimpleSeqDecoder(dec, z_size=z_size, device=device)
    # else:
    #     # hier model ------------------------------------------------------------------------------
    #     assert conductor_params is not None and c_size is not None and n_subsequences is not None
    #
    #     conductor = hier.Conductor(c_size=c_size, **conductor_params)
    #     dec = base.AnotherDecoder(output_size=dec_mel_to_idx.dict_size(), z_size=c_size, **decoder_params)
    #     seq_decoder = hier.HierarchicalSeqDecoder(conductor, dec, n_subsequences=n_subsequences, z_size=z_size,
    #                                               device=device)
    # ----------------------------------------------
    model = base.Seq2Seq(enc,
                         seq_decoder,
                         sos_token=dec_mel_to_idx.get_sos_token()).to(device)

    print(model)

    # define optimizer ----------------------------------------------------------------------------
    opt = optim.Adam(model.parameters(), lr=learning_rate, betas=(0.9, 0.999))

    # define scheduler ----------------------------------------------------------------------------
    lr_scheduler = optim.lr_scheduler.ReduceLROnPlateau(
        opt,
        mode='min',
        factor=lr_scheduler_factor,
        patience=lr_scheduler_patience,
        verbose=True)

    # load checkpoint, if given -------------------------------------------------------------------
    step = 1
    if ckpt_path is not None:
        step = utils.load_ckpt(ckpt_path, model, opt, lr_scheduler, device) + 1
        print(
            f"Loaded checkpoint from \"{ckpt_path}\" start from step {step}.")

    # start train loop ----------------------------------------------------------------------------
    train(model,
          dl_train,
          opt,
          lr_scheduler,
          device,
          beta_settings,
          sampling_settings,
          free_bits,
          step=step,
          num_steps=num_steps,
          dl_eval=dl_eval,
          evaluate_interval=evaluate_interval,
          advanced_interval=advanced_interval,
          print_interval=print_interval)
Beispiel #13
0
def test_hier_seq_decoder_output_shape():
    z_size = 9
    hidden_size = 10
    embed_size = 12
    output_size = 90
    n_layers = 2
    n_subseq = 2
    c_size = 44

    device = torch.device("cpu")
    params = {
        "output_size": output_size,
        "z_size": z_size,
        "n_subsequences": n_subseq,
        "device": device,
        "conductor_args": {
            "key": "conductor",
            "params": {
                "hidden_size": 23,
                "c_size": c_size,
                "n_layers": n_layers
            }
        },
        "seq_decoder_args": {
            "key": "greedy",
            "params": {
                "z_size": c_size,
                "decoder_args": {
                    "key": "simple",
                    "params": {
                        "embed_size": embed_size,
                        "hidden_size": hidden_size,
                        "n_layers": n_layers,
                        "dropout_prob": 0.2
                    }
                }
            }
        },
    }

    seq_decoder = seq_decoder_factory.create("hier", **params)
    assert isinstance(seq_decoder, seq_decoder_module.HierarchicalSeqDecoder)
    assert isinstance(seq_decoder.seq_decoder,
                      seq_decoder_module.GreedySeqDecoder)
    assert isinstance(seq_decoder.seq_decoder.decoder, base.SimpleDecoder)
    assert_seq_decoder_output_shape(seq_decoder, z_size)

    params["seq_decoder_args"]["params"]["decoder_args"] = {
        "key": "another",
        "params": {
            "embed_size": embed_size,
            "hidden_size": hidden_size,
            "n_layers": n_layers,
            "dropout_prob": 0.3,
            "z_size": c_size
        }
    }

    seq_decoder = seq_decoder_factory.create("hier", **params)
    assert isinstance(seq_decoder, seq_decoder_module.HierarchicalSeqDecoder)
    assert isinstance(seq_decoder.seq_decoder,
                      seq_decoder_module.GreedySeqDecoder)
    assert isinstance(seq_decoder.seq_decoder.decoder, base.AnotherDecoder)
    assert_seq_decoder_output_shape(seq_decoder, z_size)