Ejemplo n.º 1
0
def run_train(w2vsource, w2vdim, w2vnumfilters, lexdim, lexnumfilters, randomseed, model_name, is_expanded,
              attention_depth_w2v, attention_depth_lex, num_epochs, l2_reg_lambda, l1_reg_lambda, simple_run=True):
    if simple_run == True:
        print '======================================[simple_run]======================================'

    max_len = 60
    norm_model = []

    rt_list = ['w2vrt', 'w2vlexrt', 'attrt', 'attbrt', 'a2vrt', 'a2vindrt', 'a2vindbrt', 'a2vindw2vrt',
               'a2vindlexrt', 'cnna2vindrt', 'cnna2vindw2vrt', 'cnna2vindlexrt', 'cnnmcrt', 'w2vlexcrt',
               'w2vlexca2vrt', 'cnnmca2vrt']

    multichannel = False
    if model_name == 'cnnmc' or model_name == 'cnnmcrt' or model_name == 'cnnmca2v' or model_name == 'cnnmca2vrt':
        multichannel = True

    multichannel_a2v = False
    if model_name == 'cnnmca2v' or model_name == 'cnnmca2vrt':
        multichannel_a2v = True

    rt_data = False
    if model_name in rt_list:
        rt_data = True

    with Timer("lex"):
        if is_expanded == 0:
            print 'old way of loading lexicon'
            norm_model, raw_model = load_lexicon_unigram(lexdim)
            # with open('../data/lexicon_data/lex15.pickle', 'rb') as handle:
            #     norm_model = pickle.load(handle)

        else:
            print 'new way of loading lexicon'
            default_vector_dic = {'EverythingUnigramsPMIHS': [0],
                                  'HS-AFFLEX-NEGLEX-unigrams': [0, 0, 0],
                                  'Maxdiff-Twitter-Lexicon_0to1': [0.50403226],
                                  'S140-AFFLEX-NEGLEX-unigrams': [0, 0, 0],
                                  'unigrams-pmilexicon': [0, 0, 0],
                                  'unigrams-pmilexicon_sentiment_140': [0, 0, 0],
                                  'BL': [0]}

            lexfile_list = ['EverythingUnigramsPMIHS.pickle',
                            'HS-AFFLEX-NEGLEX-unigrams.pickle',
                            'Maxdiff-Twitter-Lexicon_0to1.pickle',
                            'S140-AFFLEX-NEGLEX-unigrams.pickle',
                            'unigrams-pmilexicon.pickle',
                            'unigrams-pmilexicon_sentiment_140.pickle',
                            'BL.pickle']

            for idx, lexfile in enumerate(lexfile_list):
                if is_expanded == 1234567:  # expand all
                    # fname = '../data/le/exp_compact.%s' % lexfile
                    # print 'expanded lexicon for exp_compact.%s' % lexfile
                    fname = '../data/le/exp_1.1.%s' % lexfile
                    print 'expanded lexicon for exp_1.1.%s' % lexfile


                elif is_expanded - 1 == idx:
                    # fname = '../data/le/exp_%s' % lexfile
                    # print 'expanded lexicon for exp_%s' % lexfile
                    fname = '../data/le/exp_compact.%s' % lexfile
                    print 'expanded lexicon for exp_compact.%s' % lexfile
                    # fname = '../data/le/exp_1.1.%s' % lexfile
                    # print 'expanded lexicon for exp_1.1.%s' % lexfile

                else:
                    fname = '../data/le/%s' % lexfile
                    print 'default lexicon for %s' % lexfile

                if is_expanded == 8:
                    fname = '../data/le/new/%s' % lexfile
                    print 'new default lexicon for %s' % lexfile

                with open(fname, 'rb') as handle:
                    each_model = pickle.load(handle)
                    default_vector = default_vector_dic[lexfile.replace('.pickle', '')]
                    each_model["<PAD/>"] = default_vector
                    norm_model.append(each_model)

    with Timer("w2v"):
        w2vmodel = load_w2v(w2vdim, simple_run=simple_run, source=w2vsource)
        # if w2vsource == "twitter":
        #     w2vmodel = load_w2v(w2vdim, simple_run=simple_run, source=w2vsource)
        # else:
        #     w2vmodel = load_w2v(w2vdim, simple_run=simple_run, source="amazon")



    unigram_lexicon_model = norm_model
    # unigram_lexicon_model = raw_model

    if simple_run:
        if multichannel_a2v is True or multichannel is True:
            x_train, y_train, x_lex_train, x_fat_train = \
                cnn_data_helpers.load_data('trn_sample', w2vmodel, unigram_lexicon_model,
                                                                       max_len, multichannel=multichannel)
            x_dev, y_dev, x_lex_dev, x_fat_dev = \
                cnn_data_helpers.load_data('dev_sample', w2vmodel, unigram_lexicon_model, max_len,
                                                                 multichannel=multichannel)
            x_test, y_test, x_lex_test, x_fat_test = \
                cnn_data_helpers.load_data('tst_sample', w2vmodel, unigram_lexicon_model, max_len,
                                                                    multichannel=multichannel)
        else:
            x_train, y_train, x_lex_train, _ = cnn_data_helpers.load_data('trn_sample', w2vmodel, unigram_lexicon_model,
                                                                       max_len, multichannel=multichannel)
            x_dev, y_dev, x_lex_dev, _ = cnn_data_helpers.load_data('dev_sample', w2vmodel, unigram_lexicon_model, max_len,
                                                                 multichannel = multichannel)
            x_test, y_test, x_lex_test, _ = cnn_data_helpers.load_data('tst_sample', w2vmodel, unigram_lexicon_model, max_len,
                                                                    multichannel=multichannel)

    else:
        if multichannel_a2v is True or multichannel is True:
            x_train, y_train, x_lex_train, x_fat_train = \
                cnn_data_helpers.load_data('trn', w2vmodel, unigram_lexicon_model, max_len,
                                           rottenTomato=rt_data, multichannel=multichannel)
            x_dev, y_dev, x_lex_dev, x_fat_dev = \
                cnn_data_helpers.load_data('dev', w2vmodel, unigram_lexicon_model, max_len,
                                           rottenTomato=rt_data, multichannel=multichannel)
            x_test, y_test, x_lex_test, x_fat_test = \
                cnn_data_helpers.load_data('tst', w2vmodel, unigram_lexicon_model, max_len,
                                           rottenTomato=rt_data, multichannel=multichannel)
        else:
            x_train, y_train, x_lex_train, _ = cnn_data_helpers.load_data('trn', w2vmodel, unigram_lexicon_model, max_len,
                                                                       rottenTomato=rt_data, multichannel=multichannel)
            x_dev, y_dev, x_lex_dev, _ = cnn_data_helpers.load_data('dev', w2vmodel, unigram_lexicon_model, max_len,
                                                                 rottenTomato=rt_data, multichannel=multichannel)
            x_test, y_test, x_lex_test, _ = cnn_data_helpers.load_data('tst', w2vmodel, unigram_lexicon_model, max_len,
                                                                    rottenTomato=rt_data, multichannel=multichannel)

    del (w2vmodel)
    del (norm_model)
    # del(raw_model)
    gc.collect()

    print("Train/Dev split: {:d}/{:d}".format(len(y_train), len(y_dev)))

    # Training
    # ==================================================
    if randomseed > 0:
        tf.set_random_seed(randomseed+10)
    with tf.Graph().as_default():
        max_af1_dev = 0
        index_at_max_af1_dev = 0
        af1_tst_at_max_af1_dev = 0

        session_conf = tf.ConfigProto(
            allow_soft_placement=FLAGS.allow_soft_placement,
            log_device_placement=FLAGS.log_device_placement)
        sess = tf.Session(config=session_conf)
        with sess.as_default():
            if randomseed > 0:
                tf.set_random_seed(randomseed)

            num_classes = 3
            if model_name in rt_list:
                num_classes = 5

            if model_name == 'w2v' or model_name == 'w2vrt':
                cnn = W2V_CNN(
                    sequence_length=x_train.shape[1],
                    num_classes=num_classes,
                    embedding_size=w2vdim,
                    filter_sizes=list(map(int, FLAGS.filter_sizes.split(","))),
                    num_filters=w2vnumfilters,
                    l2_reg_lambda=l2_reg_lambda,
                    l1_reg_lambda=l1_reg_lambda)

            elif model_name == 'w2vlex' or model_name == 'w2vlexrt':
                cnn = W2V_LEX_CNN(
                    sequence_length=x_train.shape[1],
                    num_classes=num_classes,
                    embedding_size=w2vdim,
                    embedding_size_lex=lexdim,
                    num_filters_lex=lexnumfilters,
                    filter_sizes=list(map(int, FLAGS.filter_sizes.split(","))),
                    num_filters=w2vnumfilters,
                    l2_reg_lambda=l2_reg_lambda,
                    l1_reg_lambda=l1_reg_lambda)

            elif model_name == 'att' or model_name == 'attrt':
                cnn = TextCNNPreAttention(
                    sequence_length=x_train.shape[1],
                    num_classes=num_classes,
                    embedding_size=w2vdim,
                    embedding_size_lex=lexdim,
                    num_filters_lex=lexnumfilters,
                    filter_sizes=list(map(int, FLAGS.filter_sizes.split(","))),
                    num_filters=w2vnumfilters,
                    l2_reg_lambda=l2_reg_lambda,
                    l1_reg_lambda=l1_reg_lambda)

            elif model_name == 'attb' or model_name == 'attbrt':
                cnn = TextCNNPreAttentionBias(
                    sequence_length=x_train.shape[1],
                    num_classes=num_classes,
                    embedding_size=w2vdim,
                    embedding_size_lex=lexdim,
                    num_filters_lex=lexnumfilters,
                    filter_sizes=list(map(int, FLAGS.filter_sizes.split(","))),
                    num_filters=w2vnumfilters,
                    l2_reg_lambda=l2_reg_lambda,
                    l1_reg_lambda=l1_reg_lambda)

            elif model_name == 'a2v' or model_name == 'a2vrt':
                cnn = TextAttention2Vec(
                    sequence_length=x_train.shape[1],
                    num_classes=num_classes,
                    embedding_size=w2vdim,
                    embedding_size_lex=lexdim,
                    num_filters_lex=lexnumfilters,
                    filter_sizes=list(map(int, FLAGS.filter_sizes.split(","))),
                    num_filters=w2vnumfilters,
                    l2_reg_lambda=l2_reg_lambda,
                    l1_reg_lambda=l1_reg_lambda)

            elif model_name == 'a2vind' or model_name == 'a2vindrt':
                cnn = TextAttention2VecIndividual(
                    sequence_length=x_train.shape[1],
                    num_classes=num_classes,
                    embedding_size=w2vdim,
                    embedding_size_lex=lexdim,
                    num_filters_lex=lexnumfilters,
                    filter_sizes=list(map(int, FLAGS.filter_sizes.split(","))),
                    num_filters=w2vnumfilters,
                    attention_depth_w2v=attention_depth_w2v,
                    attention_depth_lex=attention_depth_lex,
                    l2_reg_lambda=l2_reg_lambda,
                    l1_reg_lambda=l1_reg_lambda)

            elif model_name == 'a2vindb' or model_name == 'a2vindbrt':
                cnn = TextAttention2VecIndividualBias(
                    sequence_length=x_train.shape[1],
                    num_classes=num_classes,
                    embedding_size=w2vdim,
                    embedding_size_lex=lexdim,
                    num_filters_lex=lexnumfilters,
                    filter_sizes=list(map(int, FLAGS.filter_sizes.split(","))),
                    num_filters=w2vnumfilters,
                    attention_depth_w2v=attention_depth_w2v,
                    attention_depth_lex=attention_depth_lex,
                    l2_reg_lambda=l2_reg_lambda,
                    l1_reg_lambda=l1_reg_lambda)

            elif model_name == 'a2vindw2v' or model_name == 'a2vindw2vrt':
                cnn = TextAttention2VecIndividualW2v(
                    sequence_length=x_train.shape[1],
                    num_classes=num_classes,
                    embedding_size=w2vdim,
                    embedding_size_lex=lexdim,
                    num_filters_lex=lexnumfilters,
                    filter_sizes=list(map(int, FLAGS.filter_sizes.split(","))),
                    num_filters=w2vnumfilters,
                    attention_depth_w2v=attention_depth_w2v,
                    attention_depth_lex=attention_depth_lex,
                    l2_reg_lambda=l2_reg_lambda,
                    l1_reg_lambda=l1_reg_lambda)

            elif model_name == 'a2vindlex' or model_name == 'a2vindw2vrt':
                cnn = TextAttention2VecIndividualLex(
                    sequence_length=x_train.shape[1],
                    num_classes=num_classes,
                    embedding_size=w2vdim,
                    embedding_size_lex=lexdim,
                    num_filters_lex=lexnumfilters,
                    filter_sizes=list(map(int, FLAGS.filter_sizes.split(","))),
                    num_filters=w2vnumfilters,
                    attention_depth_w2v=attention_depth_w2v,
                    attention_depth_lex=attention_depth_lex,
                    l2_reg_lambda=l2_reg_lambda,
                    l1_reg_lambda=l1_reg_lambda)

            elif model_name == 'cnna2vind' or model_name == 'cnna2vindrt':
                cnn = TextCNNAttention2VecIndividual(
                    sequence_length=x_train.shape[1],
                    num_classes=num_classes,
                    embedding_size=w2vdim,
                    embedding_size_lex=lexdim,
                    num_filters_lex=lexnumfilters,
                    filter_sizes=list(map(int, FLAGS.filter_sizes.split(","))),
                    num_filters=w2vnumfilters,
                    attention_depth_w2v=attention_depth_w2v,
                    attention_depth_lex=attention_depth_lex,
                    l2_reg_lambda=l2_reg_lambda,
                    l1_reg_lambda=l1_reg_lambda)

            elif model_name == 'cnna2vindw2v' or model_name == 'cnna2vindw2vrt':
                cnn = TextCNNAttention2VecIndividualW2v(
                    sequence_length=x_train.shape[1],
                    num_classes=num_classes,
                    embedding_size=w2vdim,
                    embedding_size_lex=lexdim,
                    num_filters_lex=lexnumfilters,
                    filter_sizes=list(map(int, FLAGS.filter_sizes.split(","))),
                    num_filters=w2vnumfilters,
                    attention_depth_w2v=attention_depth_w2v,
                    attention_depth_lex=attention_depth_lex,
                    l2_reg_lambda=l2_reg_lambda,
                    l1_reg_lambda=l1_reg_lambda)

            elif model_name == 'cnna2vindlex' or model_name == 'cnna2vindlexrt':
                cnn = TextCNNAttention2VecIndividualLex(
                    sequence_length=x_train.shape[1],
                    num_classes=num_classes,
                    embedding_size=w2vdim,
                    embedding_size_lex=lexdim,
                    num_filters_lex=lexnumfilters,
                    filter_sizes=list(map(int, FLAGS.filter_sizes.split(","))),
                    num_filters=w2vnumfilters,
                    attention_depth_w2v=attention_depth_w2v,
                    attention_depth_lex=attention_depth_lex,
                    l2_reg_lambda=l2_reg_lambda,
                    l1_reg_lambda=l1_reg_lambda)


            elif model_name =='cnnmc' or model_name =='cnnmcrt':
                cnn = W2V_LEX_CNN_MC(
                    sequence_length=x_train.shape[1],
                    num_classes=num_classes,
                    embedding_size=w2vdim,
                    filter_sizes=list(map(int, FLAGS.filter_sizes.split(","))),
                    num_filters=w2vnumfilters,
                    l2_reg_lambda=l2_reg_lambda,
                    l1_reg_lambda=l1_reg_lambda)

            elif model_name == 'w2vlexc' or model_name == 'w2vlexcrt':
                cnn = W2V_LEX_CNN_CONCAT(
                    sequence_length=x_train.shape[1],
                    num_classes=num_classes,
                    embedding_size=w2vdim,
                    embedding_size_lex=lexdim,
                    filter_sizes=list(map(int, FLAGS.filter_sizes.split(","))),
                    num_filters=w2vnumfilters,
                    l2_reg_lambda=l2_reg_lambda,
                    l1_reg_lambda=l1_reg_lambda)

            elif model_name == 'w2vlexca2v' or model_name == 'w2vlexca2vrt':
                cnn = W2V_LEX_CNN_CONCAT_A2V(
                    sequence_length=x_train.shape[1],
                    num_classes=num_classes,
                    embedding_size=w2vdim,
                    embedding_size_lex=lexdim,
                    filter_sizes=list(map(int, FLAGS.filter_sizes.split(","))),
                    num_filters=w2vnumfilters,
                    attention_depth_w2v=attention_depth_w2v,
                    attention_depth_lex=attention_depth_lex,
                    l2_reg_lambda=l2_reg_lambda,
                    l1_reg_lambda=l1_reg_lambda)

            elif model_name == 'cnnmca2v' or model_name == 'cnnmca2vrt':
                cnn = W2V_LEX_CNN_MC_A2V(
                    sequence_length=x_train.shape[1],
                    num_classes=num_classes,
                    embedding_size=w2vdim,
                    embedding_size_lex=lexdim,
                    filter_sizes=list(map(int, FLAGS.filter_sizes.split(","))),
                    num_filters=w2vnumfilters,
                    attention_depth_w2v=50,
                    attention_depth_lex=20,
                    l2_reg_lambda=l2_reg_lambda,
                    l1_reg_lambda=l1_reg_lambda)

            else: # default is w2vlex
                cnn = W2V_LEX_CNN(
                    sequence_length=x_train.shape[1],
                    num_classes=num_classes,
                    embedding_size=w2vdim,
                    embedding_size_lex=lexdim,
                    num_filters_lex=lexnumfilters,
                    filter_sizes=list(map(int, FLAGS.filter_sizes.split(","))),
                    num_filters=w2vnumfilters,
                    l2_reg_lambda=l2_reg_lambda,
                    l1_reg_lambda=l1_reg_lambda)


            # Define Training procedure
            global_step = tf.Variable(0, name="global_step", trainable=False)
            optimizer = tf.train.AdamOptimizer(1e-3)
            grads_and_vars = optimizer.compute_gradients(cnn.loss)
            train_op = optimizer.apply_gradients(grads_and_vars, global_step=global_step)

            # Keep track of gradient values and sparsity (optional)
            grad_summaries = []
            for g, v in grads_and_vars:
                if g is not None:
                    grad_hist_summary = tf.histogram_summary("{}/grad/hist".format(v.name), g)
                    sparsity_summary = tf.scalar_summary("{}/grad/sparsity".format(v.name), tf.nn.zero_fraction(g))
                    grad_summaries.append(grad_hist_summary)
                    grad_summaries.append(sparsity_summary)
            grad_summaries_merged = tf.merge_summary(grad_summaries)

            # Output directory for models and summaries
            timestamp = str(int(time.time()))
            out_dir = os.path.abspath(os.path.join(os.path.curdir, "runs", timestamp))
            print("Writing to {}\n".format(out_dir))

            # Summaries for loss and accuracy
            loss_summary = tf.scalar_summary("loss", cnn.loss)
            acc_summary = tf.scalar_summary("accuracy", cnn.accuracy)
            f1_summary = tf.scalar_summary("avg_f1", cnn.avg_f1)

            # Train Summaries
            train_summary_op = tf.merge_summary([loss_summary, acc_summary, f1_summary, grad_summaries_merged])
            train_summary_dir = os.path.join(out_dir, "summaries", "train")
            train_summary_writer = tf.train.SummaryWriter(train_summary_dir, sess.graph_def)

            # Dev summaries
            dev_summary_op = tf.merge_summary([loss_summary, acc_summary, f1_summary])
            dev_summary_dir = os.path.join(out_dir, "summaries", "dev")
            dev_summary_writer = tf.train.SummaryWriter(dev_summary_dir, sess.graph_def)

            # Test summaries
            test_summary_op = tf.merge_summary([loss_summary, acc_summary, f1_summary])
            test_summary_dir = os.path.join(out_dir, "summaries", "test")
            test_summary_writer = tf.train.SummaryWriter(test_summary_dir, sess.graph_def)

            # Checkpoint directory. Tensorflow assumes this directory already exists so we need to create it
            checkpoint_dir = os.path.abspath(os.path.join(out_dir, "checkpoints"))
            checkpoint_prefix = os.path.join(checkpoint_dir, "model")
            if not os.path.exists(checkpoint_dir):
                os.makedirs(checkpoint_dir)
            saver = tf.train.Saver(tf.all_variables())

            # Initialize all variables
            sess.run(tf.initialize_all_variables())

            def train_step(x_batch, y_batch, x_batch_lex=None, x_batch_fat=None, multichannel=False):
                """
                A single training step
                """
                if x_batch_fat is not None:
                    if x_batch_lex is None:
                        feed_dict = {
                            cnn.input_x: x_batch_fat,
                            cnn.input_y: y_batch,
                            cnn.dropout_keep_prob: FLAGS.dropout_keep_prob
                        }
                    else:
                        feed_dict = {
                            cnn.input_x_2c: x_batch_fat,
                            cnn.input_x: x_batch,
                            cnn.input_y: y_batch,
                            cnn.input_x_lexicon: x_batch_lex,
                            cnn.dropout_keep_prob: FLAGS.dropout_keep_prob
                        }

                else:
                    if x_batch_lex is None:
                        feed_dict = {
                            cnn.input_x: x_batch,
                            cnn.input_y: y_batch,
                            cnn.dropout_keep_prob: FLAGS.dropout_keep_prob
                        }
                    else:
                        feed_dict = {
                            cnn.input_x: x_batch,
                            cnn.input_y: y_batch,
                            # lexicon
                            cnn.input_x_lexicon: x_batch_lex,
                            cnn.dropout_keep_prob: FLAGS.dropout_keep_prob
                        }

                _, step, summaries, loss, accuracy, neg_r, neg_p, f1_neg, f1_pos, avg_f1 = sess.run(
                    [train_op, global_step, train_summary_op, cnn.loss, cnn.accuracy,
                     cnn.neg_r, cnn.neg_p, cnn.f1_neg, cnn.f1_pos, cnn.avg_f1],
                    feed_dict)
                time_str = datetime.datetime.now().isoformat()
                # print("{}: step {}, loss {:g}, acc {:g}".format(time_str, step, loss, accuracy))
                # print("{}: step {}, loss {:g}, acc {:g}, neg_r {:g} neg_p {:g} f1_neg {:g}, f1_pos {:g}, f1 {:g}".
                #      format(time_str, step, loss, accuracy, neg_r, neg_p, f1_neg, f1_pos, avg_f1))
                train_summary_writer.add_summary(summaries, step)

            def dev_step(x_batch, y_batch, x_batch_lex=None, x_batch_fat=None, writer=None, score_type='f1', multichannel=False):
                """
                Evaluates model on a dev set
                """
                if x_batch_fat is not None:
                    if x_batch_lex is None:
                        feed_dict = {
                            cnn.input_x: x_batch_fat,
                            cnn.input_y: y_batch,
                            cnn.dropout_keep_prob: 1.0
                        }
                    else:
                        feed_dict = {
                            cnn.input_x_2c: x_batch_fat,
                            cnn.input_x: x_batch,
                            cnn.input_y: y_batch,
                            cnn.input_x_lexicon: x_batch_lex,
                            cnn.dropout_keep_prob: 1.0
                        }

                else:
                    if x_batch_lex is None:
                        feed_dict = {
                            cnn.input_x: x_batch,
                            cnn.input_y: y_batch,
                            cnn.dropout_keep_prob: 1.0
                        }
                    else:
                        feed_dict = {
                            cnn.input_x: x_batch,
                            cnn.input_y: y_batch,
                            # lexicon
                            cnn.input_x_lexicon: x_batch_lex,
                            cnn.dropout_keep_prob: 1.0
                        }

                step, summaries, loss, accuracy, neg_r, neg_p, f1_neg, f1_pos, avg_f1 = sess.run(
                    [global_step, dev_summary_op, cnn.loss, cnn.accuracy,
                     cnn.neg_r, cnn.neg_p, cnn.f1_neg, cnn.f1_pos, cnn.avg_f1],
                    feed_dict)
                time_str = datetime.datetime.now().isoformat()
                print("{} : {} step {}, loss {:g}, acc {:g}, neg_r {:g} neg_p {:g} f1_neg {:g}, f1_pos {:g}, f1 {:g}".
                      format("DEV", time_str, step, loss, accuracy, neg_r, neg_p, f1_neg, f1_pos, avg_f1))
                if writer:
                    writer.add_summary(summaries, step)

                if score_type == 'f1':
                    return avg_f1
                else:
                    return accuracy

            def test_step(x_batch, y_batch, x_batch_lex=None, x_batch_fat=None, writer=None, score_type='f1', multichannel=False):
                """
                Evaluates model on a test set
                """
                if x_batch_fat is not None:
                    if x_batch_lex is None:
                        feed_dict = {
                            cnn.input_x: x_batch_fat,
                            cnn.input_y: y_batch,
                            cnn.dropout_keep_prob: 1.0
                        }
                    else:
                        feed_dict = {
                            cnn.input_x_2c: x_batch_fat,
                            cnn.input_x: x_batch,
                            cnn.input_y: y_batch,
                            cnn.input_x_lexicon: x_batch_lex,
                            cnn.dropout_keep_prob: 1.0
                        }

                else:
                    if x_batch_lex is None:
                        feed_dict = {
                            cnn.input_x: x_batch,
                            cnn.input_y: y_batch,
                            cnn.dropout_keep_prob: 1.0
                        }
                    else:
                        feed_dict = {
                            cnn.input_x: x_batch,
                            cnn.input_y: y_batch,
                            # lexicon
                            cnn.input_x_lexicon: x_batch_lex,
                            cnn.dropout_keep_prob: 1.0
                        }

                step, summaries, loss, accuracy, neg_r, neg_p, f1_neg, f1_pos, avg_f1 = sess.run(
                    [global_step, dev_summary_op, cnn.loss, cnn.accuracy,
                     cnn.neg_r, cnn.neg_p, cnn.f1_neg, cnn.f1_pos, cnn.avg_f1],
                    feed_dict)
                time_str = datetime.datetime.now().isoformat()
                print("{} : {} step {}, loss {:g}, acc {:g}, neg_r {:g} neg_p {:g} f1_neg {:g}, f1_pos {:g}, f1 {:g}".
                      format("TEST", time_str, step, loss, accuracy, neg_r, neg_p, f1_neg, f1_pos, avg_f1))
                if writer:
                    writer.add_summary(summaries, step)

                if score_type == 'f1':
                    return avg_f1
                else:
                    return accuracy

            # Generate batches
            if multichannel_a2v is True or multichannel is True:
                batches = cnn_data_helpers.batch_iter(
                    list(zip(x_train, y_train, x_lex_train, x_fat_train)), FLAGS.batch_size, num_epochs)
            else:
                batches = cnn_data_helpers.batch_iter(
                    list(zip(x_train, y_train, x_lex_train)), FLAGS.batch_size, num_epochs)


            # Training loop. For each batch...
            for batch in batches:
                if multichannel_a2v is True or multichannel is True:
                    x_batch, y_batch, x_batch_lex, x_batch_fat = zip(*batch)
                else:
                    x_batch, y_batch, x_batch_lex = zip(*batch)

                if model_name == 'w2v' or model_name == 'w2vrt':
                    train_step(x_batch, y_batch)

                else:
                    if multichannel_a2v is True:
                        train_step(x_batch, y_batch, x_batch_lex, x_batch_fat)
                    elif multichannel is True:
                        train_step(x_batch, y_batch, x_batch_lex=None, x_batch_fat=x_batch_fat,
                                   multichannel=multichannel)
                    else:
                        train_step(x_batch, y_batch, x_batch_lex, multichannel=multichannel)
                    # train_step(x_batch, y_batch, x_batch_lex)


                current_step = tf.train.global_step(sess, global_step)
                if current_step % FLAGS.evaluate_every == 0:
                    print("Evaluation:")
                    if rt_data == True:
                        score_type = 'acc'
                    else:
                        score_type = 'f1'

                    if model_name == 'w2v' or model_name == 'w2vrt':
                        curr_af1_dev = dev_step(x_dev, y_dev, writer=dev_summary_writer, score_type=score_type)
                        curr_af1_tst = test_step(x_test, y_test, writer=test_summary_writer, score_type=score_type)

                    else:
                        if multichannel_a2v is True:
                            curr_af1_dev = dev_step(x_dev, y_dev, x_lex_dev, x_fat_dev, writer=dev_summary_writer,
                                                    score_type=score_type, multichannel=multichannel)
                            curr_af1_tst = test_step(x_test, y_test, x_lex_test, x_fat_test, writer=test_summary_writer,
                                                     score_type=score_type, multichannel=multichannel)

                        elif multichannel is True:
                            curr_af1_dev = dev_step(x_dev, y_dev, x_batch_lex=None, x_batch_fat=x_fat_dev,
                                                    writer=dev_summary_writer,
                                                    score_type=score_type, multichannel=multichannel)
                            curr_af1_tst = test_step(x_test, y_test, x_batch_lex=None, x_batch_fat=x_fat_test,
                                                     writer=test_summary_writer,
                                                     score_type=score_type, multichannel=multichannel)
                        else:
                            curr_af1_dev = dev_step(x_dev, y_dev, x_lex_dev, writer=dev_summary_writer,
                                                    score_type=score_type, multichannel=multichannel)
                            curr_af1_tst = test_step(x_test, y_test, x_lex_test, writer=test_summary_writer,
                                                     score_type = score_type, multichannel=multichannel)


                    # if model_name == 'w2v':
                    #     curr_af1_dev = dev_step(x_dev, y_dev, writer=dev_summary_writer)
                    #     # path = saver.save(sess, checkpoint_prefix, global_step=current_step)
                    #     # print("Saved model checkpoint to {}\n".format(path))
                    #
                    #     curr_af1_tst = test_step(x_test, y_test, writer=test_summary_writer)
                    #     # path = saver.save(sess, checkpoint_prefix, global_step=current_step)
                    #     # print("Saved model checkpoint to {}\n".format(path))
                    #
                    # elif model_name == 'w2vrt':
                    #     curr_af1_dev = dev_step(x_dev, y_dev, writer=dev_summary_writer, score_type='acc')
                    #     curr_af1_tst = test_step(x_test, y_test, writer=test_summary_writer, score_type='acc')
                    #
                    # elif model_name == 'w2vlexrt':
                    #     curr_af1_dev = dev_step(x_dev, y_dev, x_lex_dev, writer=dev_summary_writer, score_type='acc')
                    #     curr_af1_tst = test_step(x_test, y_test, x_lex_test, writer=test_summary_writer,
                    #                              score_type='acc')
                    # else:
                    #     curr_af1_dev = dev_step(x_dev, y_dev, x_lex_dev, writer=dev_summary_writer)
                    #     # path = saver.save(sess, checkpoint_prefix, global_step=current_step)
                    #     # print("Saved model checkpoint to {}\n".format(path))
                    #
                    #     curr_af1_tst = test_step(x_test, y_test, x_lex_test, writer=test_summary_writer)
                    #     # path = saver.save(sess, checkpoint_prefix, global_step=current_step)
                    #     # print("Saved model checkpoint to {}\n".format(path))

                    if curr_af1_dev > max_af1_dev:
                        max_af1_dev = curr_af1_dev
                        index_at_max_af1_dev = current_step
                        af1_tst_at_max_af1_dev = curr_af1_tst

                        path = saver.save(sess, checkpoint_prefix, global_step=current_step)
                        print("Saved model checkpoint to {}\n".format(path))

                    if rt_data == True:
                        print 'Status: [%d] Max Acc for dev (%f), Max Acc for tst (%f)\n' % (
                            index_at_max_af1_dev, max_af1_dev*100, af1_tst_at_max_af1_dev*100)
                    else:
                        print 'Status: [%d] Max f1 for dev (%f), Max f1 for tst (%f)\n' % (
                            index_at_max_af1_dev, max_af1_dev, af1_tst_at_max_af1_dev)


                    sys.stdout.flush()
Ejemplo n.º 2
0
def run_train(w2vsource, w2vdim, w2vnumfilters, lexdim, lexnumfilters, randomseed, datasource, model_name, trainable, the_epoch):

    np.random.seed(randomseed)
    max_len = 60
    norm_model = []

    with Timer("lex"):
        norm_model, raw_model = load_lexicon_unigram(lexdim)

        # print 'new way of loading lexicon'
        # default_vector_dic = {'EverythingUnigramsPMIHS': [0],
        #                       'HS-AFFLEX-NEGLEX-unigrams': [0, 0, 0],
        #                       'Maxdiff-Twitter-Lexicon_0to1': [0.5],
        #                       'S140-AFFLEX-NEGLEX-unigrams': [0, 0, 0],
        #                       'unigrams-pmilexicon': [0, 0, 0],
        #                       'unigrams-pmilexicon_sentiment_140': [0, 0, 0],
        #                       'BL': [0]}
        #
        # lexfile_list = ['EverythingUnigramsPMIHS.pickle',
        #                 'HS-AFFLEX-NEGLEX-unigrams.pickle',
        #                 'Maxdiff-Twitter-Lexicon_0to1.pickle',
        #                 'S140-AFFLEX-NEGLEX-unigrams.pickle',
        #                 'unigrams-pmilexicon.pickle',
        #                 'unigrams-pmilexicon_sentiment_140.pickle',
        #                 'BL.pickle']
        #
        #
        # for idx, lexfile in enumerate(lexfile_list):
        #     fname = '../data/le/%s' % lexfile
        #     print 'default lexicon for %s' % lexfile
        #
        #     with open(fname, 'rb') as handle:
        #         each_model = pickle.load(handle)
        #         default_vector = default_vector_dic[lexfile.replace('.pickle', '')]
        #         each_model["<PAD/>"] = default_vector
        #         norm_model.append(each_model)

    
    unigram_lexicon_model = norm_model


    # CONFIGURE
    # ==================================================
    if datasource == 'semeval':
        numberofclass = 3
        use_rotten_tomato = False
    elif datasource == 'sst':
        numberofclass = 5
        use_rotten_tomato = True


    # Training
    # ==================================================
    if randomseed > 0:
        tf.set_random_seed(randomseed)
    with tf.Graph().as_default():
        tf.set_random_seed(randomseed)
        max_af1_dev = 0
        index_at_max_af1_dev = 0
        af1_tst_at_max_af1_dev = 0

        #WORD2VEC
        x_text, y = cnn_data_helpers.load_data_trainable("everydata", rottenTomato=use_rotten_tomato)
        max_document_length = max([len(x.split(" ")) for x in x_text])
        vocab_processor = learn.preprocessing.VocabularyProcessor(max_document_length)
        vocab_processor.fit_transform(x_text)
        total_vocab_size = len(vocab_processor.vocabulary_)

        x_train, y_train = cnn_data_helpers.load_data_trainable("trn", rottenTomato=use_rotten_tomato)
        x_dev, y_dev = cnn_data_helpers.load_data_trainable("dev", rottenTomato=use_rotten_tomato)
        x_test, y_test = cnn_data_helpers.load_data_trainable("tst", rottenTomato=use_rotten_tomato)
        x_train = np.array(list(vocab_processor.fit_transform(x_train)))
        x_dev = np.array(list(vocab_processor.fit_transform(x_dev)))
        x_test = np.array(list(vocab_processor.fit_transform(x_test)))



        del(norm_model)
        gc.collect()

        print("Train/Dev split: {:d}/{:d}".format(len(y_train), len(y_dev)))


        session_conf = tf.ConfigProto(
          allow_soft_placement=FLAGS.allow_soft_placement,
          log_device_placement=FLAGS.log_device_placement)
        sess = tf.Session(config=session_conf)
        with sess.as_default():
            if randomseed > 0:
                tf.set_random_seed(randomseed)

            cnn = W2V_TRAINABLE(
                sequence_length=x_train.shape[1],
                num_classes=numberofclass,
                vocab_size=len(vocab_processor.vocabulary_),
                is_trainable=trainable,
                embedding_size=w2vdim,
                filter_sizes=list(map(int, FLAGS.filter_sizes.split(","))),
                num_filters=w2vnumfilters,
                embedding_size_lex=lexdim,
                num_filters_lex=lexnumfilters,
                themodel=model_name,
                l2_reg_lambda=FLAGS.l2_reg_lambda
            )
           
            # Define Training procedure
            global_step = tf.Variable(0, name="global_step", trainable=False)
            optimizer = tf.train.AdamOptimizer(1e-3)
            grads_and_vars = optimizer.compute_gradients(cnn.loss)
            train_op = optimizer.apply_gradients(grads_and_vars, global_step=global_step)

            # Keep track of gradient values and sparsity (optional)
            grad_summaries = []
            for g, v in grads_and_vars:
                if g is not None:
                    grad_hist_summary = tf.histogram_summary("{}/grad/hist".format(v.name), g)
                    sparsity_summary = tf.scalar_summary("{}/grad/sparsity".format(v.name), tf.nn.zero_fraction(g))
                    grad_summaries.append(grad_hist_summary)
                    grad_summaries.append(sparsity_summary)
            grad_summaries_merged = tf.merge_summary(grad_summaries)

            # Output directory for models and summaries
            timestamp = str(int(time.time()))
            out_dir = os.path.abspath(os.path.join(os.path.curdir, "runs", timestamp))
            print("Writing to {}\n".format(out_dir))

            # Summaries for loss and accuracy
            loss_summary = tf.scalar_summary("loss", cnn.loss)
            acc_summary = tf.scalar_summary("accuracy", cnn.accuracy)
            f1_summary = tf.scalar_summary("avg_f1", cnn.avg_f1)

            # Train Summaries
            train_summary_op = tf.merge_summary([loss_summary, acc_summary, f1_summary, grad_summaries_merged])
            train_summary_dir = os.path.join(out_dir, "summaries", "train")
            train_summary_writer = tf.train.SummaryWriter(train_summary_dir, sess.graph_def)

            # Dev summaries
            dev_summary_op = tf.merge_summary([loss_summary, acc_summary, f1_summary])
            dev_summary_dir = os.path.join(out_dir, "summaries", "dev")
            dev_summary_writer = tf.train.SummaryWriter(dev_summary_dir, sess.graph_def)

            # Test summaries
            test_summary_op = tf.merge_summary([loss_summary, acc_summary, f1_summary])
            test_summary_dir = os.path.join(out_dir, "summaries", "test")
            test_summary_writer = tf.train.SummaryWriter(test_summary_dir, sess.graph_def)

            # Checkpoint directory. Tensorflow assumes this directory already exists so we need to create it
            checkpoint_dir = os.path.abspath(os.path.join(out_dir, "checkpoints"))
            checkpoint_prefix = os.path.join(checkpoint_dir, "model")
            if not os.path.exists(checkpoint_dir):
                os.makedirs(checkpoint_dir)
            saver = tf.train.Saver(tf.all_variables())

            # Initialize all variables
            sess.run(tf.initialize_all_variables())
            the_base_path = '../data/emory_w2v/'
            if w2vsource == "twitter":
                the_model_path = the_base_path + 'w2v-%d.bin' % w2vdim
            elif w2vsource == "amazon":
                the_model_path = the_base_path + 'w2v-%d-%s.bin' % (w2vdim, w2vsource)

            def load_w2v(w2vdim, simple_run=True, source="twitter"):
                if simple_run:
                    return {'a': np.array([np.float32(0.0)] * w2vdim)}

                else:
                    if source == "twitter":
                        model_path = '../data/emory_w2v/w2v-%d.bin' % w2vdim
                    elif source == "amazon":
                        model_path = '../data/emory_w2v/w2v-%d-%s.bin' % (w2vdim, source)

                    model = Word2Vec.load_word2vec_format(model_path, binary=True)
                    print("The vocabulary size is: " + str(len(model.vocab)))

                    return model


            with Timer("w2v"):
                w2vmodel = load_w2v(w2vdim, simple_run=False)

            # initial matrix with random uniform
            initW = np.random.uniform(0.0, 0.0,(total_vocab_size, w2vdim))
            initW_lex = np.random.uniform(0.0, 0.0,(total_vocab_size, lexdim))
            # load any vectors from the word2vec
            with Timer("Assigning w2v..."):
                for idx, word in enumerate(vocab_processor.vocabulary_._reverse_mapping):
                    if w2vmodel.vocab.has_key(word) == True:
                        initW[idx] = w2vmodel[word]

            # with Timer("LOADING W2V..."):
            #     print("LOADING word2vec file {} \n".format(the_model_path))
            #     #W2V
            #     with open(the_model_path, "rb") as f:
            #         header = f.readline()
            #         vocab_size, layer1_size = map(int, header.split())
            #         binary_len = np.dtype('float32').itemsize * layer1_size
            #         for line in xrange(vocab_size):
            #             word = []
            #             while True:
            #                 ch = f.read(1)
            #                 if ch == ' ':
            #                     word = ''.join(word)
            #                     break
            #                 if ch != '\n':
            #                     word.append(ch)
            #             idx = vocab_processor.vocabulary_.get(word)
            #             if idx != 0:
            #                 #print str(idx) + " -> " + word
            #                 initW[idx] = np.fromstring(f.read(binary_len), dtype='float32')
            #             else:
            #                 f.read(binary_len)

            with Timer("LOADING LEXICON..."):
                vocabulary_set = set()
                for index, eachModel in enumerate(unigram_lexicon_model):
                    for word in eachModel:
                        vocabulary_set.add(word)

                for word in vocabulary_set:
                    lexiconList = np.empty([0, 1])
                    for index, eachModel in enumerate(unigram_lexicon_model):
                        if word in eachModel:
                            temp = np.array(np.float32(eachModel[word]))
                        else:
                            temp = np.array(np.float32(eachModel["<PAD/>"]))
                        lexiconList = np.append(lexiconList, temp)

                    idx = vocab_processor.vocabulary_.get(word)
                    if idx != 0:
                        initW_lex[idx] = lexiconList




            sess.run(cnn.W.assign(initW))
            if model_name == 'w2v_lex':
                sess.run(cnn.W_lex.assign(initW_lex))

            def train_step(x_batch, y_batch):
                """
                A single training step
                """
                feed_dict = {
                    cnn.input_x: x_batch,
                    cnn.input_y: y_batch,
                    cnn.dropout_keep_prob: FLAGS.dropout_keep_prob
                }
                _, step, summaries, loss, accuracy, neg_r, neg_p, f1_neg, f1_pos, avg_f1 = sess.run(
                    [train_op, global_step, train_summary_op, cnn.loss, cnn.accuracy,
                     cnn.neg_r, cnn.neg_p, cnn.f1_neg, cnn.f1_pos, cnn.avg_f1],
                    feed_dict)
                time_str = datetime.datetime.now().isoformat()
                train_summary_writer.add_summary(summaries, step)

            def dev_step(x_batch, y_batch, writer=None, score_type='f1'):
                """
                Evaluates model on a dev set
                """
                feed_dict = {
                    cnn.input_x: x_batch,
                    cnn.input_y: y_batch,
                    cnn.dropout_keep_prob: 1.0
                }
                step, summaries, loss, accuracy, neg_r, neg_p, f1_neg, f1_pos, avg_f1 = sess.run(
                    [global_step, dev_summary_op, cnn.loss, cnn.accuracy,
                     cnn.neg_r, cnn.neg_p, cnn.f1_neg, cnn.f1_pos, cnn.avg_f1],
                    feed_dict)
                time_str = datetime.datetime.now().isoformat()
                print("{} : {} step {}, loss {:g}, acc {:g}, neg_r {:g} neg_p {:g} f1_neg {:g}, f1_pos {:g}, f1 {:g}".
                      format("DEV", time_str, step, loss, accuracy, neg_r, neg_p, f1_neg, f1_pos, avg_f1))
                if writer:
                    writer.add_summary(summaries, step)

                if score_type == 'f1':
                    return avg_f1
                else:
                    return accuracy

            def test_step(x_batch, y_batch, writer=None, score_type='f1'):
                """
                Evaluates model on a test set
                """

                feed_dict = {
                    cnn.input_x: x_batch,
                    cnn.input_y: y_batch,
                    cnn.dropout_keep_prob: 1.0
                }
                step, summaries, loss, accuracy, neg_r, neg_p, f1_neg, f1_pos, avg_f1 = sess.run(
                    [global_step, dev_summary_op, cnn.loss, cnn.accuracy,
                     cnn.neg_r, cnn.neg_p, cnn.f1_neg, cnn.f1_pos, cnn.avg_f1],
                    feed_dict)
                time_str = datetime.datetime.now().isoformat()
                print("{} : {} step {}, loss {:g}, acc {:g}, neg_r {:g} neg_p {:g} f1_neg {:g}, f1_pos {:g}, f1 {:g}".
                      format("TEST", time_str, step, loss, accuracy, neg_r, neg_p, f1_neg, f1_pos, avg_f1))
                if writer:
                    writer.add_summary(summaries, step)

                if score_type == 'f1':
                    return avg_f1
                else:
                    return accuracy

            # Generate batches
            batches = cnn_data_helpers.batch_iter(
                list(zip(x_train, y_train)), FLAGS.batch_size, the_epoch)

            # Training loop. For each batch...
            for batch in batches:
                x_batch, y_batch = zip(*batch)
                train_step(x_batch, y_batch)


                current_step = tf.train.global_step(sess, global_step)
                if current_step % FLAGS.evaluate_every == 0:
                    
                    print("Evaluation:")

                    if datasource == 'semeval':
                        curr_af1_dev = dev_step(x_dev, y_dev, writer=dev_summary_writer)
                        curr_af1_tst = test_step(x_test, y_test, writer=test_summary_writer)

                    elif datasource == 'sst':
                        curr_af1_dev = dev_step(x_dev, y_dev, writer=dev_summary_writer, score_type = 'acc')
                        curr_af1_tst = test_step(x_test, y_test, writer=test_summary_writer, score_type = 'acc')


                    if curr_af1_dev > max_af1_dev:
                        max_af1_dev = curr_af1_dev
                        index_at_max_af1_dev = current_step
                        af1_tst_at_max_af1_dev = curr_af1_tst

                    print 'Status: [%d] Max f1 for dev (%f), Max f1 for tst (%f)\n' % (
                        index_at_max_af1_dev, max_af1_dev, af1_tst_at_max_af1_dev)
                    sys.stdout.flush()
Ejemplo n.º 3
0
def run_train(w2vsource, w2vdim, w2vnumfilters, lexdim, lexnumfilters, randomseed, model_name, is_expanded, simple_run = True):
    if simple_run == True:
        print '======================================[simple_run]======================================'


    max_len = 60
    norm_model = []

    with Timer("lex"):
        if is_expanded == 0:
            print 'old way of loading lexicon'
            norm_model, raw_model = load_lexicon_unigram(lexdim)
            # with open('../data/lexicon_data/lex15.pickle', 'rb') as handle:
            #     norm_model = pickle.load(handle)

        else:
            print 'new way of loading lexicon'
            default_vector_dic = {'EverythingUnigramsPMIHS': [0],
                                  'HS-AFFLEX-NEGLEX-unigrams': [0, 0, 0],
                                  'Maxdiff-Twitter-Lexicon_0to1': [0.5],
                                  'S140-AFFLEX-NEGLEX-unigrams': [0, 0, 0],
                                  'unigrams-pmilexicon': [0, 0, 0],
                                  'unigrams-pmilexicon_sentiment_140': [0, 0, 0],
                                  'BL': [0]}

            lexfile_list = ['EverythingUnigramsPMIHS.pickle',
                            'HS-AFFLEX-NEGLEX-unigrams.pickle',
                            'Maxdiff-Twitter-Lexicon_0to1.pickle',
                            'S140-AFFLEX-NEGLEX-unigrams.pickle',
                            'unigrams-pmilexicon.pickle',
                            'unigrams-pmilexicon_sentiment_140.pickle',
                            'BL.pickle']


            for idx, lexfile in enumerate(lexfile_list):
                if is_expanded == 1234567: # expand all
                    fname = '../data/le/exp_compact.%s' % lexfile
                    print 'expanded lexicon for exp_compact.%s' % lexfile
                    # fname = '../data/le/exp_1.1.%s' % lexfile
                    # print 'expanded lexicon for exp_1.1.%s' % lexfile


                elif is_expanded-1 == idx:
                    # fname = '../data/le/exp_%s' % lexfile
                    # print 'expanded lexicon for exp_%s' % lexfile
                    fname = '../data/le/exp_compact.%s' % lexfile
                    print 'expanded lexicon for exp_compact.%s' % lexfile
                    # fname = '../data/le/exp_1.1.%s' % lexfile
                    # print 'expanded lexicon for exp_1.1.%s' % lexfile

                else:
                    fname = '../data/le/%s' % lexfile
                    print 'default lexicon for %s' % lexfile

                with open(fname, 'rb') as handle:
                    each_model = pickle.load(handle)
                    default_vector = default_vector_dic[lexfile.replace('.pickle', '')]
                    each_model["<PAD/>"] = default_vector
                    norm_model.append(each_model)


    with Timer("w2v"):
        if w2vsource == "twitter":
            w2vmodel = load_w2v(w2vdim, simple_run=simple_run)
        else:
            w2vmodel = load_w2v(w2vdim, simple_run=simple_run)


    unigram_lexicon_model = norm_model
    # unigram_lexicon_model = raw_model

    if simple_run:
        x_train, y_train, x_lex_train = cnn_data_helpers.load_data('trn_sample', w2vmodel, unigram_lexicon_model,
                                                                   max_len)
        x_dev, y_dev, x_lex_dev = cnn_data_helpers.load_data('dev_sample', w2vmodel, unigram_lexicon_model, max_len)
        x_test, y_test, x_lex_test = cnn_data_helpers.load_data('tst_sample', w2vmodel, unigram_lexicon_model, max_len)
    elif model_name == "w2vrt" or  model_name == "w2vrtlex":
        x_train, y_train, x_lex_train = cnn_data_helpers.load_data('trn', w2vmodel, unigram_lexicon_model, max_len, True)
        x_dev, y_dev, x_lex_dev = cnn_data_helpers.load_data('dev', w2vmodel, unigram_lexicon_model, max_len, True)
        x_test, y_test, x_lex_test = cnn_data_helpers.load_data('tst', w2vmodel, unigram_lexicon_model, max_len, True) 
    else:
        x_train, y_train, x_lex_train = cnn_data_helpers.load_data('trn', w2vmodel, unigram_lexicon_model, max_len)
        x_dev, y_dev, x_lex_dev = cnn_data_helpers.load_data('dev', w2vmodel, unigram_lexicon_model, max_len)
        x_test, y_test, x_lex_test = cnn_data_helpers.load_data('tst', w2vmodel, unigram_lexicon_model, max_len)


    del(w2vmodel)
    del(norm_model)
    # del(raw_model)
    gc.collect()

    print("Train/Dev split: {:d}/{:d}".format(len(y_train), len(y_dev)))


    # Training
    # ==================================================
    if randomseed > 0:
        tf.set_random_seed(randomseed)
    with tf.Graph().as_default():
        max_af1_dev = 0
        index_at_max_af1_dev = 0
        af1_tst_at_max_af1_dev = 0

        session_conf = tf.ConfigProto(
          allow_soft_placement=FLAGS.allow_soft_placement,
          log_device_placement=FLAGS.log_device_placement)
        sess = tf.Session(config=session_conf)
        with sess.as_default():
            if randomseed > 0:
                tf.set_random_seed(randomseed)

            if model_name=='w2v':
                cnn = W2V_CNN(
                    sequence_length=x_train.shape[1],
                    num_classes=3,
                    embedding_size=w2vdim,
                    filter_sizes=list(map(int, FLAGS.filter_sizes.split(","))),
                    num_filters=w2vnumfilters,
                    l2_reg_lambda=FLAGS.l2_reg_lambda
                )
            elif model_name=='w2vrt':
                cnn = W2V_CNN(
                    sequence_length=x_train.shape[1],
                    num_classes=5,
                    embedding_size=w2vdim,
                    filter_sizes=list(map(int, FLAGS.filter_sizes.split(","))),
                    num_filters=w2vnumfilters,
                    l2_reg_lambda=FLAGS.l2_reg_lambda
                )

            elif model_name=='w2vlex':
                cnn = W2V_LEX_CNN(
                    sequence_length=x_train.shape[1],
                    num_classes=3,
                    embedding_size=w2vdim,
                    embedding_size_lex=lexdim,
                    num_filters_lex=lexnumfilters,
                    filter_sizes=list(map(int, FLAGS.filter_sizes.split(","))),
                    num_filters=w2vnumfilters,
                    l2_reg_lambda=FLAGS.l2_reg_lambda)
           
            elif model_name=='w2vrtlex':
                cnn = W2V_LEX_CNN(
                    sequence_length=x_train.shape[1],
                    num_classes=5,
                    embedding_size=w2vdim,
                    embedding_size_lex=lexdim,
                    num_filters_lex=lexnumfilters,
                    filter_sizes=list(map(int, FLAGS.filter_sizes.split(","))),
                    num_filters=w2vnumfilters,
                    l2_reg_lambda=FLAGS.l2_reg_lambda)

            else: # model_name == 'attention'
                cnn = TextCNNAttention(
                    sequence_length=x_train.shape[1],
                    num_classes=3,
                    embedding_size=w2vdim,
                    embedding_size_lex=lexdim,
                    num_filters_lex=lexnumfilters,
                    filter_sizes=list(map(int, FLAGS.filter_sizes.split(","))),
                    num_filters=w2vnumfilters,
                    l2_reg_lambda=FLAGS.l2_reg_lambda)


            # Define Training procedure
            global_step = tf.Variable(0, name="global_step", trainable=False)
            optimizer = tf.train.AdamOptimizer(1e-3)
            grads_and_vars = optimizer.compute_gradients(cnn.loss)
            train_op = optimizer.apply_gradients(grads_and_vars, global_step=global_step)

            # Keep track of gradient values and sparsity (optional)
            grad_summaries = []
            for g, v in grads_and_vars:
                if g is not None:
                    grad_hist_summary = tf.histogram_summary("{}/grad/hist".format(v.name), g)
                    sparsity_summary = tf.scalar_summary("{}/grad/sparsity".format(v.name), tf.nn.zero_fraction(g))
                    grad_summaries.append(grad_hist_summary)
                    grad_summaries.append(sparsity_summary)
            grad_summaries_merged = tf.merge_summary(grad_summaries)

            # Output directory for models and summaries
            timestamp = str(int(time.time()))
            out_dir = os.path.abspath(os.path.join(os.path.curdir, "runs", timestamp))
            print("Writing to {}\n".format(out_dir))

            # Summaries for loss and accuracy
            loss_summary = tf.scalar_summary("loss", cnn.loss)
            acc_summary = tf.scalar_summary("accuracy", cnn.accuracy)
            f1_summary = tf.scalar_summary("avg_f1", cnn.avg_f1)

            # Train Summaries
            train_summary_op = tf.merge_summary([loss_summary, acc_summary, f1_summary, grad_summaries_merged])
            train_summary_dir = os.path.join(out_dir, "summaries", "train")
            train_summary_writer = tf.train.SummaryWriter(train_summary_dir, sess.graph_def)

            # Dev summaries
            dev_summary_op = tf.merge_summary([loss_summary, acc_summary, f1_summary])
            dev_summary_dir = os.path.join(out_dir, "summaries", "dev")
            dev_summary_writer = tf.train.SummaryWriter(dev_summary_dir, sess.graph_def)

            # Test summaries
            test_summary_op = tf.merge_summary([loss_summary, acc_summary, f1_summary])
            test_summary_dir = os.path.join(out_dir, "summaries", "test")
            test_summary_writer = tf.train.SummaryWriter(test_summary_dir, sess.graph_def)

            # Checkpoint directory. Tensorflow assumes this directory already exists so we need to create it
            checkpoint_dir = os.path.abspath(os.path.join(out_dir, "checkpoints"))
            checkpoint_prefix = os.path.join(checkpoint_dir, "model")
            if not os.path.exists(checkpoint_dir):
                os.makedirs(checkpoint_dir)
            saver = tf.train.Saver(tf.all_variables())

            # Initialize all variables
            sess.run(tf.initialize_all_variables())

            def train_step(x_batch, y_batch, x_batch_lex=None):
                """
                A single training step
                """
                if x_batch_lex != None:
                    feed_dict = {
                        cnn.input_x: x_batch,
                        cnn.input_y: y_batch,
                        # lexicon
                        cnn.input_x_lexicon: x_batch_lex,
                        cnn.dropout_keep_prob: FLAGS.dropout_keep_prob
                    }
                else: 
                    feed_dict = {
                        cnn.input_x: x_batch,
                        cnn.input_y: y_batch,
                        cnn.dropout_keep_prob: FLAGS.dropout_keep_prob
                    }
                _, step, summaries, loss, accuracy, neg_r, neg_p, f1_neg, f1_pos, avg_f1 = sess.run(
                    [train_op, global_step, train_summary_op, cnn.loss, cnn.accuracy,
                     cnn.neg_r, cnn.neg_p, cnn.f1_neg, cnn.f1_pos, cnn.avg_f1],
                    feed_dict)
                time_str = datetime.datetime.now().isoformat()
                # print("{}: step {}, loss {:g}, acc {:g}".format(time_str, step, loss, accuracy))
                #print("{}: step {}, loss {:g}, acc {:g}, neg_r {:g} neg_p {:g} f1_neg {:g}, f1_pos {:g}, f1 {:g}".
                #      format(time_str, step, loss, accuracy, neg_r, neg_p, f1_neg, f1_pos, avg_f1))
                train_summary_writer.add_summary(summaries, step)

            def dev_step(x_batch, y_batch, x_batch_lex=None, writer=None, score_type='f1'):
                """
                Evaluates model on a dev set
                """
                if x_batch_lex != None:
                    feed_dict = {
                        cnn.input_x: x_batch,
                        cnn.input_y: y_batch,
                        # lexicon
                        cnn.input_x_lexicon: x_batch_lex,
                        cnn.dropout_keep_prob: 1.0
                    }
                else: 
                    feed_dict = {
                        cnn.input_x: x_batch,
                        cnn.input_y: y_batch,
                        cnn.dropout_keep_prob: 1.0
                    }
                step, summaries, loss, accuracy, neg_r, neg_p, f1_neg, f1_pos, avg_f1 = sess.run(
                    [global_step, dev_summary_op, cnn.loss, cnn.accuracy,
                     cnn.neg_r, cnn.neg_p, cnn.f1_neg, cnn.f1_pos, cnn.avg_f1],
                    feed_dict)
                time_str = datetime.datetime.now().isoformat()
                print("{} : {} step {}, loss {:g}, acc {:g}, neg_r {:g} neg_p {:g} f1_neg {:g}, f1_pos {:g}, f1 {:g}".
                      format("DEV", time_str, step, loss, accuracy, neg_r, neg_p, f1_neg, f1_pos, avg_f1))
                if writer:
                    writer.add_summary(summaries, step)

                if score_type == 'f1':
                    return avg_f1
                else:
                    return accuracy

            def test_step(x_batch, y_batch, x_batch_lex=None, writer=None, score_type='f1'):
                """
                Evaluates model on a test set
                """
                if x_batch_lex != None:
                    feed_dict = {
                        cnn.input_x: x_batch,
                        cnn.input_y: y_batch,
                        # lexicon
                        cnn.input_x_lexicon: x_batch_lex,
                        cnn.dropout_keep_prob: 1.0
                    }
                else: 
                    feed_dict = {
                        cnn.input_x: x_batch,
                        cnn.input_y: y_batch,
                        cnn.dropout_keep_prob: 1.0
                    }
                step, summaries, loss, accuracy, neg_r, neg_p, f1_neg, f1_pos, avg_f1 = sess.run(
                    [global_step, dev_summary_op, cnn.loss, cnn.accuracy,
                     cnn.neg_r, cnn.neg_p, cnn.f1_neg, cnn.f1_pos, cnn.avg_f1],
                    feed_dict)
                time_str = datetime.datetime.now().isoformat()
                print("{} : {} step {}, loss {:g}, acc {:g}, neg_r {:g} neg_p {:g} f1_neg {:g}, f1_pos {:g}, f1 {:g}".
                      format("TEST", time_str, step, loss, accuracy, neg_r, neg_p, f1_neg, f1_pos, avg_f1))
                if writer:
                    writer.add_summary(summaries, step)

                if score_type == 'f1':
                    return avg_f1
                else:
                    return accuracy

            # Generate batches
            batches = cnn_data_helpers.batch_iter(
                list(zip(x_train, y_train, x_lex_train)), FLAGS.batch_size, FLAGS.num_epochs)
            # Training loop. For each batch...
            for batch in batches:
                x_batch, y_batch, x_batch_lex = zip(*batch)

                if model_name=='w2v' or model_name=='w2vrt':
                    train_step(x_batch, y_batch)
                else:
                    train_step(x_batch, y_batch, x_batch_lex)


                current_step = tf.train.global_step(sess, global_step)
                if current_step % FLAGS.evaluate_every == 0:
                    print("Evaluation:")

                    if model_name == 'w2v':
                        curr_af1_dev = dev_step(x_dev, y_dev, writer=dev_summary_writer)
                        # path = saver.save(sess, checkpoint_prefix, global_step=current_step)
                        # print("Saved model checkpoint to {}\n".format(path))

                        curr_af1_tst = test_step(x_test, y_test, writer=test_summary_writer)
                        # path = saver.save(sess, checkpoint_prefix, global_step=current_step)
                        # print("Saved model checkpoint to {}\n".format(path))

                    elif model_name == 'w2vrt':
                        curr_af1_dev = dev_step(x_dev, y_dev, writer=dev_summary_writer, score_type = 'acc')
                        curr_af1_tst = test_step(x_test, y_test, writer=test_summary_writer, score_type = 'acc')

                    elif model_name == 'w2vrtlex':
                        curr_af1_dev = dev_step(x_dev, y_dev, x_lex_dev, writer=dev_summary_writer, score_type = 'acc')
                        curr_af1_tst = test_step(x_test, y_test, x_lex_test, writer=test_summary_writer, score_type = 'acc')
                    else:
                        curr_af1_dev = dev_step(x_dev, y_dev, x_lex_dev, writer=dev_summary_writer)
                            # path = saver.save(sess, checkpoint_prefix, global_step=current_step)
                            # print("Saved model checkpoint to {}\n".format(path))

                        curr_af1_tst = test_step(x_test, y_test, x_lex_test, writer=test_summary_writer)
                            # path = saver.save(sess, checkpoint_prefix, global_step=current_step)
                            # print("Saved model checkpoint to {}\n".format(path))

                    if curr_af1_dev > max_af1_dev:
                        max_af1_dev = curr_af1_dev
                        index_at_max_af1_dev = current_step
                        af1_tst_at_max_af1_dev = curr_af1_tst

                        path = saver.save(sess, checkpoint_prefix, global_step=current_step)
                        print("Saved model checkpoint to {}\n".format(path))

                    print 'Status: [%d] Max f1 for dev (%f), Max f1 for tst (%f)\n' % (
                        index_at_max_af1_dev, max_af1_dev, af1_tst_at_max_af1_dev)
                    sys.stdout.flush()
Ejemplo n.º 4
0
def run_train(w2vdim, w2vnumfilters, lexdim, lexnumfilters, randomseed, lex_col, simple_run = True):
    if simple_run == True:
        print '======================================[simple_run]======================================'


    max_len = 60

    with Timer("lex"):
        norm_model, raw_model = load_lexicon_unigram(lexdim)

    with Timer("w2v"):
        w2vmodel = load_w2v(w2vdim, simple_run=simple_run)

    unigram_lexicon_model = norm_model
    # unigram_lexicon_model = raw_model

    if simple_run:
        x_train, y_train, x_lex_train = cnn_data_helpers.load_data('trn_sample', w2vmodel, unigram_lexicon_model,
                                                                   max_len)
        # print len(x_lex_train)
        # print len(x_lex_train[0])
        # print x_lex_train.shape
        # print x_lex_train[:, :, lex_col].shape
        x_dev, y_dev, x_lex_dev = cnn_data_helpers.load_data('dev_sample', w2vmodel, unigram_lexicon_model, max_len)
        x_test, y_test, x_lex_test = cnn_data_helpers.load_data('tst_sample', w2vmodel, unigram_lexicon_model, max_len)



    else:
        x_train, y_train, x_lex_train = cnn_data_helpers.load_data('trn', w2vmodel, unigram_lexicon_model, max_len)
        x_dev, y_dev, x_lex_dev = cnn_data_helpers.load_data('dev', w2vmodel, unigram_lexicon_model, max_len)
        x_test, y_test, x_lex_test = cnn_data_helpers.load_data('tst', w2vmodel, unigram_lexicon_model, max_len)

        x_lex_train = x_lex_train[:, :, lex_col]
        x_lex_dev = x_lex_dev[:, :, lex_col]
        x_lex_test = x_lex_test[:, :, lex_col]

    del(w2vmodel)
    del(norm_model)
    del(raw_model)
    gc.collect()

    print("Train/Dev split: {:d}/{:d}".format(len(y_train), len(y_dev)))


    # Training
    # ==================================================
    if randomseed > 0:
        tf.set_random_seed(randomseed)
    with tf.Graph().as_default():
        max_af1_dev = 0
        index_at_max_af1_dev = 0
        af1_tst_at_max_af1_dev = 0

        session_conf = tf.ConfigProto(
          allow_soft_placement=FLAGS.allow_soft_placement,
          log_device_placement=FLAGS.log_device_placement)
        sess = tf.Session(config=session_conf)
        with sess.as_default():
            if randomseed > 0:
                tf.set_random_seed(randomseed)
            cnn = W2V_LEX_CNN(
                sequence_length=x_train.shape[1],
                num_classes=3,
                embedding_size=w2vdim,
                embedding_size_lex=lexdim,
                num_filters_lex = lexnumfilters,
                filter_sizes=list(map(int, FLAGS.filter_sizes.split(","))),
                num_filters=w2vnumfilters,
                l2_reg_lambda=FLAGS.l2_reg_lambda)

            # Define Training procedure
            global_step = tf.Variable(0, name="global_step", trainable=False)
            optimizer = tf.train.AdamOptimizer(1e-3)
            grads_and_vars = optimizer.compute_gradients(cnn.loss)
            train_op = optimizer.apply_gradients(grads_and_vars, global_step=global_step)

            # Keep track of gradient values and sparsity (optional)
            grad_summaries = []
            for g, v in grads_and_vars:
                if g is not None:
                    grad_hist_summary = tf.histogram_summary("{}/grad/hist".format(v.name), g)
                    sparsity_summary = tf.scalar_summary("{}/grad/sparsity".format(v.name), tf.nn.zero_fraction(g))
                    grad_summaries.append(grad_hist_summary)
                    grad_summaries.append(sparsity_summary)
            grad_summaries_merged = tf.merge_summary(grad_summaries)

            # Output directory for models and summaries
            timestamp = str(int(time.time()))
            out_dir = os.path.abspath(os.path.join(os.path.curdir, "runs", timestamp))
            print("Writing to {}\n".format(out_dir))

            # Summaries for loss and accuracy
            loss_summary = tf.scalar_summary("loss", cnn.loss)
            acc_summary = tf.scalar_summary("accuracy", cnn.accuracy)
            f1_summary = tf.scalar_summary("avg_f1", cnn.avg_f1)

            # Train Summaries
            train_summary_op = tf.merge_summary([loss_summary, acc_summary, f1_summary, grad_summaries_merged])
            train_summary_dir = os.path.join(out_dir, "summaries", "train")
            train_summary_writer = tf.train.SummaryWriter(train_summary_dir, sess.graph_def)

            # Dev summaries
            dev_summary_op = tf.merge_summary([loss_summary, acc_summary, f1_summary])
            dev_summary_dir = os.path.join(out_dir, "summaries", "dev")
            dev_summary_writer = tf.train.SummaryWriter(dev_summary_dir, sess.graph_def)

            # Test summaries
            test_summary_op = tf.merge_summary([loss_summary, acc_summary, f1_summary])
            test_summary_dir = os.path.join(out_dir, "summaries", "test")
            test_summary_writer = tf.train.SummaryWriter(test_summary_dir, sess.graph_def)

            # Checkpoint directory. Tensorflow assumes this directory already exists so we need to create it
            checkpoint_dir = os.path.abspath(os.path.join(out_dir, "checkpoints"))
            checkpoint_prefix = os.path.join(checkpoint_dir, "model")
            if not os.path.exists(checkpoint_dir):
                os.makedirs(checkpoint_dir)
            saver = tf.train.Saver(tf.all_variables())

            # Initialize all variables
            sess.run(tf.initialize_all_variables())

            def train_step(x_batch, y_batch, x_batch_lex):
                """
                A single training step
                """
                feed_dict = {
                    cnn.input_x: x_batch,
                    cnn.input_y: y_batch,
                    # lexicon
                    cnn.input_x_lexicon: x_batch_lex,
                    cnn.dropout_keep_prob: FLAGS.dropout_keep_prob
                }
                _, step, summaries, loss, accuracy, neg_r, neg_p, f1_neg, f1_pos, avg_f1 = sess.run(
                    [train_op, global_step, train_summary_op, cnn.loss, cnn.accuracy,
                     cnn.neg_r, cnn.neg_p, cnn.f1_neg, cnn.f1_pos, cnn.avg_f1],
                    feed_dict)
                time_str = datetime.datetime.now().isoformat()
                # print("{}: step {}, loss {:g}, acc {:g}".format(time_str, step, loss, accuracy))
                #print("{}: step {}, loss {:g}, acc {:g}, neg_r {:g} neg_p {:g} f1_neg {:g}, f1_pos {:g}, f1 {:g}".
                #      format(time_str, step, loss, accuracy, neg_r, neg_p, f1_neg, f1_pos, avg_f1))
                train_summary_writer.add_summary(summaries, step)

            def dev_step(x_batch, y_batch, x_batch_lex, writer=None):
                """
                Evaluates model on a dev set
                """
                feed_dict = {
                    cnn.input_x: x_batch,
                    cnn.input_y: y_batch,
                    # lexicon
                    cnn.input_x_lexicon: x_batch_lex,
                    cnn.dropout_keep_prob: 1.0
                }
                step, summaries, loss, accuracy, neg_r, neg_p, f1_neg, f1_pos, avg_f1 = sess.run(
                    [global_step, dev_summary_op, cnn.loss, cnn.accuracy,
                     cnn.neg_r, cnn.neg_p, cnn.f1_neg, cnn.f1_pos, cnn.avg_f1],
                    feed_dict)
                time_str = datetime.datetime.now().isoformat()
                print("{} : {} step {}, loss {:g}, acc {:g}, neg_r {:g} neg_p {:g} f1_neg {:g}, f1_pos {:g}, f1 {:g}".
                      format("DEV", time_str, step, loss, accuracy, neg_r, neg_p, f1_neg, f1_pos, avg_f1))
                if writer:
                    writer.add_summary(summaries, step)

                return avg_f1

            def test_step(x_batch, y_batch, x_batch_lex, writer=None):
                """
                Evaluates model on a test set
                """
                feed_dict = {
                    cnn.input_x: x_batch,
                    cnn.input_y: y_batch,
                    # lexicon
                    cnn.input_x_lexicon: x_batch_lex,
                    cnn.dropout_keep_prob: 1.0
                }
                step, summaries, loss, accuracy, neg_r, neg_p, f1_neg, f1_pos, avg_f1 = sess.run(
                    [global_step, dev_summary_op, cnn.loss, cnn.accuracy,
                     cnn.neg_r, cnn.neg_p, cnn.f1_neg, cnn.f1_pos, cnn.avg_f1],
                    feed_dict)
                time_str = datetime.datetime.now().isoformat()
                print("{} : {} step {}, loss {:g}, acc {:g}, neg_r {:g} neg_p {:g} f1_neg {:g}, f1_pos {:g}, f1 {:g}".
                      format("TEST", time_str, step, loss, accuracy, neg_r, neg_p, f1_neg, f1_pos, avg_f1))
                if writer:
                    writer.add_summary(summaries, step)

                return avg_f1

            # Generate batches
            batches = cnn_data_helpers.batch_iter(
                list(zip(x_train, y_train, x_lex_train)), FLAGS.batch_size, FLAGS.num_epochs)
            # Training loop. For each batch...
            for batch in batches:
                x_batch, y_batch, x_batch_lex = zip(*batch)
                train_step(x_batch, y_batch, x_batch_lex)
                current_step = tf.train.global_step(sess, global_step)
                if current_step % FLAGS.evaluate_every == 0:
                    print("Evaluation:")
                    curr_af1_dev = dev_step(x_dev, y_dev, x_lex_dev, writer=dev_summary_writer)
                        # path = saver.save(sess, checkpoint_prefix, global_step=current_step)
                        # print("Saved model checkpoint to {}\n".format(path))

                    curr_af1_tst = test_step(x_test, y_test, x_lex_test, writer=test_summary_writer)
                        # path = saver.save(sess, checkpoint_prefix, global_step=current_step)
                        # print("Saved model checkpoint to {}\n".format(path))

                    if curr_af1_dev > max_af1_dev:
                        max_af1_dev = curr_af1_dev
                        index_at_max_af1_dev = current_step
                        af1_tst_at_max_af1_dev = curr_af1_tst

                    print 'Status: [%d] Max f1 for dev (%f), Max f1 for tst (%f)\n' % (
                        index_at_max_af1_dev, max_af1_dev, af1_tst_at_max_af1_dev)
                    sys.stdout.flush()
Ejemplo n.º 5
0
def run_train(w2vsource, w2vdim, w2vnumfilters, lexdim, lexnumfilters, randomseed, model_name, trainable, is_expanded, simple_run = True):
    if simple_run == True:
        print '======================================[simple_run]======================================'


    max_len = 60
    norm_model = []

    if model_name != "nonstaticRT":
        with Timer("lex"):
            if is_expanded == 0:
                print 'old way of loading lexicon'
                norm_model, raw_model = load_lexicon_unigram(lexdim)


            else:
                print 'new way of loading lexicon'
                default_vector_dic = {'EverythingUnigramsPMIHS': [0],
                                      'HS-AFFLEX-NEGLEX-unigrams': [0, 0, 0],
                                      'Maxdiff-Twitter-Lexicon_0to1': [0.5],
                                      'S140-AFFLEX-NEGLEX-unigrams': [0, 0, 0],
                                      'unigrams-pmilexicon': [0, 0, 0],
                                      'unigrams-pmilexicon_sentiment_140': [0, 0, 0],
                                      'BL': [0]}

                lexfile_list = ['EverythingUnigramsPMIHS.pickle',
                                'HS-AFFLEX-NEGLEX-unigrams.pickle',
                                'Maxdiff-Twitter-Lexicon_0to1.pickle',
                                'S140-AFFLEX-NEGLEX-unigrams.pickle',
                                'unigrams-pmilexicon.pickle',
                                'unigrams-pmilexicon_sentiment_140.pickle',
                                'BL.pickle']


                for idx, lexfile in enumerate(lexfile_list):
                    if is_expanded-1 == idx:
                        fname = '../data/le/exp_%s' % lexfile
                        print 'expanded lexicon for %s' % lexfile

                    else:
                        fname = '../data/le/%s' % lexfile
                        print 'default lexicon for %s' % lexfile

                    with open(fname, 'rb') as handle:
                        each_model = pickle.load(handle)
                        default_vector = default_vector_dic[lexfile.replace('.pickle', '')]
                        each_model["<PAD/>"] = default_vector
                        norm_model.append(each_model)


        with Timer("w2v"):
            if w2vsource == "twitter":
                w2vmodel = load_w2v(w2vdim, simple_run=simple_run)
            else:
                w2vmodel = load_w2v(w2vdim, simple_run=simple_run, source = "amazon")


        unigram_lexicon_model = norm_model

    # Training
    # ==================================================
    if randomseed > 0:
        tf.set_random_seed(randomseed)
    with tf.Graph().as_default():
        tf.set_random_seed(randomseed)
        max_af1_dev = 0
        index_at_max_af1_dev = 0
        af1_tst_at_max_af1_dev = 0

        x_text, y = cnn_data_helpers.load_data_nonstatic("everydata", rottenTomato=True)

        max_document_length = max([len(x.split(" ")) for x in x_text])
        vocab_processor = learn.preprocessing.VocabularyProcessor(max_document_length)
        vocab_processor.fit_transform(x_text)
        total_vocab_size = len(vocab_processor.vocabulary_)


        if model_name == "w2vrt" or  model_name == "w2vrtlex":
            x_train, y_train, x_lex_train = cnn_data_helpers.load_data('trn', w2vmodel, unigram_lexicon_model, max_len, True)
            x_dev, y_dev, x_lex_dev = cnn_data_helpers.load_data('dev', w2vmodel, unigram_lexicon_model, max_len, True)
            x_test, y_test, x_lex_test = cnn_data_helpers.load_data('tst', w2vmodel, unigram_lexicon_model, max_len, True) 
        elif model_name == 'nonstaticRT':
            x_train, y_train = cnn_data_helpers.load_data_nonstatic("trn", rottenTomato=True)
            x_train = np.array(list(vocab_processor.fit_transform(x_train)))
            x_dev, y_dev = cnn_data_helpers.load_data_nonstatic("dev", rottenTomato=True)
            x_dev = np.array(list(vocab_processor.fit_transform(x_dev)))
            x_test, y_test = cnn_data_helpers.load_data_nonstatic("tst", rottenTomato=True)
            x_test = np.array(list(vocab_processor.fit_transform(x_test)))
        else:
            x_train, y_train, x_lex_train = cnn_data_helpers.load_data('trn', w2vmodel, unigram_lexicon_model, max_len)
            x_dev, y_dev, x_lex_dev = cnn_data_helpers.load_data('dev', w2vmodel, unigram_lexicon_model, max_len)
            x_test, y_test, x_lex_test = cnn_data_helpers.load_data('tst', w2vmodel, unigram_lexicon_model, max_len)

        if model_name != "nonstaticRT":
            del(w2vmodel)
            del(norm_model)
            # del(raw_model)
            gc.collect()

        print("Train/Dev split: {:d}/{:d}".format(len(y_train), len(y_dev)))


        session_conf = tf.ConfigProto(
          allow_soft_placement=FLAGS.allow_soft_placement,
          log_device_placement=FLAGS.log_device_placement)
        sess = tf.Session(config=session_conf)
        with sess.as_default():
            if randomseed > 0:
                tf.set_random_seed(randomseed)

            if model_name=='w2v':
                cnn = W2V_CNN(
                    sequence_length=x_train.shape[1],
                    num_classes=3,
                    embedding_size=w2vdim,
                    filter_sizes=list(map(int, FLAGS.filter_sizes.split(","))),
                    num_filters=w2vnumfilters,
                    l2_reg_lambda=FLAGS.l2_reg_lambda
                )
            elif model_name == 'nonstaticRT':
                cnn = W2V_NONSTATIC(
                    sequence_length=x_train.shape[1],
                    num_classes=5,
                    vocab_size=len(vocab_processor.vocabulary_),
                    is_trainable=trainable,
                    embedding_size=w2vdim,
                    filter_sizes=list(map(int, FLAGS.filter_sizes.split(","))),
                    num_filters=w2vnumfilters,
                    l2_reg_lambda=FLAGS.l2_reg_lambda
                )

            elif model_name=='w2vrt':
                cnn = W2V_CNN(
                    sequence_length=x_train.shape[1],
                    num_classes=5,
                    embedding_size=w2vdim,
                    filter_sizes=list(map(int, FLAGS.filter_sizes.split(","))),
                    num_filters=w2vnumfilters,
                    l2_reg_lambda=FLAGS.l2_reg_lambda
                )

            elif model_name=='w2vlex':
                cnn = W2V_LEX_CNN(
                    sequence_length=x_train.shape[1],
                    num_classes=3,
                    embedding_size=w2vdim,
                    embedding_size_lex=lexdim,
                    num_filters_lex=lexnumfilters,
                    filter_sizes=list(map(int, FLAGS.filter_sizes.split(","))),
                    num_filters=w2vnumfilters,
                    l2_reg_lambda=FLAGS.l2_reg_lambda)
           
            elif model_name=='w2vrtlex':
                cnn = W2V_LEX_CNN(
                    sequence_length=x_train.shape[1],
                    num_classes=5,
                    embedding_size=w2vdim,
                    embedding_size_lex=lexdim,
                    num_filters_lex=lexnumfilters,
                    filter_sizes=list(map(int, FLAGS.filter_sizes.split(","))),
                    num_filters=w2vnumfilters,
                    l2_reg_lambda=FLAGS.l2_reg_lambda)

            else: # model_name == 'attention'
                cnn = TextCNNAttention(
                    sequence_length=x_train.shape[1],
                    num_classes=3,
                    embedding_size=w2vdim,
                    embedding_size_lex=lexdim,
                    num_filters_lex=lexnumfilters,
                    filter_sizes=list(map(int, FLAGS.filter_sizes.split(","))),
                    num_filters=w2vnumfilters,
                    l2_reg_lambda=FLAGS.l2_reg_lambda)


            # Define Training procedure
            global_step = tf.Variable(0, name="global_step", trainable=False)
            optimizer = tf.train.AdamOptimizer(1e-3)
            grads_and_vars = optimizer.compute_gradients(cnn.loss)
            train_op = optimizer.apply_gradients(grads_and_vars, global_step=global_step)

            # Keep track of gradient values and sparsity (optional)
            grad_summaries = []
            for g, v in grads_and_vars:
                if g is not None:
                    grad_hist_summary = tf.histogram_summary("{}/grad/hist".format(v.name), g)
                    sparsity_summary = tf.scalar_summary("{}/grad/sparsity".format(v.name), tf.nn.zero_fraction(g))
                    grad_summaries.append(grad_hist_summary)
                    grad_summaries.append(sparsity_summary)
            grad_summaries_merged = tf.merge_summary(grad_summaries)

            # Output directory for models and summaries
            timestamp = str(int(time.time()))
            out_dir = os.path.abspath(os.path.join(os.path.curdir, "runs", timestamp))
            print("Writing to {}\n".format(out_dir))

            # Summaries for loss and accuracy
            loss_summary = tf.scalar_summary("loss", cnn.loss)
            acc_summary = tf.scalar_summary("accuracy", cnn.accuracy)
            f1_summary = tf.scalar_summary("avg_f1", cnn.avg_f1)

            # Train Summaries
            train_summary_op = tf.merge_summary([loss_summary, acc_summary, f1_summary, grad_summaries_merged])
            train_summary_dir = os.path.join(out_dir, "summaries", "train")
            train_summary_writer = tf.train.SummaryWriter(train_summary_dir, sess.graph_def)

            # Dev summaries
            dev_summary_op = tf.merge_summary([loss_summary, acc_summary, f1_summary])
            dev_summary_dir = os.path.join(out_dir, "summaries", "dev")
            dev_summary_writer = tf.train.SummaryWriter(dev_summary_dir, sess.graph_def)

            # Test summaries
            test_summary_op = tf.merge_summary([loss_summary, acc_summary, f1_summary])
            test_summary_dir = os.path.join(out_dir, "summaries", "test")
            test_summary_writer = tf.train.SummaryWriter(test_summary_dir, sess.graph_def)

            # Checkpoint directory. Tensorflow assumes this directory already exists so we need to create it
            checkpoint_dir = os.path.abspath(os.path.join(out_dir, "checkpoints"))
            checkpoint_prefix = os.path.join(checkpoint_dir, "model")
            if not os.path.exists(checkpoint_dir):
                os.makedirs(checkpoint_dir)
            saver = tf.train.Saver(tf.all_variables())

            # Initialize all variables
            sess.run(tf.initialize_all_variables())
            the_base_path = '../data/emory_w2v/'
            if w2vsource == "twitter":
                the_model_path = the_base_path + 'w2v-%d.bin' % w2vdim
            elif w2vsource == "amazon":
                the_model_path = the_base_path + 'w2v-%d-%s.bin' % (w2vdim, w2vsource)

            print the_model_path
            if model_name == 'nonstaticRT':
                # initial matrix with random uniform
                initW = np.random.uniform(-0.25,0.25,(total_vocab_size, w2vdim))
                # load any vectors from the word2vec
                with Timer("w2v"):
                    print("Load word2vec file {} for NONSTATIC \n".format(the_model_path))
                    with open(the_model_path, "rb") as f:
                        header = f.readline()
                        vocab_size, layer1_size = map(int, header.split())
                        binary_len = np.dtype('float32').itemsize * layer1_size
                        for line in xrange(vocab_size):
                            word = []
                            while True:
                                ch = f.read(1)
                                if ch == ' ':
                                    word = ''.join(word)
                                    break
                                if ch != '\n':
                                    word.append(ch)   
                            idx = vocab_processor.vocabulary_.get(word)
                            if idx != 0:
                                #print str(idx) + " -> " + word
                                initW[idx] = np.fromstring(f.read(binary_len), dtype='float32') 
                            else:
                                f.read(binary_len)    

                sess.run(cnn.W.assign(initW))

            def train_step(x_batch, y_batch, x_batch_lex=None):
                """
                A single training step
                """

                if x_batch_lex != None:
                    feed_dict = {
                        cnn.input_x: x_batch,
                        cnn.input_y: y_batch,
                        # lexicon
                        cnn.input_x_lexicon: x_batch_lex,
                        cnn.dropout_keep_prob: FLAGS.dropout_keep_prob
                    }
                else: 
                    feed_dict = {
                        cnn.input_x: x_batch,
                        cnn.input_y: y_batch,
                        cnn.dropout_keep_prob: FLAGS.dropout_keep_prob
                    }
                _, step, summaries, loss, accuracy, neg_r, neg_p, f1_neg, f1_pos, avg_f1 = sess.run(
                    [train_op, global_step, train_summary_op, cnn.loss, cnn.accuracy,
                     cnn.neg_r, cnn.neg_p, cnn.f1_neg, cnn.f1_pos, cnn.avg_f1],
                    feed_dict)
                time_str = datetime.datetime.now().isoformat()
                # print("{}: step {}, loss {:g}, acc {:g}".format(time_str, step, loss, accuracy))
                #print("{}: step {}, loss {:g}, acc {:g}, neg_r {:g} neg_p {:g} f1_neg {:g}, f1_pos {:g}, f1 {:g}".
                #      format(time_str, step, loss, accuracy, neg_r, neg_p, f1_neg, f1_pos, avg_f1))
                train_summary_writer.add_summary(summaries, step)

            def dev_step(x_batch, y_batch, x_batch_lex=None, writer=None, score_type='f1'):
                """
                Evaluates model on a dev set
                """
                if x_batch_lex != None:
                    feed_dict = {
                        cnn.input_x: x_batch,
                        cnn.input_y: y_batch,
                        # lexicon
                        cnn.input_x_lexicon: x_batch_lex,
                        cnn.dropout_keep_prob: 1.0
                    }
                else: 
                    feed_dict = {
                        cnn.input_x: x_batch,
                        cnn.input_y: y_batch,
                        cnn.dropout_keep_prob: 1.0
                    }
                step, summaries, loss, accuracy, neg_r, neg_p, f1_neg, f1_pos, avg_f1 = sess.run(
                    [global_step, dev_summary_op, cnn.loss, cnn.accuracy,
                     cnn.neg_r, cnn.neg_p, cnn.f1_neg, cnn.f1_pos, cnn.avg_f1],
                    feed_dict)
                time_str = datetime.datetime.now().isoformat()
                print("{} : {} step {}, loss {:g}, acc {:g}, neg_r {:g} neg_p {:g} f1_neg {:g}, f1_pos {:g}, f1 {:g}".
                      format("DEV", time_str, step, loss, accuracy, neg_r, neg_p, f1_neg, f1_pos, avg_f1))
                if writer:
                    writer.add_summary(summaries, step)

                if score_type == 'f1':
                    return avg_f1
                else:
                    return accuracy

            def test_step(x_batch, y_batch, x_batch_lex=None, writer=None, score_type='f1'):
                """
                Evaluates model on a test set
                """
                if x_batch_lex != None:
                    feed_dict = {
                        cnn.input_x: x_batch,
                        cnn.input_y: y_batch,
                        # lexicon
                        cnn.input_x_lexicon: x_batch_lex,
                        cnn.dropout_keep_prob: 1.0
                    }
                else: 
                    feed_dict = {
                        cnn.input_x: x_batch,
                        cnn.input_y: y_batch,
                        cnn.dropout_keep_prob: 1.0
                    }
                step, summaries, loss, accuracy, neg_r, neg_p, f1_neg, f1_pos, avg_f1 = sess.run(
                    [global_step, dev_summary_op, cnn.loss, cnn.accuracy,
                     cnn.neg_r, cnn.neg_p, cnn.f1_neg, cnn.f1_pos, cnn.avg_f1],
                    feed_dict)
                time_str = datetime.datetime.now().isoformat()
                print("{} : {} step {}, loss {:g}, acc {:g}, neg_r {:g} neg_p {:g} f1_neg {:g}, f1_pos {:g}, f1 {:g}".
                      format("TEST", time_str, step, loss, accuracy, neg_r, neg_p, f1_neg, f1_pos, avg_f1))
                if writer:
                    writer.add_summary(summaries, step)

                if score_type == 'f1':
                    return avg_f1
                else:
                    return accuracy

            # Generate batches
            if model_name == 'nonstaticRT':
                batches = cnn_data_helpers.batch_iter(
                    list(zip(x_train, y_train)), FLAGS.batch_size, FLAGS.num_epochs)
            else:
                batches = cnn_data_helpers.batch_iter(
                    list(zip(x_train, y_train, x_lex_train)), FLAGS.batch_size, FLAGS.num_epochs)
            # Training loop. For each batch...
            for batch in batches:
                if model_name == 'nonstaticRT':
                     x_batch, y_batch = zip(*batch)
                else:
                    x_batch, y_batch, x_batch_lex = zip(*batch)

                if model_name=='w2v' or model_name=='w2vrt' or model_name == 'nonstaticRT':
                    train_step(x_batch, y_batch)
                else:
                    train_step(x_batch, y_batch, x_batch_lex)


                current_step = tf.train.global_step(sess, global_step)
                if current_step % FLAGS.evaluate_every == 0:
                    print("Evaluation:")

                    if model_name == 'w2v':
                        curr_af1_dev = dev_step(x_dev, y_dev, writer=dev_summary_writer)
                        # path = saver.save(sess, checkpoint_prefix, global_step=current_step)
                        # print("Saved model checkpoint to {}\n".format(path))

                        curr_af1_tst = test_step(x_test, y_test, writer=test_summary_writer)
                        # path = saver.save(sess, checkpoint_prefix, global_step=current_step)
                        # print("Saved model checkpoint to {}\n".format(path))

                    elif model_name == 'w2vrt' or model_name =='nonstaticRT':
                        curr_af1_dev = dev_step(x_dev, y_dev, writer=dev_summary_writer, score_type = 'acc')
                        curr_af1_tst = test_step(x_test, y_test, writer=test_summary_writer, score_type = 'acc')

                    elif model_name == 'w2vrtlex':
                        curr_af1_dev = dev_step(x_dev, y_dev, x_lex_dev, writer=dev_summary_writer, score_type = 'acc')
                        curr_af1_tst = test_step(x_test, y_test, x_lex_test, writer=test_summary_writer, score_type = 'acc')
                    else:
                        curr_af1_dev = dev_step(x_dev, y_dev, x_lex_dev, writer=dev_summary_writer)
                            # path = saver.save(sess, checkpoint_prefix, global_step=current_step)
                            # print("Saved model checkpoint to {}\n".format(path))

                        curr_af1_tst = test_step(x_test, y_test, x_lex_test, writer=test_summary_writer)
                            # path = saver.save(sess, checkpoint_prefix, global_step=current_step)
                            # print("Saved model checkpoint to {}\n".format(path))

                    if curr_af1_dev > max_af1_dev:
                        max_af1_dev = curr_af1_dev
                        index_at_max_af1_dev = current_step
                        af1_tst_at_max_af1_dev = curr_af1_tst

                    print 'Status: [%d] Max f1 for dev (%f), Max f1 for tst (%f)\n' % (
                        index_at_max_af1_dev, max_af1_dev, af1_tst_at_max_af1_dev)
                    sys.stdout.flush()
Ejemplo n.º 6
0
def run_train(w2v_path, trn_path, dev_path, model_path, lex_path_list, w2vnumfilters, lexnumfilters, randomseed,
              num_epochs, num_class, max_sentence_len, l2_reg_lambda, l1_reg_lambda, simple_run=True):
    if simple_run == True:
        print '======================================[simple_run]======================================'

    if len(lex_path_list)==0:
        model_name = 'w2v'
    else:
        model_name = 'w2vlex'

    best_model_path = None

    max_len = max_sentence_len

    multichannel = False
    multichannel_a2v = False
    rt_data = False

    with utils.smart_open(w2v_path) as fin:
        header = utils.to_unicode(fin.readline())
        w2vdim = int(header.split(' ')[1].strip())


    with Timer("w2v"):
        w2vmodel = load_w2v_withpath(w2v_path)

    with Timer("lex"):
        norm_model, raw_model = load_lexicon_unigram(lex_path_list)

    lexdim = 0
    for model_idx in range(len(norm_model)):
        lexdim += len(norm_model[model_idx].values()[0])

    unigram_lexicon_model = norm_model
    # unigram_lexicon_model = raw_model

    if simple_run:
        x_train, y_train, x_lex_train, _ = cnn_data_helpers.load_data('trn_sample', w2vmodel, unigram_lexicon_model,
                                                                   max_len, multichannel=multichannel)
        x_dev, y_dev, x_lex_dev, _ = cnn_data_helpers.load_data('dev_sample', w2vmodel, unigram_lexicon_model, max_len,
                                                             multichannel = multichannel)

    else:
        x_train, y_train, x_lex_train, _ = cnn_data_helpers.load_data(trn_path, w2vmodel, unigram_lexicon_model, max_len,
                                                                   rottenTomato=rt_data, multichannel=multichannel)
        x_dev, y_dev, x_lex_dev, _ = cnn_data_helpers.load_data(dev_path, w2vmodel, unigram_lexicon_model, max_len,
                                                             rottenTomato=rt_data, multichannel=multichannel)

    del (w2vmodel)
    del (norm_model)
    # del(raw_model)
    gc.collect()

    print("Train/Dev split: {:d}/{:d}".format(len(y_train), len(y_dev)))

    # Training
    # ==================================================
    if randomseed > 0:
        tf.set_random_seed(randomseed+10)
    with tf.Graph().as_default():
        max_af1_dev = 0
        index_at_max_af1_dev = 0
        af1_tst_at_max_af1_dev = 0

        session_conf = tf.ConfigProto(
            allow_soft_placement=FLAGS.allow_soft_placement,
            log_device_placement=FLAGS.log_device_placement)
        sess = tf.Session(config=session_conf)
        with sess.as_default():
            if randomseed > 0:
                tf.set_random_seed(randomseed)

            num_classes = num_class

            if model_name == 'w2v':
                cnn = W2V_CNN(
                    sequence_length=x_train.shape[1],
                    num_classes=num_classes,
                    embedding_size=w2vdim,
                    filter_sizes=list(map(int, FLAGS.filter_sizes.split(","))),
                    num_filters=w2vnumfilters,
                    l2_reg_lambda=l2_reg_lambda,
                    l1_reg_lambda=l1_reg_lambda)

            else:
                cnn = W2V_LEX_CNN(
                    sequence_length=x_train.shape[1],
                    num_classes=num_classes,
                    embedding_size=w2vdim,
                    embedding_size_lex=lexdim,
                    num_filters_lex=lexnumfilters,
                    filter_sizes=list(map(int, FLAGS.filter_sizes.split(","))),
                    num_filters=w2vnumfilters,
                    l2_reg_lambda=l2_reg_lambda,
                    l1_reg_lambda=l1_reg_lambda)

            # Define Training procedure
            global_step = tf.Variable(0, name="global_step", trainable=False)
            optimizer = tf.train.AdamOptimizer(1e-3)
            grads_and_vars = optimizer.compute_gradients(cnn.loss)
            train_op = optimizer.apply_gradients(grads_and_vars, global_step=global_step)

            # Keep track of gradient values and sparsity (optional)
            grad_summaries = []
            for g, v in grads_and_vars:
                if g is not None:
                    grad_hist_summary = tf.histogram_summary("{}/grad/hist".format(v.name), g)
                    sparsity_summary = tf.scalar_summary("{}/grad/sparsity".format(v.name), tf.nn.zero_fraction(g))
                    grad_summaries.append(grad_hist_summary)
                    grad_summaries.append(sparsity_summary)
            grad_summaries_merged = tf.merge_summary(grad_summaries)

            # Output directory for models and summaries
            timestamp = str(int(time.time()))
            out_dir = os.path.abspath(os.path.join(os.path.curdir, "runs", timestamp))
            print("Writing to {}\n".format(out_dir))

            # Summaries for loss and accuracy
            loss_summary = tf.scalar_summary("loss", cnn.loss)
            acc_summary = tf.scalar_summary("accuracy", cnn.accuracy)
            f1_summary = tf.scalar_summary("avg_f1", cnn.avg_f1)

            # Train Summaries
            train_summary_op = tf.merge_summary([loss_summary, acc_summary, f1_summary, grad_summaries_merged])
            train_summary_dir = os.path.join(out_dir, "summaries", "train")
            train_summary_writer = tf.train.SummaryWriter(train_summary_dir, sess.graph_def)

            # Dev summaries
            dev_summary_op = tf.merge_summary([loss_summary, acc_summary, f1_summary])
            dev_summary_dir = os.path.join(out_dir, "summaries", "dev")
            dev_summary_writer = tf.train.SummaryWriter(dev_summary_dir, sess.graph_def)

            # Test summaries
            test_summary_op = tf.merge_summary([loss_summary, acc_summary, f1_summary])
            test_summary_dir = os.path.join(out_dir, "summaries", "test")
            test_summary_writer = tf.train.SummaryWriter(test_summary_dir, sess.graph_def)

            # Checkpoint directory. Tensorflow assumes this directory already exists so we need to create it
            checkpoint_dir = os.path.abspath(os.path.join(out_dir, "checkpoints"))
            checkpoint_prefix = os.path.join(checkpoint_dir, "model")
            if not os.path.exists(checkpoint_dir):
                os.makedirs(checkpoint_dir)
            saver = tf.train.Saver(tf.all_variables())

            # Initialize all variables
            sess.run(tf.initialize_all_variables())

            def train_step(x_batch, y_batch, x_batch_lex=None, x_batch_fat=None, multichannel=False):
                """
                A single training step
                """
                if x_batch_fat is not None:
                    if x_batch_lex is None:
                        feed_dict = {
                            cnn.input_x: x_batch_fat,
                            cnn.input_y: y_batch,
                            cnn.dropout_keep_prob: FLAGS.dropout_keep_prob
                        }
                    else:
                        feed_dict = {
                            cnn.input_x_2c: x_batch_fat,
                            cnn.input_x: x_batch,
                            cnn.input_y: y_batch,
                            cnn.input_x_lexicon: x_batch_lex,
                            cnn.dropout_keep_prob: FLAGS.dropout_keep_prob
                        }

                else:
                    if x_batch_lex is None:
                        feed_dict = {
                            cnn.input_x: x_batch,
                            cnn.input_y: y_batch,
                            cnn.dropout_keep_prob: FLAGS.dropout_keep_prob
                        }
                    else:
                        feed_dict = {
                            cnn.input_x: x_batch,
                            cnn.input_y: y_batch,
                            # lexicon
                            cnn.input_x_lexicon: x_batch_lex,
                            cnn.dropout_keep_prob: FLAGS.dropout_keep_prob
                        }

                _, step, summaries, loss, accuracy, neg_r, neg_p, f1_neg, f1_pos, avg_f1 = sess.run(
                    [train_op, global_step, train_summary_op, cnn.loss, cnn.accuracy,
                     cnn.neg_r, cnn.neg_p, cnn.f1_neg, cnn.f1_pos, cnn.avg_f1],
                    feed_dict)
                time_str = datetime.datetime.now().isoformat()
                # print("{}: step {}, loss {:g}, acc {:g}".format(time_str, step, loss, accuracy))
                # print("{}: step {}, loss {:g}, acc {:g}, neg_r {:g} neg_p {:g} f1_neg {:g}, f1_pos {:g}, f1 {:g}".
                #      format(time_str, step, loss, accuracy, neg_r, neg_p, f1_neg, f1_pos, avg_f1))
                train_summary_writer.add_summary(summaries, step)

            def dev_step(x_batch, y_batch, x_batch_lex=None, x_batch_fat=None, writer=None, score_type='f1', multichannel=False):
                """
                Evaluates model on a dev set
                """
                if x_batch_fat is not None:
                    if x_batch_lex is None:
                        feed_dict = {
                            cnn.input_x: x_batch_fat,
                            cnn.input_y: y_batch,
                            cnn.dropout_keep_prob: 1.0
                        }
                    else:
                        feed_dict = {
                            cnn.input_x_2c: x_batch_fat,
                            cnn.input_x: x_batch,
                            cnn.input_y: y_batch,
                            cnn.input_x_lexicon: x_batch_lex,
                            cnn.dropout_keep_prob: 1.0
                        }

                else:
                    if x_batch_lex is None:
                        feed_dict = {
                            cnn.input_x: x_batch,
                            cnn.input_y: y_batch,
                            cnn.dropout_keep_prob: 1.0
                        }
                    else:
                        feed_dict = {
                            cnn.input_x: x_batch,
                            cnn.input_y: y_batch,
                            # lexicon
                            cnn.input_x_lexicon: x_batch_lex,
                            cnn.dropout_keep_prob: 1.0
                        }

                step, summaries, loss, accuracy, neg_r, neg_p, f1_neg, f1_pos, avg_f1 = sess.run(
                    [global_step, dev_summary_op, cnn.loss, cnn.accuracy,
                     cnn.neg_r, cnn.neg_p, cnn.f1_neg, cnn.f1_pos, cnn.avg_f1],
                    feed_dict)
                time_str = datetime.datetime.now().isoformat()
                print("{} : {} step {}, loss {:g}, acc {:g}, neg_r {:g} neg_p {:g} f1_neg {:g}, f1_pos {:g}, f1 {:g}".
                      format("DEV", time_str, step, loss, accuracy, neg_r, neg_p, f1_neg, f1_pos, avg_f1))
                if writer:
                    writer.add_summary(summaries, step)

                if score_type == 'f1':
                    return avg_f1
                else:
                    return accuracy

            # Generate batches
            batches = cnn_data_helpers.batch_iter(
                list(zip(x_train, y_train, x_lex_train)), FLAGS.batch_size, num_epochs)


            # Training loop. For each batch...
            for batch in batches:
                if multichannel_a2v is True or multichannel is True:
                    x_batch, y_batch, x_batch_lex, x_batch_fat = zip(*batch)
                else:
                    x_batch, y_batch, x_batch_lex = zip(*batch)

                if model_name == 'w2v' or model_name == 'w2vrt':
                    train_step(x_batch, y_batch)

                else:
                    if multichannel_a2v is True:
                        train_step(x_batch, y_batch, x_batch_lex, x_batch_fat)
                    elif multichannel is True:
                        train_step(x_batch, y_batch, x_batch_lex=None, x_batch_fat=x_batch_fat,
                                   multichannel=multichannel)
                    else:
                        train_step(x_batch, y_batch, x_batch_lex, multichannel=multichannel)

                current_step = tf.train.global_step(sess, global_step)
                if current_step % FLAGS.evaluate_every == 0:
                    print("Evaluation:")
                    if rt_data == True:
                        score_type = 'acc'
                    else:
                        score_type = 'f1'

                    if model_name == 'w2v' or model_name == 'w2vrt':
                        curr_af1_dev = dev_step(x_dev, y_dev, writer=dev_summary_writer, score_type=score_type)

                    else:
                        curr_af1_dev = dev_step(x_dev, y_dev, x_lex_dev, writer=dev_summary_writer,
                                                score_type=score_type, multichannel=multichannel)

                    if curr_af1_dev > max_af1_dev:
                        max_af1_dev = curr_af1_dev
                        index_at_max_af1_dev = current_step

                        path = saver.save(sess, checkpoint_prefix, global_step=current_step)
                        best_model_path = path
                        print("Saved model checkpoint to {}\n".format(path))
                        copyfile(best_model_path, model_path)

                    if rt_data == True:
                        print 'Status: [%d] Max Acc for dev (%f)\n' % (
                            index_at_max_af1_dev, max_af1_dev*100)
                    else:
                        print 'Status: [%d] Max f1 for dev (%f)\n' % (
                            index_at_max_af1_dev, max_af1_dev)

                    sys.stdout.flush()