def test_greedy_trained(): embed_size = 12 input_size = output_size = 91 device = torch.device("cpu") encoder = base.Encoder(input_size=input_size, embed_size=embed_size, hidden_size=32, z_size=Z_SIZE, n_layers=2) params = { "output_size": output_size, "z_size": Z_SIZE, "device": device, "decoder_args": { "key": "another", "params": { "z_size": Z_SIZE, "embed_size": embed_size, "hidden_size": 64, "n_layers": 2 } } } seq_decoder = seq_decoder_factory.create("greedy", **params) model = base.Seq2Seq(encoder, seq_decoder, 90) assert_trained(model)
def test_sample_seq2seq(): embed_size = 12 input_size = output_size = 91 z_size = 33 device = torch.device("cpu") enc = base.Encoder(input_size=input_size, embed_size=embed_size, hidden_size=32, z_size=z_size, n_layers=2) params = { "output_size": output_size, "z_size": z_size, "device": device, "temperature": 3.0, "decoder_args": { "key": "another", "params": { "embed_size": embed_size, "hidden_size": 64, "z_size": z_size, "n_layers": 2 } } } seq_decoder = seq_decoder_factory.create("sample", **params) assert_seq2seq(enc, seq_decoder, sos_token=90)
def test_factory_greedy(): z_size = 111 device = torch.device("cpu") params = { "decoder_args": { "key": "another", "params": { "embed_size": 43, "hidden_size": 222, "n_layers": 1, "z_size": z_size } }, "output_size": 101, "z_size": z_size, "device": device } seq_decoder = seq_decoder_factory.create("greedy", **params) assert isinstance(seq_decoder, seq_decoder_module.GreedySeqDecoder) assert seq_decoder.z_size == z_size assert isinstance(seq_decoder.decoder, base.AnotherDecoder) assert seq_decoder.decoder.output_size == 101 assert seq_decoder.decoder.embed_size == 43 assert seq_decoder.decoder.hidden_size == 222 assert seq_decoder.decoder.z_size == z_size assert seq_decoder.decoder.n_layers == 1 assert seq_decoder.decoder.dropout_prob == 0.
def test_factory_sample(): z_size = 111 device = torch.device("cpu") params = { "decoder_args": { "key": "another", "params": { "embed_size": 43, "hidden_size": 222, "n_layers": 4, "z_size": z_size, "dropout_prob": 0.3 } }, "output_size": 81, "temperature": 0.4, "z_size": z_size, "device": device } seq_decoder = seq_decoder_factory.create("sample", **params) assert isinstance(seq_decoder, seq_decoder_module.GreedySeqDecoder) assert seq_decoder.z_size == z_size assert seq_decoder.temperature == 0.4 assert isinstance(seq_decoder.decoder, base.AnotherDecoder) assert seq_decoder.decoder.output_size == 81 assert seq_decoder.decoder.embed_size == 43 assert seq_decoder.decoder.hidden_size == 222 assert seq_decoder.decoder.z_size == z_size assert seq_decoder.decoder.n_layers == 4 assert seq_decoder.decoder.dropout_prob == .3
def test_factory_hier(): z_size = 38 c_size = 23 device = torch.device("cpu") params = { "conductor_args": { "key": "conductor", "params": { "hidden_size": 13, "n_layers": 3, "c_size": c_size } }, "seq_decoder_args": { "key": "greedy", "params": { "decoder_args": { "key": "another", "params": { "embed_size": 12, "hidden_size": 12, "n_layers": 2, "z_size": c_size } }, "z_size": c_size } }, "output_size": 91, "n_subsequences": 8, "z_size": z_size, "device": device } seq_decoder = seq_decoder_factory.create("hier", **params) assert isinstance(seq_decoder, seq_decoder_module.HierarchicalSeqDecoder) assert seq_decoder.n_subsequences == 8 assert seq_decoder.z_size == z_size assert isinstance(seq_decoder.conductor, base.Conductor) assert seq_decoder.conductor.hidden_size == 13 assert seq_decoder.conductor.c_size == c_size assert seq_decoder.conductor.n_layers == 3 assert isinstance(seq_decoder.seq_decoder, seq_decoder_module.GreedySeqDecoder) assert seq_decoder.seq_decoder.z_size == c_size assert isinstance(seq_decoder.seq_decoder.decoder, base.AnotherDecoder) assert seq_decoder.seq_decoder.decoder.output_size == 91 assert seq_decoder.seq_decoder.decoder.embed_size == 12 assert seq_decoder.seq_decoder.decoder.hidden_size == 12 assert seq_decoder.seq_decoder.decoder.z_size == c_size assert seq_decoder.seq_decoder.decoder.n_layers == 2 assert seq_decoder.seq_decoder.decoder.dropout_prob == 0.
def test_sample_seq_decoder_output_shape(): z_size = 9 hidden_size = 10 embed_size = 12 output_size = 90 n_layers = 2 device = torch.device("cpu") params = { "output_size": output_size, "z_size": z_size, "device": device, "temperature": 1.e-8, "decoder_args": { "key": "simple", "params": { "embed_size": embed_size, "hidden_size": hidden_size, "n_layers": n_layers, "dropout_prob": 0.2 } } } seq_decoder = seq_decoder_factory.create("sample", **params) assert isinstance(seq_decoder, seq_decoder_module.SampleSeqDecoder) assert isinstance(seq_decoder.decoder, base.SimpleDecoder) assert_seq_decoder_output_shape(seq_decoder, z_size) params["decoder_args"] = { "key": "another", "params": { "embed_size": embed_size, "hidden_size": hidden_size, "n_layers": n_layers, "dropout_prob": 0.3, "z_size": z_size } } seq_decoder = seq_decoder_factory.create("sample", **params) assert isinstance(seq_decoder.decoder, base.AnotherDecoder) assert_seq_decoder_output_shape(seq_decoder, z_size)
def test_hier_trained(): embed_size = 12 input_size = output_size = 91 c_size = 134 n_subseq = 2 device = torch.device("cpu") encoder = base.Encoder(input_size=input_size, embed_size=embed_size, hidden_size=32, z_size=Z_SIZE, n_layers=2) params = { "output_size": output_size, "z_size": Z_SIZE, "n_subsequences": n_subseq, "device": device, "conductor_args": { "key": "conductor", "params": { "hidden_size": 23, "c_size": c_size, "n_layers": 3 } }, "seq_decoder_args": { "key": "sample", "params": { "z_size": c_size, "device": device, "temperature": 3.0, "decoder_args": { "key": "another", "params": { "embed_size": embed_size, "hidden_size": 64, "n_layers": 2, "dropout_prob": 0.3, "z_size": c_size } } } }, } seq_decoder = seq_decoder_factory.create("hier", **params) model = base.Seq2Seq(encoder, seq_decoder, 90) assert_trained(model)
def test_hier_seq2seq(): embed_size = 12 input_size = output_size = 91 z_size = 33 c_size = 121 n_subseq = 2 device = torch.device("cpu") enc = base.Encoder(input_size=input_size, embed_size=embed_size, hidden_size=32, z_size=z_size, n_layers=2) params = { "output_size": output_size, "z_size": z_size, "n_subsequences": n_subseq, "device": device, "conductor_args": { "key": "conductor", "params": { "hidden_size": 23, "c_size": c_size, "n_layers": 3 } }, "seq_decoder_args": { "key": "greedy", "params": { "z_size": c_size, "decoder_args": { "key": "another", "params": { "embed_size": embed_size, "hidden_size": 64, "n_layers": 2, "dropout_prob": 0.3, "z_size": c_size } } } }, } seq_decoder = seq_decoder_factory.create("hier", **params) assert_seq2seq(enc, seq_decoder, sos_token=90)
def test_hier_seq_decoder_input(): z_size = 2 batch_size = 1 seq_len = 32 output_size = 90 device = torch.device("cpu") params = { "output_size": output_size, "z_size": z_size, "n_subsequences": 2, "device": device, "conductor_args": { "key": "conductor", "params": { "hidden_size": 1, "c_size": 1, "n_layers": 1 } }, "seq_decoder_args": { "key": "greedy", "params": { "z_size": 1, "device": device, "decoder_args": { "key": "fake", "params": {} } } }, } seq_decoder = seq_decoder_factory.create("hier", **params) assert isinstance(seq_decoder, seq_decoder_module.HierarchicalSeqDecoder) z = torch.randn((batch_size, z_size)) for _ in range(100): trg = torch.randint(0, 50, size=(batch_size, seq_len)) assert_input_order(seq_decoder, trg, z)
def test_greedy_seq_decoder_input(): z_size = 2 batch_size = 1 seq_len = 10 output_size = 90 device = torch.device("cpu") params = { "output_size": output_size, "z_size": z_size, "device": device, "decoder_args": { "key": "fake", "params": {} } } seq_decoder = seq_decoder_factory.create("greedy", **params) assert isinstance(seq_decoder, seq_decoder_module.GreedySeqDecoder) z = torch.randn((batch_size, z_size)) for _ in range(100): trg = torch.randint(0, 50, size=(batch_size, seq_len)) assert_input_order(seq_decoder, trg, z)
def test_greedy_seq2seq_ckpt(): embed_size = 12 input_size = output_size = 91 z_size = 33 device = torch.device("cpu") encoder = base.Encoder(input_size=input_size, embed_size=embed_size, hidden_size=32, z_size=z_size, n_layers=2) params = { "output_size": output_size, "z_size": z_size, "device": device, "decoder_args": { "key": "another", "params": { "z_size": z_size, "embed_size": embed_size, "hidden_size": 64, "n_layers": 2, "dropout_prob": 0.2 } } } seq_decoder = seq_decoder_factory.create("greedy", **params) model = base.Seq2Seq(encoder, seq_decoder, 90) assert_seq2seq_ckpt(model)
def run(_run, num_steps, batch_size, num_workers, z_size, beta_settings, sampling_settings, free_bits, encoder_params, seq_decoder_args, learning_rate, train_dir, eval_dir, slice_bar, lr_scheduler_factor, lr_scheduler_patience, evaluate_interval=1000, advanced_interval=200, print_interval=200, ckpt_path=None): global logger # define dataset ------------------------------------------------------------------------------ enc_mel_to_idx = melody_dataset.MapMelodyToIndex(has_sos_token=False) dec_mel_to_idx = melody_dataset.MapMelodyToIndex(has_sos_token=True) ds_train = melody_dataset.MelodyDataset(midi_dir=train_dir, slice_bars=slice_bar, transforms=enc_mel_to_idx, train=True) dl_train = DataLoader(ds_train, batch_size=batch_size, num_workers=num_workers, drop_last=True) ds_eval = melody_dataset.MelodyDataset(midi_dir=eval_dir, slice_bars=slice_bar, transforms=enc_mel_to_idx, train=False) dl_eval = DataLoader(ds_eval, batch_size=batch_size, num_workers=num_workers, drop_last=False) print( f"Train/Eval files: {len(ds_train.midi_files)} / {len(ds_eval.midi_files)}" ) # define logger ------------------------------------------------------------------------------- run_dir = utils.get_run_dir(_run) writer = SummaryWriter(log_dir=run_dir) ckpt_dir = Path(writer.log_dir) / "ckpt" ckpt_dir.mkdir(parents=True, exist_ok=True) logger = Logger(writer, ckpt_dir=ckpt_dir, melody_dict=dec_mel_to_idx) # define model -------------------------------------------------------------------------------- device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') enc = base.Encoder(input_size=enc_mel_to_idx.dict_size(), z_size=z_size, **encoder_params) params = { **seq_decoder_args["params"], "output_size": dec_mel_to_idx.dict_size(), "device": device } seq_decoder = seq_decoder_factory.create(seq_decoder_args["key"], **params) # --------------------------------------------- # if not use_hier: # # flat model ------------------------------------------------------------------------------ # dec = base.AnotherDecoder(output_size=dec_mel_to_idx.dict_size(), z_size=z_size, **decoder_params) # seq_decoder = base.SimpleSeqDecoder(dec, z_size=z_size, device=device) # else: # # hier model ------------------------------------------------------------------------------ # assert conductor_params is not None and c_size is not None and n_subsequences is not None # # conductor = hier.Conductor(c_size=c_size, **conductor_params) # dec = base.AnotherDecoder(output_size=dec_mel_to_idx.dict_size(), z_size=c_size, **decoder_params) # seq_decoder = hier.HierarchicalSeqDecoder(conductor, dec, n_subsequences=n_subsequences, z_size=z_size, # device=device) # ---------------------------------------------- model = base.Seq2Seq(enc, seq_decoder, sos_token=dec_mel_to_idx.get_sos_token()).to(device) print(model) # define optimizer ---------------------------------------------------------------------------- opt = optim.Adam(model.parameters(), lr=learning_rate, betas=(0.9, 0.999)) # define scheduler ---------------------------------------------------------------------------- lr_scheduler = optim.lr_scheduler.ReduceLROnPlateau( opt, mode='min', factor=lr_scheduler_factor, patience=lr_scheduler_patience, verbose=True) # load checkpoint, if given ------------------------------------------------------------------- step = 1 if ckpt_path is not None: step = utils.load_ckpt(ckpt_path, model, opt, lr_scheduler, device) + 1 print( f"Loaded checkpoint from \"{ckpt_path}\" start from step {step}.") # start train loop ---------------------------------------------------------------------------- train(model, dl_train, opt, lr_scheduler, device, beta_settings, sampling_settings, free_bits, step=step, num_steps=num_steps, dl_eval=dl_eval, evaluate_interval=evaluate_interval, advanced_interval=advanced_interval, print_interval=print_interval)
def test_hier_seq_decoder_output_shape(): z_size = 9 hidden_size = 10 embed_size = 12 output_size = 90 n_layers = 2 n_subseq = 2 c_size = 44 device = torch.device("cpu") params = { "output_size": output_size, "z_size": z_size, "n_subsequences": n_subseq, "device": device, "conductor_args": { "key": "conductor", "params": { "hidden_size": 23, "c_size": c_size, "n_layers": n_layers } }, "seq_decoder_args": { "key": "greedy", "params": { "z_size": c_size, "decoder_args": { "key": "simple", "params": { "embed_size": embed_size, "hidden_size": hidden_size, "n_layers": n_layers, "dropout_prob": 0.2 } } } }, } seq_decoder = seq_decoder_factory.create("hier", **params) assert isinstance(seq_decoder, seq_decoder_module.HierarchicalSeqDecoder) assert isinstance(seq_decoder.seq_decoder, seq_decoder_module.GreedySeqDecoder) assert isinstance(seq_decoder.seq_decoder.decoder, base.SimpleDecoder) assert_seq_decoder_output_shape(seq_decoder, z_size) params["seq_decoder_args"]["params"]["decoder_args"] = { "key": "another", "params": { "embed_size": embed_size, "hidden_size": hidden_size, "n_layers": n_layers, "dropout_prob": 0.3, "z_size": c_size } } seq_decoder = seq_decoder_factory.create("hier", **params) assert isinstance(seq_decoder, seq_decoder_module.HierarchicalSeqDecoder) assert isinstance(seq_decoder.seq_decoder, seq_decoder_module.GreedySeqDecoder) assert isinstance(seq_decoder.seq_decoder.decoder, base.AnotherDecoder) assert_seq_decoder_output_shape(seq_decoder, z_size)