Exemple #1
0
def main(args):
    ## ドレミ
    def easyTones():
        training_seq_lengths = torch.tensor([8]*1)
        training_data_sequences = torch.zeros(1,8,88)
        for i in range(1):
            training_data_sequences[i][0][int(70-i*10)  ] = 1
            training_data_sequences[i][1][int(70-i*10)+2] = 1
            training_data_sequences[i][2][int(70-i*10)+4] = 1
            training_data_sequences[i][3][int(70-i*10)+5] = 1
            training_data_sequences[i][4][int(70-i*10)+7] = 1
            training_data_sequences[i][5][int(70-i*10)+9] = 1
            training_data_sequences[i][6][int(70-i*10)+11] = 1
            training_data_sequences[i][7][int(70-i*10)+12] = 1
        return training_seq_lengths, training_data_sequences

    def superEasyTones():
        training_seq_lengths = torch.tensor([8]*10)
        training_data_sequences = torch.zeros(10,8,88)
        for i in range(10):
            for j in range(8):
                training_data_sequences[i][j][int(30+i*5)] = 1
        return training_seq_lengths, training_data_sequences

    ## ドドド、ドドド、ドドド
    def easiestTones():
        training_seq_lengths = torch.tensor([8]*10)
        training_data_sequences = torch.zeros(10,8,88)
        for i in range(10):
            for j in range(8):
                training_data_sequences[i][j][int(70)] = 1
        return training_seq_lengths, training_data_sequences

    # setup logging
    logging.basicConfig(level=logging.DEBUG, format='%(message)s', filename=args.log, filemode='w')
    console = logging.StreamHandler()
    console.setLevel(logging.INFO)
    logging.getLogger('').addHandler(console)
    logging.info(args)

    data = poly.load_data(poly.JSB_CHORALES)
    training_seq_lengths = data['train']['sequence_lengths']
    training_data_sequences = data['train']['sequences']
    training_seq_lengths, training_data_sequences = easiestTones()
    test_seq_lengths = data['test']['sequence_lengths']
    test_data_sequences = data['test']['sequences']
    test_seq_lengths, test_data_sequences = easiestTones()
    val_seq_lengths = data['valid']['sequence_lengths']
    val_data_sequences = data['valid']['sequences']
    val_seq_lengths, val_data_sequences = easiestTones()
    N_train_data = len(training_seq_lengths)
    N_train_time_slices = float(torch.sum(training_seq_lengths))
    N_mini_batches = int(N_train_data / args.mini_batch_size +
                         int(N_train_data % args.mini_batch_size > 0))

    logging.info("N_train_data: %d     avg. training seq. length: %.2f    N_mini_batches: %d" %
                 (N_train_data, training_seq_lengths.float().mean(), N_mini_batches))

    # how often we do validation/test evaluation during training
    val_test_frequency = 50
    # the number of samples we use to do the evaluation
    n_eval_samples = 1

    # package repeated copies of val/test data for faster evaluation
    # (i.e. set us up for vectorization)
    def rep(x):
        rep_shape = torch.Size([x.size(0) * n_eval_samples]) + x.size()[1:]
        repeat_dims = [1] * len(x.size())
        repeat_dims[0] = n_eval_samples
        return x.repeat(repeat_dims).reshape(n_eval_samples, -1).transpose(1, 0).reshape(rep_shape)

    # get the validation/test data ready for the dmm: pack into sequences, etc.
    val_seq_lengths = rep(val_seq_lengths)
    test_seq_lengths = rep(test_seq_lengths)
    val_batch, val_batch_reversed, val_batch_mask, val_seq_lengths = poly.get_mini_batch(
        torch.arange(n_eval_samples * val_data_sequences.shape[0]), rep(val_data_sequences),
        val_seq_lengths, cuda=args.cuda)
    test_batch, test_batch_reversed, test_batch_mask, test_seq_lengths = poly.get_mini_batch(
        torch.arange(n_eval_samples * test_data_sequences.shape[0]), rep(test_data_sequences),
        test_seq_lengths, cuda=args.cuda)

    # instantiate the dmm
    dmm = DMM(rnn_dropout_rate=args.rnn_dropout_rate, num_iafs=args.num_iafs,
              iaf_dim=args.iaf_dim, use_cuda=args.cuda)

    # setup optimizer
    adam_params = {"lr": args.learning_rate, "betas": (args.beta1, args.beta2),
                   "clip_norm": args.clip_norm, "lrd": args.lr_decay,
                   "weight_decay": args.weight_decay}
    adam = ClippedAdam(adam_params)

    # setup inference algorithm
    if args.tmc:
        if args.jit:
            raise NotImplementedError("no JIT support yet for TMC")
        tmc_loss = TraceTMC_ELBO()
        dmm_guide = config_enumerate(dmm.guide, default="parallel", num_samples=args.tmc_num_samples, expand=False)
        svi = SVI(dmm.model, dmm_guide, adam, loss=tmc_loss)
    elif args.tmcelbo:
        if args.jit:
            raise NotImplementedError("no JIT support yet for TMC ELBO")
        elbo = TraceEnum_ELBO()
        dmm_guide = config_enumerate(dmm.guide, default="parallel", num_samples=args.tmc_num_samples, expand=False)
        svi = SVI(dmm.model, dmm_guide, adam, loss=elbo)
    else:
        elbo = JitTrace_ELBO() if args.jit else Trace_ELBO()
        svi = SVI(dmm.model, dmm.guide, adam, loss=elbo)

    # now we're going to define some functions we need to form the main training loop

    # saves the model and optimizer states to disk
    def save_checkpoint():
        logging.info("saving model to %s..." % args.save_model)
        torch.save(dmm.state_dict(), args.save_model)
        # logging.info("saving optimizer states to %s..." % args.save_opt)
        # adam.save(args.save_opt)
        logging.info("done saving model and optimizer checkpoints to disk.")

    # loads the model and optimizer states from disk
    def load_checkpoint():
        assert exists(args.load_opt) and exists(args.load_model), \
            "--load-model and/or --load-opt misspecified"
        logging.info("loading model from %s..." % args.load_model)
        dmm.load_state_dict(torch.load(args.load_model))
        logging.info("loading optimizer states from %s..." % args.load_opt)
        adam.load(args.load_opt)
        logging.info("done loading model and optimizer states.")

    # prepare a mini-batch and take a gradient step to minimize -elbo
    def process_minibatch(epoch, which_mini_batch, shuffled_indices):
        if args.annealing_epochs > 0 and epoch < args.annealing_epochs:
            # compute the KL annealing factor approriate for the current mini-batch in the current epoch
            min_af = args.minimum_annealing_factor
            annealing_factor = min_af + (1.0 - min_af) * \
                (float(which_mini_batch + epoch * N_mini_batches + 1) /
                 float(args.annealing_epochs * N_mini_batches))
        else:
            # by default the KL annealing factor is unity
            annealing_factor = 1.0

        # compute which sequences in the training set we should grab
        mini_batch_start = (which_mini_batch * args.mini_batch_size)
        mini_batch_end = np.min([(which_mini_batch + 1) * args.mini_batch_size, N_train_data])
        mini_batch_indices = shuffled_indices[mini_batch_start:mini_batch_end]
        # grab a fully prepped mini-batch using the helper function in the data loader
        mini_batch, mini_batch_reversed, mini_batch_mask, mini_batch_seq_lengths \
            = poly.get_mini_batch(mini_batch_indices, training_data_sequences,
                                  training_seq_lengths, cuda=args.cuda)
        # do an actual gradient step
        loss = svi.step(mini_batch, mini_batch_reversed, mini_batch_mask,
                        mini_batch_seq_lengths, annealing_factor)
        # keep track of the training loss
        return loss

    # helper function for doing evaluation
    def do_evaluation():
        # put the RNN into evaluation mode (i.e. turn off drop-out if applicable)
        dmm.rnn.eval()

        # compute the validation and test loss n_samples many times
        val_nll = svi.evaluate_loss(val_batch, val_batch_reversed, val_batch_mask,
                                    val_seq_lengths) / float(torch.sum(val_seq_lengths))
        test_nll = svi.evaluate_loss(test_batch, test_batch_reversed, test_batch_mask,
                                     test_seq_lengths) / float(torch.sum(test_seq_lengths))

        # put the RNN back into training mode (i.e. turn on drop-out if applicable)
        dmm.rnn.train()
        return val_nll, test_nll

    # if checkpoint files provided, load model and optimizer states from disk before we start training
    if args.load_opt != '' and args.load_model != '':
        load_checkpoint()

    #################
    # TRAINING LOOP #
    #################
    times = [time.time()]
    for epoch in range(args.num_epochs):
        # if specified, save model and optimizer states to disk every checkpoint_freq epochs
        if args.checkpoint_freq > 0 and epoch > 0 and epoch % args.checkpoint_freq == 0:
            save_checkpoint()

        # accumulator for our estimate of the negative log likelihood (or rather -elbo) for this epoch
        epoch_nll = 0.0
        # prepare mini-batch subsampling indices for this epoch
        shuffled_indices = torch.randperm(N_train_data)

        # process each mini-batch; this is where we take gradient steps
        for which_mini_batch in range(N_mini_batches):
            epoch_nll += process_minibatch(epoch, which_mini_batch, shuffled_indices)

        # report training diagnostics
        times.append(time.time())
        epoch_time = times[-1] - times[-2]
        logging.info("[training epoch %04d]  %.4f \t\t\t\t(dt = %.3f sec)" %
                     (epoch, epoch_nll / N_train_time_slices, epoch_time))

        # do evaluation on test and validation data and report results
        if val_test_frequency > 0 and epoch > 0 and epoch % val_test_frequency == 0:
            val_nll, test_nll = do_evaluation()
            logging.info("[val/test epoch %04d]  %.4f  %.4f" % (epoch, val_nll, test_nll))
def main(args):

    ## This func is for saving losses at final
    def saveGraph(loss_list, sub_error_list, now):
        FS = 10
        fig = plt.figure()
        # plt.rcParams["font.size"] = FS
        plt.plot(loss_list, label="LOSS")
        plt.plot(sub_error_list, label="Reconstruction Error")
        plt.ylim(bottom=0)
        plt.title("Loss")
        plt.xlabel("epoch", fontsize=FS)
        plt.ylabel("loss", fontsize=FS)
        plt.legend()
        fig.savefig(os.path.join("saveData", now, "LOSS.png"))
        plt.close()

    def saveReconSinGraph(train_data, recon_data, length, path, number):
        FS = 10
        fig = plt.figure()
        ax = fig.add_subplot()
        # plt.plot(train_data[:,0], train_data[:,1], label="Training data")
        plt.plot(train_data[:, 0], train_data[:, 1], label="Training data")
        # plt.plot(recon_data.detach().numpy(), label="Reconstructed data")
        plt.plot(recon_data.detach().numpy()[:, 0],
                 recon_data.detach().numpy()[:, 1],
                 label="Reconstructed data")
        # plt.plot(recon_data.detach().numpy(), label="Reconstructed data")
        plt.title("Reconstruction")
        # plt.ylim(top=10, bottom=-10)
        # plt.xlabel("time", fontsize=FS)
        # plt.ylabel("y", fontsize=FS)
        plt.legend()
        fig.savefig(os.path.join(path,
                                 "Reconstruction" + str(number) + ".png"))
        plt.close()

    def saveGeneSinGraph(gene_data, length, path, number):
        FS = 10
        fig = plt.figure()
        ax = fig.add_subplot()
        # plt.rcParams["font.size"] = FS
        plt.plot(gene_data.detach().numpy()[:, 0],
                 gene_data.detach().numpy()[:, 1],
                 label="Generated data")
        # plt.plot(gene_data.detach().numpy(), label="Generated data")
        # plt.plot(gene_data.detach().numpy(), label="Generated data")
        plt.title("Genetration")
        # plt.ylim(top=10, bottom=-10)
        # plt.xlabel("time", fontsize=FS)
        # plt.ylabel("y", fontsize=FS)
        plt.legend()
        fig.savefig(os.path.join(path, "Generation" + str(number) + ".png"))
        plt.close()

    FS = 10
    plt.rcParams["font.size"] = FS

    ## 長さ最長129、例えば長さが60のやつは61~129はすべて0データ
    data = poly.load_data(poly.JSB_CHORALES)
    training_seq_lengths = data['train']['sequence_lengths'][:args.N_songs]
    training_data_sequences = data['train']['sequences'][:args.N_songs, :8]
    training_seq_lengths = torch.tensor([8] * args.N_songs)

    traing_data = []
    for i in range(args.N_songs):
        traing_data.append(
            musics.VanDelPol(torch.randn(2), args.length, args.T))
        # traing_data.append(musics.Nonlinear(torch.randn(1), args.length, args.T))
    training_data_sequences = torch.stack(traing_data)
    training_seq_lengths = torch.tensor([args.length] * args.N_songs)
    data_dim = training_data_sequences.size(-1)

    N_train_data = len(training_seq_lengths)
    N_train_time_slices = float(torch.sum(training_seq_lengths))
    N_mini_batches = int(N_train_data / args.mini_batch_size +
                         int(N_train_data % args.mini_batch_size > 0))

    z_dim = args.z_dim  #100 #2
    rnn_dim = args.rnn_dim  #200 #30
    transition_dim = args.transition_dim  #200 #30
    emission_dim = args.emission_dim  #100 #30
    encoder = Encoder(input_dim=training_data_sequences.size(-1),
                      z_dim=z_dim,
                      rnn_dim=rnn_dim,
                      N_z0=args.N_songs)
    # encoder = Encoder(z_dim=100, rnn_dim=200)
    prior = Prior(z_dim=z_dim,
                  transition_dim=transition_dim,
                  N_z0=args.N_songs,
                  use_cuda=args.cuda)
    decoder = Emitter(input_dim=training_data_sequences.size(-1),
                      z_dim=z_dim,
                      emission_dim=emission_dim,
                      use_cuda=args.cuda)

    # Create optimizer algorithm
    # optimizer = optim.SGD(dmm.parameters(), lr=args.learning_rate)
    # optimizer = optim.Adam(dmm.parameters(), lr=args.learning_rate, betas=(0.96, 0.999), weight_decay=2.0)
    params = list(encoder.parameters()) + list(prior.parameters()) + list(
        decoder.parameters())
    optimizer = optim.SGD(params, lr=args.learning_rate)
    # optimizer = optim.Adam(params, lr=args.learning_rate)
    # Add learning rate scheduler
    scheduler = optim.lr_scheduler.ExponentialLR(optimizer, 0.9999)
    # scheduler = optim.lr_scheduler.ExponentialLR(optimizer, 1.)
    #1でもやってみる

    # make directory for Data
    now = datetime.datetime.now().strftime('%Y%m%d_%H_%M')
    os.makedirs(os.path.join("saveData", now), exist_ok=True)

    # save args as TEXT
    f = open(os.path.join("saveData", now, 'args.txt'), 'w')  # 書き込みモードで開く
    for name, site in vars(args).items():
        f.write(name + " = ")
        f.write(str(site) + "\n")
    f.close()  # ファイルを閉じる

    #######################
    #### TRAINING LOOP ####
    #######################
    times = [time.time()]
    losses = []
    recon_errors = []
    pbar = tqdm.tqdm(range(args.num_epochs))
    for epoch in pbar:
        epoch_nll = 0
        epoch_recon_error = 0
        shuffled_indices = torch.randperm(N_train_data)
        # print("Proceeding: %.2f " % (epoch*100/5000) + "%")

        # process each mini-batch; this is where we take gradient steps
        for which_mini_batch in range(N_mini_batches):

            # compute which sequences in the training set we should grab
            mini_batch_start = (which_mini_batch * args.mini_batch_size)
            mini_batch_end = np.min([
                (which_mini_batch + 1) * args.mini_batch_size, N_train_data
            ])
            mini_batch_indices = shuffled_indices[
                mini_batch_start:mini_batch_end]

            # grab a fully prepped mini-batch using the helper function in the data loader
            mini_batch, mini_batch_reversed, mini_batch_mask, mini_batch_seq_lengths \
                = poly.get_mini_batch(mini_batch_indices, training_data_sequences,
                                        training_seq_lengths, cuda=args.cuda)

            # reset gradients
            optimizer.zero_grad()
            loss = 0

            N_repeat = 1
            for i in range(N_repeat):
                seiki = 1 / N_repeat
                # generate mini batch from training mini batch
                pos_z, pos_z_loc, pos_z_scale = encoder(
                    mini_batch, mini_batch_reversed, mini_batch_mask,
                    mini_batch_seq_lengths)
                pri_z, pri_z_loc, pri_z_scale = prior(
                    length=mini_batch.size(1), N_generate=mini_batch.size(0))

                # Regularizer (KL-diveergence(pos||pri))
                regularizer = seiki * D_Wass(pos_z_loc, pri_z_loc, pos_z_scale,
                                             pri_z_scale)
                # regularizer = seiki * D_KL(pos_z_loc, pri_z_loc, pos_z_scale, pri_z_scale)
                # regularizer = seiki * D_JS_Monte(pos_z_loc, pri_z_loc, pos_z_scale, pri_z_scale, pos_z, pri_z)
                reconstruction_error = seiki * torch.norm(
                    mini_batch - decoder(pos_z),
                    dim=2).sum() / mini_batch.size(0) / mini_batch.size(
                        1) / mini_batch.size(2)

                loss += args.lam * regularizer + reconstruction_error
            reconed_x = decoder(pos_z)

            loss.backward()
            optimizer.step()
            scheduler.step()
            # if args.rnn_clip != None:
            #     for p in dmm.rnn.parameters():
            #         p.data.clamp_(-args.rnn_clip, args.rnn_clip)

            epoch_nll += loss
            epoch_recon_error += reconstruction_error

        # report training diagnostics
        times.append(time.time())
        losses.append(epoch_nll)
        recon_errors.append(epoch_recon_error)
        # epoch_time = times[-1] - times[-2]
        pbar.set_description("LOSS = %f " % epoch_nll)

        if epoch % args.checkpoint_freq == 0:
            path = os.path.join("saveData", now, "Epoch%d" % epoch)
            os.makedirs(path, exist_ok=True)
            for i in range(len(mini_batch)):
                saveReconSinGraph(mini_batch[i], reconed_x[i], args.length,
                                  path, i)
                saveGeneSinGraph(decoder(pri_z)[i], args.length, path, i)

        if epoch == args.num_epochs - 1:
            path = os.path.join("saveData", now, "Epoch%d" % args.num_epochs)
            os.makedirs(path, exist_ok=True)
            for i in range(len(mini_batch)):
                saveReconSinGraph(mini_batch[i], reconed_x[i], args.length,
                                  path, i)
                saveGeneSinGraph(decoder(pri_z)[i], args.length, path, i)

            saveGraph(losses, recon_errors, now)
            saveDic = {
                "optimizer": optimizer,
                "mini_batch": mini_batch,
                "Encoder_dic": encoder.state_dict,
                "Prior_dic": prior.state_dict,
                "Emitter_dic": decoder.state_dict,
                "epoch_times": times,
                "losses": losses
            }
            # torch.save(saveDic,os.path.join("saveData", now, "dic_Epoch%d"%(epoch+1)))
            torch.save(saveDic, os.path.join("saveData", now, "DMM_dic"))
def main(args):

    ## This func is for saving losses at final
    def saveGraph(loss_list, sub_error_list, now):
        FS = 10
        fig = plt.figure()
        # plt.rcParams["font.size"] = FS
        plt.plot(loss_list, label="LOSS")
        plt.plot(sub_error_list, label="Reconstruction Error")
        plt.ylim(bottom=0)
        plt.title("Loss")
        plt.xlabel("epoch", fontsize=FS)
        plt.ylabel("loss", fontsize=FS)
        plt.legend()
        fig.savefig(os.path.join("saveData", now, "LOSS.png"))
        plt.close()

    def saveReconSinGraph(train_data, recon_data, length, path, number):
        FS = 10
        fig = plt.figure()
        # plt.rcParams["font.size"] = FS
        x = np.linspace(0, 2 * np.pi, length)
        plt.plot(x, train_data, label="Training data")
        plt.plot(x, recon_data.detach().numpy(), label="Reconstructed data")
        # plt.ylim(bottom=0)
        plt.title("Sin Curves")
        # plt.ylim(top=10, bottom=-10)
        plt.xlabel("time", fontsize=FS)
        plt.ylabel("y", fontsize=FS)
        plt.legend()
        fig.savefig(os.path.join(path,
                                 "Reconstruction" + str(number) + ".png"))
        plt.close()

    def saveGeneSinGraph(gene_data, length, path, number):
        FS = 10
        fig = plt.figure()
        # plt.rcParams["font.size"] = FS
        x = np.linspace(0, 2 * np.pi, length)
        plt.plot(x, gene_data.detach().numpy(), label="Generated data")
        # plt.ylim(bottom=0)
        plt.title("Sin Curves")
        # plt.ylim(top=10, bottom=-10)
        plt.xlabel("time", fontsize=FS)
        plt.ylabel("y", fontsize=FS)
        plt.legend()
        fig.savefig(os.path.join(path, "Generation" + str(number) + ".png"))
        plt.close()

    ## This func is for save generatedTones and trainingTones as MIDI
    def save_as_midi(song, path="", name="default.mid", BPM=120, velocity=100):
        pm = pretty_midi.PrettyMIDI(resolution=960,
                                    initial_tempo=BPM)  #pretty_midiオブジェクトを作ります
        instrument = pretty_midi.Instrument(0)  #instrumentはトラックみたいなものです。
        for i, tones in enumerate(song):
            which_tone = torch.nonzero((tones == 1),
                                       as_tuple=False).reshape(-1)
            if len(which_tone) == 0:
                note = pretty_midi.Note(
                    velocity=0, pitch=0, start=i,
                    end=i + 1)  #noteはNoteOnEventとNoteOffEventに相当します。
                instrument.notes.append(note)
            else:
                for which in which_tone:
                    note = pretty_midi.Note(
                        velocity=velocity,
                        pitch=int(which),
                        start=i,
                        end=i + 1)  #noteはNoteOnEventとNoteOffEventに相当します。
                    instrument.notes.append(note)
        pm.instruments.append(instrument)
        pm.write(os.path.join(path, name))  #midiファイルを書き込みます。

    # generate Xs
    def generate_Xs(batch_data):
        songs_list = []
        for i, song in enumerate(batch_data):
            tones_container = []
            for time in range(batch_data.size(1)):
                p = dist.Bernoulli(probs=song[time])
                tone = p.sample()
                tones_container.append(tone)
            tones_container = torch.stack(tones_container)
            songs_list.append(tones_container)
        return songs_list

    def saveSongs(songs_list, mini_batch, N_songs, path):
        # print(len(songs_list[0][0]))
        if len(songs_list) != len(mini_batch):
            assert False
        if N_songs <= len(songs_list):
            song_No = random.sample(range(len(songs_list)), k=N_songs)
        else:
            song_No = random.sample(range(len(songs_list)), k=len(songs_list))
        for i, Number in enumerate(song_No):
            save_as_midi(song=songs_list[Number],
                         path=path,
                         name="No%d_Gene.midi" % i)
            save_as_midi(song=mini_batch[Number],
                         path=path,
                         name="No%d_Tran.midi" % i)

    FS = 10
    plt.rcParams["font.size"] = FS

    ## 長さ最長129、例えば長さが60のやつは61~129はすべて0データ
    data = poly.load_data(poly.JSB_CHORALES)
    training_seq_lengths = data['train']['sequence_lengths'][:args.N_songs]
    training_data_sequences = data['train']['sequences'][:args.N_songs, :8]
    training_seq_lengths = torch.tensor([8] * args.N_songs)

    if args.doremifasorasido:
        # ## ドドド、レレレ、ミミミ、ドレミ
        training_seq_lengths, training_data_sequences = musics.doremifasorasido(
            args.N_songs)
    if args.dododorerere:
        # ドドド、レレレ
        training_seq_lengths, training_data_sequences = musics.dododorerere(
            args.N_songs)
    if args.dodododododo:
        ## ドドドのみ
        training_seq_lengths, training_data_sequences = musics.dodododododo(
            args.N_songs)

    if args.sin:
        training_data_sequences = musics.createSin_allChanged(
            args.N_songs, args.length)
        training_seq_lengths = torch.tensor([args.length] * args.N_songs)

    training_data_sequences = musics.createNewTrainingData(
        args.N_songs, args.length)
    # traing_data =[]
    # for i in range(args.N_songs):
    #     traing_data.append(musics.Nonlinear(torch.randn(1), args.length, T = args.T))
    # training_data_sequences = torch.stack(traing_data)
    training_seq_lengths = torch.tensor([args.length] * args.N_songs)
    data_dim = training_data_sequences.size(-1)

    N_train_data = len(training_seq_lengths)
    N_train_time_slices = float(torch.sum(training_seq_lengths))
    N_mini_batches = int(N_train_data / args.mini_batch_size +
                         int(N_train_data % args.mini_batch_size > 0))

    z_dim = args.z_dim  #100 #2
    rnn_dim = args.rnn_dim  #200 #30
    transition_dim = args.transition_dim  #200 #30
    emission_dim = args.emission_dim  #100 #30
    encoder = Encoder(input_dim=training_data_sequences.size(-1),
                      z_dim=z_dim,
                      rnn_dim=rnn_dim,
                      N_z0=args.N_songs)
    # encoder = Encoder(z_dim=100, rnn_dim=200)
    prior = Prior(z_dim=z_dim,
                  transition_dim=transition_dim,
                  N_z0=args.N_songs,
                  use_cuda=args.cuda,
                  R=args.R)
    decoder = Emitter(input_dim=training_data_sequences.size(-1),
                      z_dim=z_dim,
                      emission_dim=emission_dim,
                      use_cuda=args.cuda)

    # Create optimizer algorithm
    # optimizer = optim.SGD(dmm.parameters(), lr=args.learning_rate)
    # optimizer = optim.Adam(dmm.parameters(), lr=args.learning_rate, betas=(0.96, 0.999), weight_decay=2.0)
    params = list(encoder.parameters()) + list(prior.parameters()) + list(
        decoder.parameters())
    optimizer = optim.Adam(params, lr=args.learning_rate)
    # Add learning rate scheduler
    scheduler = optim.lr_scheduler.ExponentialLR(optimizer, 0.9999)
    # scheduler = optim.lr_scheduler.ExponentialLR(optimizer, 1.)
    #1でもやってみる

    # make directory for Data
    now = datetime.datetime.now().strftime('%Y%m%d_%H_%M')
    os.makedirs(os.path.join("saveData", now), exist_ok=True)

    # save args as TEXT
    f = open(os.path.join("saveData", now, 'args.txt'), 'w')  # 書き込みモードで開く
    for name, site in vars(args).items():
        f.write(name + " = ")
        f.write(str(site) + "\n")
    f.close()  # ファイルを閉じる

    #######################
    #### TRAINING LOOP ####
    #######################
    times = [time.time()]
    losses = []
    recon_errors = []
    pbar = tqdm.tqdm(range(args.num_epochs))
    for epoch in pbar:
        epoch_nll = 0
        epoch_recon_error = 0
        shuffled_indices = torch.randperm(N_train_data)
        # print("Proceeding: %.2f " % (epoch*100/5000) + "%")

        # process each mini-batch; this is where we take gradient steps
        for which_mini_batch in range(N_mini_batches):

            # compute which sequences in the training set we should grab
            mini_batch_start = (which_mini_batch * args.mini_batch_size)
            mini_batch_end = np.min([
                (which_mini_batch + 1) * args.mini_batch_size, N_train_data
            ])
            mini_batch_indices = shuffled_indices[
                mini_batch_start:mini_batch_end]

            # grab a fully prepped mini-batch using the helper function in the data loader
            mini_batch, mini_batch_reversed, mini_batch_mask, mini_batch_seq_lengths \
                = poly.get_mini_batch(mini_batch_indices, training_data_sequences,
                                        training_seq_lengths, cuda=args.cuda)

            # reset gradients
            optimizer.zero_grad()
            loss = 0

            N_repeat = 1
            for i in range(N_repeat):
                seiki = 1 / N_repeat
                # generate mini batch from training mini batch
                pos_z, pos_z_loc, pos_z_scale = encoder(
                    mini_batch, mini_batch_reversed, mini_batch_mask,
                    mini_batch_seq_lengths)
                pri_z, pri_z_loc, pri_z_scale = prior(
                    length=mini_batch.size(1), N_generate=mini_batch.size(0))

                # Regularizer (KL-diveergence(pos||pri))
                # if any(torch.isnan(pos_z_loc.reshape(-1))):
                #     print("pos_z_loc")
                #     assert False
                # if any(torch.isnan(pos_z_scale.reshape(-1))):
                #     print("pos_z_scale")
                #     assert False
                # if any(torch.isnan(pri_z_loc.reshape(-1))):
                #     print("pri_z_loc")
                #     assert False
                # if any(torch.isnan(pri_z_scale.reshape(-1))):
                #     print("pri_z_scale")
                #     assert False
                regularizer = seiki * D_Wass(pos_z_loc, pri_z_loc, pos_z_scale,
                                             pri_z_scale)
                # if any(torch.isnan(regularizer.reshape(-1))):
                #     print("regula")
                #     assert False
                # print(regularizer)
                # regularizer = seiki * D_KL(pos_z_loc, pri_z_loc, pos_z_scale, pri_z_scale)
                # regularizer = seiki * D_JS_Monte(pos_z_loc, pri_z_loc, pos_z_scale, pri_z_scale, pos_z, pri_z)
                reconstruction_error = seiki * torch.norm(
                    mini_batch - decoder(pos_z),
                    dim=2).sum() / mini_batch.size(0) / mini_batch.size(
                        1) / mini_batch.size(2)

                loss += args.lam * regularizer + reconstruction_error
            reconed_x = decoder(pos_z)
            # Reconstruction Error
            # reconstruction_error = torch.norm(mini_batch - reconed_x, dim=2).sum()/mini_batch.size(0)/mini_batch.size(1)/mini_batch.size(2)
            # loss += reconstruction_error

            # # Covariance Penalty
            # sumOfCovariance = torch.sum(pos_z_scale + pri_z_scale)
            # loss -= 0.00007 * sumOfCovariance

            # # generate mini batch from training mini batch
            # pos_z, pos_z_loc, pos_z_scale = encoder(mini_batch, mini_batch_reversed, mini_batch_mask, mini_batch_seq_lengths)
            # pri_z, pri_z_loc, pri_z_scale = prior(length = mini_batch.size(1), N_generate= mini_batch.size(0))

            # reconed_x = decoder(pos_z)
            # # Reconstruction Error
            # reconstruction_error = torch.norm(mini_batch - reconed_x, dim=2).sum()/mini_batch.size(0)/mini_batch.size(1)/mini_batch.size(2)
            # # Regularizer (KL-diveergence(pos||pri))
            # regularizer = D_Wass(pos_z_loc, pri_z_loc, pos_z_scale, pri_z_scale)

            # loss += reconstruction_error + args.lam * regularizer

            # # NaN detector
            # if torch.isnan(loss):
            #     saveDic = {
            #     "optimizer":optimizer,
            #     "mini_batch":mini_batch,
            #     "Encoder_dic": encoder.state_dict,
            #     "Prior_dic": prior.state_dict,
            #     "Emitter_dic": decoder.state_dict,
            #     }
            #     torch.save(saveDic,os.path.join("saveData", now, "fail_DMM_dic"))
            #     assert False, ("LOSS ia NaN!")

            # saveDic = {
            # "optimizer":optimizer,
            # "mini_batch":mini_batch,
            # "Encoder_dic": encoder.state_dict,
            # "Prior_dic": prior.state_dict,
            # "Emitter_dic": decoder.state_dict,
            # }
            # torch.save(saveDic,os.path.join("saveData", now, "DMM_dic"))

            # do an actual gradient step
            loss.backward()
            optimizer.step()
            scheduler.step()
            # if args.rnn_clip != None:
            #     for p in dmm.rnn.parameters():
            #         p.data.clamp_(-args.rnn_clip, args.rnn_clip)

            epoch_nll += loss
            epoch_recon_error += reconstruction_error

        # report training diagnostics
        times.append(time.time())
        losses.append(epoch_nll)
        recon_errors.append(epoch_recon_error)
        # epoch_time = times[-1] - times[-2]
        pbar.set_description("LOSS = %f " % epoch_nll)

        if epoch % args.checkpoint_freq == 0:
            path = os.path.join("saveData", now, "Epoch%d" % epoch)
            os.makedirs(path, exist_ok=True)
            if data_dim != 1:
                reco_list = generate_Xs(reconed_x)
                gene_list = generate_Xs(decoder(pri_z))
                for i in range(args.N_generate):
                    save_as_midi(song=reco_list[i],
                                 path=path,
                                 name="No%d_Reco.midi" % i)
                    save_as_midi(song=mini_batch[i],
                                 path=path,
                                 name="No%d_Real.midi" % i)
                    save_as_midi(song=gene_list[i],
                                 path=path,
                                 name="No%d_Gene.midi" % i)
            else:
                for i in range(len(mini_batch)):
                    saveReconSinGraph(mini_batch[i], reconed_x[i], args.length,
                                      path, i)
                    saveGeneSinGraph(decoder(pri_z)[i], args.length, path, i)
                saveDic = {
                    "optimizer": optimizer,
                    "mini_batch": mini_batch,
                    "Encoder_dic": encoder.state_dict,
                    "Prior_dic": prior.state_dict,
                    "Emitter_dic": decoder.state_dict,
                    "epoch_times": times,
                    "losses": losses
                }
                # torch.save(saveDic,os.path.join("saveData", now, "dic_Epoch%d"%(epoch+1)))
                torch.save(saveDic, os.path.join(path, "DMM_dic"))

        if epoch == args.num_epochs - 1:
            path = os.path.join("saveData", now, "Epoch%d" % args.num_epochs)
            os.makedirs(path, exist_ok=True)
            if data_dim != 1:
                reco_list = generate_Xs(reconed_x)
                gene_list = generate_Xs(decoder(pri_z))
                for i in range(args.N_generate):
                    save_as_midi(song=reco_list[i],
                                 path=path,
                                 name="No%d_Reco.midi" % i)
                    save_as_midi(song=mini_batch[i],
                                 path=path,
                                 name="No%d_Real.midi" % i)
                    save_as_midi(song=gene_list[i],
                                 path=path,
                                 name="No%d_Gene.midi" % i)
            else:
                for i in range(len(mini_batch)):
                    saveReconSinGraph(mini_batch[i], reconed_x[i], args.length,
                                      path, i)
                    saveGeneSinGraph(decoder(pri_z)[i], args.length, path, i)
                saveDic = {
                    "optimizer": optimizer,
                    "mini_batch": mini_batch,
                    "Encoder_dic": encoder.state_dict,
                    "Prior_dic": prior.state_dict,
                    "Emitter_dic": decoder.state_dict,
                    "epoch_times": times,
                    "losses": losses
                }
                # torch.save(saveDic,os.path.join("saveData", now, "dic_Epoch%d"%(epoch+1)))
                torch.save(saveDic, os.path.join(path, "DMM_dic"))

            saveGraph(losses, recon_errors, now)
            saveDic = {
                "optimizer": optimizer,
                "mini_batch": mini_batch,
                "Encoder_dic": encoder.state_dict,
                "Prior_dic": prior.state_dict,
                "Emitter_dic": decoder.state_dict,
                "epoch_times": times,
                "losses": losses
            }
            # torch.save(saveDic,os.path.join("saveData", now, "dic_Epoch%d"%(epoch+1)))
            torch.save(saveDic, os.path.join("saveData", now, "DMM_dic"))
Exemple #4
0
def main(args):
    if args.cuda:
        torch.set_default_tensor_type("torch.cuda.FloatTensor")

    logging.info("Loading data")
    data = poly.load_data(poly.JSB_CHORALES)

    logging.info("-" * 40)
    model = models[args.model]
    logging.info("Training {} on {} sequences".format(
        model.__name__, len(data["train"]["sequences"])))
    sequences = data["train"]["sequences"]
    lengths = data["train"]["sequence_lengths"]

    # find all the notes that are present at least once in the training set
    present_notes = (sequences == 1).sum(0).sum(0) > 0
    # remove notes that are never played (we remove 37/88 notes)
    sequences = sequences[..., present_notes]

    if args.truncate:
        lengths = lengths.clamp(max=args.truncate)
        sequences = sequences[:, :args.truncate]
    num_observations = float(lengths.sum())
    pyro.set_rng_seed(args.seed)
    pyro.clear_param_store()

    # We'll train using MAP Baum-Welch, i.e. MAP estimation while marginalizing
    # out the hidden state x. This is accomplished via an automatic guide that
    # learns point estimates of all of our conditional probability tables,
    # named probs_*.
    guide = AutoDelta(
        poutine.block(model,
                      expose_fn=lambda msg: msg["name"].startswith("probs_")))

    # To help debug our tensor shapes, let's print the shape of each site's
    # distribution, value, and log_prob tensor. Note this information is
    # automatically printed on most errors inside SVI.
    if args.print_shapes:
        first_available_dim = -2 if model is model_0 else -3
        guide_trace = poutine.trace(guide).get_trace(
            sequences, lengths, args=args, batch_size=args.batch_size)
        model_trace = poutine.trace(
            poutine.replay(poutine.enum(model, first_available_dim),
                           guide_trace)).get_trace(sequences,
                                                   lengths,
                                                   args=args,
                                                   batch_size=args.batch_size)
        logging.info(model_trace.format_shapes())

    # Enumeration requires a TraceEnum elbo and declaring the max_plate_nesting.
    # All of our models have two plates: "data" and "tones".
    optim = Adam({"lr": args.learning_rate})
    if args.tmc:
        if args.jit:
            raise NotImplementedError(
                "jit support not yet added for TraceTMC_ELBO")
        elbo = TraceTMC_ELBO(max_plate_nesting=1 if model is model_0 else 2)
        tmc_model = poutine.infer_config(
            model,
            lambda msg: {
                "num_samples": args.tmc_num_samples,
                "expand": False
            } if msg["infer"].get("enumerate", None) == "parallel" else {},
        )  # noqa: E501
        svi = SVI(tmc_model, guide, optim, elbo)
    else:
        Elbo = JitTraceEnum_ELBO if args.jit else TraceEnum_ELBO
        elbo = Elbo(
            max_plate_nesting=1 if model is model_0 else 2,
            strict_enumeration_warning=(model is not model_7),
            jit_options={"time_compilation": args.time_compilation},
        )
        svi = SVI(model, guide, optim, elbo)

    # We'll train on small minibatches.
    logging.info("Step\tLoss")
    for step in range(args.num_steps):
        loss = svi.step(sequences,
                        lengths,
                        args=args,
                        batch_size=args.batch_size)
        logging.info("{: >5d}\t{}".format(step, loss / num_observations))

    if args.jit and args.time_compilation:
        logging.debug("time to compile: {} s.".format(
            elbo._differentiable_loss.compile_time))

    # We evaluate on the entire training dataset,
    # excluding the prior term so our results are comparable across models.
    train_loss = elbo.loss(model,
                           guide,
                           sequences,
                           lengths,
                           args,
                           include_prior=False)
    logging.info("training loss = {}".format(train_loss / num_observations))

    # Finally we evaluate on the test dataset.
    logging.info("-" * 40)
    logging.info("Evaluating on {} test sequences".format(
        len(data["test"]["sequences"])))
    sequences = data["test"]["sequences"][..., present_notes]
    lengths = data["test"]["sequence_lengths"]
    if args.truncate:
        lengths = lengths.clamp(max=args.truncate)
    num_observations = float(lengths.sum())

    # note that since we removed unseen notes above (to make the problem a bit easier and for
    # numerical stability) this test loss may not be directly comparable to numbers
    # reported on this dataset elsewhere.
    test_loss = elbo.loss(model,
                          guide,
                          sequences,
                          lengths,
                          args=args,
                          include_prior=False)
    logging.info("test loss = {}".format(test_loss / num_observations))

    # We expect models with higher capacity to perform better,
    # but eventually overfit to the training set.
    capacity = sum(
        value.reshape(-1).size(0) for value in pyro.get_param_store().values())
    logging.info("{} capacity = {} parameters".format(model.__name__,
                                                      capacity))
def main(args):

    ## ドレミ
    def easyTones():
        max = 70
        interval = 10
        N = 10
        training_seq_lengths = torch.tensor([8] * N)
        training_data_sequences = torch.zeros(N, 8, 88)
        for i in range(N):
            training_data_sequences[i][0][int(max - i * interval)] = 1
            training_data_sequences[i][1][int(max - i * interval) + 2] = 1
            training_data_sequences[i][2][int(max - i * interval) + 4] = 1
            training_data_sequences[i][3][int(max - i * interval) + 5] = 1
            training_data_sequences[i][4][int(max - i * interval) + 7] = 1
            training_data_sequences[i][5][int(max - i * interval) + 9] = 1
            training_data_sequences[i][6][int(max - i * interval) + 11] = 1
            training_data_sequences[i][7][int(max - i * interval) + 12] = 1
        return training_seq_lengths, training_data_sequences

    ## ドドド、レレレ
    def superEasyTones():
        training_seq_lengths = torch.tensor([8] * 10)
        training_data_sequences = torch.zeros(10, 8, 88)
        for i in range(10):
            for j in range(8):
                training_data_sequences[i][j][int(30 + i * 5)] = 1
        return training_seq_lengths, training_data_sequences

    ## ドドド、ドドド、ドドド
    def easiestTones():
        training_seq_lengths = torch.tensor([8] * 10)
        training_data_sequences = torch.zeros(10, 8, 88)
        for i in range(10):
            for j in range(8):
                training_data_sequences[i][j][int(70)] = 1
        return training_seq_lengths, training_data_sequences

    # rep is short for "repeat"
    # which means how many times we use certain sample to do validation/test evaluation during training
    def rep(x):
        rep_shape = torch.Size([x.size(0) * n_eval_samples]) + x.size()[1:]
        repeat_dims = [1] * len(x.size())
        repeat_dims[0] = n_eval_samples
        return x.repeat(repeat_dims).reshape(n_eval_samples, -1).transpose(
            1, 0).reshape(rep_shape)

    ## This func is for saving losses at final
    def saveGraph(loss_list, now):
        FS = 10
        fig = plt.figure()
        plt.rcParams["font.size"] = FS
        plt.plot(loss_list)
        plt.ylim(bottom=0)
        plt.title("Wasserstein Loss")
        plt.xlabel("epoch", fontsize=FS)
        plt.ylabel("loss", fontsize=FS)
        fig.savefig(os.path.join("saveData", now, "LOSS.png"))

    ## This func is for save generatedTones and trainingTones as MIDI
    def save_as_midi(song, path="", name="default.mid", BPM=120, velocity=100):
        pm = pretty_midi.PrettyMIDI(resolution=960,
                                    initial_tempo=BPM)  #pretty_midiオブジェクトを作ります
        instrument = pretty_midi.Instrument(0)  #instrumentはトラックみたいなものです。
        for i, tones in enumerate(song):
            which_tone = torch.nonzero((tones == 1),
                                       as_tuple=False).reshape(-1)
            if len(which_tone) == 0:
                note = pretty_midi.Note(
                    velocity=0, pitch=0, start=i,
                    end=i + 1)  #noteはNoteOnEventとNoteOffEventに相当します。
                instrument.notes.append(note)
            else:
                for which in which_tone:
                    note = pretty_midi.Note(
                        velocity=velocity,
                        pitch=int(which),
                        start=i,
                        end=i + 1)  #noteはNoteOnEventとNoteOffEventに相当します。
                    instrument.notes.append(note)
        pm.instruments.append(instrument)
        pm.write(os.path.join(path, name))  #midiファイルを書き込みます。

    # generate Xs
    def generate_Xs(dmm, mini_batch, mini_batch_reversed, mini_batch_mask,
                    mini_batch_seq_lengths):
        songs_list = []
        generated_mini_batch = dmm(mini_batch, mini_batch_reversed,
                                   mini_batch_mask, mini_batch_seq_lengths)
        for i, song in enumerate(generated_mini_batch):
            tones_container = []
            for time in range(mini_batch_seq_lengths[i]):
                p = dist.Bernoulli(probs=song[time])
                tone = p.sample()
                tones_container.append(tone)
            tones_container = torch.stack(tones_container)
            songs_list.append(tones_container)
        return songs_list

    def saveSongs(songs_list, mini_batch, N_songs, path):
        # print(len(songs_list[0][0]))
        if len(songs_list) != len(mini_batch):
            assert False
        if N_songs <= len(songs_list):
            song_No = random.sample(range(len(songs_list)), k=N_songs)
        else:
            song_No = random.sample(range(len(songs_list)), k=len(songs_list))
        for i, Number in enumerate(song_No):
            save_as_midi(song=songs_list[Number],
                         path=path,
                         name="No%d_Gene.midi" % i)
            save_as_midi(song=mini_batch[Number],
                         path=path,
                         name="No%d_Tran.midi" % i)

    ## 長さ最長129、例えば長さが60のやつは61~129はすべて0データ
    data = poly.load_data(poly.JSB_CHORALES)
    training_seq_lengths = data['train']['sequence_lengths'][:3]
    training_data_sequences = data['train']['sequences'][:3, :8]
    training_seq_lengths = torch.tensor([8] * 10)
    # training_seq_lengths = torch.tensor([8])

    if args.eas:
        # ## ドドド、レレレ、ミミミ、ドレミ
        training_seq_lengths, training_data_sequences = easyTones()

    if args.sup:
        # ドドド、レレレ
        training_seq_lengths, training_data_sequences = superEasyTones()

    if args.est:
        ## ドドドのみ
        training_seq_lengths, training_data_sequences = easiestTones()

    test_seq_lengths = data['test']['sequence_lengths'][:3]
    test_seq_lengths = torch.tensor([8, 8, 8])
    test_data_sequences = data['test']['sequences'][:3, :8]
    val_seq_lengths = data['valid']['sequence_lengths']
    val_data_sequences = data['valid']['sequences']
    N_train_data = len(training_seq_lengths)
    # N_train_data = len(test_seq_lengths)
    N_train_time_slices = float(torch.sum(training_seq_lengths))
    N_mini_batches = int(N_train_data / args.mini_batch_size +
                         int(N_train_data % args.mini_batch_size > 0))

    # how often we do validation/test evaluation during training
    val_test_frequency = 50
    # the number of samples we use to do the evaluation
    n_eval_samples = 1

    # get the validation/test data ready for the dmm: pack into sequences, etc.
    # val_seq_lengths = rep(val_seq_lengths)
    # # test_seq_lengths = rep(test_seq_lengths)
    # val_batch, val_batch_reversed, val_batch_mask, val_seq_lengths = poly.get_mini_batch(
    #     torch.arange(n_eval_samples * val_data_sequences.shape[0]), rep(val_data_sequences),
    #     val_seq_lengths, cuda=args.cuda)
    # test_batch, test_batch_reversed, test_batch_mask, test_seq_lengths = poly.get_mini_batch(
    #     torch.arange(n_eval_samples * test_data_sequences.shape[0]), rep(test_data_sequences),
    #     test_seq_lengths, cuda=args.cuda)

    dmm = DMM(use_cuda=args.cuda, rnn_check=args.rck, rpt=args.rpt)
    WN = WGAN_network(use_cuda=args.cuda)
    W = WassersteinLoss(WN,
                        N_loops=args.N_loops,
                        lr=args.w_learning_rate,
                        use_cuda=args.cuda)

    # Create optimizer algorithm
    # optimizer = optim.SGD(dmm.parameters(), lr=args.learning_rate)
    # optimizer = optim.Adam(dmm.parameters(), lr=args.learning_rate, betas=(0.96, 0.999), weight_decay=2.0)
    optimizer = optim.Adam(dmm.parameters(), lr=args.learning_rate)
    # Add learning rate scheduler
    # scheduler = optim.lr_scheduler.ExponentialLR(optimizer, 0.9999)
    scheduler = optim.lr_scheduler.ExponentialLR(optimizer, 1.)
    #1でもやってみる

    # make directory for Data
    now = datetime.datetime.now().strftime('%Y%m%d_%H_%M')
    os.makedirs(os.path.join("saveData", now), exist_ok=True)

    # save args as TEXT
    f = open(os.path.join("saveData", now, 'args.txt'), 'w')  # 書き込みモードで開く
    for name, site in vars(args).items():
        f.write(name + " = ")
        f.write(str(site) + "\n")
    f.close()  # ファイルを閉じる

    #######################
    #### TRAINING LOOP ####
    #######################
    times = [time.time()]
    losses = []
    pbar = tqdm.tqdm(range(args.num_epochs))
    for epoch in pbar:
        epoch_nll = 0
        # shuffled_indices = torch.randperm(N_train_data)
        shuffled_indices = torch.arange(N_train_data)
        # print("Proceeding: %.2f " % (epoch*100/5000) + "%")

        # process each mini-batch; this is where we take gradient steps
        for which_mini_batch in range(N_mini_batches):

            # compute which sequences in the training set we should grab
            mini_batch_start = (which_mini_batch * args.mini_batch_size)
            mini_batch_end = np.min([
                (which_mini_batch + 1) * args.mini_batch_size, N_train_data
            ])
            mini_batch_indices = shuffled_indices[
                mini_batch_start:mini_batch_end]

            # grab a fully prepped mini-batch using the helper function in the data loader
            mini_batch, mini_batch_reversed, mini_batch_mask, mini_batch_seq_lengths \
                = poly.get_mini_batch(mini_batch_indices, training_data_sequences,
                                        training_seq_lengths, cuda=args.cuda)
            # mini_batch, mini_batch_reversed, mini_batch_mask, mini_batch_seq_lengths \
            #     = poly.get_mini_batch(mini_batch_indices, test_data_sequences,
            #                             test_seq_lengths, cuda=args.cuda)

            # reset gradients
            optimizer.zero_grad()

            # generate mini batch from training mini batch
            # NaN_detect(dmm, epoch, message="Before Generate")
            generated_mini_batch = dmm(mini_batch, mini_batch_reversed,
                                       mini_batch_mask, mini_batch_seq_lengths)
            # NaN_detect(dmm, epoch, message="After Generate")
            # if any(torch.isnan(generated_mini_batch.reshape(-1))):
            #     print("GENERATED")
            #     assert False
            # NaN_detect(WN, epoch, message="calc_Before")
            # calculate loss
            if args.wass:
                loss = EMD(mini_batch, generated_mini_batch)
            if args.wgan:
                loss = W.calc(mini_batch, generated_mini_batch)
            # NaN_detect(WN, epoch, message="calc_After")

            # NaN_detect(dmm, epoch, message="step_Before")
            # do an actual gradient step
            loss.backward()
            optimizer.step()
            scheduler.step()
            if args.rnn_clip != None:
                for p in dmm.rnn.parameters():
                    p.data.clamp_(-args.rnn_clip, args.rnn_clip)
                    # p.data.clamp_(-0.01, 0.01)
            # NaN_detect(dmm, epoch, message="step_After")

            epoch_nll += loss

        # report training diagnostics
        times.append(time.time())
        losses.append(epoch_nll)
        epoch_time = times[-1] - times[-2]
        # logging.info("[training epoch %04d]  %.4f \t\t\t\t(dt = %.3f sec)" %
        #                 (epoch, epoch_nll / N_train_time_slices, epoch_time))
        # if epoch % 10 == 0:
        # print("epoch %d time : %d sec" % (epoch, int(epoch_time)))
        pbar.set_description("LOSS = %f " % epoch_nll)
        # tqdm.write("\r"+"        loss : %f " % epoch_nll, end="")

        if epoch % args.checkpoint_freq == 0:
            path = os.path.join("saveData", now, "Epoch%d" % epoch)
            os.makedirs(path, exist_ok=True)
            songs_list = generate_Xs(dmm, mini_batch, mini_batch_reversed,
                                     mini_batch_mask, mini_batch_seq_lengths)
            saveSongs(songs_list, mini_batch, N_songs=2, path=path)

        if epoch == args.num_epochs - 1:
            path = os.path.join("saveData", now, "Epoch%d" % args.num_epochs)
            os.makedirs(path, exist_ok=True)
            songs_list = generate_Xs(dmm, mini_batch, mini_batch_reversed,
                                     mini_batch_mask, mini_batch_seq_lengths)
            saveSongs(songs_list, mini_batch, N_songs=2, path=path)

            saveGraph(losses, now)
            saveDic = {
                "DMM_dic": dmm.state_dict,
                "WGAN_Network_dic": WN.state_dict,
                "epoch_times": times,
                "losses": losses
            }
            # torch.save(saveDic,os.path.join("saveData", now, "dic_Epoch%d"%(epoch+1)))
            torch.save(saveDic, os.path.join("saveData", now, "DMM_dic"))