Example #1
0
def run_tensorflow(args, train_data=None, test_data=None):
    # create a tensorflow session
    run_config = tf.ConfigProto()
    run_config.gpu_options.allow_growth = True

    row_shapes = []
    if test_data is not None:
        row_shapes += [h.shape[0] for h in test_data]
    if train_data is not None:
        row_shapes += [h.shape[0] for h in train_data]
    max_rows = np.max(row_shapes)
    with tf.Session(config=run_config) as sess:
        if args.debug:
            sess = tf_debug.LocalCLIDebugWrapperSession(sess)

        if not args.use_cnn2:
            network = CNN(sess, args.vector_size, args.output_dim, max_rows,
                          args.conv_dim, args.conv_fmaps, args.hidden_units,
                          args.hidden_activation, args.dropout_rates,
                          args.batch_enabled, args.batch_size,
                          args.conv_activation, args.checkpoint_dir,
                          args.summary_dir)
        else:
            network = CNN2(sess,
                           args.vector_size,
                           args.output_dim,
                           max_rows,
                           enable_batch=args.batch_enabled,
                           batch_size=args.batch_size,
                           checkpoint_dir=args.checkpoint_dir,
                           summary_dir=args.summary_dir)
        network.build_model()
        print("\n".join([str(v) for v in tf.global_variables()]))
        if args.train_cnn:
            network.train(train_data, train_labels, args.learning_rate,
                          args.beta1, args.epochs, 500)
            utils.logger.info("training finished")
            utils.logger.info("start train data prediction test")
            predicted_labels = network.predict(train_data)
            utils.logger.info(
                "scores: f1: %.4f, recall: %.4f, precision: %.4f" %
                (sklearn.metrics.f1_score(
                    train_labels, predicted_labels, average='weighted'),
                 sklearn.metrics.recall_score(
                     train_labels, predicted_labels, average='weighted'),
                 sklearn.metrics.precision_score(
                     train_labels, predicted_labels, average='weighted')))
        if args.validate_cnn:
            if not args.train_cnn:
                if args.checkpoint_file is not None:
                    if not network.load_file(args.checkpoint_file)[0]:
                        utils.logger.error(
                            "specified checkpoint_file is not valid")
                        exit(-1)
                elif not network.load(args.checkpoint_dir)[0]:
                    utils.logger.error(
                        "No model available in checkpoint directory, train one first"
                    )
                    exit(-1)
            utils.logger.info("start validatation on train data")
            predicted_labels = network.predict(train_data)
            utils.logger.info(
                "scores: f1: %.4f, recall: %.4f, precision: %.4f" %
                (sklearn.metrics.f1_score(
                    train_labels, predicted_labels, average='weighted'),
                 sklearn.metrics.recall_score(
                     train_labels, predicted_labels, average='weighted'),
                 sklearn.metrics.precision_score(
                     train_labels, predicted_labels, average='weighted')))
        if args.predict_cnn:
            if not args.train_cnn:
                if args.checkpoint_file is not None:
                    if not network.load_file(args.checkpoint_file)[0]:
                        utils.logger.error(
                            "specified checkpoint_file is not valid")
                        exit(-1)
                elif not network.load(args.checkpoint_dir)[0]:
                    utils.logger.error(
                        "No model available in checkpoint directory, train one first"
                    )
                    exit(-1)
            result = network.predict(test_data)
            utils.save_labels(args.o, result)
            utils.logger.info("labels predicted and saved")