示例#1
0
    def process_minibatch(epoch, which_mini_batch, shuffled_indices):
        if args.annealing_epochs > 0 and epoch < args.annealing_epochs:
            # compute the KL annealing factor approriate for the current mini-batch in the current epoch
            min_af = args.minimum_annealing_factor
            annealing_factor = min_af + (1.0 - min_af) * \
                (float(which_mini_batch + epoch * N_mini_batches + 1) /
                 float(args.annealing_epochs * N_mini_batches))
        else:
            # by default the KL annealing factor is unity
            annealing_factor = 1.0

        # compute which sequences in the training set we should grab
        mini_batch_start = (which_mini_batch * args.mini_batch_size)
        mini_batch_end = np.min([(which_mini_batch + 1) * args.mini_batch_size,
                                 N_train_data])
        mini_batch_indices = shuffled_indices[mini_batch_start:mini_batch_end]
        # grab a fully prepped mini-batch using the helper function in the data loader
        mini_batch, mini_batch_reversed, mini_batch_mask, mini_batch_seq_lengths \
            = poly.get_mini_batch(mini_batch_indices, training_data_sequences,
                                  training_seq_lengths, cuda=args.cuda)
        # do an actual gradient step
        loss = svi.step(mini_batch, mini_batch_reversed, mini_batch_mask,
                        mini_batch_seq_lengths, annealing_factor)
        # keep track of the training loss
        return loss
示例#2
0
def main(args):
    ## ドレミ
    def easyTones():
        training_seq_lengths = torch.tensor([8]*1)
        training_data_sequences = torch.zeros(1,8,88)
        for i in range(1):
            training_data_sequences[i][0][int(70-i*10)  ] = 1
            training_data_sequences[i][1][int(70-i*10)+2] = 1
            training_data_sequences[i][2][int(70-i*10)+4] = 1
            training_data_sequences[i][3][int(70-i*10)+5] = 1
            training_data_sequences[i][4][int(70-i*10)+7] = 1
            training_data_sequences[i][5][int(70-i*10)+9] = 1
            training_data_sequences[i][6][int(70-i*10)+11] = 1
            training_data_sequences[i][7][int(70-i*10)+12] = 1
        return training_seq_lengths, training_data_sequences

    def superEasyTones():
        training_seq_lengths = torch.tensor([8]*10)
        training_data_sequences = torch.zeros(10,8,88)
        for i in range(10):
            for j in range(8):
                training_data_sequences[i][j][int(30+i*5)] = 1
        return training_seq_lengths, training_data_sequences

    ## ドドド、ドドド、ドドド
    def easiestTones():
        training_seq_lengths = torch.tensor([8]*10)
        training_data_sequences = torch.zeros(10,8,88)
        for i in range(10):
            for j in range(8):
                training_data_sequences[i][j][int(70)] = 1
        return training_seq_lengths, training_data_sequences

    # setup logging
    logging.basicConfig(level=logging.DEBUG, format='%(message)s', filename=args.log, filemode='w')
    console = logging.StreamHandler()
    console.setLevel(logging.INFO)
    logging.getLogger('').addHandler(console)
    logging.info(args)

    data = poly.load_data(poly.JSB_CHORALES)
    training_seq_lengths = data['train']['sequence_lengths']
    training_data_sequences = data['train']['sequences']
    training_seq_lengths, training_data_sequences = easiestTones()
    test_seq_lengths = data['test']['sequence_lengths']
    test_data_sequences = data['test']['sequences']
    test_seq_lengths, test_data_sequences = easiestTones()
    val_seq_lengths = data['valid']['sequence_lengths']
    val_data_sequences = data['valid']['sequences']
    val_seq_lengths, val_data_sequences = easiestTones()
    N_train_data = len(training_seq_lengths)
    N_train_time_slices = float(torch.sum(training_seq_lengths))
    N_mini_batches = int(N_train_data / args.mini_batch_size +
                         int(N_train_data % args.mini_batch_size > 0))

    logging.info("N_train_data: %d     avg. training seq. length: %.2f    N_mini_batches: %d" %
                 (N_train_data, training_seq_lengths.float().mean(), N_mini_batches))

    # how often we do validation/test evaluation during training
    val_test_frequency = 50
    # the number of samples we use to do the evaluation
    n_eval_samples = 1

    # package repeated copies of val/test data for faster evaluation
    # (i.e. set us up for vectorization)
    def rep(x):
        rep_shape = torch.Size([x.size(0) * n_eval_samples]) + x.size()[1:]
        repeat_dims = [1] * len(x.size())
        repeat_dims[0] = n_eval_samples
        return x.repeat(repeat_dims).reshape(n_eval_samples, -1).transpose(1, 0).reshape(rep_shape)

    # get the validation/test data ready for the dmm: pack into sequences, etc.
    val_seq_lengths = rep(val_seq_lengths)
    test_seq_lengths = rep(test_seq_lengths)
    val_batch, val_batch_reversed, val_batch_mask, val_seq_lengths = poly.get_mini_batch(
        torch.arange(n_eval_samples * val_data_sequences.shape[0]), rep(val_data_sequences),
        val_seq_lengths, cuda=args.cuda)
    test_batch, test_batch_reversed, test_batch_mask, test_seq_lengths = poly.get_mini_batch(
        torch.arange(n_eval_samples * test_data_sequences.shape[0]), rep(test_data_sequences),
        test_seq_lengths, cuda=args.cuda)

    # instantiate the dmm
    dmm = DMM(rnn_dropout_rate=args.rnn_dropout_rate, num_iafs=args.num_iafs,
              iaf_dim=args.iaf_dim, use_cuda=args.cuda)

    # setup optimizer
    adam_params = {"lr": args.learning_rate, "betas": (args.beta1, args.beta2),
                   "clip_norm": args.clip_norm, "lrd": args.lr_decay,
                   "weight_decay": args.weight_decay}
    adam = ClippedAdam(adam_params)

    # setup inference algorithm
    if args.tmc:
        if args.jit:
            raise NotImplementedError("no JIT support yet for TMC")
        tmc_loss = TraceTMC_ELBO()
        dmm_guide = config_enumerate(dmm.guide, default="parallel", num_samples=args.tmc_num_samples, expand=False)
        svi = SVI(dmm.model, dmm_guide, adam, loss=tmc_loss)
    elif args.tmcelbo:
        if args.jit:
            raise NotImplementedError("no JIT support yet for TMC ELBO")
        elbo = TraceEnum_ELBO()
        dmm_guide = config_enumerate(dmm.guide, default="parallel", num_samples=args.tmc_num_samples, expand=False)
        svi = SVI(dmm.model, dmm_guide, adam, loss=elbo)
    else:
        elbo = JitTrace_ELBO() if args.jit else Trace_ELBO()
        svi = SVI(dmm.model, dmm.guide, adam, loss=elbo)

    # now we're going to define some functions we need to form the main training loop

    # saves the model and optimizer states to disk
    def save_checkpoint():
        logging.info("saving model to %s..." % args.save_model)
        torch.save(dmm.state_dict(), args.save_model)
        # logging.info("saving optimizer states to %s..." % args.save_opt)
        # adam.save(args.save_opt)
        logging.info("done saving model and optimizer checkpoints to disk.")

    # loads the model and optimizer states from disk
    def load_checkpoint():
        assert exists(args.load_opt) and exists(args.load_model), \
            "--load-model and/or --load-opt misspecified"
        logging.info("loading model from %s..." % args.load_model)
        dmm.load_state_dict(torch.load(args.load_model))
        logging.info("loading optimizer states from %s..." % args.load_opt)
        adam.load(args.load_opt)
        logging.info("done loading model and optimizer states.")

    # prepare a mini-batch and take a gradient step to minimize -elbo
    def process_minibatch(epoch, which_mini_batch, shuffled_indices):
        if args.annealing_epochs > 0 and epoch < args.annealing_epochs:
            # compute the KL annealing factor approriate for the current mini-batch in the current epoch
            min_af = args.minimum_annealing_factor
            annealing_factor = min_af + (1.0 - min_af) * \
                (float(which_mini_batch + epoch * N_mini_batches + 1) /
                 float(args.annealing_epochs * N_mini_batches))
        else:
            # by default the KL annealing factor is unity
            annealing_factor = 1.0

        # compute which sequences in the training set we should grab
        mini_batch_start = (which_mini_batch * args.mini_batch_size)
        mini_batch_end = np.min([(which_mini_batch + 1) * args.mini_batch_size, N_train_data])
        mini_batch_indices = shuffled_indices[mini_batch_start:mini_batch_end]
        # grab a fully prepped mini-batch using the helper function in the data loader
        mini_batch, mini_batch_reversed, mini_batch_mask, mini_batch_seq_lengths \
            = poly.get_mini_batch(mini_batch_indices, training_data_sequences,
                                  training_seq_lengths, cuda=args.cuda)
        # do an actual gradient step
        loss = svi.step(mini_batch, mini_batch_reversed, mini_batch_mask,
                        mini_batch_seq_lengths, annealing_factor)
        # keep track of the training loss
        return loss

    # helper function for doing evaluation
    def do_evaluation():
        # put the RNN into evaluation mode (i.e. turn off drop-out if applicable)
        dmm.rnn.eval()

        # compute the validation and test loss n_samples many times
        val_nll = svi.evaluate_loss(val_batch, val_batch_reversed, val_batch_mask,
                                    val_seq_lengths) / float(torch.sum(val_seq_lengths))
        test_nll = svi.evaluate_loss(test_batch, test_batch_reversed, test_batch_mask,
                                     test_seq_lengths) / float(torch.sum(test_seq_lengths))

        # put the RNN back into training mode (i.e. turn on drop-out if applicable)
        dmm.rnn.train()
        return val_nll, test_nll

    # if checkpoint files provided, load model and optimizer states from disk before we start training
    if args.load_opt != '' and args.load_model != '':
        load_checkpoint()

    #################
    # TRAINING LOOP #
    #################
    times = [time.time()]
    for epoch in range(args.num_epochs):
        # if specified, save model and optimizer states to disk every checkpoint_freq epochs
        if args.checkpoint_freq > 0 and epoch > 0 and epoch % args.checkpoint_freq == 0:
            save_checkpoint()

        # accumulator for our estimate of the negative log likelihood (or rather -elbo) for this epoch
        epoch_nll = 0.0
        # prepare mini-batch subsampling indices for this epoch
        shuffled_indices = torch.randperm(N_train_data)

        # process each mini-batch; this is where we take gradient steps
        for which_mini_batch in range(N_mini_batches):
            epoch_nll += process_minibatch(epoch, which_mini_batch, shuffled_indices)

        # report training diagnostics
        times.append(time.time())
        epoch_time = times[-1] - times[-2]
        logging.info("[training epoch %04d]  %.4f \t\t\t\t(dt = %.3f sec)" %
                     (epoch, epoch_nll / N_train_time_slices, epoch_time))

        # do evaluation on test and validation data and report results
        if val_test_frequency > 0 and epoch > 0 and epoch % val_test_frequency == 0:
            val_nll, test_nll = do_evaluation()
            logging.info("[val/test epoch %04d]  %.4f  %.4f" % (epoch, val_nll, test_nll))
示例#3
0
def main(args):

    ## This func is for saving losses at final
    def saveGraph(loss_list, sub_error_list, now):
        FS = 10
        fig = plt.figure()
        # plt.rcParams["font.size"] = FS
        plt.plot(loss_list, label="LOSS")
        plt.plot(sub_error_list, label="Reconstruction Error")
        plt.ylim(bottom=0)
        plt.title("Loss")
        plt.xlabel("epoch", fontsize=FS)
        plt.ylabel("loss", fontsize=FS)
        plt.legend()
        fig.savefig(os.path.join("saveData", now, "LOSS.png"))
        plt.close()

    def saveReconSinGraph(train_data, recon_data, length, path, number):
        FS = 10
        fig = plt.figure()
        ax = fig.add_subplot()
        # plt.plot(train_data[:,0], train_data[:,1], label="Training data")
        plt.plot(train_data[:, 0], train_data[:, 1], label="Training data")
        # plt.plot(recon_data.detach().numpy(), label="Reconstructed data")
        plt.plot(recon_data.detach().numpy()[:, 0],
                 recon_data.detach().numpy()[:, 1],
                 label="Reconstructed data")
        # plt.plot(recon_data.detach().numpy(), label="Reconstructed data")
        plt.title("Reconstruction")
        # plt.ylim(top=10, bottom=-10)
        # plt.xlabel("time", fontsize=FS)
        # plt.ylabel("y", fontsize=FS)
        plt.legend()
        fig.savefig(os.path.join(path,
                                 "Reconstruction" + str(number) + ".png"))
        plt.close()

    def saveGeneSinGraph(gene_data, length, path, number):
        FS = 10
        fig = plt.figure()
        ax = fig.add_subplot()
        # plt.rcParams["font.size"] = FS
        plt.plot(gene_data.detach().numpy()[:, 0],
                 gene_data.detach().numpy()[:, 1],
                 label="Generated data")
        # plt.plot(gene_data.detach().numpy(), label="Generated data")
        # plt.plot(gene_data.detach().numpy(), label="Generated data")
        plt.title("Genetration")
        # plt.ylim(top=10, bottom=-10)
        # plt.xlabel("time", fontsize=FS)
        # plt.ylabel("y", fontsize=FS)
        plt.legend()
        fig.savefig(os.path.join(path, "Generation" + str(number) + ".png"))
        plt.close()

    FS = 10
    plt.rcParams["font.size"] = FS

    ## 長さ最長129、例えば長さが60のやつは61~129はすべて0データ
    data = poly.load_data(poly.JSB_CHORALES)
    training_seq_lengths = data['train']['sequence_lengths'][:args.N_songs]
    training_data_sequences = data['train']['sequences'][:args.N_songs, :8]
    training_seq_lengths = torch.tensor([8] * args.N_songs)

    traing_data = []
    for i in range(args.N_songs):
        traing_data.append(
            musics.VanDelPol(torch.randn(2), args.length, args.T))
        # traing_data.append(musics.Nonlinear(torch.randn(1), args.length, args.T))
    training_data_sequences = torch.stack(traing_data)
    training_seq_lengths = torch.tensor([args.length] * args.N_songs)
    data_dim = training_data_sequences.size(-1)

    N_train_data = len(training_seq_lengths)
    N_train_time_slices = float(torch.sum(training_seq_lengths))
    N_mini_batches = int(N_train_data / args.mini_batch_size +
                         int(N_train_data % args.mini_batch_size > 0))

    z_dim = args.z_dim  #100 #2
    rnn_dim = args.rnn_dim  #200 #30
    transition_dim = args.transition_dim  #200 #30
    emission_dim = args.emission_dim  #100 #30
    encoder = Encoder(input_dim=training_data_sequences.size(-1),
                      z_dim=z_dim,
                      rnn_dim=rnn_dim,
                      N_z0=args.N_songs)
    # encoder = Encoder(z_dim=100, rnn_dim=200)
    prior = Prior(z_dim=z_dim,
                  transition_dim=transition_dim,
                  N_z0=args.N_songs,
                  use_cuda=args.cuda)
    decoder = Emitter(input_dim=training_data_sequences.size(-1),
                      z_dim=z_dim,
                      emission_dim=emission_dim,
                      use_cuda=args.cuda)

    # Create optimizer algorithm
    # optimizer = optim.SGD(dmm.parameters(), lr=args.learning_rate)
    # optimizer = optim.Adam(dmm.parameters(), lr=args.learning_rate, betas=(0.96, 0.999), weight_decay=2.0)
    params = list(encoder.parameters()) + list(prior.parameters()) + list(
        decoder.parameters())
    optimizer = optim.SGD(params, lr=args.learning_rate)
    # optimizer = optim.Adam(params, lr=args.learning_rate)
    # Add learning rate scheduler
    scheduler = optim.lr_scheduler.ExponentialLR(optimizer, 0.9999)
    # scheduler = optim.lr_scheduler.ExponentialLR(optimizer, 1.)
    #1でもやってみる

    # make directory for Data
    now = datetime.datetime.now().strftime('%Y%m%d_%H_%M')
    os.makedirs(os.path.join("saveData", now), exist_ok=True)

    # save args as TEXT
    f = open(os.path.join("saveData", now, 'args.txt'), 'w')  # 書き込みモードで開く
    for name, site in vars(args).items():
        f.write(name + " = ")
        f.write(str(site) + "\n")
    f.close()  # ファイルを閉じる

    #######################
    #### TRAINING LOOP ####
    #######################
    times = [time.time()]
    losses = []
    recon_errors = []
    pbar = tqdm.tqdm(range(args.num_epochs))
    for epoch in pbar:
        epoch_nll = 0
        epoch_recon_error = 0
        shuffled_indices = torch.randperm(N_train_data)
        # print("Proceeding: %.2f " % (epoch*100/5000) + "%")

        # process each mini-batch; this is where we take gradient steps
        for which_mini_batch in range(N_mini_batches):

            # compute which sequences in the training set we should grab
            mini_batch_start = (which_mini_batch * args.mini_batch_size)
            mini_batch_end = np.min([
                (which_mini_batch + 1) * args.mini_batch_size, N_train_data
            ])
            mini_batch_indices = shuffled_indices[
                mini_batch_start:mini_batch_end]

            # grab a fully prepped mini-batch using the helper function in the data loader
            mini_batch, mini_batch_reversed, mini_batch_mask, mini_batch_seq_lengths \
                = poly.get_mini_batch(mini_batch_indices, training_data_sequences,
                                        training_seq_lengths, cuda=args.cuda)

            # reset gradients
            optimizer.zero_grad()
            loss = 0

            N_repeat = 1
            for i in range(N_repeat):
                seiki = 1 / N_repeat
                # generate mini batch from training mini batch
                pos_z, pos_z_loc, pos_z_scale = encoder(
                    mini_batch, mini_batch_reversed, mini_batch_mask,
                    mini_batch_seq_lengths)
                pri_z, pri_z_loc, pri_z_scale = prior(
                    length=mini_batch.size(1), N_generate=mini_batch.size(0))

                # Regularizer (KL-diveergence(pos||pri))
                regularizer = seiki * D_Wass(pos_z_loc, pri_z_loc, pos_z_scale,
                                             pri_z_scale)
                # regularizer = seiki * D_KL(pos_z_loc, pri_z_loc, pos_z_scale, pri_z_scale)
                # regularizer = seiki * D_JS_Monte(pos_z_loc, pri_z_loc, pos_z_scale, pri_z_scale, pos_z, pri_z)
                reconstruction_error = seiki * torch.norm(
                    mini_batch - decoder(pos_z),
                    dim=2).sum() / mini_batch.size(0) / mini_batch.size(
                        1) / mini_batch.size(2)

                loss += args.lam * regularizer + reconstruction_error
            reconed_x = decoder(pos_z)

            loss.backward()
            optimizer.step()
            scheduler.step()
            # if args.rnn_clip != None:
            #     for p in dmm.rnn.parameters():
            #         p.data.clamp_(-args.rnn_clip, args.rnn_clip)

            epoch_nll += loss
            epoch_recon_error += reconstruction_error

        # report training diagnostics
        times.append(time.time())
        losses.append(epoch_nll)
        recon_errors.append(epoch_recon_error)
        # epoch_time = times[-1] - times[-2]
        pbar.set_description("LOSS = %f " % epoch_nll)

        if epoch % args.checkpoint_freq == 0:
            path = os.path.join("saveData", now, "Epoch%d" % epoch)
            os.makedirs(path, exist_ok=True)
            for i in range(len(mini_batch)):
                saveReconSinGraph(mini_batch[i], reconed_x[i], args.length,
                                  path, i)
                saveGeneSinGraph(decoder(pri_z)[i], args.length, path, i)

        if epoch == args.num_epochs - 1:
            path = os.path.join("saveData", now, "Epoch%d" % args.num_epochs)
            os.makedirs(path, exist_ok=True)
            for i in range(len(mini_batch)):
                saveReconSinGraph(mini_batch[i], reconed_x[i], args.length,
                                  path, i)
                saveGeneSinGraph(decoder(pri_z)[i], args.length, path, i)

            saveGraph(losses, recon_errors, now)
            saveDic = {
                "optimizer": optimizer,
                "mini_batch": mini_batch,
                "Encoder_dic": encoder.state_dict,
                "Prior_dic": prior.state_dict,
                "Emitter_dic": decoder.state_dict,
                "epoch_times": times,
                "losses": losses
            }
            # torch.save(saveDic,os.path.join("saveData", now, "dic_Epoch%d"%(epoch+1)))
            torch.save(saveDic, os.path.join("saveData", now, "DMM_dic"))
def main(args):

    ## This func is for saving losses at final
    def saveGraph(loss_list, sub_error_list, now):
        FS = 10
        fig = plt.figure()
        # plt.rcParams["font.size"] = FS
        plt.plot(loss_list, label="LOSS")
        plt.plot(sub_error_list, label="Reconstruction Error")
        plt.ylim(bottom=0)
        plt.title("Loss")
        plt.xlabel("epoch", fontsize=FS)
        plt.ylabel("loss", fontsize=FS)
        plt.legend()
        fig.savefig(os.path.join("saveData", now, "LOSS.png"))
        plt.close()

    def saveReconSinGraph(train_data, recon_data, length, path, number):
        FS = 10
        fig = plt.figure()
        # plt.rcParams["font.size"] = FS
        x = np.linspace(0, 2 * np.pi, length)
        plt.plot(x, train_data, label="Training data")
        plt.plot(x, recon_data.detach().numpy(), label="Reconstructed data")
        # plt.ylim(bottom=0)
        plt.title("Sin Curves")
        # plt.ylim(top=10, bottom=-10)
        plt.xlabel("time", fontsize=FS)
        plt.ylabel("y", fontsize=FS)
        plt.legend()
        fig.savefig(os.path.join(path,
                                 "Reconstruction" + str(number) + ".png"))
        plt.close()

    def saveGeneSinGraph(gene_data, length, path, number):
        FS = 10
        fig = plt.figure()
        # plt.rcParams["font.size"] = FS
        x = np.linspace(0, 2 * np.pi, length)
        plt.plot(x, gene_data.detach().numpy(), label="Generated data")
        # plt.ylim(bottom=0)
        plt.title("Sin Curves")
        # plt.ylim(top=10, bottom=-10)
        plt.xlabel("time", fontsize=FS)
        plt.ylabel("y", fontsize=FS)
        plt.legend()
        fig.savefig(os.path.join(path, "Generation" + str(number) + ".png"))
        plt.close()

    ## This func is for save generatedTones and trainingTones as MIDI
    def save_as_midi(song, path="", name="default.mid", BPM=120, velocity=100):
        pm = pretty_midi.PrettyMIDI(resolution=960,
                                    initial_tempo=BPM)  #pretty_midiオブジェクトを作ります
        instrument = pretty_midi.Instrument(0)  #instrumentはトラックみたいなものです。
        for i, tones in enumerate(song):
            which_tone = torch.nonzero((tones == 1),
                                       as_tuple=False).reshape(-1)
            if len(which_tone) == 0:
                note = pretty_midi.Note(
                    velocity=0, pitch=0, start=i,
                    end=i + 1)  #noteはNoteOnEventとNoteOffEventに相当します。
                instrument.notes.append(note)
            else:
                for which in which_tone:
                    note = pretty_midi.Note(
                        velocity=velocity,
                        pitch=int(which),
                        start=i,
                        end=i + 1)  #noteはNoteOnEventとNoteOffEventに相当します。
                    instrument.notes.append(note)
        pm.instruments.append(instrument)
        pm.write(os.path.join(path, name))  #midiファイルを書き込みます。

    # generate Xs
    def generate_Xs(batch_data):
        songs_list = []
        for i, song in enumerate(batch_data):
            tones_container = []
            for time in range(batch_data.size(1)):
                p = dist.Bernoulli(probs=song[time])
                tone = p.sample()
                tones_container.append(tone)
            tones_container = torch.stack(tones_container)
            songs_list.append(tones_container)
        return songs_list

    def saveSongs(songs_list, mini_batch, N_songs, path):
        # print(len(songs_list[0][0]))
        if len(songs_list) != len(mini_batch):
            assert False
        if N_songs <= len(songs_list):
            song_No = random.sample(range(len(songs_list)), k=N_songs)
        else:
            song_No = random.sample(range(len(songs_list)), k=len(songs_list))
        for i, Number in enumerate(song_No):
            save_as_midi(song=songs_list[Number],
                         path=path,
                         name="No%d_Gene.midi" % i)
            save_as_midi(song=mini_batch[Number],
                         path=path,
                         name="No%d_Tran.midi" % i)

    FS = 10
    plt.rcParams["font.size"] = FS

    ## 長さ最長129、例えば長さが60のやつは61~129はすべて0データ
    data = poly.load_data(poly.JSB_CHORALES)
    training_seq_lengths = data['train']['sequence_lengths'][:args.N_songs]
    training_data_sequences = data['train']['sequences'][:args.N_songs, :8]
    training_seq_lengths = torch.tensor([8] * args.N_songs)

    if args.doremifasorasido:
        # ## ドドド、レレレ、ミミミ、ドレミ
        training_seq_lengths, training_data_sequences = musics.doremifasorasido(
            args.N_songs)
    if args.dododorerere:
        # ドドド、レレレ
        training_seq_lengths, training_data_sequences = musics.dododorerere(
            args.N_songs)
    if args.dodododododo:
        ## ドドドのみ
        training_seq_lengths, training_data_sequences = musics.dodododododo(
            args.N_songs)

    if args.sin:
        training_data_sequences = musics.createSin_allChanged(
            args.N_songs, args.length)
        training_seq_lengths = torch.tensor([args.length] * args.N_songs)

    training_data_sequences = musics.createNewTrainingData(
        args.N_songs, args.length)
    # traing_data =[]
    # for i in range(args.N_songs):
    #     traing_data.append(musics.Nonlinear(torch.randn(1), args.length, T = args.T))
    # training_data_sequences = torch.stack(traing_data)
    training_seq_lengths = torch.tensor([args.length] * args.N_songs)
    data_dim = training_data_sequences.size(-1)

    N_train_data = len(training_seq_lengths)
    N_train_time_slices = float(torch.sum(training_seq_lengths))
    N_mini_batches = int(N_train_data / args.mini_batch_size +
                         int(N_train_data % args.mini_batch_size > 0))

    z_dim = args.z_dim  #100 #2
    rnn_dim = args.rnn_dim  #200 #30
    transition_dim = args.transition_dim  #200 #30
    emission_dim = args.emission_dim  #100 #30
    encoder = Encoder(input_dim=training_data_sequences.size(-1),
                      z_dim=z_dim,
                      rnn_dim=rnn_dim,
                      N_z0=args.N_songs)
    # encoder = Encoder(z_dim=100, rnn_dim=200)
    prior = Prior(z_dim=z_dim,
                  transition_dim=transition_dim,
                  N_z0=args.N_songs,
                  use_cuda=args.cuda,
                  R=args.R)
    decoder = Emitter(input_dim=training_data_sequences.size(-1),
                      z_dim=z_dim,
                      emission_dim=emission_dim,
                      use_cuda=args.cuda)

    # Create optimizer algorithm
    # optimizer = optim.SGD(dmm.parameters(), lr=args.learning_rate)
    # optimizer = optim.Adam(dmm.parameters(), lr=args.learning_rate, betas=(0.96, 0.999), weight_decay=2.0)
    params = list(encoder.parameters()) + list(prior.parameters()) + list(
        decoder.parameters())
    optimizer = optim.Adam(params, lr=args.learning_rate)
    # Add learning rate scheduler
    scheduler = optim.lr_scheduler.ExponentialLR(optimizer, 0.9999)
    # scheduler = optim.lr_scheduler.ExponentialLR(optimizer, 1.)
    #1でもやってみる

    # make directory for Data
    now = datetime.datetime.now().strftime('%Y%m%d_%H_%M')
    os.makedirs(os.path.join("saveData", now), exist_ok=True)

    # save args as TEXT
    f = open(os.path.join("saveData", now, 'args.txt'), 'w')  # 書き込みモードで開く
    for name, site in vars(args).items():
        f.write(name + " = ")
        f.write(str(site) + "\n")
    f.close()  # ファイルを閉じる

    #######################
    #### TRAINING LOOP ####
    #######################
    times = [time.time()]
    losses = []
    recon_errors = []
    pbar = tqdm.tqdm(range(args.num_epochs))
    for epoch in pbar:
        epoch_nll = 0
        epoch_recon_error = 0
        shuffled_indices = torch.randperm(N_train_data)
        # print("Proceeding: %.2f " % (epoch*100/5000) + "%")

        # process each mini-batch; this is where we take gradient steps
        for which_mini_batch in range(N_mini_batches):

            # compute which sequences in the training set we should grab
            mini_batch_start = (which_mini_batch * args.mini_batch_size)
            mini_batch_end = np.min([
                (which_mini_batch + 1) * args.mini_batch_size, N_train_data
            ])
            mini_batch_indices = shuffled_indices[
                mini_batch_start:mini_batch_end]

            # grab a fully prepped mini-batch using the helper function in the data loader
            mini_batch, mini_batch_reversed, mini_batch_mask, mini_batch_seq_lengths \
                = poly.get_mini_batch(mini_batch_indices, training_data_sequences,
                                        training_seq_lengths, cuda=args.cuda)

            # reset gradients
            optimizer.zero_grad()
            loss = 0

            N_repeat = 1
            for i in range(N_repeat):
                seiki = 1 / N_repeat
                # generate mini batch from training mini batch
                pos_z, pos_z_loc, pos_z_scale = encoder(
                    mini_batch, mini_batch_reversed, mini_batch_mask,
                    mini_batch_seq_lengths)
                pri_z, pri_z_loc, pri_z_scale = prior(
                    length=mini_batch.size(1), N_generate=mini_batch.size(0))

                # Regularizer (KL-diveergence(pos||pri))
                # if any(torch.isnan(pos_z_loc.reshape(-1))):
                #     print("pos_z_loc")
                #     assert False
                # if any(torch.isnan(pos_z_scale.reshape(-1))):
                #     print("pos_z_scale")
                #     assert False
                # if any(torch.isnan(pri_z_loc.reshape(-1))):
                #     print("pri_z_loc")
                #     assert False
                # if any(torch.isnan(pri_z_scale.reshape(-1))):
                #     print("pri_z_scale")
                #     assert False
                regularizer = seiki * D_Wass(pos_z_loc, pri_z_loc, pos_z_scale,
                                             pri_z_scale)
                # if any(torch.isnan(regularizer.reshape(-1))):
                #     print("regula")
                #     assert False
                # print(regularizer)
                # regularizer = seiki * D_KL(pos_z_loc, pri_z_loc, pos_z_scale, pri_z_scale)
                # regularizer = seiki * D_JS_Monte(pos_z_loc, pri_z_loc, pos_z_scale, pri_z_scale, pos_z, pri_z)
                reconstruction_error = seiki * torch.norm(
                    mini_batch - decoder(pos_z),
                    dim=2).sum() / mini_batch.size(0) / mini_batch.size(
                        1) / mini_batch.size(2)

                loss += args.lam * regularizer + reconstruction_error
            reconed_x = decoder(pos_z)
            # Reconstruction Error
            # reconstruction_error = torch.norm(mini_batch - reconed_x, dim=2).sum()/mini_batch.size(0)/mini_batch.size(1)/mini_batch.size(2)
            # loss += reconstruction_error

            # # Covariance Penalty
            # sumOfCovariance = torch.sum(pos_z_scale + pri_z_scale)
            # loss -= 0.00007 * sumOfCovariance

            # # generate mini batch from training mini batch
            # pos_z, pos_z_loc, pos_z_scale = encoder(mini_batch, mini_batch_reversed, mini_batch_mask, mini_batch_seq_lengths)
            # pri_z, pri_z_loc, pri_z_scale = prior(length = mini_batch.size(1), N_generate= mini_batch.size(0))

            # reconed_x = decoder(pos_z)
            # # Reconstruction Error
            # reconstruction_error = torch.norm(mini_batch - reconed_x, dim=2).sum()/mini_batch.size(0)/mini_batch.size(1)/mini_batch.size(2)
            # # Regularizer (KL-diveergence(pos||pri))
            # regularizer = D_Wass(pos_z_loc, pri_z_loc, pos_z_scale, pri_z_scale)

            # loss += reconstruction_error + args.lam * regularizer

            # # NaN detector
            # if torch.isnan(loss):
            #     saveDic = {
            #     "optimizer":optimizer,
            #     "mini_batch":mini_batch,
            #     "Encoder_dic": encoder.state_dict,
            #     "Prior_dic": prior.state_dict,
            #     "Emitter_dic": decoder.state_dict,
            #     }
            #     torch.save(saveDic,os.path.join("saveData", now, "fail_DMM_dic"))
            #     assert False, ("LOSS ia NaN!")

            # saveDic = {
            # "optimizer":optimizer,
            # "mini_batch":mini_batch,
            # "Encoder_dic": encoder.state_dict,
            # "Prior_dic": prior.state_dict,
            # "Emitter_dic": decoder.state_dict,
            # }
            # torch.save(saveDic,os.path.join("saveData", now, "DMM_dic"))

            # do an actual gradient step
            loss.backward()
            optimizer.step()
            scheduler.step()
            # if args.rnn_clip != None:
            #     for p in dmm.rnn.parameters():
            #         p.data.clamp_(-args.rnn_clip, args.rnn_clip)

            epoch_nll += loss
            epoch_recon_error += reconstruction_error

        # report training diagnostics
        times.append(time.time())
        losses.append(epoch_nll)
        recon_errors.append(epoch_recon_error)
        # epoch_time = times[-1] - times[-2]
        pbar.set_description("LOSS = %f " % epoch_nll)

        if epoch % args.checkpoint_freq == 0:
            path = os.path.join("saveData", now, "Epoch%d" % epoch)
            os.makedirs(path, exist_ok=True)
            if data_dim != 1:
                reco_list = generate_Xs(reconed_x)
                gene_list = generate_Xs(decoder(pri_z))
                for i in range(args.N_generate):
                    save_as_midi(song=reco_list[i],
                                 path=path,
                                 name="No%d_Reco.midi" % i)
                    save_as_midi(song=mini_batch[i],
                                 path=path,
                                 name="No%d_Real.midi" % i)
                    save_as_midi(song=gene_list[i],
                                 path=path,
                                 name="No%d_Gene.midi" % i)
            else:
                for i in range(len(mini_batch)):
                    saveReconSinGraph(mini_batch[i], reconed_x[i], args.length,
                                      path, i)
                    saveGeneSinGraph(decoder(pri_z)[i], args.length, path, i)
                saveDic = {
                    "optimizer": optimizer,
                    "mini_batch": mini_batch,
                    "Encoder_dic": encoder.state_dict,
                    "Prior_dic": prior.state_dict,
                    "Emitter_dic": decoder.state_dict,
                    "epoch_times": times,
                    "losses": losses
                }
                # torch.save(saveDic,os.path.join("saveData", now, "dic_Epoch%d"%(epoch+1)))
                torch.save(saveDic, os.path.join(path, "DMM_dic"))

        if epoch == args.num_epochs - 1:
            path = os.path.join("saveData", now, "Epoch%d" % args.num_epochs)
            os.makedirs(path, exist_ok=True)
            if data_dim != 1:
                reco_list = generate_Xs(reconed_x)
                gene_list = generate_Xs(decoder(pri_z))
                for i in range(args.N_generate):
                    save_as_midi(song=reco_list[i],
                                 path=path,
                                 name="No%d_Reco.midi" % i)
                    save_as_midi(song=mini_batch[i],
                                 path=path,
                                 name="No%d_Real.midi" % i)
                    save_as_midi(song=gene_list[i],
                                 path=path,
                                 name="No%d_Gene.midi" % i)
            else:
                for i in range(len(mini_batch)):
                    saveReconSinGraph(mini_batch[i], reconed_x[i], args.length,
                                      path, i)
                    saveGeneSinGraph(decoder(pri_z)[i], args.length, path, i)
                saveDic = {
                    "optimizer": optimizer,
                    "mini_batch": mini_batch,
                    "Encoder_dic": encoder.state_dict,
                    "Prior_dic": prior.state_dict,
                    "Emitter_dic": decoder.state_dict,
                    "epoch_times": times,
                    "losses": losses
                }
                # torch.save(saveDic,os.path.join("saveData", now, "dic_Epoch%d"%(epoch+1)))
                torch.save(saveDic, os.path.join(path, "DMM_dic"))

            saveGraph(losses, recon_errors, now)
            saveDic = {
                "optimizer": optimizer,
                "mini_batch": mini_batch,
                "Encoder_dic": encoder.state_dict,
                "Prior_dic": prior.state_dict,
                "Emitter_dic": decoder.state_dict,
                "epoch_times": times,
                "losses": losses
            }
            # torch.save(saveDic,os.path.join("saveData", now, "dic_Epoch%d"%(epoch+1)))
            torch.save(saveDic, os.path.join("saveData", now, "DMM_dic"))
示例#5
0
def main(args):
    # make directory for Data
    now = datetime.datetime.now().strftime('%Y%m%d_%H_%M')
    os.makedirs(os.path.join("saveEstimate", now), exist_ok=True)

    # save args as TEXT
    f = open(os.path.join("saveEstimate", now, 'args.txt'), 'w')  # 書き込みモードで開く
    for name, site in vars(args).items():
        f.write(name + " = ")
        f.write(str(site) + "\n")
    f.close()  # ファイルを閉じる

    #Arguments
    N_songs = 10
    length = 20
    mini_batch_size = 20
    z_dim = 100
    rnn_dim = 200
    transition_dim = 200
    emission_dim = 100

    reconstructions = []
    generations = []
    trainings = []
    for i in range(args.N_estimate):
        encoder = Encoder(input_dim=1,
                          z_dim=z_dim,
                          rnn_dim=rnn_dim,
                          N_z0=N_songs,
                          var_clip=args.clips[i])
        prior = Prior(z_dim=z_dim,
                      transition_dim=transition_dim,
                      N_z0=N_songs,
                      var_clip=args.clips[i])
        decoder = Emitter(input_dim=1, z_dim=z_dim, emission_dim=emission_dim)
        date = "2021" + args.dates[i]
        path = os.path.join("saveData", date)
        DMM_dics = torch.load(os.path.join(path, "DMM_dic"))
        training_data_sequences = DMM_dics["mini_batch"]

        encoder.load_state_dict(DMM_dics["Encoder_dic"]())
        prior.load_state_dict(DMM_dics["Prior_dic"]())
        decoder.load_state_dict(DMM_dics["Emitter_dic"]())

        which_mini_batch = 0
        training_seq_lengths = torch.tensor([length] * N_songs)
        data_dim = training_data_sequences.size(-1)

        N_train_data = len(training_seq_lengths)
        N_train_time_slices = float(torch.sum(training_seq_lengths))
        N_mini_batches = int(N_train_data / mini_batch_size +
                             int(N_train_data % mini_batch_size > 0))
        shuffled_indices = torch.randperm(N_train_data)
        mini_batch_start = (which_mini_batch * mini_batch_size)
        mini_batch_end = np.min([(which_mini_batch + 1) * mini_batch_size,
                                 N_train_data])
        mini_batch_indices = shuffled_indices[mini_batch_start:mini_batch_end]

        # grab a fully prepped mini-batch using the helper function in the data loader
        mini_batch, mini_batch_reversed, mini_batch_mask, mini_batch_seq_lengths \
            = poly.get_mini_batch(mini_batch_indices, training_data_sequences,
                                    training_seq_lengths)

        pri_z, pri_z_loc, pri_z_scale = prior(length=mini_batch.size(1),
                                              N_generate=args.N_generate)
        pos_z, a, b = encoder(mini_batch, mini_batch_reversed, mini_batch_mask,
                              mini_batch_seq_lengths)
        reconed_x = decoder(pos_z)
        reconstructions.append(reconed_x)
        generations.append(decoder(pri_z))
        trainings.append(mini_batch)
        entropy = 0
        entropy -= torch.sum(
            torch.log(multi_normal_prob(pri_z_loc, pri_z_scale,
                                        pri_z))) / args.N_generate
        entropy += 0.5 * pri_z.size(1) * pri_z.size(2) * torch.log(
            torch.tensor(2 * math.pi))
        how_close = calc_how_close_to_mini_batch(decoder(pri_z), mini_batch)

        ppath = os.path.join("saveEstimate", now, date)
        os.makedirs(ppath, exist_ok=True)
        os.makedirs(os.path.join(ppath, "Reconstruction"), exist_ok=True)
        os.makedirs(os.path.join(ppath, "Generation"), exist_ok=True)
        os.makedirs(os.path.join(ppath, "Training"), exist_ok=True)
        for j in range(args.N_pics):
            saveReconSinGraph(mini_batch[j], reconed_x[j], pri_z.size(1),
                              os.path.join(ppath, "Reconstruction"), j, date)
        saveGeneSinGraph(
            decoder(pri_z)[:3], pri_z.size(1),
            os.path.join(ppath, "Generation"), j, date)
        #saveTrainSinGraph(mini_batch, pri_z.size(1), os.path.join(ppath,"Training"), 1, date)

        # save args as TEXT
        f = open(os.path.join("saveEstimate", now, 'estimate.txt'),
                 'a')  # 書き込みモードで開く
        f.write("2021" + args.dates[i] + "\n")
        f.write("Entorpy = " + str(entropy) + "\n")
        f.write("How close to mini_batch = " + str(how_close) + "\n\n")
        f.close()  # ファイルを閉じる

    ppath = os.path.join("saveEstimate", now)
    saveTrainSinGraph(mini_batch, pri_z.size(1), os.path.join(ppath), 1, date)
    ## PAINT PICTURE ##
    Ver = 2 * args.N_estimate
    Ver = 8
    Hor = 6
    fig = plt.figure(figsize=(Hor, Ver))
    plt.subplots_adjust(left=0.08, right=0.95, bottom=0.08, top=0.95)
    FS = 6
    plt.rcParams["font.size"] = FS

    x = np.linspace(0, args.length, args.length)
    for i in range(args.N_estimate):
        plt.subplot(args.N_estimate, 2, 2 * i + 1)
        index = np.random.randint(0, args.N_songs)
        plt.plot(x, trainings[i][index], label="Training data", color="black")
        plt.plot(x,
                 reconstructions[i][index].detach().numpy(),
                 label="Reconstructed data",
                 color="orange")
        plt.grid()
        plt.ylim(top=10, bottom=-10)
        plt.xlim(0, length)
        if i == args.N_estimate - 1:
            plt.xlabel("t", fontsize=10)
        plt.legend()
        plt.gca().get_xaxis().set_major_locator(
            ticker.MaxNLocator(integer=True))

        plt.subplot(args.N_estimate, 2, 2 * i + 2)
        for j in range(3):
            plt.plot(x,
                     generations[i][j].detach().numpy(),
                     label="Generated data" + str(j + 1))
        plt.grid()
        plt.ylim(top=10, bottom=-10)
        plt.xlim(0, args.length)
        plt.gca().get_xaxis().set_major_locator(
            ticker.MaxNLocator(integer=True))
        if i == args.N_estimate - 1:
            plt.xlabel("t", fontsize=10)
        plt.legend()
    fig.savefig(os.path.join(ppath, "Results" + ".pdf"))
示例#6
0
def main(args):

    ## ドレミ
    def easyTones():
        max = 70
        interval = 10
        N = 10
        training_seq_lengths = torch.tensor([8] * N)
        training_data_sequences = torch.zeros(N, 8, 88)
        for i in range(N):
            training_data_sequences[i][0][int(max - i * interval)] = 1
            training_data_sequences[i][1][int(max - i * interval) + 2] = 1
            training_data_sequences[i][2][int(max - i * interval) + 4] = 1
            training_data_sequences[i][3][int(max - i * interval) + 5] = 1
            training_data_sequences[i][4][int(max - i * interval) + 7] = 1
            training_data_sequences[i][5][int(max - i * interval) + 9] = 1
            training_data_sequences[i][6][int(max - i * interval) + 11] = 1
            training_data_sequences[i][7][int(max - i * interval) + 12] = 1
        return training_seq_lengths, training_data_sequences

    ## ドドド、レレレ
    def superEasyTones():
        training_seq_lengths = torch.tensor([8] * 10)
        training_data_sequences = torch.zeros(10, 8, 88)
        for i in range(10):
            for j in range(8):
                training_data_sequences[i][j][int(30 + i * 5)] = 1
        return training_seq_lengths, training_data_sequences

    ## ドドド、ドドド、ドドド
    def easiestTones():
        training_seq_lengths = torch.tensor([8] * 10)
        training_data_sequences = torch.zeros(10, 8, 88)
        for i in range(10):
            for j in range(8):
                training_data_sequences[i][j][int(70)] = 1
        return training_seq_lengths, training_data_sequences

    # rep is short for "repeat"
    # which means how many times we use certain sample to do validation/test evaluation during training
    def rep(x):
        rep_shape = torch.Size([x.size(0) * n_eval_samples]) + x.size()[1:]
        repeat_dims = [1] * len(x.size())
        repeat_dims[0] = n_eval_samples
        return x.repeat(repeat_dims).reshape(n_eval_samples, -1).transpose(
            1, 0).reshape(rep_shape)

    ## This func is for saving losses at final
    def saveGraph(loss_list, now):
        FS = 10
        fig = plt.figure()
        plt.rcParams["font.size"] = FS
        plt.plot(loss_list)
        plt.ylim(bottom=0)
        plt.title("Wasserstein Loss")
        plt.xlabel("epoch", fontsize=FS)
        plt.ylabel("loss", fontsize=FS)
        fig.savefig(os.path.join("saveData", now, "LOSS.png"))

    ## This func is for save generatedTones and trainingTones as MIDI
    def save_as_midi(song, path="", name="default.mid", BPM=120, velocity=100):
        pm = pretty_midi.PrettyMIDI(resolution=960,
                                    initial_tempo=BPM)  #pretty_midiオブジェクトを作ります
        instrument = pretty_midi.Instrument(0)  #instrumentはトラックみたいなものです。
        for i, tones in enumerate(song):
            which_tone = torch.nonzero((tones == 1),
                                       as_tuple=False).reshape(-1)
            if len(which_tone) == 0:
                note = pretty_midi.Note(
                    velocity=0, pitch=0, start=i,
                    end=i + 1)  #noteはNoteOnEventとNoteOffEventに相当します。
                instrument.notes.append(note)
            else:
                for which in which_tone:
                    note = pretty_midi.Note(
                        velocity=velocity,
                        pitch=int(which),
                        start=i,
                        end=i + 1)  #noteはNoteOnEventとNoteOffEventに相当します。
                    instrument.notes.append(note)
        pm.instruments.append(instrument)
        pm.write(os.path.join(path, name))  #midiファイルを書き込みます。

    # generate Xs
    def generate_Xs(dmm, mini_batch, mini_batch_reversed, mini_batch_mask,
                    mini_batch_seq_lengths):
        songs_list = []
        generated_mini_batch = dmm(mini_batch, mini_batch_reversed,
                                   mini_batch_mask, mini_batch_seq_lengths)
        for i, song in enumerate(generated_mini_batch):
            tones_container = []
            for time in range(mini_batch_seq_lengths[i]):
                p = dist.Bernoulli(probs=song[time])
                tone = p.sample()
                tones_container.append(tone)
            tones_container = torch.stack(tones_container)
            songs_list.append(tones_container)
        return songs_list

    def saveSongs(songs_list, mini_batch, N_songs, path):
        # print(len(songs_list[0][0]))
        if len(songs_list) != len(mini_batch):
            assert False
        if N_songs <= len(songs_list):
            song_No = random.sample(range(len(songs_list)), k=N_songs)
        else:
            song_No = random.sample(range(len(songs_list)), k=len(songs_list))
        for i, Number in enumerate(song_No):
            save_as_midi(song=songs_list[Number],
                         path=path,
                         name="No%d_Gene.midi" % i)
            save_as_midi(song=mini_batch[Number],
                         path=path,
                         name="No%d_Tran.midi" % i)

    ## 長さ最長129、例えば長さが60のやつは61~129はすべて0データ
    data = poly.load_data(poly.JSB_CHORALES)
    training_seq_lengths = data['train']['sequence_lengths'][:3]
    training_data_sequences = data['train']['sequences'][:3, :8]
    training_seq_lengths = torch.tensor([8] * 10)
    # training_seq_lengths = torch.tensor([8])

    if args.eas:
        # ## ドドド、レレレ、ミミミ、ドレミ
        training_seq_lengths, training_data_sequences = easyTones()

    if args.sup:
        # ドドド、レレレ
        training_seq_lengths, training_data_sequences = superEasyTones()

    if args.est:
        ## ドドドのみ
        training_seq_lengths, training_data_sequences = easiestTones()

    test_seq_lengths = data['test']['sequence_lengths'][:3]
    test_seq_lengths = torch.tensor([8, 8, 8])
    test_data_sequences = data['test']['sequences'][:3, :8]
    val_seq_lengths = data['valid']['sequence_lengths']
    val_data_sequences = data['valid']['sequences']
    N_train_data = len(training_seq_lengths)
    # N_train_data = len(test_seq_lengths)
    N_train_time_slices = float(torch.sum(training_seq_lengths))
    N_mini_batches = int(N_train_data / args.mini_batch_size +
                         int(N_train_data % args.mini_batch_size > 0))

    # how often we do validation/test evaluation during training
    val_test_frequency = 50
    # the number of samples we use to do the evaluation
    n_eval_samples = 1

    # get the validation/test data ready for the dmm: pack into sequences, etc.
    # val_seq_lengths = rep(val_seq_lengths)
    # # test_seq_lengths = rep(test_seq_lengths)
    # val_batch, val_batch_reversed, val_batch_mask, val_seq_lengths = poly.get_mini_batch(
    #     torch.arange(n_eval_samples * val_data_sequences.shape[0]), rep(val_data_sequences),
    #     val_seq_lengths, cuda=args.cuda)
    # test_batch, test_batch_reversed, test_batch_mask, test_seq_lengths = poly.get_mini_batch(
    #     torch.arange(n_eval_samples * test_data_sequences.shape[0]), rep(test_data_sequences),
    #     test_seq_lengths, cuda=args.cuda)

    dmm = DMM(use_cuda=args.cuda, rnn_check=args.rck, rpt=args.rpt)
    WN = WGAN_network(use_cuda=args.cuda)
    W = WassersteinLoss(WN,
                        N_loops=args.N_loops,
                        lr=args.w_learning_rate,
                        use_cuda=args.cuda)

    # Create optimizer algorithm
    # optimizer = optim.SGD(dmm.parameters(), lr=args.learning_rate)
    # optimizer = optim.Adam(dmm.parameters(), lr=args.learning_rate, betas=(0.96, 0.999), weight_decay=2.0)
    optimizer = optim.Adam(dmm.parameters(), lr=args.learning_rate)
    # Add learning rate scheduler
    # scheduler = optim.lr_scheduler.ExponentialLR(optimizer, 0.9999)
    scheduler = optim.lr_scheduler.ExponentialLR(optimizer, 1.)
    #1でもやってみる

    # make directory for Data
    now = datetime.datetime.now().strftime('%Y%m%d_%H_%M')
    os.makedirs(os.path.join("saveData", now), exist_ok=True)

    # save args as TEXT
    f = open(os.path.join("saveData", now, 'args.txt'), 'w')  # 書き込みモードで開く
    for name, site in vars(args).items():
        f.write(name + " = ")
        f.write(str(site) + "\n")
    f.close()  # ファイルを閉じる

    #######################
    #### TRAINING LOOP ####
    #######################
    times = [time.time()]
    losses = []
    pbar = tqdm.tqdm(range(args.num_epochs))
    for epoch in pbar:
        epoch_nll = 0
        # shuffled_indices = torch.randperm(N_train_data)
        shuffled_indices = torch.arange(N_train_data)
        # print("Proceeding: %.2f " % (epoch*100/5000) + "%")

        # process each mini-batch; this is where we take gradient steps
        for which_mini_batch in range(N_mini_batches):

            # compute which sequences in the training set we should grab
            mini_batch_start = (which_mini_batch * args.mini_batch_size)
            mini_batch_end = np.min([
                (which_mini_batch + 1) * args.mini_batch_size, N_train_data
            ])
            mini_batch_indices = shuffled_indices[
                mini_batch_start:mini_batch_end]

            # grab a fully prepped mini-batch using the helper function in the data loader
            mini_batch, mini_batch_reversed, mini_batch_mask, mini_batch_seq_lengths \
                = poly.get_mini_batch(mini_batch_indices, training_data_sequences,
                                        training_seq_lengths, cuda=args.cuda)
            # mini_batch, mini_batch_reversed, mini_batch_mask, mini_batch_seq_lengths \
            #     = poly.get_mini_batch(mini_batch_indices, test_data_sequences,
            #                             test_seq_lengths, cuda=args.cuda)

            # reset gradients
            optimizer.zero_grad()

            # generate mini batch from training mini batch
            # NaN_detect(dmm, epoch, message="Before Generate")
            generated_mini_batch = dmm(mini_batch, mini_batch_reversed,
                                       mini_batch_mask, mini_batch_seq_lengths)
            # NaN_detect(dmm, epoch, message="After Generate")
            # if any(torch.isnan(generated_mini_batch.reshape(-1))):
            #     print("GENERATED")
            #     assert False
            # NaN_detect(WN, epoch, message="calc_Before")
            # calculate loss
            if args.wass:
                loss = EMD(mini_batch, generated_mini_batch)
            if args.wgan:
                loss = W.calc(mini_batch, generated_mini_batch)
            # NaN_detect(WN, epoch, message="calc_After")

            # NaN_detect(dmm, epoch, message="step_Before")
            # do an actual gradient step
            loss.backward()
            optimizer.step()
            scheduler.step()
            if args.rnn_clip != None:
                for p in dmm.rnn.parameters():
                    p.data.clamp_(-args.rnn_clip, args.rnn_clip)
                    # p.data.clamp_(-0.01, 0.01)
            # NaN_detect(dmm, epoch, message="step_After")

            epoch_nll += loss

        # report training diagnostics
        times.append(time.time())
        losses.append(epoch_nll)
        epoch_time = times[-1] - times[-2]
        # logging.info("[training epoch %04d]  %.4f \t\t\t\t(dt = %.3f sec)" %
        #                 (epoch, epoch_nll / N_train_time_slices, epoch_time))
        # if epoch % 10 == 0:
        # print("epoch %d time : %d sec" % (epoch, int(epoch_time)))
        pbar.set_description("LOSS = %f " % epoch_nll)
        # tqdm.write("\r"+"        loss : %f " % epoch_nll, end="")

        if epoch % args.checkpoint_freq == 0:
            path = os.path.join("saveData", now, "Epoch%d" % epoch)
            os.makedirs(path, exist_ok=True)
            songs_list = generate_Xs(dmm, mini_batch, mini_batch_reversed,
                                     mini_batch_mask, mini_batch_seq_lengths)
            saveSongs(songs_list, mini_batch, N_songs=2, path=path)

        if epoch == args.num_epochs - 1:
            path = os.path.join("saveData", now, "Epoch%d" % args.num_epochs)
            os.makedirs(path, exist_ok=True)
            songs_list = generate_Xs(dmm, mini_batch, mini_batch_reversed,
                                     mini_batch_mask, mini_batch_seq_lengths)
            saveSongs(songs_list, mini_batch, N_songs=2, path=path)

            saveGraph(losses, now)
            saveDic = {
                "DMM_dic": dmm.state_dict,
                "WGAN_Network_dic": WN.state_dict,
                "epoch_times": times,
                "losses": losses
            }
            # torch.save(saveDic,os.path.join("saveData", now, "dic_Epoch%d"%(epoch+1)))
            torch.save(saveDic, os.path.join("saveData", now, "DMM_dic"))
示例#7
0
def main(args):
    # make directory for Data
    now = args.date
    est_path = os.path.join("saveEstimate", now)
    os.makedirs(est_path, exist_ok=True)

    # get the name of dir of saveData
    data_path = os.path.join("saveData", now)
    files = os.listdir(data_path)
    files_dir = [f for f in files if os.path.isdir(os.path.join(data_path, f))]

    # # save args as TEXT
    # f = open(os.path.join("saveEstimate", now, 'args.txt'), 'w') # 書き込みモードで開く
    # for name, site in vars(args).items():
    #     f.write(name +" = ")
    #     f.write(str(site) + "\n")
    # f.close() # ファイルを閉じる

    #Arguments
    N_songs = args.N_songs
    length = args.length
    mini_batch_size = 20
    z_dim = args.z_dim  #100 #2
    rnn_dim = args.rnn_dim  #200 #30
    transition_dim = args.transition_dim  #200 #30
    emission_dim = args.emission_dim  #100 #30

    reconstructions = []
    generations = []
    trainings = []
    for i in range(len(files_dir)):
        encoder = Encoder(input_dim=1,
                          z_dim=z_dim,
                          rnn_dim=rnn_dim,
                          N_z0=N_songs)
        prior = Prior(z_dim=z_dim, transition_dim=transition_dim, N_z0=N_songs)
        decoder = Emitter(input_dim=1, z_dim=z_dim, emission_dim=emission_dim)
        # if files_dir[i] == "Epoch2000":
        #     continue
        DMM_dics = torch.load(os.path.join(data_path, files_dir[i], "DMM_dic"))
        training_data_sequences = DMM_dics["mini_batch"]
        encoder.load_state_dict(DMM_dics["Encoder_dic"]())
        prior.load_state_dict(DMM_dics["Prior_dic"]())
        decoder.load_state_dict(DMM_dics["Emitter_dic"]())

        which_mini_batch = 0
        training_seq_lengths = torch.tensor([length] * N_songs)
        data_dim = training_data_sequences.size(-1)

        N_train_data = len(training_seq_lengths)
        N_train_time_slices = float(torch.sum(training_seq_lengths))
        N_mini_batches = int(N_train_data / mini_batch_size +
                             int(N_train_data % mini_batch_size > 0))
        shuffled_indices = torch.randperm(N_train_data)
        mini_batch_start = (which_mini_batch * mini_batch_size)
        mini_batch_end = np.min([(which_mini_batch + 1) * mini_batch_size,
                                 N_train_data])
        mini_batch_indices = shuffled_indices[mini_batch_start:mini_batch_end]

        # grab a fully prepped mini-batch using the helper function in the data loader
        mini_batch, mini_batch_reversed, mini_batch_mask, mini_batch_seq_lengths \
            = poly.get_mini_batch(mini_batch_indices, training_data_sequences,
                                    training_seq_lengths)

        pri_z, pri_z_loc, pri_z_scale = prior(length=mini_batch.size(1),
                                              N_generate=args.N_generate)
        pos_z, a, b = encoder(mini_batch, mini_batch_reversed, mini_batch_mask,
                              mini_batch_seq_lengths)
        reconed_x = decoder(pos_z)
        reconstructions.append(reconed_x)
        generations.append(decoder(pri_z))
        trainings.append(mini_batch)
        # entropy = 0
        # entropy -= torch.sum(torch.log(multi_normal_prob(pri_z_loc, pri_z_scale, pri_z))) / args.N_generate
        # entropy += 0.5 * pri_z.size(1) * pri_z.size(2) * torch.log(torch.tensor(2*math.pi))
        # how_close = calc_how_close_to_mini_batch(decoder(pri_z), mini_batch)

        ppath = os.path.join(est_path, files_dir[i])
        os.makedirs(ppath, exist_ok=True)
        os.makedirs(os.path.join(ppath, "Reconstruction"), exist_ok=True)
        os.makedirs(os.path.join(ppath, "Generation"), exist_ok=True)
        os.makedirs(os.path.join(ppath, "Training"), exist_ok=True)
        for j in range(args.N_pics):
            saveReconSinGraph(mini_batch[j], reconed_x[j], pri_z.size(1),
                              os.path.join(ppath, "Reconstruction"), j,
                              args.date)
        saveGeneSinGraph(
            decoder(pri_z)[:3], pri_z.size(1),
            os.path.join(ppath, "Generation"), j, args.date)
        saveTrainSinGraph(mini_batch[:5], pri_z.size(1),
                          os.path.join(ppath, "Training"), 1, args.date)