def test_load_wrong_args(data, kwargs, err, compiled): box_size = data.box_size i = data.x_channels o = data.y_channels model1 = UNet(box_size=box_size, input_channels=i, output_channels=o, scale=data.scale, data_handle=data) if compiled: model1.compile(optimizer=Adam(lr=1e-6), loss='binary_crossentropy', metrics=[dice, dice_loss, ovl, ovl_loss]) with tempfile.NamedTemporaryFile(suffix='.hdf') as f: model1.save(f.name) with pytest.raises(err, match=list(kwargs)[0]): UNet.load_model(f.name, data_handle=data, **kwargs)
def test_save_load(data, kwargs, compiled): from keras.models import load_model as keras_load box_size = data.box_size i = data.x_channels o = data.y_channels model1 = UNet(box_size=box_size, input_channels=i, output_channels=o, scale=data.scale, data_handle=data) if compiled: model1.compile(optimizer=Adam(lr=1e-6), loss='binary_crossentropy', metrics=[dice, dice_loss, ovl, ovl_loss]) weights1 = model1.get_weights() with tempfile.NamedTemporaryFile(suffix='.hdf') as f: model1.save(f.name) model2 = UNet.load_model(f.name, data_handle=data, **kwargs) weights2 = model2.get_weights() assert model1.to_json() == model2.to_json() for w1, w2 in zip(weights1, weights2): assert np.allclose(w1, w2) with tempfile.NamedTemporaryFile(suffix='.hdf') as f: model1.save_keras(f.name) model2 = keras_load(f.name) weights2 = model2.get_weights() for w1, w2 in zip(weights1, weights2): assert np.allclose(w1, w2)
def main(): args = parse_args() if args.output is None: args.output = 'output_' + time.strftime('%Y-%m-%d') if not os.path.exists(args.output): os.makedirs(args.output) if not os.access(args.output, os.W_OK): raise IOError( 'Cannot create files inside %s (check your permissions).' % args.output) if args.train_ids: with open(args.train_ids) as f: train_ids = list(filter(None, f.read().split('\n'))) else: train_ids = None if args.test_ids: with open(args.test_ids) as f: test_ids = list(filter(None, f.read().split('\n'))) else: test_ids = None if train_ids: if test_ids: all_ids = sorted(set(train_ids) | set(test_ids)) else: all_ids = train_ids else: all_ids = None data = DataWrapper(args.input, test_set=test_ids, pdbids=all_ids, load_data=args.load) if args.model: model = UNet.load_model(args.model, data_handle=data) else: model = UNet(data_handle=data) model.compile(optimizer=Adam(lr=1e-6), loss=dice_loss, metrics=[dice, ovl, 'binary_crossentropy']) train_batch_generator = data.batch_generator(batch_size=args.batch_size) callbacks = [ ModelCheckpoint(os.path.join(args.output, 'checkpoint.hdf'), save_best_only=False) ] if test_ids: val_batch_generator = data.batch_generator(batch_size=args.batch_size, subset='test') num_val_steps = max(args.steps_per_epoch // 5, 1) callbacks.append( ModelCheckpoint(os.path.join(args.output, 'best_weights.hdf'), save_best_only=True)) else: val_batch_generator = None num_val_steps = None model.fit_generator(train_batch_generator, steps_per_epoch=args.steps_per_epoch, epochs=args.epochs, verbose=args.verbose, callbacks=callbacks, validation_data=val_batch_generator, validation_steps=num_val_steps) history = pd.DataFrame(model.history.history) history.to_csv(os.path.join(args.output, 'history.csv')) model.save(os.path.join(args.output, 'model.hdf'))