Beispiel #1
0
def run_task(args, extra_args):
    # load data
    task_func = get_task_function(args.module, args.task, args.submodule)
    module_kwargs = get_task_function_defaults(args.module, args.dataset)
    module_kwargs.update(extra_args)
    if check_datasets(args.dataset):
        data_dict = get_dataset_builder(args.dataset)(args.batch_size,
                                                      args.resize_data_shape,
                                                      args.data_limit,
                                                      args.data_dir)
    else:
        data_dict = setup_tfds(args.dataset,
                               args.batch_size,
                               args.resize_data_shape,
                               args.data_limit,
                               args.data_dir,
                               shuffle_seed=args.seed)

    print('Running {} on {} with arguments \n{}'.format(
        args.task, args.module, args.dataset, module_kwargs))
    model = task_func(data_dict,
                      seed=args.seed,
                      output_dir=args.output_dir,
                      debug=args.debug,
                      **module_kwargs)

    return model
Beispiel #2
0
    def setUp(self):
        super(VisualizableMixinTest, self).__init__()
        self.args = dict()
        self.args['data_limit'] = 80
        self.args['batch_size'] = 8

        self.model = BasicMnistVisualizable()
        self.ds = setup_tfds('mnist', self.args['batch_size'], None,
                             self.args['data_limit'])
Beispiel #3
0
 def test_mnist(self):
     """Run one small epoch of MNIST just to make sure no errors are thrown"""
     # takes ~4 seconds on a laptop
     mnist_ds = setup_tfds(self.args['dataset'], self.args['batch_size'],
                           None, self.args['data_limit'])
     learn(mnist_ds,
           encoder=self.args['encoder'],
           decoder=self.args['decoder'],
           latent_dim=self.args['latent_dim'],
           epochs=self.args['epochs'],
           output_dir=self.args['output_dir'])
Beispiel #4
0
 def setUp(self):
     self.args = dict()
     self.args['batch_size'] = 8
     self.args['data_limit'] = 24
     self.args['num_prototypes'] = 100
     self.args['prototype_dim'] = 64
     self.ds = setup_tfds('mnist', self.args['batch_size'], None,
                          self.args['data_limit'])
     conv_stack = get_network_builder('mnist_conv')()
     self.model = ProtoPNet(conv_stack, self.args['num_prototypes'],
                            self.args['prototype_dim'],
                            self.ds['num_classes'])
Beispiel #5
0
def train(args, extra_args):
    # load data
    learn = get_learn_function(args.alg)
    alg_kwargs = get_learn_function_defaults(args.alg, args.dataset)
    alg_kwargs.update(extra_args)
    data_dict = setup_tfds(args.dataset, args.batch_size,
                           args.resize_data_shape, args.data_limit,
                           args.data_dir)

    print('Training {} on {} with arguments \n{}'.format(
        args.alg, args.dataset, alg_kwargs))
    model = learn(data_dict,
                  seed=args.seed,
                  output_dir=args.output_dir,
                  debug=args.debug,
                  **alg_kwargs)

    return model
Beispiel #6
0
    def setUp(self):
        self.args = dict()
        self.args['batch_size'] = 8
        self.args['data_limit'] = 24
        self.args['conv_stack_name'] = 'ross_net'
        self.args['output_dir'] = './features_test'
        self.args['model_save_name'] = 'features_test'

        self.ds = setup_tfds('mnist', self.args['batch_size'], None,
                             self.args['data_limit'])
        objects = build_savable_objects(self.args['conv_stack_name'], self.ds,
                                        0.0001, self.args['output_dir'],
                                        self.args['model_save_name'])

        self.model = objects['model']
        self.optimizer = objects['optimizer']
        self.global_step = objects['global_step']
        self.checkpoint = objects['checkpoint']
        self.ckpt_manager = objects['ckpt_manager']
Beispiel #7
0
    def setUp(self):
        super(CpVAETest, self).setUp()
        self.args = dict()
        self.args['dataset'] = 'mnist'
        self.args['encoder'] = 'mnist_encoder'
        self.args['decoder'] = 'mnist_decoder'
        self.args['epochs'] = 1
        self.args['data_limit'] = 80
        self.args['latent_dim'] = 10
        self.args['batch_size'] = 8
        self.args['output_dir'] = 'cpvae_cpvae_test'

        self.ds = setup_tfds(self.args['dataset'], self.args['batch_size'],
                             None, self.args['data_limit'])
        self.encoder = VAEEncoder(self.args['encoder'],
                                  self.args['latent_dim'])
        self.decoder = VAEDecoder(self.args['decoder'], self.ds['shape'][-1])
        decision_tree = sklearn.tree.DecisionTreeClassifier(
            max_depth=2, min_weight_fraction_leaf=0.01, max_leaf_nodes=4)
        self.DDT = DDT(decision_tree, 10)
    def setUp(self):
        super(FeatureClassifierMixinTest, self).__init__()
        self.args = dict()
        self.args['data_limit'] = 24
        self.args['batch_size'] = 8
        self.args[
            'model_params_prefix'] = './feature_classifier_test/mnist_params'

        self.model = BasicMnistFeatureClassifier()
        self.ds = setup_tfds('mnist',
                             self.args['batch_size'],
                             None,
                             self.args['data_limit'],
                             shuffle_seed=431)

        # if saved model isn't present, train and save, otherwise load
        checkpoint = tf.train.Checkpoint(model=self.model)
        if not osp.exists(self.args['model_params_prefix'] + '-1.index'):
            optimizer = tf.keras.optimizers.Adam(learning_rate=0.01)
            for batch in self.ds['train']:
                with tf.GradientTape() as tape:
                    mean_loss = tf.reduce_mean(
                        tf.nn.sparse_softmax_cross_entropy_with_logits(
                            labels=batch['label'],
                            logits=self.model(
                                tf.cast(batch['image'], tf.float32))))
                gradients = tape.gradient(mean_loss,
                                          self.model.trainable_variables)

                optimizer.apply_gradients(
                    zip(gradients, self.model.trainable_variables))
            checkpoint.save(self.args['model_params_prefix'])
        else:
            for batch in self.ds['train']:
                self.model(tf.cast(batch['image'], tf.float32))
                break
            checkpoint.restore(self.args['model_params_prefix'] +
                               '-1').assert_consumed()
Beispiel #9
0
    def wrapper(*args, **kwargs):
        dataset = kwargs['data_dict']
        batch_size = kwargs['batch_size'] if 'batch_size' in kwargs else 32
        resize_data_shape = kwargs[
            'resize_data_shape'] if 'resize_data_shape' in kwargs else None
        data_limit = kwargs['data_limit'] if 'data_limit' in kwargs else -1
        data_dir = kwargs['data_dir'] if 'data_dir' in kwargs else None
        seed = kwargs['seed'] if 'seed' in kwargs else None
        output_dir = kwargs['output_dir'] if 'output_dir' in kwargs else './'
        if check_datasets(dataset):
            data_dict = get_dataset_builder(dataset)(batch_size,
                                                     resize_data_shape,
                                                     data_limit, data_dir)
        else:
            data_dict = setup_tfds(dataset,
                                   batch_size,
                                   resize_data_shape,
                                   data_limit,
                                   data_dir,
                                   shuffle_seed=seed)

        kwargs['data_dict'] = data_dict
        func(*args, **kwargs)
        (Path(output_dir) / 'done').touch()