Exemplo n.º 1
0
    parser.add_argument('--train_path',
                        type=str,
                        default='data/cifar10_train.zip',
                        help='Path to train dataset')
    parser.add_argument('--val_path',
                        type=str,
                        default='data/cifar10_val.zip',
                        help='Path to validation dataset')
    parser.add_argument('--test_path',
                        type=str,
                        default='data/cifar10_test.zip',
                        help='Path to test dataset')
    parser.add_argument(
        '--query_path',
        type=str,
        default='examples/data/image_classification/cifar10_test_1.png',
        help='Path(s) to query image(s), delimited by commas')
    (args, _) = parser.parse_known_args()

    queries = utils.dataset.load_images(args.query_path.split(',')).tolist()
    test_model_class(model_file_path=__file__,
                     model_class='PyDenseNetBc',
                     task='IMAGE_CLASSIFICATION',
                     dependencies={
                         ModelDependency.TORCH: '1.0.1',
                         ModelDependency.TORCHVISION: '0.2.2'
                     },
                     train_dataset_path=args.train_path,
                     val_dataset_path=args.val_path,
                     test_dataset_path=args.test_path,
                     queries=queries)
Exemplo n.º 2
0
if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--train_path',
                        type=str,
                        default='data/fashion_mnist_train.zip',
                        help='Path to train dataset')
    parser.add_argument('--val_path',
                        type=str,
                        default='data/fashion_mnist_val.zip',
                        help='Path to validation dataset')
    parser.add_argument('--test_path',
                        type=str,
                        default='data/fashion_mnist_test.zip',
                        help='Path to test dataset')
    parser.add_argument(
        '--query_path',
        type=str,
        default='examples/data/image_classification/fashion_mnist_test_1.png',
        help='Path(s) to query image(s), delimited by commas')
    (args, _) = parser.parse_known_args()

    queries = utils.dataset.load_images(args.query_path.split(',')).tolist()
    test_model_class(model_file_path=__file__,
                     model_class='TfFeedForward',
                     task='IMAGE_CLASSIFICATION',
                     dependencies={ModelDependency.TENSORFLOW: '1.12.0'},
                     train_dataset_path=args.train_path,
                     val_dataset_path=args.val_path,
                     test_dataset_path=args.test_path,
                     queries=queries)
Exemplo n.º 3
0
test_model_class(model_file_path=__file__,
                 model_class='XgbReg',
                 task='TABULAR_REGRESSION',
                 dependencies={ModelDependency.XGBOOST: '0.90'},
                 train_dataset_path='data/bodyfat_train.csv',
                 val_dataset_path='data/bodyfat_val.csv',
                 train_args={
                     'features': [
                         'density', 'age', 'weight', 'height', 'neck',
                         'chest', 'abdomen', 'hip', 'thigh', 'knee',
                         'ankle', 'biceps', 'forearm', 'wrist'
                     ],
                     'target':
                     'bodyfat'
                 },
                 queries=[{
                     'density': 1.0207,
                     'age': 65,
                     'weight': 224.5,
                     'height': 68.25,
                     'neck': 38.8,
                     'chest': 119.6,
                     'abdomen': 118.0,
                     'hip': 114.3,
                     'thigh': 61.3,
                     'knee': 42.1,
                     'ankle': 23.4,
                     'biceps': 34.9,
                     'forearm': 30.1,
                     'wrist': 19.4
                 }])
Exemplo n.º 4
0
                                    subsample=subsample,
                                    colsample_bytree=colsample_bytree)
        else:
            clf = xgb.XGBClassifier(n_estimators=n_estimators,
                                    min_child_weight=min_child_weight,
                                    max_depth=max_depth,
                                    gamma=gamma,
                                    subsample=subsample,
                                    colsample_bytree=colsample_bytree,
                                    objective='multi:softmax',
                                    num_class=num_class)
        return clf


if __name__ == '__main__':
    test_model_class(model_file_path=__file__,
                     model_class='XgbClf',
                     task='TABULAR_CLASSIFICATION',
                     dependencies={ModelDependency.XGBOOST: '0.90'},
                     train_dataset_path='data/titanic_train.csv',
                     val_dataset_path='data/titanic_val.csv',
                     train_args={
                         'features': ['Pclass', 'Sex', 'Age'],
                         'target': 'Survived'
                     },
                     queries=[{
                         'Pclass': 1,
                         'Sex': 'female',
                         'Age': 2.0
                     }])
Exemplo n.º 5
0
if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--train_path',
                        type=str,
                        default='data/fashion_mnist_train.zip',
                        help='Path to train dataset')
    parser.add_argument('--val_path',
                        type=str,
                        default='data/fashion_mnist_val.zip',
                        help='Path to validation dataset')
    parser.add_argument('--test_path',
                        type=str,
                        default='data/fashion_mnist_test.zip',
                        help='Path to test dataset')
    parser.add_argument(
        '--query_path',
        type=str,
        default='examples/data/image_classification/fashion_mnist_test_1.png',
        help='Path(s) to query image(s), delimited by commas')
    (args, _) = parser.parse_known_args()

    queries = utils.dataset.load_images(args.query_path.split(',')).tolist()
    test_model_class(model_file_path=__file__,
                     model_class='SkSvm',
                     task='IMAGE_CLASSIFICATION',
                     dependencies={ModelDependency.SCIKIT_LEARN: '0.20.0'},
                     train_dataset_path=args.train_path,
                     val_dataset_path=args.val_path,
                     test_dataset_path=args.test_path,
                     queries=queries)
Exemplo n.º 6
0
        words_embed_tsr = self._word_embed(words_tsr.view(-1)).view(N, W, Ew)

        # Apply dropout to word rep (N x W x Ew)
        words_rep_tsr = self._word_dropout(words_embed_tsr)

        # Apply bidirectional LSTM to word rep sequence (N x W x 2h)
        (words_hidden_rep_tsr, _) = self._word_lstm(words_rep_tsr)
        words_hidden_rep_tsr = words_hidden_rep_tsr.contiguous()

        # Apply linear + softmax operation for sentence rep for all sentences (N x W x t)
        word_probs_tsr = F.softmax(self._word_lin(
            words_hidden_rep_tsr.view(N * W, self._h * 2)),
                                   dim=1).view(N, W, t)

        return word_probs_tsr


if __name__ == '__main__':
    test_model_class(model_file_path=__file__,
                     model_class='PyBiLstm',
                     task='POS_TAGGING',
                     dependencies={ModelDependency.TORCH: '0.4.1'},
                     train_dataset_path='data/ptb_train.zip',
                     val_dataset_path='data/ptb_val.zip',
                     queries=[['Ms.', 'Haag', 'plays', 'Elianti', '18', '.'],
                              [
                                  'The', 'luxury', 'auto', 'maker', 'last',
                                  'year', 'sold', '1,214', 'cars', 'in', 'the',
                                  'U.S.'
                              ]])