예제 #1
0
파일: core.py 프로젝트: zhuty94/enet-keras
 def model(self):
     model = select_model(model_name=self.model_name)
     kwargs = self.model_config
     kwargs['nc'] = self.DatasetClass.num_classes()
     weights = self.DatasetClass.class_weights()
     kwargs['loss'] = [
         'categorical_crossentropy',
         # w_categorical_crossentropy(weights=weights)
     ]
     autoencoder, _ = model.build(**kwargs)
     if self.model_config['print_summary']:
         autoencoder.summary()
     try:
         h5file = self.model_config['h5file']
         print('Loading from file {}'.format(h5file))
         h5file, ext = os.path.splitext(h5file)
         autoencoder.load_weights(h5file + ext)
     except KeyError:
         autoencoder = model.transfer_weights(autoencoder)
     return autoencoder
예제 #2
0
파일: core.py 프로젝트: zhuty94/enet-keras
    def model(self):
        model_name = 'enet_unpooling'
        # model_name = 'enet_unpooling_weights_simple_setup'
        # model_name = 'enet_unpooling_no_weights'
        dataset_name = self.data_config['dataset_name']
        root_dir = 'experiments'
        pw = os.path.join(
            root_dir,
            dataset_name,
            model_name,
            'weights',
            # 'weights.enet_unpooling.02-2.59.h5'
            '{}_best.h5'.format(model_name))

        # print(pw)

        nc = getattr(datasets, dataset_name).num_classes()
        self.model_config['nc'] = nc

        autoencoder = select_model(model_name=model_name)
        # segmenter, model_name = autoencoder.build(nc=nc, w=w, h=h)
        segmenter, model_name = autoencoder.build(**self.model_config)
        segmenter.load_weights(pw)
        return segmenter
예제 #3
0
파일: test.py 프로젝트: zhuty94/enet-keras
def load_model(h5file, model_name):
    model = models.select_model(model_name)
    # h5file = os.path.join(result_dir, 'mscoco', model_name, '{}.h5'.format(model_name, h5filename))
    segmenter, model_name = model.build(nc=nc, w=dw, h=dh)
    segmenter.load_weights(h5file)
    return segmenter
예제 #4
0
def train(solver):
    dataset_name = solver['dataset_name']

    print('Preparing to train on {} data...'.format(dataset_name))

    epochs = solver['epochs']
    batch_size = solver['batch_size']
    completed_epochs = solver['completed_epochs']
    model_name = solver['model_name']

    np.random.seed(1337)  # for reproducibility

    dw = solver['dw']
    dh = solver['dh']

    resize_mode = str(solver['resize_mode'])
    instance_mode = bool(solver['instance_mode'])

    dataset = datasets.load(dataset_name)
    nc = dataset.num_classes()  # categories + background

    model = select_model(model_name=model_name)
    autoencoder, model_name = model.build(nc=nc, w=dw, h=dh)
    if 'h5file' in solver:
        h5file = solver['h5file']
        print('Loading model {}'.format(h5file))
        h5file, ext = os.path.splitext(h5file)
        autoencoder.load_weights(h5file + ext)
    else:
        autoencoder = model.transfer_weights(autoencoder)

    if K.backend() == 'tensorflow':
        print('Tensorflow backend detected; Applying memory usage constraints')
        ss = K.tf.Session(config=K.tf.ConfigProto(gpu_options=K.tf.GPUOptions(
            allow_growth=True)))
        K.set_session(ss)
        ss.run(K.tf.global_variables_initializer())

    print('Done loading {} model!'.format(model_name))

    experiment_dir = os.path.join('models', dataset_name, model_name)
    log_dir = os.path.join(experiment_dir, 'logs')
    checkpoint_dir = os.path.join(experiment_dir, 'weights')
    ensure_dir(log_dir)
    ensure_dir(checkpoint_dir)

    train_dataset, train_generator = load_dataset(dataset_name=dataset_name,
                                                  data_dir=os.path.join(
                                                      'data', dataset_name),
                                                  data_type='train2014',
                                                  instance_mode=instance_mode)
    train_gen = load_data(dataset=train_dataset,
                          generator=train_generator,
                          target_h=dh,
                          target_w=dw,
                          resize_mode=resize_mode)
    train_gen = batched(train_gen, batch_size)
    nb_train_samples = train_dataset.num_instances if instance_mode else train_dataset.num_images
    steps_per_epoch = nb_train_samples // batch_size

    validation_steps = steps_per_epoch // 10
    val_dataset, val_generator = load_dataset(
        dataset_name=dataset_name,
        data_dir=os.path.join('data', dataset_name),
        data_type='val2014',
        sample_size=validation_steps * batch_size,
        instance_mode=instance_mode)
    val_gen = load_data(dataset=val_dataset,
                        generator=val_generator,
                        target_h=dh,
                        target_w=dw,
                        resize_mode=resize_mode)
    val_gen = batched(val_gen, batch_size)

    autoencoder.fit_generator(
        generator=train_gen,
        steps_per_epoch=steps_per_epoch,
        epochs=epochs,
        verbose=1,
        callbacks=callbacks(log_dir, checkpoint_dir, model_name),
        validation_data=val_gen,
        validation_steps=validation_steps,
        initial_epoch=completed_epochs,
    )