def main():
    # --------------------------------------------------------------------------
    # Parse script args and handle options:
    # --------------------------------------------------------------------------
    ARGS = check_arguments()

    # Set numpy and tenorflow random seed
    np.random.seed(ARGS.random_seed)
    tf.set_random_seed(ARGS.random_seed)

    # Get specified model directory (default cwd)
    model_dir = ARGS.model_dir
    # Check if not using a previous run, and create a unique run directory
    if not os.path.exists(os.path.join(model_dir, LOG_FILENAME)):
        if not ARGS.no_unique_dir:
            unique_dir = "{}_{}_{}".format(
                'speech', ARGS.model_version,
                datetime.datetime.now().strftime("%y%m%d_%Hh%Mm%Ss_%f"))
            model_dir = os.path.join(model_dir, unique_dir)

    # Create directories if required ...
    if not os.path.exists(model_dir):
        os.makedirs(model_dir)

    # Set logging to print to console and log to file
    utils.set_logger(model_dir, log_fn=LOG_FILENAME)
    logging.info("Training speech model: version={}".format(
        ARGS.model_version))
    logging.info("Using model directory: {}".format(model_dir))

    # Save base parameters and exit if `--save-base-params` flag encountered
    if ARGS.save_base_params:
        base_params = speech.MODEL_BASE_PARAMS[ARGS.model_version].copy()
        base_params['model_version'] = ARGS.model_version
        base_params_path = os.path.join(model_dir, MODEL_PARAMS_STORE_FN)
        with open(base_params_path, 'w') as fp:
            logging.info("Writing base model parameters to file: {}"
                         "".format(base_params_path))
            json.dump(base_params, fp, indent=4)
        return  # exit ...

    # Load JSON model params from specified file or a previous run if available
    params_file = None
    model_params_store_fn = os.path.join(model_dir, MODEL_PARAMS_STORE_FN)
    if ARGS.params_file is not None:
        params_file = os.path.join(model_dir, ARGS.params_file)
        if not os.path.exists(params_file):
            logging.info("Could not find specified model parameters file: "
                         "{}.".format(params_file))
            return  # exit ...
        else:
            logging.info("Using stored model parameters file: "
                         "{}".format(params_file))
    elif os.path.exists(model_params_store_fn):
        params_file = model_params_store_fn
        logging.info("Using stored model parameters file: "
                     "{}".format(params_file))

    # If a model params file is found, load JSON into a model params dict
    if params_file is not None:
        try:
            with open(params_file, 'r') as fp:
                model_params = json.load(fp)
            logging.info("Successfully loaded JSON model parameters!")
        except json.JSONDecodeError as ex:
            logging.info("Could not read JSON model parameters! "
                         "Caught exception: {}".format(ex))
            return  # exit ...
    else:
        # Get the default base model params for the specified model version
        model_params = speech.MODEL_BASE_PARAMS[ARGS.model_version].copy()
        logging.info("No model parameters file found. "
                     "Using base model parameters.")

    # Read and write training and model options from specified/default args
    train_options = {}
    var_args = vars(ARGS)
    for arg in var_args:
        if arg in speech.base_model_dict:
            if var_args[arg] != -1:  # if value explicitly set for model param
                model_params[arg] = var_args[arg]
        else:
            train_options[arg] = getattr(ARGS, arg)
    logging.info("Training parameters:")
    for train_opt, opt_val in train_options.items():
        logging.info("\t{}: {}".format(train_opt, opt_val))
    train_options_path = os.path.join(model_dir, 'train_options.json')
    with open(train_options_path, 'w') as fp:
        logging.info("Writing most recent training parameters to file: {}"
                     "".format(train_options_path))
        json.dump(train_options, fp, indent=4)

    # --------------------------------------------------------------------------
    # Add additional model parameters and save:
    # --------------------------------------------------------------------------
    n_filters = 39 if (ARGS.feats_type == 'mfcc') else 40
    n_padding = ARGS.n_padded if ARGS.model_version != 'dtw' else None

    model_params['model_version'] = ARGS.model_version  # for later rebuilding
    model_params['x_input_shape'] = [None, n_filters, n_padding, 1]  # for data
    model_params['conv_input_shape'] = [n_filters, n_padding,
                                        1]  # conv reshape
    model_params['feats_type'] = ARGS.feats_type  # mfcc or fbanks
    model_params['n_padded'] = n_padding  # segment pad length
    model_params['center_padded'] = ARGS.center_padded  # center or end padded

    model_params_path = os.path.join(model_dir, MODEL_PARAMS_STORE_FN)
    with open(model_params_path, 'w') as fp:
        logging.info(
            "Writing model parameters to file: {}".format(model_params_path))
        json.dump(model_params, fp, indent=4)

    # For dtw model we simply want the model params, no training ...
    if ARGS.model_version == 'dtw':
        logging.info("Dynamic Time Warping (DTW) model params ready for test!")
        return

    # --------------------------------------------------------------------------
    # Load pre-train dataset:
    # --------------------------------------------------------------------------
    if ARGS.train_set == 'flickr-audio':  # load Flickr-Audio (default) train set
        logging.info(
            "Training speech model on dataset: {}".format('flickr-audio'))
        flickr_data = data.load_flickraudio(path=os.path.join(
            ARGS.data_dir, 'flickr_audio.npz'),
                                            feats_type=ARGS.feats_type,
                                            remove_labels=TIDIGITS_INTERSECTION
                                            )  # remove digit words from flickr
        train_data = flickr_data[0]
        val_data = flickr_data[1]

    else:  # load TIDigits train set
        logging.info("Training speech model on dataset: {}".format('tidigits'))
        tidigits_data = data.load_tidigits(path=os.path.join(
            ARGS.data_dir, 'tidigits_audio.npz'),
                                           feats_type=ARGS.feats_type,
                                           dev_size=ARGS.val_size)
        train_data = tidigits_data[0]
        val_data = tidigits_data[1]

    # --------------------------------------------------------------------------
    # Data processing pipeline (placed on CPU so GPU is free):
    # --------------------------------------------------------------------------
    with tf.device('/cpu:0'):
        # ------------------------------------
        # Create (pre-)train dataset pipeline:
        # ------------------------------------
        x_train = train_data[0]
        y_train = train_data[1]
        x_train_placeholder = tf.placeholder(
            TF_FLOAT, shape=[None, n_filters, n_padding])
        y_train_placeholder = tf.placeholder(TF_INT, shape=[None])
        # Get number of train batches/episodes per epoch
        if ARGS.n_train_episodes is not None:
            n_train_batches = ARGS.n_train_episodes
        else:
            n_train_batches = int(x_train.shape[0] / ARGS.batch_size)
        # Preprocess speech data and labels
        x_train = data.pad_sequences(x_train,
                                     ARGS.n_padded,
                                     center_padded=ARGS.center_padded)
        x_train = np.swapaxes(x_train, 2,
                              1)  # switch to shape (n_filters, n_pad)
        train_encoder = preprocessing.LabelEncoder(
        )  # encode labels to indices
        y_train = train_encoder.fit_transform(y_train)
        model_params['n_output_logits'] = np.unique(y_train).shape[0]
        # Shuffle data
        x_train_preprocess = tf.random_shuffle(x_train_placeholder,
                                               seed=ARGS.random_seed)
        y_train_preprocess = tf.random_shuffle(y_train_placeholder,
                                               seed=ARGS.random_seed)
        # Add single depth channel to feature image so it is a 'grayscale image'
        x_train_with_depth = tf.expand_dims(x_train_preprocess, axis=-1)
        # Use balanced batching pipeline if specified, else batch full dataset
        if ARGS.balanced_batching:
            train_pipeline = (data.batch_k_examples_for_p_concepts(
                x_data=x_train_with_depth,
                y_labels=y_train_preprocess,
                p_batch=ARGS.p_batch,
                k_batch=ARGS.k_batch,
                seed=ARGS.random_seed))
        else:
            train_pipeline = data.batch_dataset(x_data=x_train_with_depth,
                                                y_labels=y_train_preprocess,
                                                batch_size=ARGS.batch_size,
                                                shuffle=True,
                                                seed=ARGS.random_seed,
                                                drop_remainder=True)
        # Triplet sampling from data pipeline for offline siamese models
        if (ARGS.model_version == 'siamese_triplet'):
            train_pipeline = data.sample_dataset_triplets(
                train_pipeline,
                use_dummy_data=True,
                n_max_same_pairs=ARGS.max_offline_pairs)
        train_pipeline = train_pipeline.prefetch(
            1)  # prefetch 1 batch per step

        # --------------------------------------------
        # Create few-shot valdiation dataset pipeline:
        # --------------------------------------------
        x_val = val_data[0]
        y_val = val_data[1]
        z_val = val_data[2]
        x_val_placeholder = tf.placeholder(TF_FLOAT,
                                           shape=[None, n_filters, n_padding])
        y_val_placeholder = tf.placeholder(tf.string, shape=[None])
        z_val_placeholder = tf.placeholder(tf.string, shape=[None])
        # Preprocess speech data and labels
        x_val = data.pad_sequences(x_val,
                                   ARGS.n_padded,
                                   center_padded=ARGS.center_padded)
        x_val = np.swapaxes(x_val, 2, 1)  # switch to shape (n_filters, n_pad)
        # Add single depth channel to feature image so it is a 'grayscale image'
        x_val_with_depth = tf.expand_dims(x_val_placeholder, axis=-1)
        # Split data into disjoint support and query sets
        x_val_split, y_val_split, z_val_split = (data.make_train_test_split(
            x_val_with_depth,
            y_val_placeholder,
            z_val_placeholder,
            test_ratio=0.5,
            shuffle=True,
            seed=ARGS.random_seed))
        # Batch episodes of support and query sets for few-shot validation
        val_pipeline = (data.batch_few_shot_episodes(
            x_support_data=x_val_split[0],
            y_support_labels=y_val_split[0],
            z_support_originators=z_val_split[0],
            x_query_data=x_val_split[1],
            y_query_labels=y_val_split[1],
            z_query_originators=z_val_split[1],
            k_shot=ARGS.k_shot,
            l_way=ARGS.l_way,
            n_queries=ARGS.n_queries,
            seed=ARGS.random_seed))
        val_pipeline = val_pipeline.prefetch(1)  # prefetch 1 batch per step

        # Create pipeline iterators, and model train inputs
        train_iterator = train_pipeline.make_initializable_iterator()
        x_train_input, y_train_input = train_iterator.get_next()
        train_feed_dict = {
            x_train_placeholder: x_train,
            y_train_placeholder: y_train
        }
        val_iterator = val_pipeline.make_initializable_iterator()
        val_feed_dict = {
            x_val_placeholder: x_val,
            y_val_placeholder: y_val,
            z_val_placeholder: z_val
        }

    # --------------------------------------------------------------------------
    # Build, train, and validate model:
    # --------------------------------------------------------------------------
    # Build selected model version from base/loaded model params dict
    model_embedding, embed_input, train_flag, train_loss, train_metrics = (
        speech.build_speech_model(model_params,
                                  x_train_data=x_train_input,
                                  y_train_labels=y_train_input))
    # Get optimizer and training operation specified in model params dict
    optimizer_class = utils.literal_to_optimizer_class(
        train_options['optimizer'])
    train_optimizer = training.get_training_op(
        optimizer_class=optimizer_class,
        loss_func=train_loss,
        learn_rate=train_options['learning_rate'],
        decay_rate=train_options['decay_rate'],
        n_epoch_batches=n_train_batches)
    # Build few-shot 1-Nearest Neighbour memory comparison model
    query_input = tf.placeholder(TF_FLOAT, shape=[None, None])
    support_memory_input = tf.placeholder(TF_FLOAT, shape=[None, None])
    model_nn_memory = nearest_neighbour.fast_knn_cos(
        q_batch=query_input,
        m_keys=support_memory_input,
        k_nn=1,
        normalize=True)
    # Create tensorboard summaries
    tf.summary.scalar('loss', train_loss)
    for m_key, m_value in train_metrics.items():
        tf.summary.scalar(m_key, m_value)
    # Train the few-shot model
    train_few_shot_model(  # Train params:
        train_iterator=train_iterator,
        train_feed_dict=train_feed_dict,
        train_flag=train_flag,
        train_loss=train_loss,
        train_metrics=train_metrics,
        train_optimizer=train_optimizer,
        n_epochs=ARGS.n_max_epochs,
        max_batches=n_train_batches,
        # Validation params:
        val_iterator=val_iterator,
        val_feed_dict=val_feed_dict,
        model_embedding=model_embedding,
        embed_input=embed_input,
        query_input=query_input,
        support_memory_input=support_memory_input,
        nearest_neighbour=model_nn_memory,
        n_episodes=ARGS.n_test_episodes,
        # Other params:
        log_interval=int(n_train_batches / 5),
        model_dir=model_dir,
        summary_dir='summaries/train',
        save_filename='trained_model',
        restore_checkpoint=ARGS.restore_checkpoint)
def main():
    # --------------------------------------------------------------------------
    # Parse script args and handle options:
    # --------------------------------------------------------------------------
    ARGS = check_arguments()

    # Set numpy and tenorflow random seed
    np.random.seed(ARGS.random_seed)
    tf.set_random_seed(ARGS.random_seed)

    # Get specified model directory (default cwd)
    model_dir = ARGS.model_dir
    # Check if not using a previous run, and create a unique run directory
    if not os.path.exists(os.path.join(model_dir, LOG_FILENAME)):
        if not ARGS.no_unique_dir:
            unique_dir = "{}_{}_{}".format(
                'vision', ARGS.model_version,
                datetime.datetime.now().strftime("%y%m%d_%Hh%Mm%Ss_%f"))
            model_dir = os.path.join(model_dir, unique_dir)

    # Create directories if required ...
    if not os.path.exists(model_dir):
        os.makedirs(model_dir)

    # Set logging to print to console and log to file
    utils.set_logger(model_dir, log_fn=LOG_FILENAME)
    logging.info("Training vision model: version={}".format(
        ARGS.model_version))
    logging.info("Using model directory: {}".format(model_dir))

    # Save base parameters and exit if `--save-base-params` flag encountered
    if ARGS.save_base_params:
        base_params = vision.MODEL_BASE_PARAMS[ARGS.model_version].copy()
        base_params['model_version'] = ARGS.model_version
        base_params_path = os.path.join(model_dir, MODEL_PARAMS_STORE_FN)
        with open(base_params_path, 'w') as fp:
            logging.info("Writing base model parameters to file: {}"
                         "".format(base_params_path))
            json.dump(base_params, fp, indent=4)
        return  # exit ...

    # Load JSON model params from specified file or a previous run if available
    params_file = None
    model_params_store_fn = os.path.join(model_dir, MODEL_PARAMS_STORE_FN)
    if ARGS.params_file is not None:
        params_file = os.path.join(model_dir, ARGS.params_file)
        if not os.path.exists(params_file):
            logging.info("Could not find specified model parameters file: "
                         "{}.".format(params_file))
            return  # exit ...
        else:
            logging.info("Using stored model parameters file: "
                         "{}".format(params_file))
    elif os.path.exists(model_params_store_fn):
        params_file = model_params_store_fn
        logging.info("Using stored model parameters file: "
                     "{}".format(params_file))

    # If a model params file is found, load JSON into a model params dict
    if params_file is not None:
        try:
            with open(params_file, 'r') as fp:
                model_params = json.load(fp)
            logging.info("Successfully loaded JSON model parameters!")
        except json.JSONDecodeError as ex:
            logging.info("Could not read JSON model parameters! "
                         "Caught exception: {}".format(ex))
            return  # exit ...
    else:
        # Get the default base model params for the specified model version
        model_params = vision.MODEL_BASE_PARAMS[ARGS.model_version].copy()
        logging.info("No model parameters file found. "
                     "Using base model parameters.")

    # Read and write training and model options from specified/default args
    train_options = {}
    var_args = vars(ARGS)
    for arg in var_args:
        if arg in vision.base_model_dict:
            if var_args[arg] != -1:  # if value explicitly set for model param
                model_params[arg] = var_args[arg]
        else:
            train_options[arg] = getattr(ARGS, arg)
    logging.info("Training parameters:")
    for train_opt, opt_val in train_options.items():
        logging.info("\t{}: {}".format(train_opt, opt_val))
    train_options_path = os.path.join(model_dir, 'train_options.json')
    with open(train_options_path, 'w') as fp:
        logging.info("Writing most recent training parameters to file: {}"
                     "".format(train_options_path))
        json.dump(train_options, fp, indent=4)

    # --------------------------------------------------------------------------
    # Add additional model parameters and save:
    # --------------------------------------------------------------------------
    image_size = 105 if (ARGS.train_set == 'omniglot') else 28
    model_params['model_version'] = ARGS.model_version  # for later rebuilding
    model_params_path = os.path.join(model_dir, MODEL_PARAMS_STORE_FN)
    with open(model_params_path, 'w') as fp:
        print("Writing model parameters to file: {}".format(model_params_path))
        json.dump(model_params, fp, indent=4)

    # For pixel matching model we simply want the model params, no training ...
    if ARGS.model_version == 'pixels':
        logging.info("Pure pixel matching model params ready for test!")
        return

    # --------------------------------------------------------------------------
    # Load pre-train dataset:
    # --------------------------------------------------------------------------
    if ARGS.train_set == 'omniglot':  # load omniglot (default) train set
        logging.info("Training vision model on dataset: {}".format('omniglot'))
        train_data = data.load_omniglot(
            path=os.path.join(ARGS.data_dir, 'omniglot.npz'))
        inverse_data = True  # inverse omniglot grayscale
    else:  # load mnist train set
        logging.info("Training vision model on dataset: {}".format('mnist'))
        train_data = data.load_mnist()
        inverse_data = False  # don't inverse mnist grayscale

    # --------------------------------------------------------------------------
    # Data processing pipeline (placed on CPU so GPU is free):
    # --------------------------------------------------------------------------
    with tf.device('/cpu:0'):
        # ------------------------------------
        # Create (pre-)train dataset pipeline:
        # ------------------------------------
        x_train = train_data[0][0][ARGS.val_size:].copy()
        y_train = train_data[0][1][ARGS.val_size:]
        x_train_placeholder = tf.placeholder(
            TF_FLOAT, shape=[None, image_size, image_size])
        y_train_placeholder = tf.placeholder(TF_INT, shape=[None])
        # Get number of train batches/episodes per epoch
        if ARGS.n_train_episodes is not None:
            n_train_batches = ARGS.n_train_episodes
        else:
            n_train_batches = int(x_train.shape[0] / ARGS.batch_size)
        # Preprocess image data and labels
        x_train_preprocess = (
            data.preprocess_images(
                images=x_train_placeholder,
                normalize=True,
                inverse_gray=True,  # inverse omniglot
                resize_shape=model_params['resize_shape'],
                resize_method=tf.image.ResizeMethod.BILINEAR,
                expand_dims=True,
                dtype=TF_FLOAT))
        train_encoder = preprocessing.LabelEncoder(
        )  # encode labels to indices
        y_train = train_encoder.fit_transform(y_train)
        model_params['n_output_logits'] = np.unique(y_train).shape[0]
        # y_train = tf.cast(y_train, TF_INT)
        # Shuffle data
        x_train_preprocess = tf.random_shuffle(x_train_preprocess,
                                               seed=ARGS.random_seed)
        y_train_preprocess = tf.random_shuffle(y_train_placeholder,
                                               seed=ARGS.random_seed)
        # Use balanced batching pipeline if specified, else batch full dataset
        if ARGS.balanced_batching:
            train_pipeline = (data.batch_k_examples_for_p_concepts(
                x_data=x_train_preprocess,
                y_labels=y_train_preprocess,
                p_batch=ARGS.p_batch,
                k_batch=ARGS.k_batch,
                seed=ARGS.random_seed))


#         elif ARGS.model_version == 'siamese_triplet':
#             train_triplets = data.sample_triplets(x_data=x_train_preprocess,
#                                                   y_labels=y_train_preprocess,
#                                                   use_dummy_data=True,
#                                                   n_max_same_pairs=ARGS.max_offline_pairs)
#             x_triplet_data = tf.data.Dataset.zip((
#                 tf.data.Dataset.from_tensor_slices(train_triplets[0][0]).batch(ARGS.batch_size, drop_remainder=True),
#                 tf.data.Dataset.from_tensor_slices(train_triplets[0][1]).batch(ARGS.batch_size, drop_remainder=True),
#                 tf.data.Dataset.from_tensor_slices(train_triplets[0][1]).batch(ARGS.batch_size, drop_remainder=True)))
#             y_triplet_data = tf.data.Dataset.zip((
#                 tf.data.Dataset.from_tensor_slices(train_triplets[1][0]).batch(ARGS.batch_size, drop_remainder=True),
#                 tf.data.Dataset.from_tensor_slices(train_triplets[1][1]).batch(ARGS.batch_size, drop_remainder=True),
#                 tf.data.Dataset.from_tensor_slices(train_triplets[1][1]).batch(ARGS.batch_size, drop_remainder=True)))
#             train_pipeline = tf.data.Dataset.zip((x_triplet_data, y_triplet_data))
#             n_train_batches = 1000  # not sure of batch size, loop until out of range ...
#             # Quick hack for num triplet batches in offline siamese ...
# #             with tf.Session() as sess:
# #                 n_triplets = sess.run(tf.shape(train_triplets[0])[1], feed_dict={
# #                     x_train_placeholder: x_train, y_train_placeholder: y_train})
# #                 n_train_batches = int(n_triplets/ARGS.batch_size)
# #                 logging.info("Calculated triplet batches: {} batches for batch size {} (total triplets: {})"
# #                              .format(n_train_batches, ARGS.batch_size, n_triplets))
        else:
            train_pipeline = data.batch_dataset(x_data=x_train_preprocess,
                                                y_labels=y_train_preprocess,
                                                batch_size=ARGS.batch_size,
                                                shuffle=True,
                                                seed=ARGS.random_seed,
                                                drop_remainder=True)
        # Triplet sampling from data pipeline for offline siamese models
        if (ARGS.model_version == 'siamese_triplet'):
            train_pipeline = data.sample_dataset_triplets(
                train_pipeline,
                use_dummy_data=True,
                n_max_same_pairs=ARGS.max_offline_pairs)
        train_pipeline = train_pipeline.prefetch(
            1)  # prefetch 1 batch per step

        # --------------------------------------------
        # Create few-shot valdiation dataset pipeline:
        # --------------------------------------------
        x_val = train_data[0][0][ARGS.val_size:]
        y_val = train_data[0][1][ARGS.val_size:]
        x_val_placeholder = tf.placeholder(
            TF_FLOAT, shape=[None, image_size, image_size])
        y_val_placeholder = tf.placeholder(tf.string, shape=[None])
        # Preprocess image data and labels
        x_val_preprocess = (
            data.preprocess_images(
                images=x_val_placeholder,
                normalize=True,
                inverse_gray=True,  # inverse omniglot
                resize_shape=model_params['resize_shape'],
                resize_method=tf.image.ResizeMethod.BILINEAR,
                expand_dims=True,
                dtype=TF_FLOAT))
        # y_val = tf.cast(y_val, TF_INT)
        # Split data into disjoint support and query sets
        x_val_split, y_val_split = (data.make_train_test_split(
            x_val_preprocess,
            y_val_placeholder,
            test_ratio=0.5,
            shuffle=True,
            seed=ARGS.random_seed))
        # Batch episodes of support and query sets for few-shot validation
        val_pipeline = (  #val_support_pipeline, val_query_pipeline = (
            data.batch_few_shot_episodes(x_support_data=x_val_split[0],
                                         y_support_labels=y_val_split[0],
                                         x_query_data=x_val_split[1],
                                         y_query_labels=y_val_split[1],
                                         k_shot=ARGS.k_shot,
                                         l_way=ARGS.l_way,
                                         n_queries=ARGS.n_queries,
                                         seed=ARGS.random_seed))
        val_pipeline = val_pipeline.prefetch(1)  # prefetch 1 batch per step

        # Create pipeline iterators and model train inputs
        train_iterator = train_pipeline.make_initializable_iterator()
        x_train_input, y_train_input = train_iterator.get_next()
        train_feed_dict = {
            x_train_placeholder: x_train,
            y_train_placeholder: y_train
        }
        val_iterator = val_pipeline.make_initializable_iterator()
        val_feed_dict = {x_val_placeholder: x_val, y_val_placeholder: y_val}

    # --------------------------------------------------------------------------
    # Build, train, and validate model:
    # --------------------------------------------------------------------------
    # Build selected model version from base/loaded model params dict
    model_embedding, embed_input, train_flag, train_loss, train_metrics = (
        vision.build_vision_model(model_params,
                                  x_train_data=x_train_input,
                                  y_train_labels=y_train_input))
    # Get optimizer and training operation specified in model params dict
    optimizer_class = utils.literal_to_optimizer_class(
        train_options['optimizer'])
    train_optimizer = training.get_training_op(
        optimizer_class=optimizer_class,
        loss_func=train_loss,
        learn_rate=train_options['learning_rate'],
        decay_rate=train_options['decay_rate'],
        n_epoch_batches=n_train_batches)
    # Build few-shot 1-Nearest Neighbour memory comparison model
    query_input = tf.placeholder(TF_FLOAT, shape=[None, None])
    support_memory_input = tf.placeholder(TF_FLOAT, shape=[None, None])
    model_nn_memory = nearest_neighbour.fast_knn_cos(
        q_batch=query_input,
        m_keys=support_memory_input,
        k_nn=1,
        normalize=True)
    # Create tensorboard summaries
    tf.summary.scalar('loss', train_loss)
    for m_key, m_value in train_metrics.items():
        tf.summary.scalar(m_key, m_value)
    # Train the few-shot model
    train_few_shot_model(  # Train params:
        train_iterator=train_iterator,
        train_feed_dict=train_feed_dict,
        train_flag=train_flag,
        train_loss=train_loss,
        train_metrics=train_metrics,
        train_optimizer=train_optimizer,
        n_epochs=ARGS.n_max_epochs,
        max_batches=n_train_batches,
        # Validation params:
        val_iterator=val_iterator,
        val_feed_dict=val_feed_dict,
        model_embedding=model_embedding,
        embed_input=embed_input,
        query_input=query_input,
        support_memory_input=support_memory_input,
        nearest_neighbour=model_nn_memory,
        n_episodes=ARGS.n_test_episodes,
        # Other params:
        log_interval=int(n_train_batches / 5),
        model_dir=model_dir,
        summary_dir='summaries/train',
        save_filename='trained_model',
        restore_checkpoint=ARGS.restore_checkpoint)
def main():
    # --------------------------------------------------------------------------
    # Parse script args and handle options:
    # --------------------------------------------------------------------------
    ARGS = check_arguments()

    # Set numpy and tenorflow random seed
    np.random.seed(ARGS.random_seed)
    tf.set_random_seed(ARGS.random_seed)

    # Get specified model and output directories
    speech_model_dir = ARGS.speech_model_dir
    vision_model_dir = ARGS.vision_model_dir
    test_model_dir = ARGS.output_dir

    # Check if not using a previous run, and create a unique run directory
    if not os.path.exists(os.path.join(test_model_dir, LOG_FILENAME)):
        unique_dir = "{}_{}".format(
            'multimodal_test',
            datetime.datetime.now().strftime("%y%m%d_%Hh%Mm%Ss_%f"))
        test_model_dir = os.path.join(test_model_dir, unique_dir)

    # Create directories
    if not os.path.exists(test_model_dir):
        os.makedirs(test_model_dir)

    # Set logging to print to console and log to file
    utils.set_logger(test_model_dir, log_fn=LOG_FILENAME)
    logging.info("Using vision model directory: {}".format(vision_model_dir))
    logging.info("Using speech model directory: {}".format(speech_model_dir))

    # Load JSON model params
    speech_model_params = load_model_params(speech_model_dir,
                                            ARGS.params_file,
                                            modality='speech')
    vision_model_params = load_model_params(vision_model_dir,
                                            ARGS.params_file,
                                            modality='vision')
    if speech_model_params is None or vision_model_params is None:
        return  # exit ...

    # Read and write testing options from specified/default args
    test_options = {}
    var_args = vars(ARGS)
    for arg in var_args:
        test_options[arg] = getattr(ARGS, arg)
    logging.info("Testing parameters: {}".format(test_options))
    test_options_path = os.path.join(test_model_dir, 'test_options.json')
    with open(test_options_path, 'w') as fp:
        logging.info("Writing most recent testing parameters to file: {}"
                     "".format(test_options_path))
        json.dump(test_options, fp, indent=4)

    # --------------------------------------------------------------------------
    # Get additional model parameters:
    # --------------------------------------------------------------------------
    feats_type = speech_model_params['feats_type']
    n_padding = speech_model_params['n_padded']
    center_padded = speech_model_params['center_padded']
    n_filters = 39 if (feats_type == 'mfcc') else 40
    image_size = 28 if (ARGS.test_set == 'digits') else None  # TODO(rpeloff)

    if n_padding is None or speech_model_params['model_version'] == 'dtw':
        n_padding = 110  # pad to longest segment length in TIDigits (DTW)
        center_padded = False

    # --------------------------------------------------------------------------
    # Load test datasets:
    # --------------------------------------------------------------------------
    if ARGS.test_set == 'digits':  # load digits (default) test set
        # Load MNIST data arrays
        logging.info("Testing vision model on dataset: {}".format('mnist'))
        vision_test_data = data.load_mnist()
        vision_inverse_data = False  # don't inverse mnist grayscale

        logging.info("Testing speech model on dataset: {}".format('tidigits'))
        tidigits_data = data.load_tidigits(path=os.path.join(
            ARGS.speech_data_dir, 'tidigits_audio.npz'),
                                           feats_type=feats_type)
        speech_test_data = tidigits_data[2]
    else:  # load flickr test set TODO(rpeloff)
        raise NotImplementedError()

    # --------------------------------------------------------------------------
    # Data processing pipeline (placed on CPU so GPU is free):
    # --------------------------------------------------------------------------
    with tf.device('/cpu:0'):
        # ---------------------------------------------
        # Create speech few-shot test dataset pipeline:
        # ---------------------------------------------
        x_speech_test = speech_test_data[0]
        y_speech_test = speech_test_data[1]
        z_speech_test = speech_test_data[2]
        x_speech_test_placeholder = tf.placeholder(
            TF_FLOAT, shape=[None, n_filters, n_padding])
        y_speech_test_placeholder = tf.placeholder(tf.string, shape=[None])
        z_speech_test_placeholder = tf.placeholder(tf.string, shape=[None])
        # Preprocess speech data and labels
        x_speech_test = data.pad_sequences(x_speech_test,
                                           n_padding,
                                           center_padded=center_padded)
        x_speech_test = np.swapaxes(x_speech_test, 2, 1)  # (n_filters, n_pad)
        if not ARGS.zeros_different:  # treat 'oh' and 'zero' as same class
            y_speech_test = [
                word if word != 'o' else 'z' for word in y_speech_test
            ]
        # Add single depth channel to feature image so it is a 'grayscale image'
        x_speech_test_with_depth = tf.expand_dims(x_speech_test_placeholder,
                                                  axis=-1)
        # Split data into disjoint support and query sets
        x_speech_test_split, y_speech_test_split, z_speech_test_split = (
            data.make_train_test_split(x_speech_test_with_depth,
                                       y_speech_test_placeholder,
                                       z_speech_test_placeholder,
                                       test_ratio=0.5,
                                       shuffle=True,
                                       seed=ARGS.random_seed))

        # ---------------------------------------------
        # Create vision few-shot test dataset pipeline:
        # ---------------------------------------------
        x_vision_test = vision_test_data[1][0]
        y_vision_test = vision_test_data[1][1]
        x_vision_test_placeholder = tf.placeholder(
            TF_FLOAT, shape=[None, image_size, image_size])
        y_vision_test_placeholder = tf.placeholder(tf.string, shape=[None])
        # Preprocess image data and labels
        x_vision_test_preprocess = (data.preprocess_images(
            images=x_vision_test_placeholder,
            normalize=True,
            inverse_gray=vision_inverse_data,
            resize_shape=vision_model_params['resize_shape'],
            resize_method=tf.image.ResizeMethod.BILINEAR,
            expand_dims=True,
            dtype=TF_FLOAT))
        y_vision_test = np.array(
            [  # convert MNIST classes to TIDigits labels
                DIGITS_LABELS[digit] for digit in y_vision_test
            ],
            dtype=str)
        if ARGS.zeros_different:  # treat 'oh' and 'zero' as different classes
            zero_ind = np.where(np.isin(y_vision_test, 'z'))[0]
            oh_ind = np.random.choice(zero_ind,
                                      size=int(zero_ind.shape[0] / 2),
                                      replace=False)
            y_vision_test[
                oh_ind] = 'o'  # random replace half of 'zero' to 'oh'
        # Split data into disjoint support and query sets
        x_vision_test_split, y_vision_test_split = (data.make_train_test_split(
            x_vision_test_preprocess,
            y_vision_test_placeholder,
            test_ratio=0.5,
            shuffle=True,
            seed=ARGS.random_seed))

        # -----------------------------------------
        # Create mulitmodal few-shot test pipeline:
        # -----------------------------------------
        if ARGS.query_type == 'speech':
            speech_matching_set = False
            speech_queries = ARGS.n_queries
            vision_matching_set = True
            vision_queries = ARGS.l_way
        else:
            speech_matching_set = True
            speech_queries = ARGS.l_way
            vision_matching_set = False
            vision_queries = ARGS.n_queries
        # Create multimodal few-shot episode label set
        episode_label_set = data.create_episode_label_set(
            y_speech_test_split[0],
            y_speech_test_split[1],
            y_vision_test_split[0],
            y_vision_test_split[1],
            z_difficult_originators=(z_speech_test_split[0]
                                     if ARGS.originator_type == 'difficult'
                                     else None),
            l_way=ARGS.l_way,
            seed=ARGS.random_seed)
        # Batch episodes of support/query/matching sets for few-shot speech test
        speech_test_pipeline = (data.batch_few_shot_episodes(
            x_support_data=x_speech_test_split[0],
            y_support_labels=y_speech_test_split[0],
            z_support_originators=z_speech_test_split[0],
            x_query_data=x_speech_test_split[1],
            y_query_labels=y_speech_test_split[1],
            z_query_originators=z_speech_test_split[1],
            episode_label_set=episode_label_set,
            make_matching_set=speech_matching_set,
            originator_type=ARGS.originator_type,
            k_shot=ARGS.k_shot,
            l_way=ARGS.l_way,
            n_queries=speech_queries,
            seed=ARGS.random_seed))
        # Batch episodes of support/query/matching sets for few-shot vision test
        vision_test_pipeline = (data.batch_few_shot_episodes(
            x_support_data=x_vision_test_split[0],
            y_support_labels=y_vision_test_split[0],
            x_query_data=x_vision_test_split[1],
            y_query_labels=y_vision_test_split[1],
            episode_label_set=episode_label_set,
            make_matching_set=vision_matching_set,
            k_shot=ARGS.k_shot,
            l_way=ARGS.l_way,
            n_queries=vision_queries,
            seed=ARGS.random_seed))
        speech_test_pipeline = speech_test_pipeline.prefetch(
            1)  # prefetch 1 batch per step
        vision_test_pipeline = vision_test_pipeline.prefetch(
            1)  # prefetch 1 batch per step

        # Create pipeline iterators
        speech_test_iterator = speech_test_pipeline.make_initializable_iterator(
        )
        vision_test_iterator = vision_test_pipeline.make_initializable_iterator(
        )
        test_feed_dict = {
            x_speech_test_placeholder: x_speech_test,
            y_speech_test_placeholder: y_speech_test,
            z_speech_test_placeholder: z_speech_test,
            x_vision_test_placeholder: x_vision_test,
            y_vision_test_placeholder: y_vision_test
        }

    # --------------------------------------------------------------------------
    # Build, train, and validate model:
    # --------------------------------------------------------------------------
    # Build speech model version from loaded model params dict
    speech_graph = tf.Graph()
    with speech_graph.as_default():  #pylint: disable=E1129
        speech_model_embedding, speech_embed_input, speech_train_flag, _, _ = (
            speech.build_speech_model(speech_model_params, training=False))
    # Build selected model version from loaded model params dict
    vision_graph = tf.Graph()
    with vision_graph.as_default():  #pylint: disable=E1129
        vision_model_embedding, vision_embed_input, vision_train_flag, _, _ = (
            vision.build_vision_model(vision_model_params, training=False))
    # Build few-shot 1-Nearest Neighbour memory comparison model
    query_input = tf.placeholder(TF_FLOAT, shape=[None, None])
    support_memory_input = tf.placeholder(TF_FLOAT, shape=[None, None])
    model_nn_memory = nearest_neighbour.fast_knn_cos(
        q_batch=query_input,
        m_keys=support_memory_input,
        k_nn=1,
        normalize=True)
    # Check if test using Dynamic Time Warping instead of 1-NN for speech model
    test_dtw = False
    dtw_cost_func = None
    dtw_post_process = None
    if speech_model_params['model_version'] == 'dtw':
        test_dtw = True
        dtw_cost_func = speech_dtw.multivariate_dtw_cost_cosine
        dtw_post_process = lambda x: np.ascontiguousarray(  # as cython C-order
            np.swapaxes(  # time on x-axis for DTW
                _get_unpadded_image(np.squeeze(x, axis=-1), n_padding), 1, 0),
            dtype=float)
    # Check if test using pure pixel matching
    test_pixels = False
    if vision_model_params['model_version'] == 'pixels':
        test_pixels = True
    # Test the multimodal few-shot model
    test_mulitmodal_few_shot_model(
        test_feed_dict=test_feed_dict,
        # Speech test params ...
        speech_graph=speech_graph,
        speech_train_flag=speech_train_flag,
        speech_test_iterator=speech_test_iterator,
        speech_model_embedding=speech_model_embedding,
        speech_embed_input=speech_embed_input,
        # Vision test params ...
        vision_graph=vision_graph,
        vision_train_flag=vision_train_flag,
        vision_test_iterator=vision_test_iterator,
        vision_model_embedding=vision_model_embedding,
        vision_embed_input=vision_embed_input,
        # Nearest neigbour params ...
        query_input=query_input,
        support_memory_input=support_memory_input,
        nearest_neighbour=model_nn_memory,
        n_episodes=ARGS.n_test_episodes,
        query_type=ARGS.query_type,
        test_pixels=test_pixels,
        test_dtw=test_dtw,
        dtw_cost_func=dtw_cost_func,
        dtw_post_process=dtw_post_process,
        test_invariance=(ARGS.originator_type == 'same'
                         or ARGS.originator_type == 'difficult'),
        # Other params ...
        log_interval=int(ARGS.n_test_episodes / 10),
        model_dir=test_model_dir,
        speech_model_dir=speech_model_dir,
        vision_model_dir=vision_model_dir,
        summary_dir='summaries/test',
        speech_restore_checkpoint=ARGS.speech_restore_checkpoint,
        vision_restore_checkpoint=ARGS.vision_restore_checkpoint)
Beispiel #4
0
def main():
    # --------------------------------------------------------------------------
    # Parse script args and handle options:
    # --------------------------------------------------------------------------
    ARGS = check_arguments()
    
    # Set numpy and tenorflow random seed
    np.random.seed(ARGS.random_seed)
    tf.set_random_seed(ARGS.random_seed)

    # Get specified model and directories (default cwd)
    model_dir = ARGS.model_dir
    test_model_dir = ARGS.output_dir
    if test_model_dir is None:
        test_model_dir = model_dir
    else:
        test_model_dir = os.path.abspath(test_model_dir)
    
    # Check if not using a previous run, and create a unique run directory
    if not os.path.exists(os.path.join(test_model_dir, LOG_FILENAME)):
        unique_dir = "{}_{}".format(
            'vision_test', 
            datetime.datetime.now().strftime("%y%m%d_%Hh%Mm%Ss_%f"))
        test_model_dir = os.path.join(test_model_dir, unique_dir)
    
    # Create directories
    if not os.path.exists(test_model_dir):
        os.makedirs(test_model_dir)
    
    # Set logging to print to console and log to file
    utils.set_logger(test_model_dir, log_fn=LOG_FILENAME)
    logging.info("Using model directory: {}".format(model_dir))

    # Load JSON model params from specified file or a previous run if available
    model_params_store_fn = os.path.join(model_dir, MODEL_PARAMS_STORE_FN)
    if ARGS.params_file is not None:
        params_file = os.path.join(model_dir, ARGS.params_file)
        if not os.path.exists(params_file):
            logging.info("Could not find specified model parameters file: "
                         "{}.".format(params_file))
            return  # exit ...
        else:
            logging.info("Using stored model parameters file: "
                         "{}".format(params_file))
    elif os.path.exists(model_params_store_fn):
        params_file = model_params_store_fn
        logging.info("Using stored model parameters file: "
                     "{}".format(params_file))
    else:
        logging.info("Model parameters file {} could not be found!"
                     "".format(model_params_store_fn))
        return  # exit ...

    # Load JSON into a model params dict 
    try:
        with open(params_file, 'r') as fp:
            model_params = json.load(fp)
        logging.info("Successfully loaded JSON model parameters!")
        logging.info("Testing vision model: version={}".format(
            model_params['model_version']))
    except json.JSONDecodeError as ex:
        logging.info("Could not read JSON model parameters! "
                        "Caught exception: {}".format(ex))
        return  # exit ...
    
    # Read and write testing options from specified/default args
    test_options = {}
    var_args = vars(ARGS)
    for arg in var_args:
        test_options[arg] = getattr(ARGS, arg)
    logging.info("Testing parameters: {}".format(test_options))
    test_options_path = os.path.join(test_model_dir, 'test_options.json')
    with open(test_options_path, 'w') as fp:
        logging.info("Writing most recent testing parameters to file: {}"
                        "".format(test_options_path))
        json.dump(test_options, fp, indent=4)

    # --------------------------------------------------------------------------
    # Get additional model parameters:
    # --------------------------------------------------------------------------
    image_size = 28 if (ARGS.test_set == 'mnist') else 105

    # --------------------------------------------------------------------------
    # Load test dataset:
    # --------------------------------------------------------------------------
    if ARGS.test_set == 'mnist':  # load mnist (default) test set
        # Load MNIST data arrays
        logging.info("Testing vision model on dataset: {}".format('mnist'))
        test_data = data.load_mnist()
        inverse_data = False  # don't inverse mnist grayscale
    else:  # load omniglot test set
        logging.info("Testing vision model on dataset: {}".format('omniglot'))
        test_data = data.load_omniglot(
            path=os.path.join(ARGS.data_dir, 'omniglot.npz'))
        inverse_data = True  # inverse omniglot grayscale

    # --------------------------------------------------------------------------
    # Data processing pipeline (placed on CPU so GPU is free):
    # --------------------------------------------------------------------------
    with tf.device('/cpu:0'):
        # --------------------------------------------
        # Create few-shot test dataset pipeline:
        # --------------------------------------------
        x_test = test_data[1][0]
        y_test = [str(label) for label in test_data[1][1]]
        x_test_placeholder = tf.placeholder(TF_FLOAT, 
                                            shape=[None, image_size, image_size])
        y_test_placeholder = tf.placeholder(tf.string, shape=[None])
        # Preprocess image data
        x_test_preprocess = (
            data.preprocess_images(images=x_test_placeholder,
                                   normalize=True,
                                   inverse_gray=inverse_data,  
                                   resize_shape=model_params['resize_shape'],
                                   resize_method=tf.image.ResizeMethod.BILINEAR,
                                   expand_dims=True,
                                   dtype=TF_FLOAT))
        # Split data into disjoint support and query sets
        x_test_split, y_test_split = (
            data.make_train_test_split(x_test_preprocess,
                                       y_test_placeholder,
                                       test_ratio=0.5,
                                       shuffle=True,
                                       seed=ARGS.random_seed))
        # Batch episodes of support and query sets for few-shot validation
        test_pipeline = (  #val_support_pipeline, val_query_pipeline = (
            data.batch_few_shot_episodes(x_support_data=x_test_split[0],
                                         y_support_labels=y_test_split[0],
                                         x_query_data=x_test_split[1],
                                         y_query_labels=y_test_split[1],
                                         k_shot=ARGS.k_shot,
                                         l_way=ARGS.l_way,
                                         n_queries=ARGS.n_queries,
                                         seed=ARGS.random_seed))
        test_pipeline = test_pipeline.prefetch(1)  # prefetch 1 batch per step
        # Create pipeline iterator
        test_iterator = test_pipeline.make_initializable_iterator()
        test_feed_dict = {
            x_test_placeholder: x_test,
            y_test_placeholder: y_test
        }

    # --------------------------------------------------------------------------
    # Build, train, and validate model:
    # --------------------------------------------------------------------------
    # Build selected model version from base/loaded model params dict
    model_embedding, embed_input, train_flag, _, _ = (
        vision.build_vision_model(model_params, training=False))
    # Build few-shot 1-Nearest Neighbour memory comparison model
    query_input = tf.placeholder(TF_FLOAT, shape=[None, None])
    support_memory_input = tf.placeholder(TF_FLOAT, shape=[None, None])
    model_nn_memory = nearest_neighbour.fast_knn_cos(q_batch=query_input,
                                                     m_keys=support_memory_input,
                                                     k_nn=1,
                                                     normalize=True)
    # Check if test using pure pixel matching
    test_pixels = False
    if model_params['model_version'] == 'pixels':
        test_pixels = True
    # Test the few-shot model
    test_few_shot_model(# Test params:
                        train_flag=train_flag,
                        test_iterator=test_iterator,
                        test_feed_dict=test_feed_dict,
                        model_embedding=model_embedding,
                        embed_input=embed_input,
                        query_input=query_input,
                        support_memory_input=support_memory_input,
                        nearest_neighbour=model_nn_memory,
                        n_episodes=ARGS.n_test_episodes,
                        test_pixels=test_pixels,
                        # Other params:
                        log_interval=int(ARGS.n_test_episodes/10),
                        model_dir=model_dir,
                        output_dir=test_model_dir,
                        summary_dir='summaries/test',
                        restore_checkpoint=ARGS.restore_checkpoint)
Beispiel #5
0
def main():
    # --------------------------------------------------------------------------
    # Parse script args and handle options:
    # --------------------------------------------------------------------------
    ARGS = check_arguments()

    # Set numpy and tenorflow random seed
    np.random.seed(ARGS.random_seed)
    tf.set_random_seed(ARGS.random_seed)

    # Get specified model directory (default cwd)
    model_dir = ARGS.model_dir
    test_model_dir = ARGS.output_dir
    if test_model_dir is None:
        test_model_dir = model_dir
    else:
        test_model_dir = os.path.abspath(test_model_dir)

    # Check if not using a previous run, and create a unique run directory
    if not os.path.exists(os.path.join(test_model_dir, LOG_FILENAME)):
        unique_dir = "{}_{}".format(
            'speech_test',
            datetime.datetime.now().strftime("%y%m%d_%Hh%Mm%Ss_%f"))
        test_model_dir = os.path.join(test_model_dir, unique_dir)

    # Create directories
    if not os.path.exists(test_model_dir):
        os.makedirs(test_model_dir)

    # Set logging to print to console and log to file
    utils.set_logger(test_model_dir, log_fn=LOG_FILENAME)
    logging.info("Using model directory: {}".format(model_dir))

    # Load JSON model params from specified file or a previous run if available
    model_params_store_fn = os.path.join(model_dir, MODEL_PARAMS_STORE_FN)
    if ARGS.params_file is not None:
        params_file = os.path.join(model_dir, ARGS.params_file)
        if not os.path.exists(params_file):
            logging.info("Could not find specified model parameters file: "
                         "{}.".format(params_file))
            return  # exit ...
        else:
            logging.info("Using stored model parameters file: "
                         "{}".format(params_file))
    elif os.path.exists(model_params_store_fn):
        params_file = model_params_store_fn
        logging.info("Using stored model parameters file: "
                     "{}".format(params_file))
    else:
        logging.info("Model parameters file {} could not be found!"
                     "".format(model_params_store_fn))
        return  # exit ...

    # Load JSON into a model params dict
    try:
        with open(params_file, 'r') as fp:
            model_params = json.load(fp)
        logging.info("Successfully loaded JSON model parameters!")
        logging.info("Testing speech model: version={}".format(
            model_params['model_version']))
    except json.JSONDecodeError as ex:
        logging.info("Could not read JSON model parameters! "
                     "Caught exception: {}".format(ex))
        return  # exit ...

    # Read and write testing options from specified/default args
    test_options = {}
    var_args = vars(ARGS)
    for arg in var_args:
        test_options[arg] = getattr(ARGS, arg)
    logging.info("Testing parameters: {}".format(test_options))
    test_options_path = os.path.join(test_model_dir, 'test_options.json')
    with open(test_options_path, 'w') as fp:
        logging.info("Writing most recent testing parameters to file: {}"
                     "".format(test_options_path))
        json.dump(test_options, fp, indent=4)

    # --------------------------------------------------------------------------
    # Get additional model parameters:
    # --------------------------------------------------------------------------
    feats_type = model_params['feats_type']
    n_padding = model_params['n_padded']
    center_padded = model_params['center_padded']
    n_filters = 39 if (feats_type == 'mfcc') else 40

    if n_padding is None or model_params['model_version'] == 'dtw':
        n_padding = 110  # pad to longest segment length in TIDigits (DTW)
        center_padded = False

    # --------------------------------------------------------------------------
    # Load test dataset:
    # --------------------------------------------------------------------------
    if ARGS.test_set == 'flickr-audio':  # load Flickr-Audio test set
        logging.info(
            "Testing speech model on dataset: {}".format('flickr-audio'))
        flickr_data = data.load_flickraudio(path=os.path.join(
            ARGS.data_dir, 'flickr_audio.npz'),
                                            feats_type=feats_type,
                                            remove_labels=TIDIGITS_INTERSECTION
                                            )  # remove digit words from flickr
        test_data = flickr_data[2]

    else:  # load TIDigits (default) test set
        logging.info("Testing speech model on dataset: {}".format('tidigits'))
        tidigits_data = data.load_tidigits(path=os.path.join(
            ARGS.data_dir, 'tidigits_audio.npz'),
                                           feats_type=feats_type)
        test_data = tidigits_data[2]

    # --------------------------------------------------------------------------
    # Data processing pipeline (placed on CPU so GPU is free):
    # --------------------------------------------------------------------------
    with tf.device('/cpu:0'):
        # --------------------------------------------
        # Create few-shot test dataset pipeline:
        # --------------------------------------------
        x_test = test_data[0]
        y_test = test_data[1]
        z_test = test_data[2]
        x_test_placeholder = tf.placeholder(TF_FLOAT,
                                            shape=[None, n_filters, n_padding])
        y_test_placeholder = tf.placeholder(tf.string, shape=[None])
        z_test_placeholder = tf.placeholder(tf.string, shape=[None])
        # Preprocess speech data
        x_test = data.pad_sequences(x_test,
                                    n_padding,
                                    center_padded=center_padded)
        x_test = np.swapaxes(x_test, 2, 1)  # swap to (n_filters, n_pad)
        # Add single depth channel to feature image so it is a 'grayscale image'
        x_test_with_depth = tf.expand_dims(x_test_placeholder, axis=-1)
        # Split data into disjoint support and query sets
        x_test_split, y_test_split, z_test_split = (data.make_train_test_split(
            x_test_with_depth,
            y_test_placeholder,
            z_test_placeholder,
            test_ratio=0.5,
            shuffle=True,
            seed=ARGS.random_seed))
        # Batch episodes of support and query sets for few-shot validation
        test_pipeline = (data.batch_few_shot_episodes(
            x_support_data=x_test_split[0],
            y_support_labels=y_test_split[0],
            z_support_originators=z_test_split[0],
            x_query_data=x_test_split[1],
            y_query_labels=y_test_split[1],
            z_query_originators=z_test_split[1],
            originator_type=ARGS.originator_type,
            k_shot=ARGS.k_shot,
            l_way=ARGS.l_way,
            n_queries=ARGS.n_queries,
            seed=ARGS.random_seed))
        test_pipeline = test_pipeline.prefetch(1)  # prefetch 1 batch per step

        # Create pipeline iterator
        test_iterator = test_pipeline.make_initializable_iterator()
        test_feed_dict = {
            x_test_placeholder: x_test,
            y_test_placeholder: y_test,
            z_test_placeholder: z_test
        }

    # --------------------------------------------------------------------------
    # Build, train, and validate model:
    # --------------------------------------------------------------------------
    # Build selected model version from base/loaded model params dict
    model_embedding, embed_input, train_flag, _, _ = (
        speech.build_speech_model(model_params, training=False))
    # Build few-shot 1-Nearest Neighbour memory comparison model
    query_input = tf.placeholder(TF_FLOAT, shape=[None, None])
    support_memory_input = tf.placeholder(TF_FLOAT, shape=[None, None])
    model_nn_memory = nearest_neighbour.fast_knn_cos(
        q_batch=query_input,
        m_keys=support_memory_input,
        k_nn=1,
        normalize=True)
    # Check if test using Dynamic Time Warping instead of the memory model above
    test_dtw = False
    dtw_cost_func = None
    dtw_post_process = None
    if model_params['model_version'] == 'dtw':
        test_dtw = True
        dtw_cost_func = speech_dtw.multivariate_dtw_cost_cosine
        dtw_post_process = lambda x: np.ascontiguousarray(  # as cython C-order
            np.swapaxes(  # time on x-axis for DTW
                _get_unpadded_image(np.squeeze(x, axis=-1), n_padding), 1, 0),
            dtype=float)
    # Test the few-shot model ...
    test_few_shot_model(  # Test params:
        train_flag=train_flag,
        test_iterator=test_iterator,
        test_feed_dict=test_feed_dict,
        model_embedding=model_embedding,
        embed_input=embed_input,
        query_input=query_input,
        support_memory_input=support_memory_input,
        nearest_neighbour=model_nn_memory,
        n_episodes=ARGS.n_test_episodes,
        test_dtw=test_dtw,
        dtw_cost_func=dtw_cost_func,
        dtw_post_process=dtw_post_process,
        test_invariance=(ARGS.originator_type == 'same'
                         or ARGS.originator_type == 'difficult'),
        # Other params:
        log_interval=int(ARGS.n_test_episodes / 10),
        model_dir=model_dir,
        output_dir=test_model_dir,
        summary_dir='summaries/test',
        restore_checkpoint=ARGS.restore_checkpoint)