Example #1
0
def _train(flag, input_shape=[-1,250], expand_layer=[3,20],
          layers=[250,50,1], learning_rate=0.1, momentum=0.8,
            n_epoch=100, block_bytes=3*50*8):
    """Train Synonym extraction  model

    Args:
        flag: string, format like _$(dim)_$(window)_$(min_count), identifier

        input_shape: list of int, specify the shape of placeholder

        expand_layer: list of int, specify the feature expand layer.
                      For example, if value is [3, 20], that means expand
                      feature from 3 to 20 with a 3*20 matrix

        layers: list of int, specify the number of nerons in the MLP.For
                example, if value is [250, 50, 1], that means input layer is
                batch*250, hidden layer is 250*50, output layer is 50*1.

    """
    with tf.variable_scope('var%s' % flag):

        data_dir = "%s/data" % os.path.dirname(os.getcwd())
        model_dir = "%s/model" % data_dir
        this_dir = "%s/%s" % (model_dir, flag)
        if not os.path.exists(model_dir):
            os.mkdir(model_dir)
        if not os.path.exists(this_dir):
            os.mkdir(this_dir)

        f_train_x = flag.join(['../data/temp/train_X','.bin'])
        f_train_y = flag.join(['../data/temp/train_y','.txt'])
        f_test_x = flag.join(['../data/temp/test_X','.bin'])
        f_test_y = flag.join(['../data/temp/test_y','.txt'])

        placeholder_shape = list(input_shape)
        placeholder_shape[0] = None

        print '...building graph%s' % flag
        sess = tf.Session(config=tf.ConfigProto(
            intra_op_parallelism_threads=4))
            #Build graph
        if  len(input_shape) == 2:
            model = Model.MLP(momentum=momentum,
                                  input_shape=placeholder_shape, layers=layers)
        elif len(input_shape) == 3:
            model = Model.MLPwL(momentum=momentum,
                            input_shape=placeholder_shape,
                            expand_layer=expand_layer, layers=layers)
        else:
            raise ValueError

        # Build operations
        x_placeholder = model.x_placeholder
        y_placeholder = model.y_placeholder
        pred = model.output(x_placeholder)
        loss = model.loss_function(pred, y_placeholder)
        train_op = model.train_op(loss)

        # Builde saver
        saver = tf.train.Saver()
        tf.add_to_collection('output', pred)
        tf.add_to_collection('loss', loss)
        tf.add_to_collection('input', x_placeholder)

        # Train model
        sess.run(tf.initialize_all_variables())
        print '...start training %s' % flag
        t0 = time.time()
        for epoch in xrange(n_epoch):
            gen = Model.data_generator(f_train_x, f_train_y, block_bytes)
            batch = list(islice(gen, 50))
            while(len(batch) > 0):
                batch_x = np.array([l[0] for l in batch])
                batch_x = np.reshape(batch_x, input_shape)
                batch_y = np.array([l[1] for l in batch])

                feed_dict = {x_placeholder:batch_x,
                         y_placeholder:batch_y}
                ls, _ = sess.run([loss, train_op], feed_dict=feed_dict)
                batch = list(islice(gen, 50))

            t1 = time.time()
            message = ('epoch: %d, loss: %.5f, elapsed time: %.5f'
            % (epoch + 1, ls, t1 - t0))
            _print_bar(epoch + 1, n_epoch, message)

        # Print Model Report
        gen = Model.data_generator(f_test_x, f_test_y,
                        block_bytes=block_bytes)
        batch = list(islice(gen, 50000))
        batch_x = np.array([l[0] for l in batch])
        batch_x = np.reshape(batch_x, input_shape)
        batch_y = np.array([l[1] for l in batch])
        x = sess.run(pred, {x_placeholder:batch_x})
        x[x > 0] = 1
        x[x < 0] = -1
        precision, recall, f1 = Model.metric(x, batch_y)
        print '\n'
        print 'precision: %f, recall: %f, f1 score: %f' % (precision,
                                                           recall, f1)

        # Saving graph and architecture and variables
        graph_def = sess.graph_def
        saver.save(sess, flag.join(['%s/model' % this_dir,'']))
        saver_def = saver.saver_def

        filename = flag.join(['%s/model' % this_dir,'.meta'])
        collection_list = ['output', 'loss', 'input']
        tf.train.export_meta_graph(filename=filename,
                               graph_def=graph_def,
                               saver_def=saver_def,
                               collection_list=collection_list)
        os.system('rm -rf ../data/temp')
        return model_dir