def train_vae(options_dict):
    """Train and save a VAE."""

    # PRELIMINARY

    print(datetime.now())

    # Output directory
    hasher = hashlib.md5(repr(sorted(options_dict.items())).encode("ascii"))
    hash_str = hasher.hexdigest()[:10]
    # model_dir = path.join(
    #     "models", path.split(options_dict["data_dir"])[-1] + "." +
    #     options_dict["train_tag"], options_dict["script"], hash_str
    #     )
    model_dir = path.join(
        "models",
        path.split(options_dict["data_dir"])[-1] + "." +
        options_dict["train_tag"], "vae_model", "70")
    options_dict_fn = path.join(model_dir, "options_dict.pkl")
    print("Model directory:", model_dir)
    if not os.path.isdir(model_dir):
        os.makedirs(model_dir)
    print("Options:", options_dict)

    # Random seeds
    np.random.seed(options_dict["rnd_seed"])
    tf.set_random_seed(options_dict["rnd_seed"])

    # LOAD AND FORMAT DATA

    # Training data
    train_tag = options_dict["train_tag"]
    min_length = None
    if options_dict["train_tag"] == "rnd":
        min_length = options_dict["min_length"]
        train_tag = "all"
    npz_fn = path.join(options_dict["data_dir"], "train." + train_tag + ".npz")
    train_x, train_labels, train_lengths, train_keys = (
        data_io.load_data_from_npz(npz_fn, min_length))

    # Validation data
    if options_dict["use_test_for_val"]:
        npz_fn = path.join(options_dict["data_dir"], "test.npz")
    else:
        npz_fn = path.join(options_dict["data_dir"], "val.npz")
    val_x, val_labels, val_lengths, val_keys = (
        data_io.load_data_from_npz(npz_fn))

    # Truncate and limit dimensionality
    max_length = options_dict["max_length"]
    d_frame = 108  # None
    options_dict["n_input"] = d_frame
    print("Limiting dimensionality:", d_frame)
    print("Limiting length:", max_length)
    data_io.trunc_and_limit_dim(train_x, train_lengths, d_frame, max_length)
    data_io.trunc_and_limit_dim(val_x, val_lengths, d_frame, max_length)

    # DEFINE MODEL

    print(datetime.now())
    print("Building model")

    # Model filenames
    intermediate_model_fn = path.join(model_dir, "vae.tmp.ckpt")
    model_fn = path.join(model_dir, "vae.best_val.ckpt")

    # Model graph
    x = tf.placeholder(TF_DTYPE, [None, None, options_dict["n_input"]])
    x_lengths = tf.placeholder(TF_ITYPE, [None])
    network_dict = build_vae_from_options_dict(x, x_lengths, options_dict)
    encoder_states = network_dict["encoder_states"]
    vae = network_dict["latent_layer"]
    z_mean = vae["z_mean"]
    z_log_sigma_sq = vae["z_log_sigma_sq"]
    z = vae["z"]
    y = network_dict["decoder_output"]
    mask = network_dict["mask"]

    # VAE loss
    # reconstruction_loss = tf.reduce_mean(
    #     tf.reduce_sum(tf.reduce_mean(tf.square(x - y), -1), -1) /
    #     tf.reduce_sum(mask, 1)
    #     )  # https://danijar.com/variable-sequence-lengths-in-tensorflow/
    # loss = tflego.vae_loss_gaussian(
    #     x, y, options_dict["sigma_sq"], z_mean, z_log_sigma_sq,
    #     reconstruction_loss=reconstruction_loss
    #     )
    reconstruction_loss = 1. / (2 * options_dict["sigma_sq"]) * tf.reduce_mean(
        tf.reduce_sum(tf.reduce_mean(tf.square(x - y), -1), -1) /
        tf.reduce_sum(mask, 1)
    )  # https://danijar.com/variable-sequence-lengths-in-tensorflow/
    regularisation_loss = -0.5 * tf.reduce_sum(
        1 + z_log_sigma_sq - tf.square(z_mean) - tf.exp(z_log_sigma_sq), 1)
    loss = reconstruction_loss + tf.reduce_mean(regularisation_loss)
    # loss = tflego.vae_loss_gaussian(
    #     x, y, options_dict["sigma_sq"], z_mean, z_log_sigma_sq,
    #     reconstruction_loss=reconstruction_loss
    #     )
    optimizer = tf.train.AdamOptimizer(
        learning_rate=options_dict["learning_rate"]).minimize(loss)

    # TRAIN AND VALIDATE

    print(datetime.now())
    print("Training model")

    # Validation function
    def samediff_val(normalise=False):
        # Embed validation
        np.random.seed(options_dict["rnd_seed"])
        val_batch_iterator = batching.SimpleIterator(val_x, len(val_x), False)
        labels = [val_labels[i] for i in val_batch_iterator.indices]
        saver = tf.train.Saver()
        with tf.Session() as session:
            saver.restore(session, val_model_fn)
            for batch_x_padded, batch_x_lengths in val_batch_iterator:
                np_x = batch_x_padded
                np_x_lengths = batch_x_lengths
                np_z = session.run([z_mean],
                                   feed_dict={
                                       x: np_x,
                                       x_lengths: np_x_lengths
                                   })[0]
                break  # single batch

        embed_dict = {}
        for i, utt_key in enumerate(
            [val_keys[i] for i in val_batch_iterator.indices]):
            embed_dict[utt_key] = np_z[i]

        # Same-different
        if normalise:
            np_z_normalised = (np_z - np_z.mean(axis=0)) / np_z.std(axis=0)
            distances = pdist(np_z_normalised, metric="cosine")
            matches = samediff.generate_matches_array(labels)
            ap, prb = samediff.average_precision(distances[matches == True],
                                                 distances[matches == False])
        else:
            distances = pdist(np_z, metric="cosine")
            matches = samediff.generate_matches_array(labels)
            ap, prb = samediff.average_precision(distances[matches == True],
                                                 distances[matches == False])
        return [prb, -ap]

    # Train VAE
    val_model_fn = intermediate_model_fn
    if options_dict["train_tag"] == "rnd":
        train_batch_iterator = batching.RandomSegmentsIterator(
            train_x,
            options_dict["batch_size"],
            options_dict["n_buckets"],
            shuffle_every_epoch=True)
    else:
        train_batch_iterator = batching.SimpleBucketIterator(
            train_x,
            options_dict["batch_size"],
            options_dict["n_buckets"],
            shuffle_every_epoch=True)
    record_dict = training.train_fixed_epochs_external_val(
        options_dict["n_epochs"],
        optimizer,
        loss,
        train_batch_iterator, [x, x_lengths],
        samediff_val,
        save_model_fn=intermediate_model_fn,
        save_best_val_model_fn=model_fn,
        n_val_interval=options_dict["n_val_interval"])

    # Save record
    record_dict_fn = path.join(model_dir, "record_dict.pkl")
    print("Writing:", record_dict_fn)
    with open(record_dict_fn, "wb") as f:
        pickle.dump(record_dict, f, -1)

    # Save options_dict
    options_dict_fn = path.join(model_dir, "options_dict.pkl")
    print("Writing:" + options_dict_fn)
    with open(options_dict_fn, "wb") as f:
        pickle.dump(options_dict, f, -1)

    # FINAL EXTRINSIC EVALUATION

    print("Performing final validation")
    if options_dict["extrinsic_usefinal"]:
        val_model_fn = intermediate_model_fn
    else:
        val_model_fn = model_fn
    prb, ap = samediff_val(normalise=False)
    ap = -ap
    prb_normalised, ap_normalised = samediff_val(normalise=True)
    ap_normalised = -ap_normalised
    print("Validation AP:", ap)
    print("Validation AP with normalisation:", ap_normalised)
    ap_fn = path.join(model_dir, "val_ap.txt")
    print("Writing:", ap_fn)
    with open(ap_fn, "w") as f:
        f.write(str(ap) + "\n")
        f.write(str(ap_normalised) + "\n")
    print("Validation model:", val_model_fn)

    print(datetime.now())
示例#2
0
def train_cae(options_dict):
    """Train and save a CAE."""

    # PRELIMINARY
    assert (options_dict["train_tag"] != "rnd") or \
        (options_dict["cae_n_epochs"] == 0), \
        "random segment training only possible with AE (cae_n_epochs=0)"
    print(datetime.now())

    # Output directory
    hasher = hashlib.md5(repr(sorted(options_dict.items())).encode("ascii"))
    # hash_str = (
    #     datetime.now().strftime("%y%m%d.%Hh%M") + "." +
    #     # datetime.now().strftime("%y%m%d.%Hh%Mm%Ss") + "." +
    #     hasher.hexdigest()[:5]
    #     )
    hash_str = hasher.hexdigest()[:10]
    model_dir = path.join(
        "models", path.split(options_dict["data_dir"])[-1] + "." +
        options_dict["train_tag"],"cvae_model","70"
        )
    options_dict_fn = path.join(model_dir, "options_dict.pkl")
    print("Model directory:", model_dir)
    if not os.path.isdir(model_dir):
        os.makedirs(model_dir)
    print("Options:", options_dict)

    # Random seeds
    # random.seed(options_dict["rnd_seed"])
    np.random.seed(options_dict["rnd_seed"])
    tf.set_random_seed(options_dict["rnd_seed"])


    # LOAD AND FORMAT DATA

    # Training data
    train_tag = options_dict["train_tag"]
    min_length = None
    if options_dict["train_tag"] == "rnd":
        min_length = options_dict["min_length"]
        train_tag = "all"
    npz_fn = path.join(
        options_dict["data_dir"], "train." + train_tag + ".npz"
        )
    ##############################################
    # train_x, train_labels, train_lengths, train_keys, train_speakers = (
        # data_io.load_data_from_npz(npz_fn, min_length)
        # )
    ##############################################
    train_x, train_labels, train_lengths, train_keys = (data_io.load_data_from_npz(npz_fn, min_length))
    ##############################################
    # Pretraining data (if specified)
    pretrain_tag = options_dict["pretrain_tag"]
    if options_dict["pretrain_tag"] is not None:
        min_length = None
        if options_dict["pretrain_tag"] == "rnd":
            min_length = options_dict["min_length"]
            pretrain_tag = "all"
        npz_fn = path.join(
            options_dict["data_dir"], "train." + pretrain_tag + ".npz"
            )
        (pretrain_x, pretrain_labels, pretrain_lengths, pretrain_keys) = data_io.load_data_from_npz(npz_fn, min_length)

    # Validation data
    if options_dict["use_test_for_val"]:
        npz_fn = path.join(options_dict["data_dir"], "test.npz")
    else:
        npz_fn = path.join(options_dict["data_dir"], "val.npz")
    val_x, val_labels, val_lengths, val_keys = (
        data_io.load_data_from_npz(npz_fn)
        )

    # Convert training speakers, if speaker embeddings
    if options_dict["d_speaker_embedding"] is not None:
        train_speaker_set = set(train_speakers)
        speaker_to_id = {}
        id_to_speaker = {}
        for i, speaker in enumerate(sorted(list(train_speaker_set))):
            speaker_to_id[speaker] = i
            id_to_speaker[i] = speaker
        train_speaker_ids = []
        for speaker in train_speakers:
            train_speaker_ids.append(speaker_to_id[speaker])
        train_speaker_ids = np.array(train_speaker_ids, dtype=NP_ITYPE)
        options_dict["n_speakers"] = max(speaker_to_id.values()) + 1

    # Truncate and limit dimensionality
    max_length = options_dict["max_length"]
    # d_frame = 13  # None
    ########################################################
    d_frame = 108
    ########################################################
    options_dict["n_input"] = d_frame
    print("Limiting dimensionality:", d_frame)
    print("Limiting length:", max_length)
    data_io.trunc_and_limit_dim(train_x, train_lengths, d_frame, max_length)
    if options_dict["pretrain_tag"] is not None:
        data_io.trunc_and_limit_dim(
            pretrain_x, pretrain_lengths, d_frame, max_length
            )
    data_io.trunc_and_limit_dim(val_x, val_lengths, d_frame, max_length)

    # Get pairs
    pair_list = batching.get_pair_list(train_labels)
   # pair_list = batching.get_pair_list(train_labels, both_directions = False)
    print("No. pairs:", int(len(pair_list)/2.0))  # pairs in both directions
   # print("No. pairs:", len(pair_list))

    # DEFINE MODEL

    print(datetime.now())
    print("Building model")

    # Model filenames
    pretrain_intermediate_model_fn = path.join(model_dir, "ae.tmp.ckpt")
    pretrain_model_fn = path.join(model_dir, "ae.best_val.ckpt")
    intermediate_model_fn = path.join(model_dir, "cvae.tmp.ckpt")
    model_fn = path.join(model_dir, "cvae.best_val.ckpt")

    # Model graph
    a = tf.placeholder(TF_DTYPE, [None, None, options_dict["n_input"]])
    a_lengths = tf.placeholder(TF_ITYPE, [None])
    b = tf.placeholder(TF_DTYPE, [None, None, options_dict["n_input"]])
    b_lengths = tf.placeholder(TF_ITYPE, [None])
    network_dict = build_cae_from_options_dict(
        a, a_lengths, b_lengths, options_dict
        )


    #####################################
    z_mean = network_dict["z_mean"]
    z_log_sigma_sq = network_dict["z_log_sigma_sq"]
    z = network_dict["z"]
    ######################################
    mask = network_dict["mask"]
    # z = network_dict["z"]
    y = network_dict["y"]
    ##### y: [n_data, n_sample, maxlength, d_frame]
    if options_dict["d_speaker_embedding"] is not None:
        speaker_id = network_dict["speaker_id"]

    ######################################
    # # Reconstruction loss
    # # tf.reduce_sum(mask, 1) get how the real leangth of a datapoint (length, n_frame)
    # loss = tf.reduce_mean(
    #     tf.reduce_sum(tf.reduce_mean(tf.square(b - y), -1), -1) /
    #     tf.reduce_sum(mask, 1)
    #     )  # https://danijar.com/variable-sequence-lengths-in-tensorflow/
    # # # Temp
    # # alpha = 0.1
    # # loss += alpha*tf.reduce_mean(
    # #     tf.reduce_sum(tf.reduce_mean(tf.square(a - y), -1), -1) /
    # #     tf.reduce_sum(mask, 1)
    # #     )  # temp

    ######################################

    ##### y: [n_data, n_sample, maxlength, d_frame]
    b_cvae = tf.expand_dims(b, 1)
    temp = tf.reduce_sum(tf.reduce_mean(tf.square(b_cvae - y), -1), -1) / tf.reduce_sum(mask, -1)
    temp = tf.reduce_min(temp, -1)
    reconstruction_loss = 1./(2*options_dict["sigma_sq"]) * tf.reduce_mean(temp)
    # https://danijar.com/variable-sequence-lengths-in-tensorflow/
    regularisation_loss = -0.5*tf.reduce_sum(
        1 + z_log_sigma_sq - tf.square(z_mean) - tf.exp(z_log_sigma_sq), 1
        )
    loss = reconstruction_loss + tf.reduce_mean(regularisation_loss)
    ######################################

    optimizer = tf.train.AdamOptimizer(
        learning_rate=options_dict["learning_rate"]
        ).minimize(loss)


    ######################################
    # loss for autoencoder is different from that of cvae
    y_ae = tf.reduce_mean(y, 1)
    mask_ae = tf.reduce_mean(mask, 1)
    loss_ae = tf.reduce_mean(
        tf.reduce_sum(tf.reduce_mean(tf.square(b - y_ae), -1), -1) /
        tf.reduce_sum(mask_ae, 1)
        )  # https://danijar.com/variable-sequence-lengths-in-tensorflow/
    optimizer_ae = tf.train.AdamOptimizer(
        learning_rate=options_dict["learning_rate"]
        ).minimize(loss_ae)

    # AUTOENCODER PRETRAINING: TRAIN AND VALIDATE

    print(datetime.now())
    print("Pretraining model")

    # Validation function
    def samediff_val(normalise=True):
        # Embed validation
        np.random.seed(options_dict["rnd_seed"])
        val_batch_iterator = batching.SimpleIterator(val_x, len(val_x), False)
        labels = [val_labels[i] for i in val_batch_iterator.indices]
        saver = tf.train.Saver()
        with tf.Session() as session:
            saver.restore(session, val_model_fn)
            for batch_x_padded, batch_x_lengths in val_batch_iterator:
                np_x = batch_x_padded
                np_x_lengths = batch_x_lengths
                # np_z = session.run(
                    # [z], feed_dict={a: np_x, a_lengths: np_x_lengths}
                    # )[0]
                np_z = session.run(
                    [z_mean], feed_dict={a: np_x, a_lengths: np_x_lengths}
                    )[0]
                # print(np_z)
                break  # single batch

        embed_dict = {}
        for i, utt_key in enumerate(
                [val_keys[i] for i in val_batch_iterator.indices]):
            embed_dict[utt_key] = np_z[i]

        # Same-different
        if normalise:
            # print(np_z.shape)
            np_z_normalised = (np_z - np_z.mean(axis=0))/np_z.std(axis=0)
            distances = pdist(np_z_normalised, metric="cosine")
            matches = samediff.generate_matches_array(labels)
            ap, prb = samediff.average_precision(
                distances[matches == True], distances[matches == False]
                )
        else:
            distances = pdist(np_z, metric="cosine")
            matches = samediff.generate_matches_array(labels)
            ap, prb = samediff.average_precision(
                distances[matches == True], distances[matches == False]
                )    
        return [prb, -ap]

    # Train AE
    val_model_fn = pretrain_intermediate_model_fn
    if options_dict["pretrain_tag"] is not None:
        if options_dict["pretrain_tag"] == "rnd":
            train_batch_iterator = batching.RandomSegmentsIterator(
                pretrain_x, options_dict["ae_batch_size"],
                options_dict["ae_n_buckets"], shuffle_every_epoch=True,
                paired=True
                )
        else:
            train_batch_iterator = batching.PairedBucketIterator(
                pretrain_x, [(i, i) for i in range(len(pretrain_x))],
                options_dict["ae_batch_size"], options_dict["ae_n_buckets"],
                shuffle_every_epoch=True, speaker_ids=None if
                options_dict["d_speaker_embedding"] is None else
                train_speaker_ids
                )
    else:
        if options_dict["train_tag"] == "rnd":
            train_batch_iterator = batching.RandomSegmentsIterator(
                train_x, options_dict["ae_batch_size"],
                options_dict["ae_n_buckets"], shuffle_every_epoch=True,
                paired=True
                )
        else:
            train_batch_iterator = batching.PairedBucketIterator(
                train_x, [(i, i) for i in range(len(train_x))],
                options_dict["ae_batch_size"], options_dict["ae_n_buckets"],
                shuffle_every_epoch=True, speaker_ids=None if
                options_dict["d_speaker_embedding"] is None else
                train_speaker_ids
                )
    if options_dict["d_speaker_embedding"] is None:
        ae_record_dict = training.train_fixed_epochs_external_val(
            options_dict["ae_n_epochs"], optimizer_ae, loss_ae, train_batch_iterator,
            [a, a_lengths, b, b_lengths], samediff_val,
            save_model_fn=pretrain_intermediate_model_fn,
            save_best_val_model_fn=pretrain_model_fn,
            n_val_interval=options_dict["ae_n_val_interval"]
            )
    else:
        ae_record_dict = training.train_fixed_epochs_external_val(
            options_dict["ae_n_epochs"], optimizer_ae, loss_ae, train_batch_iterator,
            [a, a_lengths, b, b_lengths, speaker_id], samediff_val,
            save_model_fn=pretrain_intermediate_model_fn,
            save_best_val_model_fn=pretrain_model_fn,
            n_val_interval=options_dict["ae_n_val_interval"]
            )


    # CORRESPONDENCE TRAINING: TRAIN AND VALIDATE

    if options_dict["cae_n_epochs"] > 0:
        print("Training model")

        cae_pretrain_model_fn = pretrain_model_fn
        if options_dict["pretrain_usefinal"]:
            cae_pretrain_model_fn = pretrain_intermediate_model_fn
        if options_dict["ae_n_epochs"] == 0:
            cae_pretrain_model_fn = None

        # Train CAE
        val_model_fn = intermediate_model_fn
        ##########################################################
        train_batch_iterator = batching.PairedBucketIterator(
            train_x, pair_list, batch_size=options_dict["cae_batch_size"],
            n_buckets=options_dict["cae_n_buckets"], shuffle_every_epoch=True,
            speaker_ids=None if options_dict["d_speaker_embedding"] is None
            else train_speaker_ids
            )
        ##########################################################
        # train_batch_iterator = batching.PairedBucketIterator(
        #     train_x, [(i, i) for i in range(len(train_x))], batch_size=options_dict["cae_batch_size"],
        #     n_buckets=options_dict["cae_n_buckets"], shuffle_every_epoch=True,
        #     speaker_ids=None if options_dict["d_speaker_embedding"] is None
        #     else train_speaker_ids
        #     )
        ##########################################################
            
        if options_dict["d_speaker_embedding"] is None:
            cae_record_dict = training.train_fixed_epochs_external_val(
                options_dict["cae_n_epochs"], optimizer, loss,
                train_batch_iterator, [a, a_lengths, b, b_lengths],
                samediff_val, save_model_fn=intermediate_model_fn,
                save_best_val_model_fn=model_fn,
                n_val_interval=options_dict["cae_n_val_interval"],
               # load_model_fn=cae_pretrain_model_fn
                load_model_fn=path.join(
        "models", path.split(options_dict["data_dir"])[-1] + "." +
        options_dict["train_tag"],"pretrain_ae","50", "ae.best_val.ckpt"
        ))
        else:
            cae_record_dict = training.train_fixed_epochs_external_val(
                options_dict["cae_n_epochs"], optimizer, loss,
                train_batch_iterator, [a, a_lengths, b, b_lengths, speaker_id],
                samediff_val, save_model_fn=intermediate_model_fn,
                save_best_val_model_fn=model_fn,
                n_val_interval=options_dict["cae_n_val_interval"],
                load_model_fn=cae_pretrain_model_fn
                )

    # Save record
    record_dict_fn = path.join(model_dir, "record_dict.pkl")
    print("Writing:", record_dict_fn)
    with open(record_dict_fn, "wb") as f:
        pickle.dump(ae_record_dict, f, -1)
        if options_dict["cae_n_epochs"] > 0:
            pickle.dump(cae_record_dict, f, -1)

    # Save options_dict
    options_dict_fn = path.join(model_dir, "options_dict.pkl")
    print("Writing:", options_dict_fn)
    with open(options_dict_fn, "wb") as f:
        pickle.dump(options_dict, f, -1)


    # FINAL EXTRINSIC EVALUATION

    print ("Performing final validation")
    if options_dict["cae_n_epochs"] == 0:
        if options_dict["extrinsic_usefinal"]:
            val_model_fn = pretrain_intermediate_model_fn
        else:
            val_model_fn = pretrain_model_fn
    else:
        if options_dict["extrinsic_usefinal"]:
            val_model_fn = intermediate_model_fn
        else:
            val_model_fn = model_fn
    prb, ap = samediff_val(normalise=False)
    ap = -ap
    prb_normalised, ap_normalised = samediff_val(normalise=True)
    ap_normalised = -ap_normalised
    print("Validation AP:", ap)
    print("Validation AP with normalisation:", ap_normalised)
    ap_fn = path.join(model_dir, "val_ap.txt")
    print("Writing:", ap_fn)
    with open(ap_fn, "w") as f:
        f.write(str(ap) + "\n")
        f.write(str(ap_normalised) + "\n")
    print("Validation model:", val_model_fn)

    print(datetime.now())
示例#3
0
def train_cae(options_dict):
    """Train and save a CAE."""

    # PRELIMINARY
    assert (options_dict["train_tag"] != "rnd") or \
        (options_dict["cae_n_epochs"] == 0), \
        "random segment training only possible with AE (cae_n_epochs=0)"
    print(datetime.now())

    # Output directory
    hasher = hashlib.md5(repr(sorted(options_dict.items())).encode("ascii"))
    hash_str = hasher.hexdigest()[:10]
    model_dir = path.join(
        "models", options_dict["train_lang"] + "." + options_dict["train_tag"],
        options_dict["script"], hash_str)
    options_dict_fn = path.join(model_dir, "options_dict.pkl")
    print("Model directory:", model_dir)
    if not os.path.isdir(model_dir):
        os.makedirs(model_dir)
    print("Options:", options_dict)

    # Random seeds
    random.seed(options_dict["rnd_seed"])
    np.random.seed(options_dict["rnd_seed"])
    tf.set_random_seed(options_dict["rnd_seed"])

    # LOAD AND FORMAT DATA

    # Training data
    train_tag = options_dict["train_tag"]
    min_length = None
    if options_dict["train_tag"] == "rnd":
        min_length = options_dict["min_length"]
        train_tag = "all"
    if "+" in options_dict["train_lang"]:
        train_x = []
        train_labels = []
        train_lengths = []
        train_keys = []
        train_speakers = []
        train_languages = []
        for cur_lang in options_dict["train_lang"].split("+"):
            cur_npz_fn = path.join("data", cur_lang,
                                   "train." + train_tag + ".npz")
            (cur_train_x, cur_train_labels, cur_train_lengths, cur_train_keys,
             cur_train_speakers) = data_io.load_data_from_npz(
                 cur_npz_fn, min_length)
            (cur_train_x, cur_train_labels, cur_train_lengths, cur_train_keys,
             cur_train_speakers) = data_io.filter_data(
                 cur_train_x,
                 cur_train_labels,
                 cur_train_lengths,
                 cur_train_keys,
                 cur_train_speakers,
                 n_min_tokens_per_type=options_dict["n_min_tokens_per_type"],
                 n_max_types=options_dict["n_max_types"],
                 n_max_tokens=options_dict["n_max_tokens"],
                 n_max_tokens_per_type=options_dict["n_max_tokens_per_type"],
             )
            train_x.extend(cur_train_x)
            train_labels.extend(cur_train_labels)
            train_lengths.extend(cur_train_lengths)
            train_keys.extend(cur_train_keys)
            train_speakers.extend(cur_train_speakers)
            train_languages.extend([cur_lang] * len(cur_train_speakers))
        print("Total no. items:", len(train_labels))
    else:
        npz_fn = path.join("data", options_dict["train_lang"],
                           "train." + train_tag + ".npz")
        train_x, train_labels, train_lengths, train_keys, train_speakers = (
            data_io.load_data_from_npz(npz_fn, min_length))
        train_x, train_labels, train_lengths, train_keys, train_speakers = (
            data_io.filter_data(
                train_x,
                train_labels,
                train_lengths,
                train_keys,
                train_speakers,
                n_min_tokens_per_type=options_dict["n_min_tokens_per_type"],
                n_max_types=options_dict["n_max_types"],
                n_max_tokens=options_dict["n_max_tokens"],
                n_max_tokens_per_type=options_dict["n_max_tokens_per_type"],
            ))

    # Pretraining data (if specified)
    pretrain_tag = options_dict["pretrain_tag"]
    if options_dict["pretrain_tag"] is not None:
        min_length = None
        if options_dict["pretrain_tag"] == "rnd":
            min_length = options_dict["min_length"]
            pretrain_tag = "all"
        npz_fn = path.join("data", options_dict["train_lang"],
                           "train." + pretrain_tag + ".npz")
        (pretrain_x, pretrain_labels, pretrain_lengths, pretrain_keys,
         pretrain_speakers) = data_io.load_data_from_npz(npz_fn, min_length)

    # Validation data
    if options_dict["val_lang"] is not None:
        npz_fn = path.join("data", options_dict["val_lang"], "val.npz")
        val_x, val_labels, val_lengths, val_keys, val_speakers = (
            data_io.load_data_from_npz(npz_fn))

    # # Convert training speakers, if speaker embeddings
    # # To-do: Untested
    # if options_dict["d_speaker_embedding"] is not None:
    #     train_speaker_set = set(train_speakers)
    #     speaker_to_id = {}
    #     id_to_speaker = {}
    #     for i, speaker in enumerate(sorted(list(train_speaker_set))):
    #         speaker_to_id[speaker] = i
    #         id_to_speaker[i] = speaker
    #     train_speaker_ids = []
    #     for speaker in train_speakers:
    #         train_speaker_ids.append(speaker_to_id[speaker])
    #     train_speaker_ids = np.array(train_speaker_ids, dtype=NP_ITYPE)
    #     options_dict["n_speakers"] = max(speaker_to_id.values()) + 1

    # Convert training languages to integers, if language embeddings
    if options_dict["d_language_embedding"] is not None:
        train_language_set = set(train_languages)
        language_to_id = {}
        id_to_lang = {}
        for i, lang in enumerate(sorted(list(train_language_set))):
            language_to_id[lang] = i
            id_to_lang[i] = lang
        train_language_ids = []
        for lang in train_languages:
            train_language_ids.append(language_to_id[lang])
        train_language_ids = np.array(train_language_ids, dtype=NP_ITYPE)
        options_dict["n_languages"] = max(language_to_id.values()) + 1

    # Truncate and limit dimensionality
    max_length = options_dict["max_length"]
    d_frame = 13  # None
    options_dict["n_input"] = d_frame
    print("Limiting dimensionality:", d_frame)
    print("Limiting length:", max_length)
    data_io.trunc_and_limit_dim(train_x, train_lengths, d_frame, max_length)
    if options_dict["pretrain_tag"] is not None:
        data_io.trunc_and_limit_dim(pretrain_x, pretrain_lengths, d_frame,
                                    max_length)
    if options_dict["val_lang"] is not None:
        data_io.trunc_and_limit_dim(val_x, val_lengths, d_frame, max_length)

    # Get pairs
    pair_list = batching.get_pair_list(train_labels,
                                       both_directions=True,
                                       n_max_pairs=options_dict["n_max_pairs"])
    print("No. pairs:", int(len(pair_list) / 2.0))  # pairs in both directions

    # DEFINE MODEL

    print(datetime.now())
    print("Building model")

    # Model filenames
    pretrain_intermediate_model_fn = path.join(model_dir, "ae.tmp.ckpt")
    pretrain_model_fn = path.join(model_dir, "ae.best_val.ckpt")
    intermediate_model_fn = path.join(model_dir, "cae.tmp.ckpt")
    model_fn = path.join(model_dir, "cae.best_val.ckpt")

    # Model graph
    a = tf.placeholder(TF_DTYPE, [None, None, options_dict["n_input"]])
    a_lengths = tf.placeholder(TF_ITYPE, [None])
    b = tf.placeholder(TF_DTYPE, [None, None, options_dict["n_input"]])
    b_lengths = tf.placeholder(TF_ITYPE, [None])
    network_dict = build_cae_from_options_dict(a, a_lengths, b_lengths,
                                               options_dict)
    mask = network_dict["mask"]
    z = network_dict["z"]
    y = network_dict["y"]
    # if options_dict["d_speaker_embedding"] is not None:
    #     speaker_id = network_dict["speaker_id"]
    if options_dict["d_language_embedding"] is not None:
        language_id = network_dict["language_id"]

    # Reconstruction loss
    loss = tf.reduce_mean(
        tf.reduce_sum(tf.reduce_mean(tf.square(b - y), -1), -1) /
        tf.reduce_sum(mask, 1)
    )  # https://danijar.com/variable-sequence-lengths-in-tensorflow/
    optimizer = tf.train.AdamOptimizer(
        learning_rate=options_dict["learning_rate"]).minimize(loss)

    # Save options_dict
    options_dict_fn = path.join(model_dir, "options_dict.pkl")
    print("Writing:", options_dict_fn)
    with open(options_dict_fn, "wb") as f:
        pickle.dump(options_dict, f, -1)

    # AUTOENCODER PRETRAINING: TRAIN AND VALIDATE

    print(datetime.now())
    print("Pretraining model")

    # Validation function
    def samediff_val(normalise=True):
        # Embed validation
        np.random.seed(options_dict["rnd_seed"])
        val_batch_iterator = batching.SimpleIterator(val_x, len(val_x), False)
        labels = [val_labels[i] for i in val_batch_iterator.indices]
        speakers = [val_speakers[i] for i in val_batch_iterator.indices]
        saver = tf.train.Saver()
        with tf.Session() as session:
            saver.restore(session, val_model_fn)
            for batch_x_padded, batch_x_lengths in val_batch_iterator:
                np_x = batch_x_padded
                np_x_lengths = batch_x_lengths
                np_z = session.run([z],
                                   feed_dict={
                                       a: np_x,
                                       a_lengths: np_x_lengths
                                   })[0]
                break  # single batch

        embed_dict = {}
        for i, utt_key in enumerate(
            [val_keys[i] for i in val_batch_iterator.indices]):
            embed_dict[utt_key] = np_z[i]

        # Same-different
        if normalise:
            np_z_normalised = (np_z - np_z.mean(axis=0)) / np_z.std(axis=0)
            distances = pdist(np_z_normalised, metric="cosine")
        else:
            distances = pdist(np_z, metric="cosine")
        # matches = samediff.generate_matches_array(labels)
        # ap, prb = samediff.average_precision(
        #     distances[matches == True], distances[matches == False]
        #     )
        word_matches = samediff.generate_matches_array(labels)
        speaker_matches = samediff.generate_matches_array(speakers)
        sw_ap, sw_prb, swdp_ap, swdp_prb = samediff.average_precision_swdp(
            distances[np.logical_and(word_matches, speaker_matches)],
            distances[np.logical_and(word_matches, speaker_matches == False)],
            distances[word_matches == False])
        # return [sw_prb, -sw_ap, swdp_prb, -swdp_ap]
        return [swdp_prb, -swdp_ap]

    # Train AE
    val_model_fn = pretrain_intermediate_model_fn
    if options_dict["pretrain_tag"] is not None:
        if options_dict["pretrain_tag"] == "rnd":
            train_batch_iterator = batching.RandomSegmentsIterator(
                pretrain_x,
                options_dict["ae_batch_size"],
                options_dict["ae_n_buckets"],
                shuffle_every_epoch=True,
                paired=True)
        else:
            train_batch_iterator = batching.PairedBucketIterator(
                pretrain_x, [(i, i) for i in range(len(pretrain_x))],
                options_dict["ae_batch_size"],
                options_dict["ae_n_buckets"],
                shuffle_every_epoch=True,
                language_ids=None
                if options_dict["d_language_embedding"] is None else
                train_language_ids,
                flip_output=options_dict["flip_output"])
    else:
        if options_dict["train_tag"] == "rnd":
            train_batch_iterator = batching.RandomSegmentsIterator(
                train_x,
                options_dict["ae_batch_size"],
                options_dict["ae_n_buckets"],
                shuffle_every_epoch=True,
                paired=True)
        else:
            train_batch_iterator = batching.PairedBucketIterator(
                train_x, [(i, i) for i in range(len(train_x))],
                options_dict["ae_batch_size"],
                options_dict["ae_n_buckets"],
                shuffle_every_epoch=True,
                language_ids=None
                if options_dict["d_language_embedding"] is None else
                train_language_ids,
                flip_output=options_dict["flip_output"])
    if options_dict["d_language_embedding"] is None:
        if options_dict["val_lang"] is None:
            ae_record_dict = training.train_fixed_epochs(
                options_dict["ae_n_epochs"],
                optimizer,
                loss,
                train_batch_iterator, [a, a_lengths, b, b_lengths],
                save_model_fn=pretrain_intermediate_model_fn)
        else:
            ae_record_dict = training.train_fixed_epochs_external_val(
                options_dict["ae_n_epochs"],
                optimizer,
                loss,
                train_batch_iterator, [a, a_lengths, b, b_lengths],
                samediff_val,
                save_model_fn=pretrain_intermediate_model_fn,
                save_best_val_model_fn=pretrain_model_fn,
                n_val_interval=options_dict["ae_n_val_interval"])
    else:
        if options_dict["val_lang"] is None:
            ae_record_dict = training.train_fixed_epochs(
                options_dict["ae_n_epochs"],
                optimizer,
                loss,
                train_batch_iterator,
                [a, a_lengths, b, b_lengths, language_id],
                save_model_fn=pretrain_intermediate_model_fn)
        else:
            ae_record_dict = training.train_fixed_epochs_external_val(
                options_dict["ae_n_epochs"],
                optimizer,
                loss,
                train_batch_iterator,
                [a, a_lengths, b, b_lengths, language_id],
                samediff_val,
                save_model_fn=pretrain_intermediate_model_fn,
                save_best_val_model_fn=pretrain_model_fn,
                n_val_interval=options_dict["ae_n_val_interval"])

    # CORRESPONDENCE TRAINING: TRAIN AND VALIDATE

    if options_dict["cae_n_epochs"] > 0:
        print("Training model")

        cae_pretrain_model_fn = pretrain_model_fn
        if (options_dict["pretrain_usefinal"]
                or options_dict["val_lang"] is None):
            cae_pretrain_model_fn = pretrain_intermediate_model_fn
        if options_dict["ae_n_epochs"] == 0:
            cae_pretrain_model_fn = None

        # Train CAE
        val_model_fn = intermediate_model_fn
        train_batch_iterator = batching.PairedBucketIterator(
            train_x,
            pair_list,
            batch_size=options_dict["cae_batch_size"],
            n_buckets=options_dict["cae_n_buckets"],
            shuffle_every_epoch=True,
            language_ids=None if options_dict["d_language_embedding"] is None
            else train_language_ids,
            flip_output=options_dict["flip_output"])
        if options_dict["d_language_embedding"] is None:
            if options_dict["val_lang"] is None:
                cae_record_dict = training.train_fixed_epochs(
                    options_dict["cae_n_epochs"],
                    optimizer,
                    loss,
                    train_batch_iterator, [a, a_lengths, b, b_lengths],
                    save_model_fn=intermediate_model_fn,
                    load_model_fn=cae_pretrain_model_fn)
            else:
                cae_record_dict = training.train_fixed_epochs_external_val(
                    options_dict["cae_n_epochs"],
                    optimizer,
                    loss,
                    train_batch_iterator, [a, a_lengths, b, b_lengths],
                    samediff_val,
                    save_model_fn=intermediate_model_fn,
                    save_best_val_model_fn=model_fn,
                    n_val_interval=options_dict["cae_n_val_interval"],
                    load_model_fn=cae_pretrain_model_fn)
        else:
            if options_dict["val_lang"] is None:
                cae_record_dict = training.train_fixed_epochs(
                    options_dict["cae_n_epochs"],
                    optimizer,
                    loss,
                    train_batch_iterator,
                    [a, a_lengths, b, b_lengths, language_id],
                    samediff_val,
                    save_model_fn=intermediate_model_fn,
                    load_model_fn=cae_pretrain_model_fn)
            else:
                cae_record_dict = training.train_fixed_epochs_external_val(
                    options_dict["cae_n_epochs"],
                    optimizer,
                    loss,
                    train_batch_iterator,
                    [a, a_lengths, b, b_lengths, language_id],
                    samediff_val,
                    save_model_fn=intermediate_model_fn,
                    save_best_val_model_fn=model_fn,
                    n_val_interval=options_dict["cae_n_val_interval"],
                    load_model_fn=cae_pretrain_model_fn)

    # Save record
    record_dict_fn = path.join(model_dir, "record_dict.pkl")
    print("Writing:", record_dict_fn)
    with open(record_dict_fn, "wb") as f:
        pickle.dump(ae_record_dict, f, -1)
        if options_dict["cae_n_epochs"] > 0:
            pickle.dump(cae_record_dict, f, -1)

    # FINAL EXTRINSIC EVALUATION

    if options_dict["val_lang"] is not None:
        print("Performing final validation")
        if options_dict["cae_n_epochs"] == 0:
            if options_dict["extrinsic_usefinal"]:
                val_model_fn = pretrain_intermediate_model_fn
            else:
                val_model_fn = pretrain_model_fn
        else:
            if options_dict["extrinsic_usefinal"]:
                val_model_fn = intermediate_model_fn
            else:
                val_model_fn = model_fn
        # sw_prb, sw_ap, swdp_prb, swdp_ap = samediff_val(normalise=False)
        swdp_prb, swdp_ap = samediff_val(normalise=False)
        # sw_ap = -sw_ap
        swdp_ap = -swdp_ap
        swdp_prb_normalised, swdp_ap_normalised = samediff_val(normalise=True)
        # sw_ap_normalised = -sw_ap_normalised
        swdp_ap_normalised = -swdp_ap_normalised
        print("Validation SWDP AP:", swdp_ap)
        print("Validation SWDP AP with normalisation:", swdp_ap_normalised)
        ap_fn = path.join(model_dir, "val_ap.txt")
        print("Writing:", ap_fn)
        with open(ap_fn, "w") as f:
            f.write(str(swdp_ap) + "\n")
            f.write(str(swdp_ap_normalised) + "\n")
        print("Validation model:", val_model_fn)

    print(datetime.now())