Exemplo n.º 1
0
def create_mb_source(img_height, img_width, img_channels, n_classes, n_rois, data_path, data_set):
    rois_dim = 4 * n_rois
    label_dim = n_classes * n_rois

    path = os.path.normpath(os.path.join(abs_path, data_path))
    if data_set == 'test':
        map_file = os.path.join(path, test_map_filename)
    else:
        map_file = os.path.join(path, train_map_filename)
    roi_file = os.path.join(path, data_set + rois_filename_postfix)
    label_file = os.path.join(path, data_set + roilabels_filename_postfix)

    if not os.path.exists(map_file) or not os.path.exists(roi_file) or not os.path.exists(label_file):
        raise RuntimeError("File '%s', '%s' or '%s' does not exist. "
                           "Please run install_fastrcnn.py from Examples/Image/Detection/FastRCNN to fetch them" %
                           (map_file, roi_file, label_file))

    # read images
    image_source = ImageDeserializer(map_file)
    image_source.ignore_labels()
    image_source.map_features(features_stream_name,
                              [ImageDeserializer.scale(width=img_width, height=img_height, channels=img_channels,
                                                       scale_mode="pad", pad_value=114, interpolations='linear')])

    # read rois and labels
    roi_source = CTFDeserializer(roi_file)
    roi_source.map_input(roi_stream_name, dim=rois_dim, format="dense")
    label_source = CTFDeserializer(label_file)
    label_source.map_input(label_stream_name, dim=label_dim, format="dense")

    # define a composite reader
    return MinibatchSource([image_source, roi_source, label_source], epoch_size=sys.maxsize, randomize=data_set == "train")
Exemplo n.º 2
0
def train():
	global sentences, vocabulary, reverse_vocabulary
	# function will create the trainer and train it for specified number of epochs
	# Print loss 50 times while training
	print_freqency = 50
	pp = ProgressPrinter(print_freqency)

	# get the trainer
	word_one_hot, context_one_hots, negative_one_hots, targets, trainer, word_negative_context_product, embedding_layer = create_trainer()
	
	# Create a CTF reader which reads the sparse inputs
	print("reader started")
	reader = CTFDeserializer(G.CTF_input_file)
	reader.map_input(G.word_input_field, dim=G.embedding_vocab_size, format="sparse")
	# context inputs
	for i in range(context_size):
		reader.map_input(G.context_input_field.format(i), dim=G.embedding_vocab_size, format="sparse")
	# negative inputs
	for i in range(G.negative):
		reader.map_input(G.negative_input_field.format(i), dim=G.embedding_vocab_size, format="sparse")
	# targets
	reader.map_input(G.target_input_field, dim=(G.negative + 1), format="dense")
	print("reader done")

	# Get minibatch source from reader
	is_training = True
	minibatch_source = MinibatchSource(reader, randomize=is_training, epoch_size=INFINITELY_REPEAT if is_training else FULL_DATA_SWEEP)
	minibatch_source.streams[targets] = minibatch_source.streams[G.target_input_field]
	del minibatch_source.streams[G.target_input_field]
	print("minibatch source done")
	
	total_minibatches = total_training_instances // G.minibatch_size
	print("traning started")
	print("Total minibatches to train =", total_minibatches)
	for i in range(total_minibatches):
		# Collect minibatch
		# start_batch_collection = time.time()
		mb = minibatch_source.next_minibatch(G.minibatch_size, input_map=minibatch_source.streams)
		# end_batch_collection = time.time()
		# print("Batch collection time = %.6fsecs" % (end_batch_collection - start_batch_collection))
		# print("Time taken to collect one training_instance = %.6fsecs" % ((end_batch_collection - start_batch_collection)/G.minibatch_size))
		# Train minibatch
		# start_train = time.time()
		trainer.train_minibatch(mb)
		# end_train = time.time()
		# print("minibatch train time = %.6fsecs" % (end_train - start_train))
		# print("Time per training instance = %.6fsecs" % ((end_train - start_train)/G.minibatch_size))
		# Update progress printer
		pp.update_with_trainer(trainer)

		# start_batch_collection = time.time()
	print("Total training instances =", total_training_instances)
	return word_negative_context_product