def minibatch_ab(self, images, batchsize, side, do_shuffle=True, is_preview=False, is_timelapse=False): """ Keep a queue filled to 8x Batch Size """ logger.debug( "Queue batches: (image_count: %s, batchsize: %s, side: '%s', do_shuffle: %s, " "is_preview, %s, is_timelapse: %s)", len(images), batchsize, side, do_shuffle, is_preview, is_timelapse) self.batchsize = batchsize is_display = is_preview or is_timelapse queue_in, queue_out = self.make_queues(side, is_preview, is_timelapse) training_size = self.training_opts.get("training_size", 256) batch_shape = list(( (batchsize, training_size, training_size, 3), # sample images (batchsize, self.model_input_size, self.model_input_size, 3), (batchsize, self.model_output_size, self.model_output_size, 3))) if self.mask_class: batch_shape.append((self.batchsize, self.model_output_size, self.model_output_size, 1)) load_process = FixedProducerDispatcher(method=self.load_batches, shapes=batch_shape, in_queue=queue_in, out_queue=queue_out, args=(images, side, is_display, do_shuffle, batchsize)) load_process.start() logger.debug("Batching to queue: (side: '%s', is_display: %s)", side, is_display) return self.minibatch(side, is_display, load_process)
def minibatch_ab(self, images, batchsize, side, do_shuffle=True, is_timelapse=False): """ Keep a queue filled to 8x Batch Size """ logger.debug("Queue batches: (image_count: %s, batchsize: %s, side: '%s', do_shuffle: %s, " "is_timelapse: %s)", len(images), batchsize, side, do_shuffle, is_timelapse) self.batchsize = batchsize queue_in, queue_out = self.make_queues(side, is_timelapse) training_size = self.training_opts.get("training_size", 256) batch_shape = list(( (batchsize, training_size, training_size, 3), # sample images (batchsize, self.model_input_size, self.model_input_size, 3), (batchsize, self.model_output_size, self.model_output_size, 3))) if self.mask_class: batch_shape.append((self.batchsize, self.model_output_size, self.model_output_size, 1)) load_process = FixedProducerDispatcher( method=self.load_batches, shapes=batch_shape, in_queue=queue_in, out_queue=queue_out, args=(images, side, is_timelapse, do_shuffle, batchsize)) load_process.start() logger.debug("Batching to queue: (side: '%s', is_timelapse: %s)", side, is_timelapse) return self.minibatch(side, is_timelapse, load_process)
class TrainingDataGenerator(): """ Generate training data for models """ def __init__(self, model_input_size, model_output_shapes, training_opts, config): logger.debug( "Initializing %s: (model_input_size: %s, model_output_shapes: %s, " "training_opts: %s, landmarks: %s, config: %s)", self.__class__.__name__, model_input_size, model_output_shapes, { key: val for key, val in training_opts.items() if key != "landmarks" }, bool(training_opts.get("landmarks", None)), config) self.batchsize = 0 self.model_input_size = model_input_size self.model_output_shapes = model_output_shapes self.training_opts = training_opts self.mask_class = self.set_mask_class() self.landmarks = self.training_opts.get("landmarks", None) self.fixed_producer_dispatcher = None # Set by FPD when loading self._nearest_landmarks = None self.processing = ImageManipulation( model_input_size, model_output_shapes, training_opts.get("coverage_ratio", 0.625), config) logger.debug("Initialized %s", self.__class__.__name__) def set_mask_class(self): """ Set the mask function to use if using mask """ mask_type = self.training_opts.get("mask_type", None) if mask_type: logger.debug("Mask type: '%s'", mask_type) mask_class = getattr(masks, mask_type) else: mask_class = None logger.debug("Mask class: %s", mask_class) return mask_class def minibatch_ab(self, images, batchsize, side, do_shuffle=True, is_preview=False, is_timelapse=False): """ Keep a queue filled to 8x Batch Size """ logger.debug( "Queue batches: (image_count: %s, batchsize: %s, side: '%s', do_shuffle: %s, " "is_preview, %s, is_timelapse: %s)", len(images), batchsize, side, do_shuffle, is_preview, is_timelapse) self.batchsize = batchsize is_display = is_preview or is_timelapse queue_in, queue_out = self.make_queues(side, is_preview, is_timelapse) training_size = self.training_opts.get("training_size", 256) batch_shape = list(( (batchsize, training_size, training_size, 3), # sample images (batchsize, self.model_input_size, self.model_input_size, 3))) # Add the output shapes batch_shape.extend( tuple([(batchsize, ) + shape for shape in self.model_output_shapes])) logger.debug("Batch shapes: %s", batch_shape) self.fixed_producer_dispatcher = FixedProducerDispatcher( method=self.load_batches, shapes=batch_shape, in_queue=queue_in, out_queue=queue_out, args=(images, side, is_display, do_shuffle, batchsize)) self.fixed_producer_dispatcher.start() logger.debug("Batching to queue: (side: '%s', is_display: %s)", side, is_display) return self.minibatch(side, is_display, self.fixed_producer_dispatcher) def join_subprocess(self): """ Join the FixedProduceerDispatcher subprocess from outside this module """ logger.debug("Joining FixedProducerDispatcher") if self.fixed_producer_dispatcher is None: logger.debug( "FixedProducerDispatcher not yet initialized. Exiting") return self.fixed_producer_dispatcher.join() logger.debug("Joined FixedProducerDispatcher") @staticmethod def make_queues(side, is_preview, is_timelapse): """ Create the buffer token queues for Fixed Producer Dispatcher """ q_name = "_{}".format(side) if is_preview: q_name = "{}{}".format("preview", q_name) elif is_timelapse: q_name = "{}{}".format("timelapse", q_name) else: q_name = "{}{}".format("train", q_name) q_names = [ "{}_{}".format(q_name, direction) for direction in ("in", "out") ] logger.debug(q_names) queues = [queue_manager.get_queue(queue) for queue in q_names] return queues def load_batches(self, mem_gen, images, side, is_display, do_shuffle=True, batchsize=0): """ Load the warped images and target images to queue """ logger.debug( "Loading batch: (image_count: %s, side: '%s', is_display: %s, " "do_shuffle: %s)", len(images), side, is_display, do_shuffle) self.validate_samples(images) # Intialize this for each subprocess self._nearest_landmarks = dict() def _img_iter(imgs): while True: if do_shuffle: shuffle(imgs) for img in imgs: yield img img_iter = _img_iter(images) epoch = 0 for memory_wrapper in mem_gen: memory = memory_wrapper.get() logger.trace( "Putting to batch queue: (side: '%s', is_display: %s)", side, is_display) for i, img_path in enumerate(img_iter): imgs = self.process_face(img_path, side, is_display) for j, img in enumerate(imgs): memory[j][i][:] = img epoch += 1 if i == batchsize - 1: break memory_wrapper.ready() logger.debug( "Finished batching: (epoch: %s, side: '%s', is_display: %s)", epoch, side, is_display) def validate_samples(self, data): """ Check the total number of images against batchsize and return the total number of images """ length = len(data) msg = ("Number of images is lower than batch-size (Note that too few " "images may lead to bad training). # images: {}, " "batch-size: {}".format(length, self.batchsize)) try: assert length >= self.batchsize, msg except AssertionError as err: msg += ( "\nYou should increase the number of images in your training set or lower " "your batch-size.") raise FaceswapError(msg) from err @staticmethod def minibatch(side, is_display, load_process): """ A generator function that yields epoch, batchsize of warped_img and batchsize of target_img from the load queue """ logger.debug( "Launching minibatch generator for queue (side: '%s', is_display: %s)", side, is_display) for batch_wrapper in load_process: with batch_wrapper as batch: logger.trace( "Yielding batch: (size: %s, item shapes: %s, side: '%s', " "is_display: %s)", len(batch), [item.shape for item in batch], side, is_display) yield batch load_process.stop() logger.debug( "Finished minibatch generator for queue: (side: '%s', is_display: %s)", side, is_display) load_process.join() def process_face(self, filename, side, is_display): """ Load an image and perform transformation and warping """ logger.trace( "Process face: (filename: '%s', side: '%s', is_display: %s)", filename, side, is_display) image = cv2_read_img(filename, raise_error=True) if self.mask_class or self.training_opts["warp_to_landmarks"]: src_pts = self.get_landmarks(filename, image, side) if self.mask_class: image = self.mask_class(src_pts, image, channels=4).mask image = self.processing.color_adjust( image, self.training_opts["augment_color"], is_display) if not is_display: image = self.processing.random_transform(image) if not self.training_opts["no_flip"]: image = self.processing.do_random_flip(image) sample = image.copy()[:, :, :3] if self.training_opts["warp_to_landmarks"]: dst_pts = self.get_closest_match(filename, side, src_pts) processed = self.processing.random_warp_landmarks( image, src_pts, dst_pts) else: processed = self.processing.random_warp(image) processed.insert(0, sample) logger.trace( "Processed face: (filename: '%s', side: '%s', shapes: %s)", filename, side, [img.shape for img in processed]) return processed def get_landmarks(self, filename, image, side): """ Return the landmarks for this face """ logger.trace("Retrieving landmarks: (filename: '%s', side: '%s'", filename, side) lm_key = sha1(image).hexdigest() try: src_points = self.landmarks[side][lm_key] except KeyError as err: msg = ( "At least one of your images does not have a matching entry in your alignments " "file." "\nIf you are training with a mask or using 'warp to landmarks' then every " "face you intend to train on must exist within the alignments file." "\nThe specific file that caused the failure was '{}' which has a hash of {}." "\nMost likely there will be more than just this file missing from the " "alignments file. You can use the Alignments Tool to help identify missing " "alignments".format(lm_key, filename)) raise FaceswapError(msg) from err logger.trace("Returning: (src_points: %s)", src_points) return src_points def get_closest_match(self, filename, side, src_points): """ Return closest matched landmarks from opposite set """ logger.trace( "Retrieving closest matched landmarks: (filename: '%s', src_points: '%s'", filename, src_points) landmarks = self.landmarks["a"] if side == "b" else self.landmarks["b"] closest_hashes = self._nearest_landmarks.get(filename) if not closest_hashes: dst_points_items = list(landmarks.items()) dst_points = list(x[1] for x in dst_points_items) closest = (np.mean(np.square(src_points - dst_points), axis=(1, 2))).argsort()[:10] closest_hashes = tuple(dst_points_items[i][0] for i in closest) self._nearest_landmarks[filename] = closest_hashes dst_points = landmarks[choice(closest_hashes)] logger.trace("Returning: (dst_points: %s)", dst_points) return dst_points