示例#1
0
def LD_alt(context_size, history_sizes, input_shape, recurrent_name,
           pooling_name, n_hidden, dropout_rate, path_to_results,
           is_base_network, with_embdedding_layer, word_vectors_name,
           fine_tune_word_vectors, word_vectors, with_extra_features):
    """
    This implementation is similar to the paper's structure, but not exactly the same.

    Three *non-shared* S2V models will be created for three inputs (X_i-2, X_i-1, X_i).

    e.g. history_sizes = (1,1)
    Outputs from previous layer (s_i-2, s_i-1, s_i)
    will be *concatenated* to s_i-2s_i-1, s_i-1s_i,
    then feed to two *non-shared* ordinary Dense layers

    :param context_size:
    :param history_sizes: refers to the caption of Figue.2 in the paper

    :param word_vectors:
    :param input_length:
    :param layer_type:
    :param pooling_type:
    :param n_hidden:
    :param dropout_rate:
    :return:
    """
    if with_extra_features:
        inputs = [[
            Input(shape=input_shape, dtype='int32'),
            Input(shape=(21, ), dtype='float32')
        ] for _ in range(context_size)]
    else:
        inputs = [
            Input(shape=input_shape, dtype='int32')
            for _ in range(context_size)
        ]

    S2Vs = [
        S2V(input_shape,
            recurrent_name,
            pooling_name,
            n_hidden,
            dropout_rate,
            path_to_results,
            is_base_network,
            with_embdedding_layer,
            word_vectors_name,
            fine_tune_word_vectors,
            word_vectors,
            with_extra_features,
            with_last_f_f_layer=False) for _ in range(context_size)
    ]

    outputs_queue = [[S2Vs[i](inputs[i]) for i in range(context_size)]]

    # loop on network level
    for history_size in history_sizes:
        # [s1,s2,s3] -> [[s1], [s1, s2], [s1, s2, s3], [s2], [s2, s3], [s3]]
        tmp = [
            outputs_queue[-1][i:j] for i, j in itertools.combinations(
                range(len(outputs_queue[-1]) + 1), 2)
        ]  # outputs_queue[-1] refers to the outputs of last layer

        if history_size > 0:
            # e.g. history_size=1 -> [[s1, s2], [s2, s3]] -> [s1s2, s2s3]
            concatenateds = [
                concatenate(ss) for ss in tmp if len(ss) == history_size + 1
            ]
        else:
            # history_size=0 -> [[s1], [s2], [s3]] -> [s1, s2, s3]
            concatenateds = [
                ss[0] for ss in tmp if len(ss) == history_size + 1
            ]

        # Define FF
        FFs = [
            Dense(units=n_hidden, activation='tanh')
            for _ in range(len(concatenateds))
        ]

        # add new outputs to outputs_queue
        outputs_queue.append(
            [FFs[i](concatenateds[i]) for i in range(len(concatenateds))])

    # The network should have single output, if history_sizes is well-defined
    assert 1 == len(outputs_queue[-1])

    if with_extra_features:
        inputs = utlis.flatten(inputs)

    model = Model(inputs, outputs_queue[-1], name='base_network')

    plot_model(model,
               show_shapes=True,
               to_file=path_to_results + 'base_network.png')
    model.summary()

    return model
示例#2
0
def LD(context_size, history_sizes, input_shape, recurrent_name, pooling_name,
       n_hidden, dropout_rate, path_to_results, is_base_network,
       with_embdedding_layer, word_vectors_name, fine_tune_word_vectors,
       word_vectors, with_extra_features):
    """
    This implementation is the same with the paper's description

    Single *shared* S2V model will be created for three inputs (X_i-2, X_i-1, X_i).

    e.g. history_sizes = (1,1)
    Outputs from previous layer (s_i-2, s_i-1, s_i)

    p2 = dot(s_i-2, W_-2), p1 = dot(s_i-1, W_-1), p0=dot(s_i, W_0)
    # realised with Dense(activation=None, use_bias=False)

    y_i-1 = tanh(add(p2+p1)+bias_of_layer_1)
    y_i = tanh(add(p1+p0)+bias_of_layer_1)
    # *shared* bias_of_layer_1 is realised with custom Bias layer

    :param context_size:
    :param history_sizes: refers to the caption of Figue.2 in the paper

    :param word_vectors:
    :param input_length:
    :param layer_type:
    :param pooling_type:
    :param n_hidden:
    :param dropout_rate:
    :return:
    """
    if with_extra_features:
        inputs = [[
            Input(shape=input_shape, dtype='int32'),
            Input(shape=(21, ), dtype='float32')
        ] for _ in range(context_size)]
    else:
        inputs = [
            Input(shape=input_shape, dtype='int32')
            for _ in range(context_size)
        ]

    s2v = S2V(input_shape,
              recurrent_name,
              pooling_name,
              n_hidden,
              dropout_rate,
              path_to_results,
              is_base_network,
              with_embdedding_layer,
              word_vectors_name,
              fine_tune_word_vectors,
              word_vectors,
              with_extra_features,
              with_last_f_f_layer=False)

    outpus_queue = [[s2v(inputs[i]) for i in range(context_size)]]

    # loop on network level
    for history_size in history_sizes:
        activations = [
            Dense(n_hidden, activation=None,
                  use_bias=False)(outpus_queue[-1][i])
            for i in range(len(outpus_queue[-1]))
        ]  # outputs_queue[-1] refers to the outputs of last layer

        # [w1s1,w2s2,w3s3] -> [[w1s1], [w1s1, w3s2], [w1s1, w2s2, w3s3], [w2s2], [w2s2, w3s3], [w3s3]]
        tmp = [
            activations[i:j]
            for i, j in itertools.combinations(range(len(activations) + 1), 2)
        ]

        if history_size > 0:
            # history_size=1 -> [[w1s1, w2s2], [w2s2, w3s3]] -> [w1s1+w2s2, w2s2+w3s3]
            adds = [add(ss) for ss in tmp if len(ss) == history_size + 1]
        else:
            # history_size=0 -> [[w1s1], [w2s2], [w3s3]] -> [w1s1, w2s2, w3s3]
            adds = [ss[0] for ss in tmp if len(ss) == history_size + 1]

        # Define FF
        a = Activation('tanh')
        b = Bias()

        # add new outputs to outputs_queue
        outpus_queue.append([a(b(adds[i])) for i in range(len(adds))])

    # The network should have single output, if history_sizes is well-defined
    assert 1 == len(outpus_queue[-1])

    if with_extra_features:
        inputs = utlis.flatten(inputs)

    model = Model(inputs, outpus_queue[-1], name='base_network')

    plot_model(model,
               show_shapes=True,
               to_file=path_to_results + 'base_network.png')
    model.summary()

    return model
示例#3
0
def TIXIER(pre_context_size, post_context_size, input_shape, recurrent_name,
           pooling_name, n_hidden, dropout_rate, path_to_results,
           is_base_network, with_embdedding_layer, word_vectors_name,
           fine_tune_word_vectors, word_vectors, with_extra_features):
    if pre_context_size <= 0:
        raise ValueError('pre_context_size should greater than 0')

    context_size = pre_context_size + 1 + post_context_size
    if with_extra_features:
        inputs = [[
            Input(shape=input_shape, dtype='int32'),
            Input(shape=(21, ), dtype='float32')
        ] for _ in range(context_size)]
    else:
        inputs = [
            Input(shape=input_shape, dtype='int32')
            for _ in range(context_size)
        ]

    sub_model = get_sub_model(input_shape, n_hidden, dropout_rate,
                              word_vectors_name, fine_tune_word_vectors,
                              word_vectors, with_extra_features)
    attention_layer = AttentionWithVec(attend_mode='sum')
    # convert list of 2D tensors to a 3D tensor (None, left_context_size, n_hidden)
    stack_layer = Lambda(K.stack, arguments={'axis': 1})
    expand_dims_layer = Lambda(K.expand_dims, arguments={'axis': 1})

    current = sub_model(inputs[pre_context_size])
    current_pooled = Sum()(current)

    pre = expand_dims_layer(
        AttentionWithTimeDecay(reverse_decay=True)(stack_layer([
            attention_layer([sub_model(inputs[i]), current_pooled])
            for i in range(pre_context_size)
        ])))
    padded = [pre, current]

    if post_context_size > 0:
        post = expand_dims_layer(
            AttentionWithTimeDecay(reverse_decay=False)(stack_layer([
                attention_layer([sub_model(inputs[i]), current_pooled])
                for i in range(pre_context_size + 1, context_size)
            ])))
        padded.append(post)

    recurrent_layer = S2V.get_recurrent_layer(name=recurrent_name,
                                              n_hidden=n_hidden,
                                              return_sequences=True)
    pooling_layer = S2V.get_pooling_layer(name=pooling_name)
    f_f_layer = Dense(units=n_hidden, activation='tanh')

    outputs = f_f_layer(
        pooling_layer(recurrent_layer(ConcatenateContexts(axis=1)(padded))))

    if with_extra_features:
        inputs = utlis.flatten(inputs)

    model = Model(inputs, outputs, name='base_network')

    plot_model(model,
               show_shapes=True,
               to_file=path_to_results + 'base_network.png')
    model.summary()

    return model
示例#4
0
def HAN(context_size, input_shape, recurrent_name, pooling_name, n_hidden,
        dropout_rate, path_to_results, is_base_network, with_embdedding_layer,
        word_vectors_name, fine_tune_word_vectors, word_vectors,
        with_extra_features):
    if with_extra_features:
        inputs = [[
            Input(shape=input_shape, dtype='int32'),
            Input(shape=(21, ), dtype='float32')
        ] for _ in range(context_size)]
    else:
        inputs = [
            Input(shape=input_shape, dtype='int32')
            for _ in range(context_size)
        ]

    s2v_1 = S2V(input_shape,
                recurrent_name,
                pooling_name,
                n_hidden,
                dropout_rate,
                path_to_results,
                is_base_network,
                with_embdedding_layer,
                word_vectors_name,
                fine_tune_word_vectors,
                word_vectors,
                with_extra_features,
                with_last_f_f_layer=False)

    # convert list of 2D tensors to a 3D tensor (None, context_size, n_hidden)
    stack_layer = Lambda(K.stack, arguments={'axis': 1})

    # merge_mode='concat'
    input_shape = (context_size, n_hidden * 2)

    s2v_2 = S2V(input_shape,
                recurrent_name,
                pooling_name,
                n_hidden,
                dropout_rate,
                path_to_results,
                is_base_network,
                with_embdedding_layer=False,
                word_vectors_name=None,
                fine_tune_word_vectors=None,
                word_vectors=None,
                with_extra_features=False,
                with_last_f_f_layer=False)

    f_f_layer = Dense(units=n_hidden, activation='tanh')

    output = f_f_layer(
        s2v_2(stack_layer([s2v_1(inputs[i]) for i in range(context_size)])))

    if with_extra_features:
        inputs = utlis.flatten(inputs)

    model = Model(inputs, output, name='base_network')

    plot_model(model,
               show_shapes=True,
               to_file=path_to_results + 'base_network.png')
    model.summary()

    return model
示例#5
0
def train(main_network_name, word_vectors_name, fine_tune_word_vectors, with_extra_features, base_network_name, epochs, loss, distance, l2_normalization, pre_context_size, post_context_size, data_generator_train, X_validation, Y_validation, max_sequence_length, word_vectors, path_to_results, global_metrics):
    context_size = pre_context_size + 1 + post_context_size

    n_hidden = 32
    dropout_rate = 0.5

    loss_function = losses_distances.get_loss_function(name=loss)
    distance_function = losses_distances.get_distance_function(name=distance)
    optimizer = Adam()
    metrics = None  # batch-wise metrics, ['accuracy']

    os.mkdir(path_to_results+'model_on_epoch_end')
    model_checkpoint = ModelCheckpoint(filepath=path_to_results+'model_on_epoch_end/'+'{epoch}.h5', monitor="val_loss", verbose=1, save_best_only=False, save_weights_only=False, mode='min', period=1)
    callbacks = [global_metrics, model_checkpoint]

    # define inputs
    inputs = []

    if with_extra_features:
        for _ in range(utlis.n_tuple(main_network_name)):
            inputs.append(utlis.flatten([
                [Input(shape=(max_sequence_length,), dtype='int32'),
                 Input(shape=(21,), dtype='float32')]
                for _ in range(context_size)
            ]))
    else:
        for _ in range(utlis.n_tuple(main_network_name)):
            inputs.append([
                Input(shape=(max_sequence_length,), dtype='int32')
                for _ in range(context_size)
            ])

    # define sentence_encoder
    if base_network_name == 'LD':
        base_network = LD(
            context_size=context_size,
            history_sizes=(context_size-1, 0),

            input_shape=(max_sequence_length,),
            recurrent_name='LSTM',
            pooling_name='max',
            n_hidden=n_hidden,
            dropout_rate=dropout_rate,
            path_to_results=path_to_results,
            is_base_network=False,

            with_embdedding_layer=True,
            word_vectors_name=word_vectors_name,
            fine_tune_word_vectors=fine_tune_word_vectors,
            word_vectors=word_vectors,

            with_extra_features=with_extra_features
        )
    elif base_network_name == 'HAN':
        base_network = HAN(
            context_size=context_size,

            input_shape=(max_sequence_length,),
            recurrent_name='Bi-GRU',
            pooling_name='attention',
            n_hidden=n_hidden,
            dropout_rate=dropout_rate,
            path_to_results=path_to_results,
            is_base_network=False,

            with_embdedding_layer=True,
            word_vectors_name=word_vectors_name,
            fine_tune_word_vectors=fine_tune_word_vectors,
            word_vectors=word_vectors,

            with_extra_features=with_extra_features
        )
    elif base_network_name == 'TIXIER':
        base_network = TIXIER(
            pre_context_size=pre_context_size,
            post_context_size=post_context_size,

            input_shape=(max_sequence_length,),
            recurrent_name='Bi-GRU',
            pooling_name='attention',
            n_hidden=n_hidden,
            dropout_rate=dropout_rate,
            path_to_results=path_to_results,
            is_base_network=False,

            with_embdedding_layer=True,
            word_vectors_name=word_vectors_name,
            fine_tune_word_vectors=fine_tune_word_vectors,
            word_vectors=word_vectors,

            with_extra_features=with_extra_features
        )
    else:
        base_network = S2V(
            input_shape=(max_sequence_length,),
            recurrent_name=base_network_name,
            pooling_name='attention',
            n_hidden=n_hidden,
            dropout_rate=dropout_rate,
            path_to_results=path_to_results,
            is_base_network=True,

            with_embdedding_layer=True,
            word_vectors_name=word_vectors_name,
            fine_tune_word_vectors=fine_tune_word_vectors,
            word_vectors=word_vectors,

            with_extra_features=with_extra_features
        )

    outputs = [base_network(input) for input in inputs]

    outputs[1] = Lambda(lambda x: x + K.epsilon())(outputs[1])

    if l2_normalization:
        outputs = [Lambda(lambda x: K.l2_normalize(x, axis=-1))(output) for output in outputs]

    if main_network_name == 'siamese':
        d = Lambda(distance_function)([outputs[0], outputs[1]])

        model = Model(utlis.flatten(inputs), d)
    elif main_network_name == 'triplet':
        d_pos = Lambda(distance_function)([outputs[1], outputs[0]])
        d_neg = Lambda(distance_function)([outputs[1], outputs[2]])

        model = Model(utlis.flatten(inputs), concatenate([d_pos, d_neg]))
    elif main_network_name == 'quadruplet':
        d_pos = Lambda(distance_function)([outputs[1], outputs[0]])
        d_neg = Lambda(distance_function)([outputs[1], outputs[2]])
        d_neg_extra = Lambda(distance_function)([outputs[2], outputs[3]])

        model = Model(utlis.flatten(inputs), concatenate([d_pos, d_neg, d_neg_extra]))
    else:
        raise NotImplementedError()

    from keras.utils import plot_model
    plot_model(model, show_shapes=True, to_file=path_to_results+'main_network.png')
    print(model.summary())

    # https://datascience.stackexchange.com/questions/23895/multi-gpu-in-keras
    # https://keras.io/getting-started/faq/#how-can-i-run-a-keras-model-on-multiple-gpus
    # automatically detect and use *Data parallelism* gpus model
    # try:
    #     model = multi_gpu_model(model, gpus=None)
    # except:
    #     pass

    model.compile(loss=loss_function, optimizer=optimizer, metrics=metrics)

    training_start_time = time()

    model_trained = model.fit_generator(
        data_generator_train,
        epochs=epochs,
        validation_data=(X_validation, Y_validation),
        callbacks=callbacks
    )

    print("Training time finished.\n{} epochs in {}".format(epochs, datetime.timedelta(seconds=time() - training_start_time)))

    return {**model_trained.history, **global_metrics.val_scores}
示例#6
0
def generate_tuples(path_to_summlink, utterances, n_tuple, meeting_list,
                    pre_context_size, post_context_size, training):
    # convert singleton_community (community contains single utterance) to a tuple
    singleton_community_to_tuple = True

    tuples_of_meetings = {}
    for meeting_id in meeting_list:
        utterance_ids = list(utterances[meeting_id].keys())
        contextualized_utterance_ids = get_contextualized_utterance_ids(
            utterance_ids, pre_context_size, post_context_size)
        communities = load_communities(path_to_summlink, meeting_id,
                                       utterance_ids)

        # generate pairs
        if n_tuple == 2:
            genuine_label = 1
            impostor_label = 0

            genuine_pairs = []
            for community in communities:
                if len(community) == 1:
                    if singleton_community_to_tuple is True:
                        genuine_pairs += [tuple(community * 2)]
                    else:
                        pass
                else:
                    genuine_pairs += list(itertools.combinations(community, 2))

            all_pairs = list(
                itertools.combinations(list(set(utlis.flatten(communities))),
                                       2))

            impostor_pairs = list(
                set(all_pairs) - set(genuine_pairs) -
                set([(t[1], t[0]) for t in genuine_pairs]))

            volume = len(list(itertools.combinations(communities, 2)))
            if training is False:
                volume *= 1

            replacement = False
            if volume > len(genuine_pairs):
                replacement = True

            genuine_pairs = list(
                np.array(genuine_pairs)[np.random.choice(len(genuine_pairs),
                                                         size=volume,
                                                         replace=replacement)])
            impostor_pairs = list(
                np.array(impostor_pairs)[np.random.choice(len(impostor_pairs),
                                                          size=volume,
                                                          replace=False)])

            tuples_of_meetings[meeting_id] = \
                [(contextualized_utterance_ids[pair[0]], contextualized_utterance_ids[pair[1]], genuine_label) for pair in genuine_pairs] +\
                [(contextualized_utterance_ids[pair[0]], contextualized_utterance_ids[pair[1]], impostor_label) for pair in impostor_pairs]

        # generate triplets
        elif n_tuple == 3:
            all_tuples = []

            for community_i, community_j in itertools.combinations(
                    communities, 2):
                set_i = set(community_i)
                set_j = set(community_j)
                tuples = []

                if len(set_i & set_j) == 0:
                    if len(community_i) == 1:
                        if singleton_community_to_tuple is True:
                            genuine_pairs = [tuple(community_i * 2)]
                            tuples += list(
                                itertools.product(genuine_pairs, community_j))
                        else:
                            pass
                    else:
                        genuine_pairs = list(
                            itertools.permutations(community_i, 2))
                        tuples += list(
                            itertools.product(genuine_pairs, community_j))

                    if len(community_j) == 1:
                        if singleton_community_to_tuple is True:
                            genuine_pairs = [tuple(community_j * 2)]
                            tuples += list(
                                itertools.product(genuine_pairs, community_i))
                        else:
                            pass
                    else:
                        genuine_pairs = list(
                            itertools.permutations(community_j, 2))
                        tuples += list(
                            itertools.product(genuine_pairs, community_i))

                else:
                    if set_i.issubset(set_j):
                        A = list(set_j - set_i)
                        B = list(set_i)
                        D = []
                        for community in communities:
                            if len(set(community).intersection(set_j)) == 0:
                                D += community

                        if len(A) == 1:
                            genuine_pairs = [tuple(A * 2)]
                            tuples += list(itertools.product(genuine_pairs, B))
                        else:
                            genuine_pairs = list(itertools.permutations(A, 2))
                            tuples += list(itertools.product(genuine_pairs, B))

                        if len(B) == 1:
                            genuine_pairs = [tuple(B * 2)]
                            tuples += list(itertools.product(genuine_pairs, A))
                        else:
                            genuine_pairs = list(itertools.permutations(B, 2))
                            tuples += list(itertools.product(genuine_pairs, A))

                        genuine_pairs = list(itertools.product(A, B)) + list(
                            itertools.product(B, A))
                        tuples += list(itertools.product(genuine_pairs, D))
                    elif set_j.issubset(set_i):
                        A = list(set_i - set_j)
                        B = list(set_j)
                        D = []
                        for community in communities:
                            if len(set(community).intersection(set_i)) == 0:
                                D += community

                        if len(A) == 1:
                            genuine_pairs = [tuple(A * 2)]
                            tuples += list(itertools.product(genuine_pairs, B))
                        else:
                            genuine_pairs = list(itertools.permutations(A, 2))
                            tuples += list(itertools.product(genuine_pairs, B))

                        if len(B) == 1:
                            genuine_pairs = [tuple(B * 2)]
                            tuples += list(itertools.product(genuine_pairs, A))
                        else:
                            genuine_pairs = list(itertools.permutations(B, 2))
                            tuples += list(itertools.product(genuine_pairs, A))

                        genuine_pairs = list(itertools.product(A, B)) + list(
                            itertools.product(B, A))
                        tuples += list(itertools.product(genuine_pairs, D))
                    else:
                        A = list(set_i - set_j)
                        B = list(set_i & set_j)
                        C = list(set_j - set_i)
                        D = []
                        for community in communities:
                            if len(set(community).intersection(set_i
                                                               | set_j)) == 0:
                                D += community

                        if len(A) == 1:
                            genuine_pairs = [tuple(A * 2)]
                            tuples += list(itertools.product(genuine_pairs, B))
                        else:
                            genuine_pairs = list(itertools.permutations(A, 2))
                            tuples += list(itertools.product(genuine_pairs, B))

                        if len(B) == 1:
                            genuine_pairs = [tuple(B * 2)]
                            tuples += list(itertools.product(genuine_pairs, A))
                        else:
                            genuine_pairs = list(itertools.permutations(B, 2))
                            tuples += list(itertools.product(genuine_pairs, A))

                        genuine_pairs = list(itertools.product(A, B)) + list(
                            itertools.product(B, A))
                        tuples += list(itertools.product(genuine_pairs, D))

                        if len(C) == 1:
                            genuine_pairs = [tuple(C * 2)]
                            tuples += list(itertools.product(genuine_pairs, B))
                        else:
                            genuine_pairs = list(itertools.permutations(C, 2))
                            tuples += list(itertools.product(genuine_pairs, B))

                        if len(B) == 1:
                            genuine_pairs = [tuple(B * 2)]
                            tuples += list(itertools.product(genuine_pairs, C))
                        else:
                            genuine_pairs = list(itertools.permutations(B, 2))
                            tuples += list(itertools.product(genuine_pairs, C))

                        genuine_pairs = list(itertools.product(C, B)) + list(
                            itertools.product(B, C))
                        tuples += list(itertools.product(genuine_pairs, D))

                        if len(A) == 1:
                            genuine_pairs = [tuple(A * 2)]
                            tuples += list(itertools.product(genuine_pairs, C))
                        else:
                            genuine_pairs = list(itertools.permutations(A, 2))
                            tuples += list(itertools.product(genuine_pairs, C))

                        if len(C) == 1:
                            genuine_pairs = [tuple(C * 2)]
                            tuples += list(itertools.product(genuine_pairs, A))
                        else:
                            genuine_pairs = list(itertools.permutations(C, 2))
                            tuples += list(itertools.product(genuine_pairs, A))

                        genuine_pairs = list(itertools.product(B, A))
                        tuples += list(itertools.product(genuine_pairs, C))

                        genuine_pairs = list(itertools.product(B, C))
                        tuples += list(itertools.product(genuine_pairs, A))
                if training:
                    all_tuples += list(
                        np.array(tuples)[np.random.choice(len(tuples),
                                                          size=1,
                                                          replace=False)])
                else:
                    all_tuples += list(
                        np.array(tuples)[np.random.choice(len(tuples),
                                                          size=1,
                                                          replace=True)])

            tuples_of_meetings[meeting_id] = [
                (contextualized_utterance_ids[triplet[0][0]],
                 contextualized_utterance_ids[triplet[0][1]],
                 contextualized_utterance_ids[triplet[1]])
                for triplet in all_tuples
            ]

        # generate quadruplets
        elif n_tuple == 4:
            pass
        else:
            raise NotImplementedError()

    return tuples_of_meetings
示例#7
0
if base_network_name in ['LD', 'HAN', 'TIXIER']:
    pre_context_size = 3  # number of previous utterances
    post_context_size = 0  # number of following utterances
else:
    pre_context_size = 0
    post_context_size = 0
####################
l1 = set([file_name.split('.')[0] for file_name in os.listdir(path_to_utterance)])
l2 = set([file_name.split('.')[0] for file_name in os.listdir(path_to_summlink)])
meeting_list = list(l1.intersection(l2))  # {'IB4003', 'TS3012c'}

# 60%, 20%, 20% -> 81, 28, 28 meetings
# meeting_list_train, meeting_list_test = train_test_split(meeting_list, test_size=0.2)
# meeting_list_train, meeting_list_validation = train_test_split(meeting_list_train, test_size=0.25)
meeting_list_train = ['ES2002', 'ES2005', 'ES2006', 'ES2007', 'ES2008', 'ES2009', 'ES2010', 'ES2012', 'ES2013', 'ES2015', 'ES2016', 'IS1000', 'IS1001', 'IS1002', 'IS1003', 'IS1004', 'IS1005', 'IS1006', 'IS1007', 'TS3005', 'TS3008', 'TS3009', 'TS3010', 'TS3011', 'TS3012']
meeting_list_train = utlis.flatten([[mid+c for c in 'abcd'] for mid in meeting_list_train])
meeting_list_train.remove('IS1002a')
meeting_list_train.remove('IS1005d')
meeting_list_train.remove('TS3012c')

meeting_list_validation = ['ES2003', 'ES2011', 'IS1008', 'TS3004', 'TS3006']
meeting_list_validation = utlis.flatten([[mid+c for c in 'abcd'] for mid in meeting_list_validation])

meeting_list_test = ['ES2004', 'ES2014', 'IS1009', 'TS3003', 'TS3007']
meeting_list_test = utlis.flatten([[mid+c for c in 'abcd'] for mid in meeting_list_test])


dill.dump_session(path_to_results + 'variables.pkl')
#dill.load_session(path_to_results + 'variables.pkl')

if word_vectors_name == 'News':