Ejemplo n.º 1
0
 def __call__(self, input):
     """
     Transform ImageSet or TextSet.
     """
     # move the import here to break circular import
     if "zoo.feature.image.imageset.ImageSet" not in sys.modules:
         from zoo.feature.image import ImageSet
     if "zoo.feature.text.text_set.TextSet" not in sys.modules:
         from zoo.feature.text import TextSet
     # if type(input) is ImageSet:
     if isinstance(input, ImageSet):
         jset = callBigDlFunc(self.bigdl_type, "transformImageSet", self.value, input)
         return ImageSet(jvalue=jset)
     elif isinstance(input, TextSet):
         jset = callBigDlFunc(self.bigdl_type, "transformTextSet", self.value, input)
         return TextSet(jvalue=jset)
Ejemplo n.º 2
0
    def predict(self, x, batch_per_thread=4, distributed=True):
        """
        Use a model to do prediction.

        # Arguments
        x: Prediction data. A Numpy array or RDD of Sample or ImageSet.
        batch_per_thread:
          The default value is 4.
          When distributed is True,the total batch size is batch_per_thread * rdd.getNumPartitions.
          When distributed is False the total batch size is batch_per_thread * numOfCores.
        distributed: Boolean. Whether to do prediction in distributed mode or local mode.
                     Default is True. In local mode, x must be a Numpy array.
        """
        if isinstance(x, ImageSet) or isinstance(x, TextSet):
            results = callBigDlFunc(self.bigdl_type, "zooPredict",
                                    self.value,
                                    x,
                                    batch_per_thread)
            return ImageSet(results) if isinstance(x, ImageSet) else TextSet(results)
        if distributed:
            if isinstance(x, np.ndarray):
                data_rdd = to_sample_rdd(x, np.zeros([x.shape[0]]))
            elif isinstance(x, RDD):
                data_rdd = x
            else:
                raise TypeError("Unsupported prediction data type: %s" % type(x))
            results = callBigDlFunc(self.bigdl_type, "zooPredict",
                                    self.value,
                                    data_rdd,
                                    batch_per_thread)
            return results.map(lambda result: Layer.convert_output(result))
        else:
            if isinstance(x, np.ndarray) or isinstance(x, list):
                results = callBigDlFunc(self.bigdl_type, "zooPredict",
                                        self.value,
                                        self._to_jtensors(x),
                                        batch_per_thread)
                return [Layer.convert_output(result) for result in results]
            else:
                raise TypeError("Unsupported prediction data type: %s" % type(x))
Ejemplo n.º 3
0
    parser.add_option("--training_split", dest="training_split", default="0.8")
    parser.add_option("-b", "--batch_size", dest="batch_size", default="128")
    parser.add_option("-e", "--nb_epoch", dest="nb_epoch", default="20")
    parser.add_option("-l",
                      "--learning_rate",
                      dest="learning_rate",
                      default="0.01")
    parser.add_option("--log_dir",
                      dest="log_dir",
                      default="/tmp/.analytics-zoo")
    parser.add_option("-m", "--model", dest="model")

    (options, args) = parser.parse_args(sys.argv)
    sc = init_nncontext("Text Classification Example")

    text_set = TextSet.read(path=options.data_path).to_distributed(
        sc, int(options.partition_num))
    print("Processing text dataset...")
    transformed = text_set.tokenize().normalize()\
        .word2idx(remove_topN=10, max_words_num=int(options.max_words_num))\
        .shape_sequence(len=int(options.sequence_length)).generate_sample()
    train_set, val_set = transformed.random_split(
        [float(options.training_split), 1 - float(options.training_split)])

    if options.model:
        model = TextClassifier.load_model(options.model)
    else:
        token_length = int(options.token_length)
        if not (token_length == 50 or token_length == 100
                or token_length == 200 or token_length == 300):
            raise ValueError(
                'token_length for GloVe can only be 50, 100, 200, 300, but got '
Ejemplo n.º 4
0
    parser.add_option("--data_path", dest="data_path")
    parser.add_option("--embedding_file", dest="embedding_file")
    parser.add_option("--question_length", dest="question_length", default="10")
    parser.add_option("--answer_length", dest="answer_length", default="40")
    parser.add_option("--partition_num", dest="partition_num", default="4")
    parser.add_option("-b", "--batch_size", dest="batch_size", default="200")
    parser.add_option("-e", "--nb_epoch", dest="nb_epoch", default="30")
    parser.add_option("-l", "--learning_rate", dest="learning_rate", default="0.001")
    parser.add_option("-m", "--model", dest="model")
    parser.add_option("--output_path", dest="output_path")

    (options, args) = parser.parse_args(sys.argv)
    sc = init_nncontext("QARanker Example")

    q_set = TextSet.read_csv(options.data_path + "/question_corpus.csv",
                             sc, int(options.partition_num)).tokenize().normalize()\
        .word2idx(min_freq=2).shape_sequence(int(options.question_length))
    a_set = TextSet.read_csv(options.data_path+"/answer_corpus.csv",
                             sc, int(options.partition_num)).tokenize().normalize()\
        .word2idx(min_freq=2, existing_map=q_set.get_word_index())\
        .shape_sequence(int(options.answer_length))

    train_relations = Relations.read(options.data_path + "/relation_train.csv",
                                     sc, int(options.partition_num))
    train_set = TextSet.from_relation_pairs(train_relations, q_set, a_set)
    validate_relations = Relations.read(options.data_path + "/relation_valid.csv",
                                        sc, int(options.partition_num))
    validate_set = TextSet.from_relation_lists(validate_relations, q_set, a_set)

    if options.model:
        knrm = KNRM.load_model(options.model)