示例#1
0
def main():
    import vg.flickr8k_provider as dp_f
    import vg.simple_data as sd
    batch_size = 16
    prov_flickr = dp_f.getDataProvider('flickr8k', root='.', audio_kind='mfcc')
    data_flickr = sd.SimpleData(prov_flickr,
                                tokenize=sd.characters,
                                min_df=1,
                                scale=False,
                                batch_size=batch_size,
                                shuffle=False)
    net = torch.load(
        "experiments/s2-t.-s2i2-s2t.-t2s.-s2d0-d1-embed-128-joint-e/model.23.pkl"
    )
    #net = torch.load("experiments/s2-t1-s2i2-s2t0-t2s0-s2d0-d1-joint-f/model.19.pkl")
    net.SpeechTranscriber.TextDecoder.Decoder.RNN.flatten_parameters()
    net.SpeechTranscriber.SpeechEncoderBottom.RNN.flatten_parameters()
    #batches = data_flickr.iter_valid_batches()
    batches = data_flickr.iter_train_batches()
    first = next(batches)
    texts = list(data_flickr.mapper.inverse_transform(first['input']))
    args = net.SpeechTranscriber.args(first)
    args = [torch.autograd.Variable(torch.from_numpy(x)).cuda() for x in args]
    audio, target_t, target_prev_t = args
    for j in range(16):
        print(''.join(texts[j]))
        for seq in transcribe(net, audio[j:j + 1], K=5, maxlen=25):

            vals, ids = zip(*seq)

            chars = list(data_flickr.mapper.inverse_transform([ids]))[0]
            text = ''.join(
                ['_' if char == '<BEG>' else char for char in chars])
            print("{:.2f} {}".format(sum(vals), text))
        print()
示例#2
0
def phoneme_data(nets,
                 alignment_path="./data/flickr8k/dataset.val.fa.json",
                 batch_size=64):
    """Generate data for training a phoneme decoding model."""
    import vg.flickr8k_provider as dp
    logging.getLogger().setLevel('INFO')
    logging.info("Loading alignments")
    data = {}
    for line in open(alignment_path):
        item = json.loads(line)
        data[item['sentid']] = item
    logging.info("Loading audio features")
    prov = dp.getDataProvider('flickr8k', root='.', audio_kind='mfcc')
    val = list(prov.iterSentences(split='val'))
    alignments_all = [data[sent['sentid']] for sent in val]
    alignments = [
        item for item in alignments_all
        if np.all([word.get('start', False) for word in item['words']])
    ]
    sentids = set(item['sentid'] for item in alignments)
    audio = [sent['audio'] for sent in val if sent['sentid'] in sentids]
    result = {}
    logging.info("Computing data for MFCC")
    y, X = phoneme_activations(audio, alignments, mfcc=True)
    result['mfcc'] = fa_data(y, X)
    for name, net in nets:
        logging.info("Computing data for {}".format(name))
        activations = vg.activations.get_state_stack(net,
                                                     audio,
                                                     batch_size=batch_size)
        y, X = phoneme_activations(activations, alignments, mfcc=False)
        result[name] = fa_data(y, X)
    return result
示例#3
0
def main():

    prov = dp.getDataProvider('flickr8k', root='.', audio_kind='mfcc')
    sent = list(prov.iterSentences(split='val'))[:1000]
    net = torch.load("models/stack-s2-t.-s2i2-s2t.-t2s.-t2i.--f/model.23.pkl")
    audio = [s['audio'] for s in sent]
    trans = [s['raw'] for s in sent]
    stack = vg.activations.get_state_stack(net, audio, batch_size=16)
    np.save("state_stack_flickr8k_val.npy", stack)
    np.save("transcription_flickr8k_val.npy", trans)
示例#4
0
def rsa_results():
    """Table 4 (tab:rsa)"""
    from sklearn.metrics.pairwise import cosine_similarity

    import vg.flickr8k_provider as dp_f
    logging.getLogger().setLevel('INFO')
    logging.info("Loading data")
    prov_flickr = dp_f.getDataProvider('flickr8k',
                                       root='..',
                                       audio_kind='mfcc')
    logging.info("Setting up scorer")
    scorer = S.Scorer(
        prov_flickr,
        dict(split='val', tokenize=lambda sent: sent['audio'], batch_size=16))
    # SIMS
    logging.info("Computing MFCC similarity matrix")
    mfcc = np.array([audio.mean(axis=0) for audio in scorer.sentence_data])
    sim = {}
    sim['mfcc'] = cosine_similarity(mfcc)
    sim['text'] = scorer.string_sim
    sim['image'] = scorer.sim_images
    # PRED 1 s2i
    logging.info("Computing M1,s2i similarity matrix")
    net = load_best_run('{}/s2-t.-s2i2-s2t.-t2s.-t2i.'.format(PREFIX), cond='')
    pred = S.encode_sentences(net,
                              scorer.sentence_data,
                              batch_size=scorer.config['batch_size'])
    sim['m1,s2i'] = cosine_similarity(pred)
    # PRED 6 s2i
    logging.info("Computing M6,s2i similarity matrix")
    net = load_best_run('{}/s2-t1-s2i2-s2t0-t2s0-t2i1'.format(PREFIX),
                        cond='joint')
    pred = S.encode_sentences(net,
                              scorer.sentence_data,
                              batch_size=scorer.config['batch_size'])
    sim['m6,s2i'] = cosine_similarity(pred)
    # PRED 6 s2t
    logging.info("Computing M6,s2t similarity matrix")
    pred = S.encode_sentences_SpeechText(
        net, scorer.sentence_data, batch_size=scorer.config['batch_size'])
    sim['m6,s2t'] = cosine_similarity(pred)
    logging.info("Computing RSA scores")
    rows = []
    cols = {'mfcc': [], 'text': [], 'image': []}
    for row in ['m1,s2i', 'm6,s2i', 'm6,s2t', 'image']:
        rows.append(row)
        for col in ['mfcc', 'text', 'image']:
            cols[col].append(S.RSA(sim[row], sim[col]))
    return pd.DataFrame(data=cols, index=rows)
示例#5
0
def phoneme_decoding_data(
        nets,
        alignment_path="../data/flickr8k/dataset.val.fa.json",
        dataset_path="../data/flickr8k/dataset.json",
        max_size=5000,
        directory="."):
    """Generate data for training a phoneme decoding model."""
    import vg.flickr8k_provider as dp

    logging.getLogger().setLevel('INFO')
    logging.info("Loading alignments")
    data = {}
    for line in open(alignment_path):
        item = json.loads(line)
        data[item['sentid']] = item
    logging.info("Loading audio features")
    prov = dp.getDataProvider('flickr8k', root='..', audio_kind='mfcc')
    val = list(prov.iterSentences(split='val'))
    data_filter = [(data[sent['sentid']], sent) for sent in val if np.all(
        [word.get('start', False) for word in data[sent['sentid']]['words']])]
    data_filter = data_filter[:max_size]
    data_state = [
        phoneme for (utt, sent) in data_filter
        for phoneme in slices(utt, sent['audio'])
    ]
    result = {}
    logging.info("Extracting MFCC examples")
    result['mfcc'] = fa_data(data_state)
    for name, net in nets:
        result[name] = {}
        L = 1
        S = net.SpeechEncoderBottom.stride
        logging.info("Extracting recurrent layer states")
        audio = [sent['audio'] for utt, sent in data_filter]
        states = get_layer_states(net, audio, batch_size=32)
        layer = 0

        def aggregate(x):
            return x[:, layer, :].mean(axis=0)

        data_state = [
            phoneme for i in range(len(data_filter))
            for phoneme in slices(data_filter[i][0],
                                  states[i],
                                  index=lambda x: index(x, stride=S),
                                  aggregate=aggregate)
        ]
        result[name] = fa_data(data_state)
    return result
示例#6
0
# Parse command line parameters
parser = argparse.ArgumentParser()
parser.add_argument('path', metavar='path', help='Model\'s path', nargs='+')
parser.add_argument('-t',
                    help='Test mode',
                    dest='testmode',
                    action='store_true',
                    default=False)
args = parser.parse_args()

# Setup test mode
if args.testmode:
    epochs = 1
    limit = 100

prov_flickr = dp_f.getDataProvider('flickr8k', root='..', audio_kind='mfcc')
data_flickr = sd.SimpleData(prov_flickr,
                            tokenize=sd.characters,
                            min_df=1,
                            scale=False,
                            batch_size=batch_size,
                            shuffle=True,
                            limit=limit,
                            limit_val=limit)


def get_audio(sent):
    return sent['audio']


scorer = vg.scorer.ScorerASR(
示例#7
0
 def __init__(self, root='.', truncate=None):
     self.places = places.getDataProvider('places', root=root, truncate=truncate)
     self.flickr = flickr.getDataProvider('flickr8k', root=root)