def main(options): images = config.dataset.images() if not images: print('No images specified.', file=sys.stderr) return 1 tc = config.train.spec() if options.autoencoder: ids = imagery_dataset.AutoencoderDataset( images, config.train.network.chunk_size(), tc.chunk_stride) else: labels = config.dataset.labels() if not labels: print('No labels specified.', file=sys.stderr) return 1 ids = imagery_dataset.ImageryDataset( images, labels, config.train.network.chunk_size(), config.train.network.output_size(), tc.chunk_stride) try: if options.resume is not None: model = tf.keras.models.load_model(options.resume, custom_objects=ALL_LAYERS) else: model = config_model(ids.num_bands()) model, _ = train(model, ids, tc) if options.model is not None: model.save(options.model) except KeyboardInterrupt: print() print('Training cancelled.') return 0
def train_ae(ae_fn, ae_dataset): model, _ = train.train(ae_fn, ae_dataset, TrainingSpec(100, 5, 'mse', ['Accuracy'])) tmpdir = tempfile.mkdtemp() model_path = os.path.join(tmpdir, 'ae_model.h5') model.save(model_path) return model_path, tmpdir
def main(options): log_folder = config.dataset.log_folder() if log_folder: if not options.resume: # Start fresh and clear the read logs os.system('rm ' + log_folder + '/*') print('Dataset progress recording in: ' + log_folder) else: print('Resuming dataset progress recorded in: ' + log_folder) start_time = time.time() images = config.dataset.images() if not images: print('No images specified.', file=sys.stderr) return 1 tc = config.train.spec() if options.autoencoder: ids = imagery_dataset.AutoencoderDataset(images, config.train.network.chunk_size(), tc.chunk_stride, resume_mode=options.resume, log_folder=log_folder) else: labels = config.dataset.labels() if not labels: print('No labels specified.', file=sys.stderr) return 1 ids = imagery_dataset.ImageryDataset(images, labels, config.train.network.chunk_size(), config.train.network.output_size(), tc.chunk_stride, resume_mode=options.resume, log_folder=log_folder) try: if options.resume is not None: model = tf.keras.models.load_model(options.resume, custom_objects=ALL_LAYERS) else: model = config_model(ids.num_bands()) model, _ = train(model, ids, tc) if options.model is not None: save_model(model, options.model) except KeyboardInterrupt: print() print('Training cancelled.') stop_time = time.time() print('Elapsed time = ', stop_time-start_time) return 0
def test_train(dataset): #pylint: disable=redefined-outer-name def model_fn(): kerasinput = keras.layers.Input((3, 3, 1)) flat = keras.layers.Flatten()(kerasinput) dense2 = keras.layers.Dense(3 * 3, activation=tf.nn.relu)(flat) dense1 = keras.layers.Dense(2, activation=tf.nn.softmax)(dense2) reshape = keras.layers.Reshape((1, 1, 2))(dense1) return keras.Model(inputs=kerasinput, outputs=reshape) model, _ = train.train(model_fn, dataset, TrainingSpec(100, 5, 'sparse_categorical_crossentropy', ['accuracy'])) ret = model.evaluate(x=dataset.dataset().batch(1000)) assert ret[1] > 0.70 (test_image, test_label) = conftest.generate_tile() test_label = test_label[1:-1, 1:-1] output_image = npy.NumpyImageWriter() predictor = predict.LabelPredictor(model, output_image=output_image) predictor.predict(npy.NumpyImage(test_image)) # very easy test since we don't train much assert sum(sum(np.logical_xor(output_image.buffer()[:,:,0], test_label))) < 200
def evaluate_model(model_fn, dataset, output_trim=0, threshold=0.3, max_wrong=200, batch_size=10): model, _ = train.train( model_fn, dataset, TrainingSpec(batch_size, 5, 'sparse_categorical_crossentropy', ['sparse_categorical_accuracy'])) ret = model.evaluate(x=dataset.dataset().batch(1000)) assert ret[1] > threshold # very loose test since not much training (test_image, test_label) = conftest.generate_tile() if output_trim > 0: test_label = test_label[output_trim:-output_trim, output_trim:-output_trim] output_image = npy.NumpyWriter() predictor = predict.LabelPredictor(model, output_image=output_image) predictor.predict(npy.NumpyImage(test_image)) # very easy test since we don't train much assert sum(sum(np.logical_xor(output_image.buffer()[:, :], test_label))) < max_wrong
def main(options): log_folder = config.train.log_folder() if log_folder: if not options.resume: # Start fresh and clear the read logs os.system('rm -f ' + log_folder + '/*') print('Dataset progress recording in: ' + log_folder) else: print('Resuming dataset progress recorded in: ' + log_folder) images = config.dataset.images() if not images: print('No images specified.', file=sys.stderr) return 1 img = images.load(0) model = config_model(img.num_bands()) if options.resume is not None: temp_model = tf.keras.models.load_model(options.resume, custom_objects=custom_objects()) else: # this one is not built with proper scope, just used to get input and output shapes temp_model = model() start_time = time.time() tile_size = config.io.tile_size() tile_overlap = None stride = config.train.spec().stride # compute input and output sizes if temp_model.input_shape[1] is None: in_shape = None out_shape = temp_model.compute_output_shape((0, tile_size[0], tile_size[1], temp_model.input_shape[3])) out_shape = out_shape[1:3] tile_overlap = (tile_size[0] - out_shape[0], tile_size[1] - out_shape[1]) else: in_shape = temp_model.input_shape[1:3] out_shape = temp_model.output_shape[1:3] if options.autoencoder: ids = imagery_dataset.AutoencoderDataset(images, in_shape, tile_shape=tile_size, tile_overlap=tile_overlap, stride=stride) else: labels = config.dataset.labels() if not labels: print('No labels specified.', file=sys.stderr) return 1 ids = imagery_dataset.ImageryDataset(images, labels, out_shape, in_shape, tile_shape=tile_size, tile_overlap=tile_overlap, stride=stride) if log_folder is not None: ids.set_resume_mode(options.resume, log_folder) assert temp_model.input_shape[1] == temp_model.input_shape[2], 'Must have square chunks in model.' assert temp_model.input_shape[3] == ids.num_bands(), 'Model takes wrong number of bands.' tf.keras.backend.clear_session() try: model, _ = train(model, ids, config.train.spec(), options.resume) if options.model is not None: save_model(model, options.model) except KeyboardInterrupt: print('Training cancelled.') stop_time = time.time() print('Elapsed time = ', stop_time-start_time) return 0
def main(options): if mixed_policy_device_compatible( ) and not config.train.disable_mixed_precision(): mixed_precision.set_global_policy('mixed_float16') print( 'Tensorflow Mixed Precision is enabled. This improves training performance on compatible GPUs. ' 'However certain precautions should be taken and several additional changes can be made to improve ' 'performance further. Details: https://www.tensorflow.org/guide/mixed_precision#summary' ) images = config.dataset.images() if not images: print('No images specified.', file=sys.stderr) return 1 img = images.load(0) model = config_model(img.num_bands()) if options.resume is not None and not options.resume.endswith('.h5'): temp_model = load_model(options.resume) else: # this one is not built with proper scope, just used to get input and output shapes temp_model = model() start_time = time.time() tile_size = config.io.tile_size() tile_overlap = None stride = config.train.spec().stride # compute input and output sizes if temp_model.input_shape[1] is None: in_shape = None out_shape = temp_model.compute_output_shape( (0, tile_size[0], tile_size[1], temp_model.input_shape[3])) out_shape = out_shape[1:3] tile_overlap = (tile_size[0] - out_shape[0], tile_size[1] - out_shape[1]) else: in_shape = temp_model.input_shape[1:3] out_shape = temp_model.output_shape[1:3] if options.autoencoder: ids = imagery_dataset.AutoencoderDataset( images, in_shape, tile_shape=tile_size, tile_overlap=tile_overlap, stride=stride, max_rand_offset=config.train.spec().max_tile_offset) else: labels = config.dataset.labels() if not labels: print('No labels specified.', file=sys.stderr) return 1 ids = imagery_dataset.ImageryDataset( images, labels, out_shape, in_shape, tile_shape=tile_size, tile_overlap=tile_overlap, stride=stride, max_rand_offset=config.train.spec().max_tile_offset) assert temp_model.input_shape[1] == temp_model.input_shape[ 2], 'Must have square chunks in model.' assert temp_model.input_shape[3] == ids.num_bands( ), 'Model takes wrong number of bands.' tf.keras.backend.clear_session() # Try to have the internal model format we use match the output model format internal_model_extension = '.savedmodel' if options.model and ('.h5' in options.model): internal_model_extension = '.h5' try: model, _ = train(model, ids, config.train.spec(), options.resume, internal_model_extension) if options.model is not None: save_model(model, options.model) except KeyboardInterrupt: print('Training cancelled.') stop_time = time.time() print('Elapsed time = ', stop_time - start_time) return 0