示例#1
0
文件: cli.py 项目: nhu2000/faceswap
    def process_arguments(self, arguments):
        self.arguments = arguments
        print("Input Directory: {}".format(self.arguments.input_dir))
        print("Output Directory: {}".format(self.arguments.output_dir))
        self.serializer = None
        if self.arguments.serializer is None and self.arguments.alignments_path is not None:
            ext = os.path.splitext(self.arguments.alignments_path)[-1]
            self.serializer = Serializer.get_serializer_fromext(ext)
            print(self.serializer, self.arguments.alignments_path)
        else:
            self.serializer = Serializer.get_serializer(self.arguments.serializer or "json")
        print("Using {} serializer".format(self.serializer.ext))

        print('Starting, this may take a while...')

        self.output_dir = get_folder(self.arguments.output_dir)
        try:
            self.input_dir = get_image_paths(self.arguments.input_dir)
        except:
            print('Input directory not found. Please ensure it exists.')
            exit(1)

        self.filter = self.load_filter()
        self.process()
        self.finalize()
示例#2
0
    def load(self, swapped):
        serializer = Serializer.get_serializer('json')
        state_fn = ".".join([hdf['state'], serializer.ext]) 
        try:
            with open(str(self.model_dir / state_fn), 'rb') as fp:
                state = serializer.unmarshal(fp.read().decode('utf-8'))
                self._epoch_no = state['epoch_no']
        except IOError as e:
            print('Error loading training info:', e.strerror)
            self._epoch_no = 0
        except JSONDecodeError as e:
            print('Error loading training info:', e.msg)
            self._epoch_no = 0   
        
        (face_A,face_B) = (hdf['decoder_AH5'], hdf['decoder_BH5']) if not swapped else (hdf['decoder_BH5'], hdf['decoder_AH5'])                

        try:
            self.encoder.load_weights(str(self.model_dir / hdf['encoderH5']))
            self.decoder_A.load_weights(str(self.model_dir / face_A))
            self.decoder_B.load_weights(str(self.model_dir / face_B))
            print('loaded model weights')
            return True
        except Exception as e:
            print('Failed loading existing training data.')
            print(e)
            return False                            
示例#3
0
    def load(self, swapped):
        serializer = Serializer.get_serializer('json')
        state_fn = ".".join([hdf['state'], serializer.ext])
        try:
            with open(str(self.model_dir / state_fn), 'rb') as fp:
                state = serializer.unmarshal(fp.read().decode('utf-8'))
                self._epoch_no = state['epoch_no']
        except IOError as e:
            print('Error loading training info:', e.strerror)
            self._epoch_no = 0
        except JSONDecodeError as e:
            print('Error loading training info:', e.msg)
            self._epoch_no = 0

        (face_A, face_B) = (hdf['decoder_AH5'],
                            hdf['decoder_BH5']) if not swapped else (
                                hdf['decoder_BH5'], hdf['decoder_AH5'])

        try:
            self.encoder.load_weights(str(self.model_dir / hdf['encoderH5']))
            self.decoder_A.load_weights(str(self.model_dir / face_A))
            self.decoder_B.load_weights(str(self.model_dir / face_B))
            print('loaded model weights')
            return True
        except Exception as e:
            print('Failed loading existing training data.')
            print(e)
            return False
示例#4
0
文件: cli.py 项目: cjayyy/faceit
    def process_arguments(self, arguments):
        self.arguments = arguments
        print("Input Directory: {}".format(self.arguments.input_dir))
        print("Output Directory: {}".format(self.arguments.output_dir))
        self.serializer = None
        if self.arguments.serializer is None and self.arguments.alignments_path is not None:
            ext = os.path.splitext(self.arguments.alignments_path)[-1]
            self.serializer = Serializer.get_serializer_fromext(ext)
            print(self.serializer, self.arguments.alignments_path)
        else:
            self.serializer = Serializer.get_serializer(self.arguments.serializer or "json")
        print("Using {} serializer".format(self.serializer.ext))

        print('Starting, this may take a while...')

        self.output_dir = get_folder(self.arguments.output_dir)
        try:
            self.input_dir = get_image_paths(self.arguments.input_dir)
        except:
            print('Input directory not found. Please ensure it exists.')
            exit(1)

        self.filter = self.load_filter()
        self.process()
        self.finalize()
示例#5
0
 def get_serializer(self):
     """ Set the serializer to be used for loading and saving alignments """
     if not self.args.serializer and self.args.alignments_path:
         ext = os.path.splitext(self.args.alignments_path)[-1]
         serializer = Serializer.get_serializer_fromext(ext)
         print("Alignments Output: {}".format(self.args.alignments_path))
     else:
         serializer = Serializer.get_serializer(self.args.serializer or "json")
     print("Using {} serializer".format(serializer.ext))
     return serializer
示例#6
0
文件: fsmedia.py 项目: Nioy/faceswap
 def get_serializer(self):
     """ Set the serializer to be used for loading and saving alignments """
     if not self.args.serializer and self.args.alignments_path:
         ext = os.path.splitext(self.args.alignments_path)[-1]
         serializer = Serializer.get_serializer_from_ext(ext)
         print("Alignments Output: {}".format(self.args.alignments_path))
     else:
         serializer = Serializer.get_serializer(self.args.serializer)
     print("Using {} serializer".format(serializer.ext))
     return serializer
示例#7
0
 def get_session_names(self):
     """ Get the existing session names from state file """
     serializer = Serializer.get_serializer("json")
     state_file = os.path.join(self.model_dir,
                               "{}_state.{}".format(self.model_name, serializer.ext))
     with open(state_file, "rb") as inp:
         state = serializer.unmarshal(inp.read().decode("utf-8"))
         session_names = ["session_{}".format(key)
                          for key in state["sessions"].keys()]
     logger.debug("Session to restore: %s", session_names)
     return session_names
示例#8
0
 def get_serializer(self):
     """ Set the serializer to be used for loading and
         saving alignments """
     if (not hasattr(self.args, "serializer") or not self.args.serializer):
         if self.args.alignments_path:
             ext = os.path.splitext(self.args.alignments_path)[-1]
         else:
             ext = "json"
         serializer = Serializer.get_serializer_from_ext(ext)
     else:
         serializer = Serializer.get_serializer(self.args.serializer)
     print("Using {} serializer".format(serializer.ext))
     return serializer
示例#9
0
 def __init__(self, model_dir, model_name, training_image_size):
     self.serializer = Serializer.get_serializer("json")
     filename = "{}_state.{}".format(model_name, self.serializer.ext)
     self.filename = str(model_dir / filename)
     self.name = model_name
     self.iterations = 0
     self.session_iterations = 0
     self.training_size = training_image_size
     self.sessions = dict()
     self.lowest_avg_loss = dict()
     self.inputs = dict()
     self.config = dict()
     self.load()
     self.session_id = self.new_session_id()
     self.create_new_session()
示例#10
0
    def __init__(self, in_queue, queue_size, arguments):
        logger.debug("Initializing %s: (args: %s, queue_size: %s, in_queue: %s)",
                     self.__class__.__name__, arguments, queue_size, in_queue)
        self.batchsize = min(queue_size, 16)
        self.args = arguments
        self.in_queue = in_queue
        self.out_queue = queue_manager.get_queue("patch")
        self.serializer = Serializer.get_serializer("json")
        self.faces_count = 0
        self.verify_output = False
        self.model = self.load_model()
        self.predictor = self.model.converter(self.args.swap_model)
        self.queues = dict()

        self.thread = MultiThread(self.predict_faces, thread_count=1)
        self.thread.start()
        logger.debug("Initialized %s: (out_queue: %s)", self.__class__.__name__, self.out_queue)
示例#11
0
    def __init__(self, in_queue, queue_size, arguments):
        logger.debug("Initializing %s: (args: %s, queue_size: %s, in_queue: %s)",
                     self.__class__.__name__, arguments, queue_size, in_queue)
        self.batchsize = self.get_batchsize(queue_size)
        self.args = arguments
        self.in_queue = in_queue
        self.out_queue = queue_manager.get_queue("patch")
        self.serializer = Serializer.get_serializer("json")
        self.faces_count = 0
        self.verify_output = False
        self.model = self.load_model()
        self.predictor = self.model.converter(self.args.swap_model)
        self.queues = dict()

        self.thread = MultiThread(self.predict_faces, thread_count=1)
        self.thread.start()
        logger.debug("Initialized %s: (out_queue: %s)", self.__class__.__name__, self.out_queue)
示例#12
0
    def get_serializer(self, filename, serializer):
        """ Set the serializer to be used for loading and
            saving alignments

            If a filename with a valid extension is passed in
            this will be used as the serializer, otherwise the
            specified serializer will be used """
        extension = os.path.splitext(filename)[1]
        if extension in (".json", ".p", ".yaml", ".yml"):
            retval = Serializer.get_serializer_from_ext(extension)
        elif serializer not in ("json", "pickle", "yaml"):
            raise ValueError("Error: {} is not a valid serializer. Use "
                             "'json', 'pickle' or 'yaml'")
        else:
            retval = Serializer.get_serializer(serializer)
        if self.verbose:
            print("Using {} serializer for alignments".format(retval.ext))
        return retval
示例#13
0
 def __init__(self, model_dir, model_name, no_logs, training_image_size):
     logger.debug("Initializing %s: (model_dir: '%s', model_name: '%s', no_logs: %s, "
                  "training_image_size: '%s'", self.__class__.__name__, model_dir,
                  model_name, no_logs, training_image_size)
     self.serializer = Serializer.get_serializer("json")
     filename = "{}_state.{}".format(model_name, self.serializer.ext)
     self.filename = str(model_dir / filename)
     self.name = model_name
     self.iterations = 0
     self.session_iterations = 0
     self.training_size = training_image_size
     self.sessions = dict()
     self.lowest_avg_loss = dict()
     self.inputs = dict()
     self.config = dict()
     self.load()
     self.session_id = self.new_session_id()
     self.create_new_session(no_logs)
     logger.debug("Initialized %s:", self.__class__.__name__)
示例#14
0
    def save_weights(self):
        model_dir = str(self.model_dir)
        for model in hdf.values():
            backup_file(model_dir, model)
        self.encoder.save_weights(str(self.model_dir / hdf['encoderH5']))
        self.decoder_A.save_weights(str(self.model_dir / hdf['decoder_AH5']))
        self.decoder_B.save_weights(str(self.model_dir / hdf['decoder_BH5']))

        print('saved model weights')

        serializer = Serializer.get_serializer('json')
        state_fn = ".".join([hdf['state'], serializer.ext])
        state_dir = str(self.model_dir / state_fn)
        try:
            with open(state_dir, 'wb') as fp:
                state_json = serializer.marshal({'epoch_no': self.epoch_no})
                fp.write(state_json.encode('utf-8'))
        except IOError as e:
            print(e.strerror)
示例#15
0
 def __init__(self, model_dir, model_name, config_changeable_items,
              no_logs, pingpong, training_image_size):
     logger.debug("Initializing %s: (model_dir: '%s', model_name: '%s', "
                  "config_changeable_items: '%s', no_logs: %s, pingpong: %s, "
                  "training_image_size: '%s'", self.__class__.__name__, model_dir, model_name,
                  config_changeable_items, no_logs, pingpong, training_image_size)
     self.serializer = Serializer.get_serializer("json")
     filename = "{}_state.{}".format(model_name, self.serializer.ext)
     self.filename = str(model_dir / filename)
     self.name = model_name
     self.iterations = 0
     self.session_iterations = 0
     self.training_size = training_image_size
     self.sessions = dict()
     self.lowest_avg_loss = dict()
     self.inputs = dict()
     self.config = dict()
     self.load(config_changeable_items)
     self.session_id = self.new_session_id()
     self.create_new_session(no_logs, pingpong)
     logger.debug("Initialized %s:", self.__class__.__name__)
示例#16
0
    def get_serializer(filename, serializer):
        """ Set the serializer to be used for loading and
            saving alignments

            If a filename with a valid extension is passed in
            this will be used as the serializer, otherwise the
            specified serializer will be used """
        logger.debug("Getting serializer: (filename: '%s', serializer: '%s')",
                     filename, serializer)
        extension = os.path.splitext(filename)[1]
        if extension in (".json", ".p", ".yaml", ".yml"):
            logger.debug("Serializer set from file extension: '%s'", extension)
            retval = Serializer.get_serializer_from_ext(extension)
        elif serializer not in ("json", "pickle", "yaml"):
            raise ValueError("Error: {} is not a valid serializer. Use "
                             "'json', 'pickle' or 'yaml'")
        else:
            logger.debug("Serializer set from argument: '%s'", serializer)
            retval = Serializer.get_serializer(serializer)
        logger.verbose("Using '%s' serializer for alignments", retval.ext)
        return retval
示例#17
0
 def save_weights(self):
     model_dir = str(self.model_dir)
     for model in hdf.values():
         backup_file(model_dir, model)
     self.encoder.save_weights(str(self.model_dir / hdf['encoderH5']))
     self.decoder_A.save_weights(str(self.model_dir / hdf['decoder_AH5']))
     self.decoder_B.save_weights(str(self.model_dir / hdf['decoder_BH5']))
     
     print('saved model weights')
     
     serializer = Serializer.get_serializer('json')
     state_fn = ".".join([hdf['state'], serializer.ext])
     state_dir = str(self.model_dir / state_fn)                        
     try:
         with open(state_dir, 'wb') as fp:
             state_json = serializer.marshal({
                 'epoch_no' : self.epoch_no
                  })
             fp.write(state_json.encode('utf-8'))
     except IOError as e:
         print(e.strerror)                   
示例#18
0
    def get_serializer(filename, serializer):
        """ Set the serializer to be used for loading and
            saving alignments

            If a filename with a valid extension is passed in
            this will be used as the serializer, otherwise the
            specified serializer will be used """
        logger.debug("Getting serializer: (filename: '%s', serializer: '%s')",
                     filename, serializer)
        extension = os.path.splitext(filename)[1]
        if extension in (".json", ".p", ".yaml", ".yml"):
            logger.debug("Serializer set from file extension: '%s'", extension)
            retval = Serializer.get_serializer_from_ext(extension)
        elif serializer not in ("json", "pickle", "yaml"):
            raise ValueError("Error: {} is not a valid serializer. Use "
                             "'json', 'pickle' or 'yaml'")
        else:
            logger.debug("Serializer set from argument: '%s'", serializer)
            retval = Serializer.get_serializer(serializer)
        logger.verbose("Using '%s' serializer for alignments", retval.ext)
        return retval
示例#19
0
    def process_arguments(self, arguments):
        self.arguments = arguments
        print("Input Directory: {}".format(self.arguments.input_dir))
        print("Output Directory: {}".format(self.arguments.output_dir))
        print("Filter: {}".format(self.arguments.filter))
        self.serializer = None
        if self.arguments.serializer is None and self.arguments.alignments_path is not None:
            ext = os.path.splitext(self.arguments.alignments_path)[-1]
            self.serializer = Serializer.get_serializer_fromext(ext)
            print(self.serializer, self.arguments.alignments_path)
        else:
            self.serializer = Serializer.get_serializer(
                self.arguments.serializer or "json")
        print("Using {} serializer".format(self.serializer.ext))

        try:
            if self.arguments.rotate_images is not None and self.arguments.rotate_images != "off":
                if self.arguments.rotate_images == "on":
                    self.rotation_angles = range(90, 360, 90)
                else:
                    rotation_angles = [
                        int(angle)
                        for angle in self.arguments.rotate_images.split(",")
                    ]
                    if len(rotation_angles) == 1:
                        rotation_step_size = rotation_angles[0]
                        self.rotation_angles = range(rotation_step_size, 360,
                                                     rotation_step_size)
                    elif len(rotation_angles) > 1:
                        self.rotation_angles = rotation_angles
        except AttributeError:
            pass

        print('Starting, this may take a while...')

        try:
            if self.arguments.skip_existing:
                self.already_processed = get_image_paths(
                    self.arguments.output_dir)
        except AttributeError:
            pass

        self.output_dir = get_folder(self.arguments.output_dir)

        try:
            try:
                if self.arguments.skip_existing:
                    self.input_dir = get_image_paths(self.arguments.input_dir,
                                                     self.already_processed)
                    print('Excluding %s files' % len(self.already_processed))
                else:
                    self.input_dir = get_image_paths(self.arguments.input_dir)
            except AttributeError:
                self.input_dir = get_image_paths(self.arguments.input_dir)
        except:
            print('Input directory not found. Please ensure it exists.')
            exit(1)

        self.filter = self.load_filter()
        self.process()
        self.finalize()
示例#20
0
 def set_destination_serializer(self):
     """ set the destination serializer """
     self.serializer = Serializer.get_serializer(self.destination_format)