コード例 #1
0
cell_dropout = .925
#noise_scale = 8.
prenet_units = 128
n_filts = 128
n_stacks = 3
enc_units = 128
dec_units = 512
emb_dim = 15
truncation_len = seq_len
cell_dropout_scale = cell_dropout
epsilon = 1E-8
forward_init = "truncated_normal"
rnn_init = "truncated_normal"

basedir = "/Tmp/kastner/lj_speech/LJSpeech-1.0/"
ljspeech = rsync_fetch(fetch_ljspeech, "leto01")

# THESE ARE CANNOT BE PAIRED (SOME MISSING), ITERATOR PAIRS THEM UP BY NAME
wavfiles = ljspeech["wavfiles"]
jsonfiles = ljspeech["jsonfiles"]

model_path = sys.argv[1]
seed = int(abs(hash(model_path))) % (2**32 - 1)

# THESE HAVE TO BE THE SAME TO ENSURE SPLIT IS CORRECT
train_random_state = np.random.RandomState(seed)
valid_random_state = np.random.RandomState(seed)

train_itr = wavfile_caching_mel_tbptt_iterator(wavfiles,
                                               jsonfiles,
                                               batch_size,
コード例 #2
0
                    default=10,
                    type=int)
parser.add_argument('--output_mixtures',
                    dest='output_mixtures',
                    default=20,
                    type=int)
parser.add_argument('--lstm_layers', dest='lstm_layers', default=3, type=int)
parser.add_argument('--cell_dropout',
                    dest='cell_dropout',
                    default=.9,
                    type=float)
parser.add_argument('--units_per_layer', dest='units', default=400, type=int)
parser.add_argument('--restore', dest='restore', default=None, type=str)
args = parser.parse_args()

iamondb = rsync_fetch(fetch_iamondb, "leto01")
trace_data = iamondb["data"]
char_data = iamondb["target"]
batch_size = args.batch_size
truncation_len = args.seq_len
cell_dropout_scale = args.cell_dropout
vocabulary_size = len(iamondb["vocabulary"])
itr_random_state = np.random.RandomState(2177)
itr = tbptt_list_iterator(trace_data, [char_data],
                          batch_size,
                          truncation_len,
                          other_one_hot_size=[vocabulary_size],
                          random_state=itr_random_state)
epsilon = 1E-8

h_dim = args.units
コード例 #3
0
def main():
    iamondb = rsync_fetch(fetch_iamondb, "leto01")
    translation = iamondb["vocabulary"]
    #with open(os.path.join('data', 'translation.pkl'), 'rb') as file:
    #    translation = pickle.load(file)
    rev_translation = {v: k for k, v in translation.items()}

    charset = [rev_translation[i] for i in range(len(rev_translation))]
    # just for display purposes - replace <NULL> with ''
    charset[translation["<NULL>"]] = ""
    assert translation["<NULL>"] == 0

    config = tf.ConfigProto(device_count={'GPU': 0})
    with tf.Session(config=config) as sess:
        saver = tf.train.import_meta_graph(direct_model + '.meta')
        saver.restore(sess, direct_model)

        if args.text is not None:
            args_text = args.text
        else:
            raise ValueError("Must pass --text argument")

        phi_data, window_data, kappa_data, stroke_data, coords = sample_text(
            sess, args_text, translation)

        strokes = np.array(stroke_data)
        epsilon = 1e-8
        strokes[:, :2] = np.cumsum(strokes[:, :2], axis=0)
        minx, maxx = np.min(strokes[:, 0]), np.max(strokes[:, 0])
        miny, maxy = np.min(strokes[:, 1]), np.max(strokes[:, 1])

        if args.info:
            delta = abs(maxx - minx) / 400.
            x = np.arange(minx, maxx, delta)
            y = np.arange(miny, maxy, delta)
            x_grid, y_grid = np.meshgrid(x, y)
            z_grid = np.zeros_like(x_grid)
            for i in range(strokes.shape[0]):
                # what
                cov = np.array([[strokes[i, 2], 0.], [0., strokes[i, 3]]])
                gauss = mlab.bivariate_normal(x_grid,
                                              y_grid,
                                              mux=strokes[i, 0],
                                              muy=strokes[i, 1],
                                              sigmax=cov[0, 0],
                                              sigmay=cov[1, 1],
                                              sigmaxy=strokes[i, 4] *
                                              cov[0, 0] * cov[1, 1])
                # needs to be rho * sigmax * sigmay
                z_grid += gauss * np.power(strokes[i, 2] + strokes[i, 3],
                                           0.4) / (np.max(gauss) + epsilon)

            for f in os.listdir("."):
                if "plot" in f and f.endswith(".png"):
                    print("Removing old plot {}".format(f))
                    os.remove(f)

            t = int(time.time())
            new = "plot_{}".format(hash(t) % 10**5) + "_{}.png"
            plt.figure()
            plt.imshow(z_grid, interpolation="bilinear", cmap=cm.jet)
            plt.axes().get_xaxis().set_visible(False)
            plt.axes().get_yaxis().set_visible(False)
            plt.axis("off")
            new_d = new.format("density")
            plt.title("Density")
            print("Saving to {}".format(new_d))
            plt.savefig(new_d)
            plt.close()

            plt.figure()
            for stroke in split_strokes(cumsum(np.array(coords))):
                if args.color is not None:
                    plt.plot(stroke[:, 0], -stroke[:, 1], color=args.color)
                else:
                    plt.plot(stroke[:, 0], -stroke[:, 1])
            plt.title(args.text)
            plt.axes().set_aspect('equal')
            plt.axes().get_xaxis().set_visible(False)
            plt.axes().get_yaxis().set_visible(False)
            plt.axis("off")

            new_h = new.format("handwriting")
            print("Saving to {}".format(new_h))
            plt.savefig(new_h)
            plt.close()

            plt.figure()
            phi_img = np.vstack(phi_data).T[::-1, :]
            plt.imshow(phi_img,
                       interpolation='nearest',
                       aspect='auto',
                       cmap=cm.jet)
            plt.yticks(np.arange(0, len(args_text) + 1))
            plt.axes().set_yticklabels(list(' ' + args_text[::-1]),
                                       rotation='vertical',
                                       fontsize=8)
            plt.grid(False)
            plt.title('Phi')
            new_p = new.format("phi")
            print("Saving to {}".format(new_p))
            plt.savefig(new_p)
            plt.close()

            plt.figure()
            window_img = np.vstack(window_data).T
            plt.imshow(window_img,
                       interpolation='nearest',
                       aspect='auto',
                       cmap=cm.jet)
            plt.yticks(np.arange(0, len(charset)))
            plt.axes().set_yticklabels(list(charset),
                                       rotation='vertical',
                                       fontsize=8)
            plt.grid(False)
            plt.title('Window')
            new_w = new.format("window")
            print("Saving to {}".format(new_w))
            plt.savefig(new_w)
            plt.close()
        else:
            fig, ax = plt.subplots(1, 1)
            for stroke in split_strokes(cumsum(np.array(coords))):
                if args.color is not None:
                    plt.plot(stroke[:, 0], -stroke[:, 1], color=args.color)
                else:
                    plt.plot(stroke[:, 0], -stroke[:, 1])
            ax.set_title(args.text)
            ax.set_aspect('equal')

            for f in os.listdir("."):
                if "gen_plot" in f and f.endswith(".png"):
                    print("Removing old plot {}".format(f))
                    os.remove(f)

            t = int(time.time())
            new = "gen_plot_{}.png".format(hash(t) % 10**5)
            print("Saving to {}".format(new))
            plt.savefig(new)