import os

sys.path.insert(1, os.path.join(sys.path[0], '../src'))
import nn

parser = argparse.ArgumentParser()
parser.add_argument('source_layout', metavar='source-layout', help='Path source network layout specification')
parser.add_argument('source_weights', metavar='source-weights', help='Path sourcce network weights')
parser.add_argument('target_layout', metavar='target-layout', help='Path target network layout specification')
parser.add_argument('target_weights', metavar='target-weights', help='Path target network weights')
parser.add_argument('layerspec', help='Which layers to copy. Format: a-b-...-z where a-z are 0-based layer numbers')
args = parser.parse_args()

#~ Load source model
print('Loading source model from {0}'.format(args.source_layout))
source_layout = nn.load_layout(args.source_layout)
source_model, source_optimizer = nn.build_model_to_layout(source_layout)

#~ Load source weights
print('\tLoading source weights from {0}'.format(args.source_weights))
source_model.load_weights(args.source_weights)

#~ Load target model
print('Loading target model from {0}'.format(args.target_layout))
target_layout = nn.load_layout(args.target_layout)
target_model, target_optimizer = nn.build_model_to_layout(target_layout)

#~ Load target weights
if os.path.isfile(args.target_weights):
	print('\tLoading target weights from {0}'.format(args.target_weights))
	target_model.load_weights(args.target_weights)
						# paste the filter into the filter collection
						filter_collection.paste(im_filter, (xpos, ypos))

				# save the filter collection of layer 'id'
				filter_collection.save(path + "weights_on_layer_" + str(layer_id) + "_" + filename + ".png")

if __name__ == "__main__":
	#~ Parse parameters
	parser = argparse.ArgumentParser()
	parser.add_argument('weights', help='Path to the weights which are to be loaded')
	parser.add_argument('layout', help='Path network layout specification')
	parser.add_argument('-s', '--savepath', help='Path to save location of the visualized filters', default='./')
	parser.add_argument('-v', '--verbose', help='Determine whether the programm shall print information in the terminal or not', action="store_true")
	parser.add_argument('-n', '--filename', help='Pass a string which is appended to the created image files.', default='')
	args = parser.parse_args()

	# Load model
	print('Loading model from {0}'.format(args.layout))
	layout = nn.load_layout(args.layout)
	model, optimizer = nn.build_model_to_layout(layout)

	#~ Load weights
	print('Loading weights from \"{0}\"'.format(args.weights))
	model.load_weights(args.weights)

	# visualize filters
	print('Generating visualizations and storing to {0}'.format(args.savepath))
	visualize_filters(model, args.savepath, args.filename)

	print('Done')
예제 #3
0
loss_matrix = np.zeros((len(momentums), len(learningrates)))
test_loss_matrix = np.zeros((len(momentums), len(learningrates)))

data_path = "data/MUCT_fixed/muct-landmarks/MUCT_TRAIN_KAGGLE_REDUCED.csv"
test_data_path = "data/MUCT_fixed/muct-landmarks/MUCT_TEST_KAGGLE_REDUCED.csv"
layout_path = "layouts/etl_kaggle_240_320_tutorial_glorot_normal_dropout.l"
weight_store_path = "weights/240_320_color_dropout/cd_3000"
batchsize = 4
epochs = 3000
normalize = 2
normalize_output = True
grayscale = False

print('Loading model from {0}'.format(layout_path))
layout = nn.load_layout(layout_path)

# Get input shape and resolution
input_shape = layout[0][1]['input_shape']
resolution = input_shape[1:]

# Print from where the images are loaded, to which resolution they are scaled and whether they are normalized
if normalize == 1:
	print('Loading data from {0} and rescaling it to {1}x{2}. Input images are normalized to [0,1]'.format(data_path, resolution[0], resolution[1]))
elif normalize == 2:
	print('Loading data from {0} and rescaling it to {1}x{2}. Input images are normalized to [-1,1]'.format(data_path, resolution[0], resolution[1]))
else:
	print('Loading data from {0} and rescaling it to {1}x{2}. Images are not normalized!'.format(data_path, resolution[0], resolution[1]))

# Load data
x_train, y_train, original_resolution = dataset_io.read_data(data_path, resolution, normalize=normalize, grayscale=grayscale, return_original_resolution=True) # change