Пример #1
0
    def test_famous(self, args):

        song = np.load('./datasets/famous_songs/P2C/merged_npy/YMCA.npy')

        if self.checkpoint.restore(self.checkpoint_manager.latest_checkpoint):
            print(" [*] Load checkpoint succeeded!")
        else:
            print(" [!] Load checkpoint failed...")

        if args.which_direction == 'AtoB':
            transfer = self.generator_A2B(song, training=False)
        else:
            transfer = self.generator_B2A(song, training=False)

        save_midis(transfer, './datasets/famous_songs/P2C/transfer/YMCA.mid',
                   127)
        np.save('./datasets/famous_songs/P2C/transfer/YMCA.npy', transfer)
Пример #2
0
    def sample_model(self, samples, sample_dir, epoch, idx):

        print('generating samples during learning......')

        if not os.path.exists(os.path.join(sample_dir, 'B2A')):
            os.makedirs(os.path.join(sample_dir, 'B2A'))
        if not os.path.exists(os.path.join(sample_dir, 'A2B')):
            os.makedirs(os.path.join(sample_dir, 'A2B'))

        save_midis(
            samples[0],
            './{}/A2B/{:02d}_{:04d}_origin.mid'.format(sample_dir, epoch, idx))
        save_midis(
            samples[1], './{}/A2B/{:02d}_{:04d}_transfer.mid'.format(
                sample_dir, epoch, idx))
        save_midis(
            samples[2],
            './{}/A2B/{:02d}_{:04d}_cycle.mid'.format(sample_dir, epoch, idx))
        save_midis(
            samples[3],
            './{}/B2A/{:02d}_{:04d}_origin.mid'.format(sample_dir, epoch, idx))
        save_midis(
            samples[4], './{}/B2A/{:02d}_{:04d}_transfer.mid'.format(
                sample_dir, epoch, idx))
        save_midis(
            samples[5],
            './{}/B2A/{:02d}_{:04d}_cycle.mid'.format(sample_dir, epoch, idx))
Пример #3
0
    def test(self, args):

        if args.which_direction == 'AtoB':
            sample_files = glob('./datasets/{}/test/*.*'.format(
                self.dataset_A_dir))
        elif args.which_direction == 'BtoA':
            sample_files = glob('./datasets/{}/test/*.*'.format(
                self.dataset_B_dir))
        else:
            raise Exception('--which_direction must be AtoB or BtoA')
        sample_files.sort(key=lambda x: int(
            os.path.splitext(os.path.basename(x))[0].split('_')[-1]))

        if self.checkpoint.restore(self.checkpoint_manager.latest_checkpoint):
            print(" [*] Load checkpoint succeeded!")
        else:
            print(" [!] Load checkpoint failed...")

        test_dir_mid = os.path.join(
            args.test_dir,
            '{}2{}_{}_{}_{}/{}/mid'.format(self.dataset_A_dir,
                                           self.dataset_B_dir,
                                           self.now_datetime, self.model,
                                           self.sigma_d, args.which_direction))
        if not os.path.exists(test_dir_mid):
            os.makedirs(test_dir_mid)

        test_dir_npy = os.path.join(
            args.test_dir,
            '{}2{}_{}_{}_{}/{}/npy'.format(self.dataset_A_dir,
                                           self.dataset_B_dir,
                                           self.now_datetime, self.model,
                                           self.sigma_d, args.which_direction))
        if not os.path.exists(test_dir_npy):
            os.makedirs(test_dir_npy)

        for idx in range(len(sample_files)):
            print('Processing midi: ', sample_files[idx])
            sample_npy = np.load(sample_files[idx]) * 1.

            # save midis
            origin = sample_npy.reshape(1, sample_npy.shape[0],
                                        sample_npy.shape[1], 1)
            midi_path_origin = os.path.join(test_dir_mid,
                                            '{}_origin.mid'.format(idx + 1))
            midi_path_transfer = os.path.join(
                test_dir_mid, '{}_transfer.mid'.format(idx + 1))
            midi_path_cycle = os.path.join(test_dir_mid,
                                           '{}_cycle.mid'.format(idx + 1))

            if args.which_direction == 'AtoB':

                transfer = self.generator_A2B(origin, training=False)
                cycle = self.generator_B2A(transfer, training=False)

            else:

                transfer = self.generator_B2A(origin, training=False)
                cycle = self.generator_A2B(transfer, training=False)

            save_midis(origin, midi_path_origin)
            save_midis(transfer, midi_path_transfer)
            save_midis(cycle, midi_path_cycle)

            # save npy files
            npy_path_origin = os.path.join(test_dir_npy, 'origin')
            npy_path_transfer = os.path.join(test_dir_npy, 'transfer')
            npy_path_cycle = os.path.join(test_dir_npy, 'cycle')

            if not os.path.exists(npy_path_origin):
                os.makedirs(npy_path_origin)
            if not os.path.exists(npy_path_transfer):
                os.makedirs(npy_path_transfer)
            if not os.path.exists(npy_path_cycle):
                os.makedirs(npy_path_cycle)

            np.save(
                os.path.join(npy_path_origin, '{}_origin.npy'.format(idx + 1)),
                origin)
            np.save(
                os.path.join(npy_path_transfer,
                             '{}_transfer.npy'.format(idx + 1)), transfer)
            np.save(
                os.path.join(npy_path_cycle, '{}_cycle.npy'.format(idx + 1)),
                cycle)
Пример #4
0
    def test(self, args):

        # load the origin samples in npy format and sorted in ascending order
        sample_files_origin = glob(
            './test/{}2{}_{}_{}_{}/{}/npy/origin/*.*'.format(
                self.dataset_A_dir, self.dataset_B_dir, self.model,
                self.sigma_d, self.now_datetime, args.which_direction))
        sample_files_origin.sort(key=lambda x: int(
            os.path.splitext(os.path.basename(x))[0].split('_')[0]))

        # load the origin samples in npy format and sorted in ascending order
        sample_files_transfer = glob(
            './test/{}2{}_{}_{}_{}/{}/npy/transfer/*.*'.format(
                self.dataset_A_dir, self.dataset_B_dir, self.model,
                self.sigma_d, self.now_datetime, args.which_direction))
        sample_files_transfer.sort(key=lambda x: int(
            os.path.splitext(os.path.basename(x))[0].split('_')[0]))

        # load the origin samples in npy format and sorted in ascending order
        sample_files_cycle = glob(
            './test/{}2{}_{}_{}_{}/{}/npy/cycle/*.*'.format(
                self.dataset_A_dir, self.dataset_B_dir, self.model,
                self.sigma_d, self.now_datetime, args.which_direction))
        sample_files_cycle.sort(key=lambda x: int(
            os.path.splitext(os.path.basename(x))[0].split('_')[0]))

        # put the origin, transfer and cycle of the same phrase in one zip
        sample_files = list(
            zip(sample_files_origin, sample_files_transfer,
                sample_files_cycle))

        if self.checkpoint.restore(self.checkpoint_manager.latest_checkpoint):
            print(" [*] Load checkpoint succeeded!")
        else:
            print(" [!] Load checkpoint failed...")

        # create a test path to store the generated sample midi files attached with probability
        test_dir_mid = os.path.join(
            args.test_dir, '{}2{}_{}_{}_{}/{}/mid_attach_prob'.format(
                self.dataset_A_dir, self.dataset_B_dir, self.model,
                self.sigma_d, self.now_datetime, args.which_direction))
        if not os.path.exists(test_dir_mid):
            os.makedirs(test_dir_mid)

        count_origin = 0
        count_transfer = 0
        count_cycle = 0
        line_list = []

        for idx in range(len(sample_files)):
            print('Classifying midi: ', sample_files[idx])

            # load sample phrases in npy formats
            origin = np.load(sample_files[idx][0])
            transfer = np.load(sample_files[idx][1])
            cycle = np.load(sample_files[idx][2])

            # get the probability for each sample phrase
            origin_softmax = tf.nn.softmax(
                self.classifier(origin * 2. - 1., training=False))
            transfer_softmax = tf.nn.softmax(
                self.classifier(transfer * 2. - 1., training=False))
            cycle_softmax = tf.nn.softmax(
                self.classifier(cycle * 2. - 1., training=False))

            origin_transfer_diff = np.abs(origin_softmax - transfer_softmax)
            content_diff = np.mean((origin * 1.0 - transfer * 1.0)**2)

            # labels: (1, 0) for A, (0, 1) for B
            if args.which_direction == 'AtoB':
                line_list.append(
                    (idx + 1, content_diff, origin_transfer_diff[0][0],
                     origin_softmax[0][0], transfer_softmax[0][0],
                     cycle_softmax[0][0]))

                # for the accuracy calculation
                count_origin += 1 if np.argmax(origin_softmax[0]) == 0 else 0
                count_transfer += 1 if np.argmax(
                    transfer_softmax[0]) == 0 else 0
                count_cycle += 1 if np.argmax(cycle_softmax[0]) == 0 else 0

                # create paths for origin, transfer and cycle samples attached with probability
                path_origin = os.path.join(
                    test_dir_mid,
                    '{}_origin_{}.mid'.format(idx + 1, origin_softmax[0][0]))
                path_transfer = os.path.join(
                    test_dir_mid,
                    '{}_transfer_{}.mid'.format(idx + 1,
                                                transfer_softmax[0][0]))
                path_cycle = os.path.join(
                    test_dir_mid,
                    '{}_cycle_{}.mid'.format(idx + 1, cycle_softmax[0][0]))

            else:
                line_list.append(
                    (idx + 1, content_diff, origin_transfer_diff[0][1],
                     origin_softmax[0][1], transfer_softmax[0][1],
                     cycle_softmax[0][1]))

                # for the accuracy calculation
                count_origin += 1 if np.argmax(origin_softmax[0]) == 1 else 0
                count_transfer += 1 if np.argmax(
                    transfer_softmax[0]) == 1 else 0
                count_cycle += 1 if np.argmax(cycle_softmax[0]) == 1 else 0

                # create paths for origin, transfer and cycle samples attached with probability
                path_origin = os.path.join(
                    test_dir_mid,
                    '{}_origin_{}.mid'.format(idx + 1, origin_softmax[0][1]))
                path_transfer = os.path.join(
                    test_dir_mid,
                    '{}_transfer_{}.mid'.format(idx + 1,
                                                transfer_softmax[0][1]))
                path_cycle = os.path.join(
                    test_dir_mid,
                    '{}_cycle_{}.mid'.format(idx + 1, cycle_softmax[0][1]))

            # generate sample MIDI files
            save_midis(origin, path_origin)
            save_midis(transfer, path_transfer)
            save_midis(cycle, path_cycle)

        # sort the line_list based on origin_transfer_diff and write to a ranking txt file
        line_list.sort(key=lambda x: x[2], reverse=True)
        with open(
                os.path.join(test_dir_mid,
                             'Rankings_{}.txt'.format(args.which_direction)),
                'w') as f:
            f.write(
                'Id  Content_diff  P_O - P_T  Prob_Origin  Prob_Transfer  Prob_Cycle'
            )
            for i in range(len(line_list)):
                f.writelines(
                    "\n%5d %5f %5f %5f %5f %5f" %
                    (line_list[i][0], line_list[i][1], line_list[i][2],
                     line_list[i][3], line_list[i][4], line_list[i][5]))
        f.close()

        # calculate the accuracy
        accuracy_origin = count_origin * 1.0 / len(sample_files)
        accuracy_transfer = count_transfer * 1.0 / len(sample_files)
        accuracy_cycle = count_cycle * 1.0 / len(sample_files)
        print('Accuracy of this classifier on test datasets is :',
              accuracy_origin, accuracy_transfer, accuracy_cycle)