Esempio n. 1
0
    def do(self, callback_name, *args):
        """Pickle the main loop object to the disk.

        If `*args` contain an argument from user, it is treated as
        saving path to be used instead of the one given at the
        construction stage.

        """
        from_main_loop, from_user = self.parse_args(callback_name, args)
        try:
            path = self.path
            if len(from_user):
                path, = from_user
#            already_saved_to = self.main_loop.log.current_row.get(SAVED_TO, ())
#            self.main_loop.log.current_row[SAVED_TO] = (
#                already_saved_to + (path,))
#            secure_pickle_dump(self.main_loop, path)
            filenames = self.save_separately_filenames(path)
            for attribute in self.save_separately:
                p = getattr(self.main_loop, attribute)
                if p:
                    secure_pickle_dump(p, filenames[attribute])
                else:
                    print("Empty %s",attribute)
            generate_samples(self.main_loop.model, self.save_subdir, self.image_size)
            if os.path.exists(self.epoch_src):
                epoch_dst = "{0}/epoch-{1:03d}.png".format(self.save_subdir, self.iteration)
                self.iteration = self.iteration + 1
                shutil.copy2(self.epoch_src, epoch_dst)
                os.system("convert -delay 5 -loop 1 {0}/epoch-*.png {0}/training.gif".format(self.save_subdir))

        except Exception:
            self.main_loop.log.current_row[SAVED_TO] = None
            raise
Esempio n. 2
0
    def do(self, callback_name, *args):
        """Pickle the main loop object to the disk.

        If `*args` contain an argument from user, it is treated as
        saving path to be used instead of the one given at the
        construction stage.

        """
        from_main_loop, from_user = self.parse_args(callback_name, args)
        try:
            path = self.path
            if len(from_user):
                path, = from_user
            already_saved_to = self.main_loop.log.current_row[SAVED_TO]
            if not already_saved_to:
                already_saved_to = ()
            self.main_loop.log.current_row[SAVED_TO] = (already_saved_to +
                                                        (path, ))
            secure_pickle_dump(self.main_loop, path)
            for attribute in self.save_separately:
                root, ext = os.path.splitext(path)
                path = root + "_" + attribute + ext
                secure_pickle_dump(getattr(self.main_loop, attribute), path)
        except Exception:
            self.main_loop.log.current_row[SAVED_TO] = None
            raise
Esempio n. 3
0
    def do(self, callback_name, *args):
        """Pickle the main loop object to the disk.

        If `*args` contain an argument from user, it is treated as
        saving path to be used instead of the one given at the
        construction stage.

        """
        from_main_loop, from_user = self.parse_args(callback_name, args)
        try:
            path = self.path
            if len(from_user):
                path, = from_user
#            already_saved_to = self.main_loop.log.current_row.get(SAVED_TO, ())
#            self.main_loop.log.current_row[SAVED_TO] = (
#                already_saved_to + (path,))
#            secure_pickle_dump(self.main_loop, path)
            filenames = self.save_separately_filenames(path)
            for attribute in self.save_separately:
                p = getattr(self.main_loop, attribute)
                if p:
                    secure_pickle_dump(p, filenames[attribute])
                else:
                    print("Empty %s", attribute)
            generate_samples(self.main_loop, self.image_size)
        except Exception:
            self.main_loop.log.current_row[SAVED_TO] = None
            raise
Esempio n. 4
0
 def do(self, callback_name, *args):
     """Pickle the main loop object to the disk."""
     try:
         self.main_loop.log.current_row[SAVED_TO] = self.path
         secure_pickle_dump(self.main_loop, self.path)
         for attribute in self.save_separately:
             root, ext = os.path.splitext(self.path)
             path = root + "_" + attribute + ext
             secure_pickle_dump(getattr(self.main_loop, attribute), path)
     except Exception:
         self.main_loop.log.current_row[SAVED_TO] = None
         raise
Esempio n. 5
0
    def do(self, callback_name, *args):
        """Pickle the main loop object to the disk.

        If `*args` contain an argument from user, it is treated as
        saving path to be used instead of the one given at the
        construction stage.

        """
        from_main_loop, from_user = self.parse_args(callback_name, args)
        try:
            path = self.path
            if from_user:
                path, = from_user
            already_saved_to = self.main_loop.log.current_row.get(SAVED_TO, ())
            self.main_loop.log.current_row[SAVED_TO] = (
                already_saved_to + (path,))
            secure_pickle_dump(self.main_loop, path)
            filenames = self.save_separately_filenames(path)
            for attribute in self.save_separately:
                secure_pickle_dump(getattr(self.main_loop, attribute),
                                   filenames[attribute])
        except Exception:
            self.main_loop.log.current_row[SAVED_TO] = None
            raise
Esempio n. 6
0
    config = getattr(cfg, args.proto)()

    logger.info("Model options:\n{}".format(pprint.pformat(config)))
    tr_stream = get_tr_stream(config)

    logger.info("Will iterate up to iteration: [{}]".format(args.iters))

    extensions = [FinishAfter(after_n_batches=args.iters)]

    # Initialize main loop
    main_loop = MainLoopWithMultiCGnoBlocks(
        models=[None for _ in config['cgs']],
        algorithm=DummyAlgorithm(),
        data_stream=tr_stream,
        extensions=extensions,
        num_encs=config['num_encs'],
        num_decs=config['num_decs'])

    # Run dummy main-loop
    logger.info(" ...running dummy main-loop")
    main_loop.run()

    logger.info(" ...saving iteration state")
    path_to_iteration_state = os.path.join(config['saveto'],
                                           'iterations_state.pkl')
    if os.path.exists(path_to_iteration_state):
        logger.warn('Iteration state already exists! appending .new')
        path_to_iteration_state += '.new'
    secure_pickle_dump(main_loop.iteration_state,
                       path_to_iteration_state)