コード例 #1
0
def generate_sequences(fuzzing_requests, checkers, fuzzing_jobs=1):
    """ Implements core restler algorithm.

    @param fuzzing_requests: The collection of requests that will be fuzzed
    @type  fuzzing_requests: FuzzingRequestCollection
    @param checkers: The list of checkers to apply
    @type  checkers: list[Checker]
    @param fuzzing_jobs: Optional number of fuzzing jobs for parallel fuzzing.
                            Default value passed is one (sequential fuzzing).
    @type  fuzzing_jobs: Int

    @return: None
    @rtype : None

    """
    if not fuzzing_requests.size:
        return

    logger.create_network_log(logger.LOG_TYPE_TESTING)

    fuzzing_mode = Settings().fuzzing_mode
    max_len = Settings().max_sequence_length
    if fuzzing_mode == 'directed-smoke-test':
        return generate_sequences_directed_smoketest(fuzzing_requests,
                                                     checkers)

    if fuzzing_jobs > 1:
        render = render_parallel
        global_lock = multiprocessing.Lock()
        fuzzing_pool = ThreadPool(fuzzing_jobs)
    else:
        global_lock = None
        fuzzing_pool = None
        render = render_sequential

    should_stop = False
    timeout_reached = False
    seq_collection_exhausted = False
    num_total_sequences = 0
    while not should_stop:

        seq_collection = [sequences.Sequence()]
        # Only for bfs: If any checkpoint file is available, load state of
        # latest generation. Note that it only makes sense to use checkpoints
        # for the bfs exploration method, since it is the only systemic and
        # exhaustive method.
        min_len = 0
        if fuzzing_mode == 'bfs':
            req_collection = GrammarRequestCollection()
            monitor = Monitor()
            req_collection, seq_collection, fuzzing_requests, monitor, min_len =\
                saver.load(req_collection, seq_collection, fuzzing_requests, monitor)
            requests.GlobalRequestCollection.Instance(
            )._req_collection = req_collection
            fuzzing_monitor.FuzzingMonitor.__instance = monitor
        # Repeat external loop only for random walk
        if fuzzing_mode != 'random-walk':
            should_stop = True

        # Initialize fuzzing schedule
        fuzzing_schedule = {}
        logger.write_to_main(f"Setting fuzzing schemes: {fuzzing_mode}")
        for length in range(min_len, max_len):
            fuzzing_schedule[length] = fuzzing_mode
            # print(" - {}: {}".format(length + 1, fuzzing_schedule[length]))

        # print general request-related stats
        logger.print_req_collection_stats(
            fuzzing_requests,
            GrammarRequestCollection().candidate_values_pool)

        generation = 0
        for length in range(min_len, max_len):
            # we can set this without locking, since noone else writes (main
            # driver is single-threaded) and every potential worker will just
            # read-access this value.
            generation = length + 1
            fuzzing_mode = fuzzing_schedule[length]

            # extend sequences with new request templates
            seq_collection = extend(seq_collection, fuzzing_requests,
                                    global_lock)
            print(f"{formatting.timestamp()}: Generation: {generation} ")

            logger.write_to_main(
                f"{formatting.timestamp()}: Generation: {generation} / "
                f"Sequences Collection Size: {len(seq_collection)} "
                f"(After {fuzzing_schedule[length]} Extend)")

            # render templates
            try:
                seq_collection_exhausted = False
                seq_collection = render(seq_collection, fuzzing_pool, checkers,
                                        generation, global_lock)

            except TimeOutException:
                logger.write_to_main("Timed out...")
                timeout_reached = True
                seq_collection_exhausted = True
                # Increase fuzzing generation after timeout because the code
                # that does it would have never been reached. This is done so
                # the previous generation's test summary is logged correctly.
                Monitor().current_fuzzing_generation += 1

            except ExhaustSeqCollectionException:
                logger.write_to_main("Exhausted collection...")
                seq_collection = []
                seq_collection_exhausted = True

            logger.write_to_main(
                f"{formatting.timestamp()}: Generation: {generation} / "
                f"Sequences Collection Size: {len(seq_collection)} "
                f"(After {fuzzing_schedule[length]} Render)")

            # saving latest state
            saver.save(GrammarRequestCollection(), seq_collection,
                       fuzzing_requests, Monitor(), generation)

            # Print stats for iteration of the current generation
            logger.print_generation_stats(GrammarRequestCollection(),
                                          Monitor(), global_lock)

            num_total_sequences += len(seq_collection)

            logger.print_request_rendering_stats(
                GrammarRequestCollection().candidate_values_pool,
                fuzzing_requests, Monitor(),
                Monitor().num_fully_rendered_requests(
                    fuzzing_requests.all_requests), generation, global_lock)

            if timeout_reached or seq_collection_exhausted:
                if timeout_reached:
                    should_stop = True
                break
        logger.write_to_main("--\n")

    if fuzzing_pool is not None:
        fuzzing_pool.close()
        fuzzing_pool.join()

    return num_total_sequences
コード例 #2
0
    def train(self):
        self.define_losses()
        self.define_summaries()

        sample_z = np.random.normal(0, 1,
                                    (self.model.sample_num, self.model.z_dim))
        _, sample_embed, _, captions = self.dataset.test.next_batch_test(
            self.model.sample_num, 0, 1)
        sample_embed = np.squeeze(sample_embed, axis=0)
        print(sample_embed.shape)

        save_captions(self.cfg.SAMPLE_DIR, captions)

        counter = 1
        start_time = time.time()

        could_load, checkpoint_counter = load(self.saver, self.sess,
                                              self.cfg.CHECKPOINT_DIR)
        if could_load:
            counter = checkpoint_counter
            print(" [*] Load SUCCESS")
        else:
            print(" [!] Load failed...")

        initialize_uninitialized(self.sess)

        # Updates per epoch are given by the training data size / batch size
        updates_per_epoch = self.dataset.train.num_examples // self.model.batch_size
        epoch_start = counter // updates_per_epoch

        for epoch in range(epoch_start, self.cfg.TRAIN.EPOCH):
            cen_epoch = epoch // 100

            for idx in range(0, updates_per_epoch):
                images, wrong_images, embed, _, _ = self.dataset.train.next_batch(
                    self.model.batch_size, 4, embeddings=True, wrong_img=True)
                batch_z = np.random.normal(
                    0, 1, (self.model.batch_size, self.model.z_dim))

                feed_dict = {
                    self.learning_rate: self.lr * (0.5**cen_epoch),
                    self.model.inputs: images,
                    self.model.wrong_inputs: wrong_images,
                    self.model.embed_inputs: embed,
                    self.model.z: batch_z,
                }

                # Update D network
                _, err_d, summary_str = self.sess.run(
                    [self.D_optim, self.D_loss, self.D_merged_summ],
                    feed_dict=feed_dict)
                self.writer.add_summary(summary_str, counter)

                # Update G network
                _, err_g, summary_str = self.sess.run(
                    [self.G_optim, self.G_loss, self.G_merged_summ],
                    feed_dict=feed_dict)
                self.writer.add_summary(summary_str, counter)

                counter += 1
                print(
                    "Epoch: [%2d] [%4d/%4d] time: %4.4f, d_loss: %.8f, g_loss: %.8f"
                    % (epoch, idx, updates_per_epoch, time.time() - start_time,
                       err_d, err_g))

                if np.mod(counter, 500) == 0:
                    try:
                        samples = self.sess.run(self.model.sampler,
                                                feed_dict={
                                                    self.model.z_sample:
                                                    sample_z,
                                                    self.model.embed_sample:
                                                    sample_embed,
                                                })
                        save_images(
                            samples,
                            get_balanced_factorization(samples.shape[0]),
                            '{}train_{:02d}_{:04d}.png'.format(
                                self.cfg.SAMPLE_DIR, epoch, idx))
                    except Exception as e:
                        print("Failed to generate sample image")
                        print(type(e))
                        print(e.args)
                        print(e)

                if np.mod(counter, 500) == 0:
                    save(self.saver, self.sess, self.cfg.CHECKPOINT_DIR,
                         counter)
コード例 #3
0
    def train(self):
        self.define_losses()
        self.define_summaries()

        sample_z = np.random.normal(0, 1,
                                    (self.model.sample_num, self.model.z_dim))
        _, sample_embed, _, captions = self.dataset.test.next_batch_test(
            self.model.sample_num, randint(0, self.dataset.test.num_examples),
            1)
        sample_embed = np.squeeze(sample_embed, axis=0)
        print(sample_embed.shape)

        # Display the captions of the sampled images
        print('\nCaptions of the sampled images:')
        for caption_idx, caption_batch in enumerate(captions):
            print('{}: {}'.format(caption_idx + 1, caption_batch[0]))
        print()

        counter = 1
        start_time = time.time()

        # Try to load the parameters of the stage II networks
        tf.global_variables_initializer().run()
        could_load, checkpoint_counter = load(self.stageii_saver, self.sess,
                                              self.cfg.CHECKPOINT_DIR)
        if could_load:
            counter = checkpoint_counter
            print(" [*] Load SUCCESS: Stage II networks are loaded.")
        else:
            print(" [!] Load failed for stage II networks...")

        could_load, checkpoint_counter = load(self.stagei_g_saver, self.sess,
                                              self.cfg_stage_i.CHECKPOINT_DIR)
        if could_load:
            counter = checkpoint_counter
            print(" [*] Load SUCCESS: Stage I generator is loaded")
        else:
            print(
                " [!] WARNING!!! Failed to load the parameters for stage I generator..."
            )

        for epoch in range(self.cfg.TRAIN.EPOCH):
            # Updates per epoch are given by the training data size / batch size
            updates_per_epoch = self.dataset.train.num_examples // self.model.batch_size

            for idx in range(0, updates_per_epoch):
                images, wrong_images, embed, _, _ = self.dataset.train.next_batch(
                    self.model.batch_size, 4)
                batch_z = np.random.normal(
                    0, 1, (self.model.batch_size, self.model.z_dim))

                # Update D network
                _, err_d_real_match, err_d_real_mismatch, err_d_fake, err_d, summary_str = self.sess.run(
                    [
                        self.D_optim, self.D_real_match_loss,
                        self.D_real_mismatch_loss, self.D_synthetic_loss,
                        self.D_loss, self.D_merged_summ
                    ],
                    feed_dict={
                        self.model.inputs: images,
                        self.model.wrong_inputs: wrong_images,
                        self.model.embed_inputs: embed,
                        self.model.z: batch_z
                    })
                self.writer.add_summary(summary_str, counter)

                # Update G network
                _, err_g, summary_str = self.sess.run(
                    [self.G_optim, self.G_loss, self.G_merged_summ],
                    feed_dict={
                        self.model.z: batch_z,
                        self.model.embed_inputs: embed
                    })
                self.writer.add_summary(summary_str, counter)

                counter += 1
                print(
                    "Epoch: [%2d] [%4d/%4d] time: %4.4f, d_loss: %.8f, g_loss: %.8f"
                    % (epoch, idx, updates_per_epoch, time.time() - start_time,
                       err_d, err_g))

                if np.mod(counter, 100) == 0:
                    try:
                        samples = self.sess.run(self.model.sampler,
                                                feed_dict={
                                                    self.model.z_sample:
                                                    sample_z,
                                                    self.model.embed_sample:
                                                    sample_embed,
                                                })
                        save_images(
                            samples, image_manifold_size(samples.shape[0]),
                            '{}train_{:02d}_{:04d}.png'.format(
                                self.cfg.SAMPLE_DIR, epoch, idx))
                        print("[Sample] d_loss: %.8f, g_loss: %.8f" %
                              (err_d, err_g))

                        # Display the captions of the sampled images
                        print('\nCaptions of the sampled images:')
                        for caption_idx, caption_batch in enumerate(captions):
                            print('{}: {}'.format(caption_idx + 1,
                                                  caption_batch[0]))
                        print()
                    except Exception as e:
                        print("Failed to generate sample image")
                        print(type(e))
                        print(e.args)
                        print(e)

                if np.mod(counter, 500) == 2:
                    save(self.stageii_saver, self.sess,
                         self.cfg.CHECKPOINT_DIR, counter)
コード例 #4
0
    def train(self):
        self.define_model()
        self.define_summaries()

        start_time = time.time()
        self.saver = tf.train.Saver(
            max_to_keep=self.cfg.TRAIN.CHECKPOINTS_TO_KEEP)

        if self.cfg.TRAIN.RESTORE_PRETRAIN:
            pretrain_saver = tf.train.Saver(self.pretrained_to_restore)

            # Load the pre-trained layer
            pretrain_saver.restore(self.sess,
                                   self.cfg.TRAIN.PRETRAINED_CHECKPOINT_DIR)

            # Initialise the not restored layers and the optimizer variables
            self.sess.run(
                tf.variables_initializer(self.not_to_restore + self.opt_vars))
            start_point = 0
        else:
            could_load, checkpoint_counter = load(self.saver, self.sess,
                                                  self.cfg.CHECKPOINT_DIR)
            if could_load:
                start_point = checkpoint_counter
                print(" [*] Load SUCCESS")
            else:
                print(" [!] Load failed...")
                raise RuntimeError(
                    'Failed to restore the complete Inception model')
        sys.stdout.flush()

        batch_size = self.cfg.TRAIN.BATCH_SIZE
        for idx in range(start_point + 1, self.cfg.TRAIN.MAX_STEPS):
            epoch_size = self.dataset.test.num_examples // batch_size
            epoch = idx // epoch_size

            images, _, _, _, labels = self.dataset.test.next_batch(batch_size,
                                                                   labels=True)

            # Bring the labels in a continuous range: [0, num_classes)
            new_labels = []
            for label in labels:
                new_labels.append(self.class_to_idx[label])

            assert (np.min(images) >= -1.)
            assert (np.max(images) <= 1.)
            assert (np.min(new_labels) >= 0)
            assert (np.max(new_labels) < 50)  # 20 for flowers, 50 for birds

            feed_dict = {
                self.x: images,
                self.labels: new_labels,
            }

            _, err = self.sess.run([self.opt_step, self.loss],
                                   feed_dict=feed_dict)

            summary_period = self.cfg.TRAIN.SUMMARY_PERIOD
            if np.mod(idx, summary_period) == 0:
                summary_str, pred = self.sess.run([self.summary_op, self.pred],
                                                  feed_dict=feed_dict)
                self.writer.add_summary(summary_str, idx)

                print("Epoch: [%2d] [%4d] time: %4.4f, loss: %.8f" %
                      (epoch, idx, time.time() - start_time, err))

            if np.mod(idx, 200) == 0:
                save(self.saver, self.sess, self.cfg.CHECKPOINT_DIR, idx)
            sys.stdout.flush()
コード例 #5
0
ファイル: trainer.py プロジェクト: pxu4114/Generative_CVS
    def train(self):
        self.define_summaries()

        self.saver = tf.train.Saver(max_to_keep=self.cfg.TRAIN.CHECKPOINTS_TO_KEEP)

        sample_z = np.random.normal(0, 1, (self.model.sample_num, self.model.z_dim))
        _, sample_cond, _, captions = self.dataset.test.next_batch_test(self.model.sample_num, 0, 1)
        # _, sample_cond, _, captions = self.dataset.test.next_batch_test(self.model.sample_num, 1, 1)
        sample_cond = np.squeeze(sample_cond, axis=0)
        print('Conditionals sampler shape: {}'.format(sample_cond.shape))

        save_captions(self.cfg.SAMPLE_DIR, captions)

        start_time = time.time()
        tf.global_variables_initializer().run()

        could_load, checkpoint_counter = load(self.saver, self.sess, self.cfg.CHECKPOINT_DIR)
        if could_load:
            start_point = checkpoint_counter
            print(" [*] Load SUCCESS")
        else:
            start_point = 0
            print(" [!] Load failed...")
        sys.stdout.flush()

        for idx in range(start_point + 1, self.cfg.TRAIN.MAX_STEPS):
            epoch_size = self.dataset.train.num_examples // self.model.batch_size
            epoch = idx // epoch_size

            images, wrong_images, embed, _, _ = self.dataset.train.next_batch(self.model.batch_size, 1, embeddings=True,
                                                                              wrong_img=True)
            batch_z = np.random.normal(0, 1, (self.model.batch_size, self.model.z_dim))
            eps = np.random.uniform(0., 1., size=(self.model.batch_size, 1, 1, 1))
            n_critic = self.cfg.TRAIN.N_CRITIC
            kiter = (idx // n_critic) // 10000

            feed_dict = {
                self.model.learning_rate_d: self.lr_d * (0.95**kiter),
                self.model.learning_rate_g: self.lr_g * (0.95**kiter),
                self.model.x: images,
                self.model.x_mismatch: wrong_images,
                self.model.cond: embed,
                self.model.z: batch_z,
                self.model.epsilon: eps,
                self.model.z_sample: sample_z,
                self.model.cond_sample: sample_cond,
                self.model.iter: idx,
            }

            _, _, err_d = self.sess.run([self.model.D_optim, self.model.kt_optim, self.model.D_loss],
                                         feed_dict=feed_dict)

            if idx % n_critic == 0:
                _, err_g = self.sess.run([self.model.G_optim, self.model.G_loss],
                                         feed_dict=feed_dict)

            summary_period = self.cfg.TRAIN.SUMMARY_PERIOD
            if np.mod(idx, summary_period) == 0:
                summary_str = self.sess.run(self.summary_op, feed_dict=feed_dict)
                self.writer.add_summary(summary_str, idx)

            if np.mod(idx, self.cfg.TRAIN.SAMPLE_PERIOD) == 0:
                try:
                    samples = self.sess.run(self.model.sampler,
                                            feed_dict={
                                                self.model.z_sample: sample_z,
                                                self.model.cond_sample: sample_cond,
                                            })
                    save_images(samples, get_balanced_factorization(samples.shape[0]),
                                '{}train_{:02d}_{:04d}.png'.format(self.cfg.SAMPLE_DIR, epoch, idx))

                except Exception as e:
                    print("Failed to generate sample image")
                    print(type(e))
                    print(e.args)
                    print(e)

            if np.mod(idx, 500) == 2:
                save(self.saver, self.sess, self.cfg.CHECKPOINT_DIR, idx)
            sys.stdout.flush()
コード例 #6
0
ファイル: pggan.py プロジェクト: shivanikush/text-to-image-1
    def train(self):
        config = tf.ConfigProto()
        config.gpu_options.allow_growth = True

        with tf.Session(config=config) as sess:

            summary_writer = tf.summary.FileWriter(self.log_dir, sess.graph)
            start_point = 0

            if self.stage != 1:
                if self.trans:
                    could_load, _ = load(self.restore, sess, self.check_dir_read)
                    if not could_load:
                        raise RuntimeError('Could not load previous stage during transition')
                else:
                    could_load, _ = load(self.saver, sess, self.check_dir_read)
                    if not could_load:
                        raise RuntimeError('Could not load current stage')

            # variables to init
            vars_to_init = initialize_uninitialized(sess)
            sess.run(tf.variables_initializer(vars_to_init))

            sample_z = np.random.normal(0, 1, (self.sample_num, self.z_dim))
            _, sample_cond, _, captions = self.dataset.test.next_batch_test(self.sample_num, 0, 1)
            sample_cond = np.squeeze(sample_cond, axis=0)
            print('Conditionals sampler shape: {}'.format(sample_cond.shape))

            save_captions(self.sample_path, captions)
            start_time = time.time()

            for idx in range(start_point + 1, self.steps):
                if self.trans:
                    # Reduce the learning rate during the transition period and slowly increase it
                    p = idx / self.steps
                    self.lr_inp = self.lr  # * np.exp(-2 * np.square(1 - p))

                epoch_size = self.dataset.train.num_examples // self.batch_size
                epoch = idx // epoch_size

                images, wrong_images, embed, _, _ = self.dataset.train.next_batch(self.batch_size, 4,
                                                                                  wrong_img=True,
                                                                                  embeddings=True)
                batch_z = np.random.normal(0, 1, (self.batch_size, self.z_dim))
                eps = np.random.uniform(0., 1., size=(self.batch_size, 1, 1, 1))

                feed_dict = {
                    self.x: images,
                    self.learning_rate: self.lr_inp,
                    self.x_mismatch: wrong_images,
                    self.cond: embed,
                    self.z: batch_z,
                    self.epsilon: eps,
                    self.z_sample: sample_z,
                    self.cond_sample: sample_cond,
                    self.iter: idx,
                }

                _, err_d = sess.run([self.D_optim, self.D_loss], feed_dict=feed_dict)
                _, err_g = sess.run([self.G_optim, self.G_loss], feed_dict=feed_dict)

                if np.mod(idx, 20) == 0:
                    summary_str = sess.run(self.summary_op, feed_dict=feed_dict)
                    summary_writer.add_summary(summary_str, idx)

                    print("Epoch: [%2d] [%4d] time: %4.4f, d_loss: %.8f, g_loss: %.8f"
                          % (epoch, idx, time.time() - start_time, err_d, err_g))

                if np.mod(idx, 2000) == 0:
                    try:
                        samples = sess.run(self.sampler, feed_dict={
                                                    self.z_sample: sample_z,
                                                    self.cond_sample: sample_cond})
                        samples = np.clip(samples, -1., 1.)
                        if self.out_size > 256:
                            samples = samples[:4]

                        save_images(samples, get_balanced_factorization(samples.shape[0]),
                                    '{}train_{:02d}_{:04d}.png'.format(self.sample_path, epoch, idx))

                    except Exception as e:
                        print("Failed to generate sample image")
                        print(type(e))
                        print(e.args)
                        print(e)

                if np.mod(idx, 2000) == 0 or idx == self.steps - 1:
                    save(self.saver, sess, self.check_dir_write, idx)
                sys.stdout.flush()

        tf.reset_default_graph()
コード例 #7
0
ファイル: trainer.py プロジェクト: pxu4114/Generative_CVS
	def train(self):
		self.define_losses()
		self.define_summaries()

		sample_z = np.random.normal(0, 1, (self.model.sample_num, self.model.z_dim))
		_, sample_embed, _, captions = self.dataset.test.next_batch_test(self.model.sample_num, 0, 1)
		im_feats_test, sent_feats_test, labels_test = self.test_data_loader.get_batch(0,self.cfg.RETRIEVAL.SAMPLE_NUM,\
														image_aug = self.cfg.RETRIEVAL.IMAGE_AUG, phase='test')        
		sample_embed = np.squeeze(sample_embed, axis=0)
		print(sample_embed.shape)

		save_captions(self.cfg.SAMPLE_DIR, captions)

		counter = 1
		start_time = time.time()

		could_load, checkpoint_counter = load(self.stageii_saver, self.sess, self.cfg.CHECKPOINT_DIR)
		if could_load:
			counter = checkpoint_counter
			print(" [*] Load SUCCESS: Stage II networks are loaded.")
		else:
			print(" [!] Load failed for stage II networks...")

		could_load, checkpoint_counter = load(self.stagei_g_saver, self.sess, self.cfg_stage_i.CHECKPOINT_DIR)
		if could_load:
			print(" [*] Load SUCCESS: Stage I generator is loaded")
		else:
			print(" [!] WARNING!!! Failed to load the parameters for stage I generator...")

		initialize_uninitialized(self.sess)

		# Updates per epoch are given by the training data size / batch size
		updates_per_epoch = self.dataset.train.num_examples // self.model.batch_size
		epoch_start = counter // updates_per_epoch

		for epoch in range(epoch_start, self.cfg.TRAIN.EPOCH):
			cen_epoch = epoch // 100

			for idx in range(0, updates_per_epoch):
				images, wrong_images, embed, _, _ = self.dataset.train.next_batch(self.model.batch_size, 1,
																				  embeddings=True,
																				  wrong_img=True)
				batch_z = np.random.normal(0, 1, (self.model.batch_size, self.model.z_dim))

				# Retrieval data loader
				if idx % updates_per_epoch == 0:
					self.R_loader.shuffle_inds()
				
				im_feats, sent_feats, labels = self.R_loader.get_batch(idx % updates_per_epoch,\
								self.cfg.RETRIEVAL.BATCH_SIZE, image_aug = self.cfg.RETRIEVAL.IMAGE_AUG)                

				feed_dict = {
					self.learning_rate: self.lr * (0.5**cen_epoch),
					self.model.inputs: images,
					self.model.wrong_inputs: wrong_images,
					# self.model.embed_inputs: embed,
					# self.model.embed_inputs: self.txt_emb,
					self.model.z: batch_z,
					self.Retrieval.image_placeholder : im_feats, 
					self.Retrieval.sent_placeholder : sent_feats,
					self.Retrieval.label_placeholder : labels
				}

				# Update D network
				_, err_d, summary_str = self.sess.run([self.D_optim, self.D_loss, self.D_merged_summ],
													  feed_dict=feed_dict)
				self.writer.add_summary(summary_str, counter)

				# Update G network
				_, err_g, summary_str = self.sess.run([self.G_optim, self.G_loss, self.G_merged_summ],
													  feed_dict=feed_dict)
				self.writer.add_summary(summary_str, counter)
				
				# Update R network
				_, err_r, summary_str = self.sess.run([self.R_optim, self.R_loss, self.R_loss_summ],
													  feed_dict=feed_dict)
				self.writer.add_summary(summary_str, counter)                 

				counter += 1
				print("Epoch: [%2d] [%4d/%4d] time: %4.4f, d_loss: %.8f, g_loss: %.8f, r_loss: %.8f"
					  % (epoch, idx, updates_per_epoch,
						 time.time() - start_time, err_d, err_g, err_r))

				if np.mod(counter, 1000) == 0:
					try:
						# pdb.set_trace()
						self.Retrieval.eval()
						sent_emb = self.sess.run(self.Retrieval.sent_embed_tensor,
												feed_dict={
															self.Retrieval.image_placeholder_test: im_feats_test,
															self.Retrieval.sent_placeholder_test: sent_feats_test,
														  })
						self.model.eval(sent_emb)								  
						samples = self.sess.run(self.model.sampler,
												feed_dict={
															self.model.z_sample: sample_z,
															# self.model.embed_sample: sample_embed,
															self.model.embed_sample: sent_emb,
														  })
						save_images(samples, get_balanced_factorization(samples.shape[0]),
									'{}train_{:02d}_{:04d}.png'.format(self.cfg.SAMPLE_DIR, epoch, idx))
					except Exception as e:
						print("Failed to generate sample image")
						print(type(e))
						print(e.args)
						print(e)

				if np.mod(counter, 500) == 2:
					save(self.stageii_saver, self.sess, self.cfg.CHECKPOINT_DIR, counter)

			if np.mod(epoch, 50) == 0 and epoch!=0:
				self.ret_eval(epoch) 
コード例 #8
0
ファイル: data.py プロジェクト: ttyfly/spider_example
 def save(self):
     # 生成文件名。这里使用了 python datetime,请查阅 python 官方文档,了解这些函数的用法
     filename = self.name + '-' + datetime.fromisoformat(
         self.generated_at).strftime('%Y-%m-%d-%H-%M-%S') + '.json'
     # 直接调用 saver,将对象信息保存为 json
     save(filename, self.__dict__)
コード例 #9
0
def dojob(x, y, is_mouse_down, keys):

    #显示底版
    loader.screen.blit(loader.SAVE_BOARD_TITLE, (0, 0))
    a0 = (x >= 20 and x < 100 and y >= 20 and y < 60)
    a1 = (x >= 0 and x <= 860 and y >= 80 and y < 280)
    a2 = (x >= 0 and x <= 860 and y >= 280 and y < 480)
    a3 = (x >= 0 and x <= 860 and y >= 480 and y < 680)

    if a0:
        loader.screen.blit(loader.SAVE_BOARD_TITLE_BACK2, (20, 20))
        if is_mouse_down == True:
            player_runtime.INFO['saving'] = False
    else:
        loader.screen.blit(loader.SAVE_BOARD_TITLE_BACK1, (20, 20))

    if player_runtime.INFO['checksover'] == True:
        loader.screen.blit(loader.SAVE_BOARD_LOT, (0, 80))
        loader.screen.blit(loader.SAVE_BOARD_LOT, (0, 280))
        loader.screen.blit(loader.SAVE_BOARD_LOT, (0, 480))

        s = 90
        for sd in player_runtime.SDATA:
            try:
                img_url = sd['img_path']
                if not img_url == '':
                    img_res = sd['img_res']
                    cr = sd['round']
                    cd = sd['date']
                    loader.screen.blit(img_res, (15, s))
                    crt = loader.BAIKE_FONT.render(cr, True, color_rgb.WHITE,
                                                   None)
                    cdt = loader.BAIKE_FONT.render(cd, True, color_rgb.WHITE,
                                                   None)

                    loader.screen.blit(crt, (450, s + 35))
                    loader.screen.blit(cdt, (450, s + 90))
            except Exception as err:
                print('save read err')
                print(err)
            s = s + 200

        #执行存档逻辑
        # 先测试覆盖存档类型
        c_index = player_runtime.INFO['cslot']
        if not player_runtime.SDATA[c_index]['img_path'] == '':
            # 覆盖存档的提示
            loader.screen.blit(loader.CHECK_SAVEOVER, (230, 200))
            # 确认按钮
            loader.screen.blit(loader.SURE, (300, 370))
            # 取消按钮
            loader.screen.blit(loader.SURENO, (460, 370))

        else:
            # 直接存档的提示
            loader.screen.blit(loader.CHECK_SAVE, (230, 200))
            # 确认按钮
            loader.screen.blit(loader.SURE, (300, 370))
            # 取消按钮
            loader.screen.blit(loader.SURENO, (460, 370))

        #点击确认
        if x >= 300 and x <= 380 and y >= 370 and y <= 410:
            loader.screen.blit(loader.SELECT_SAVE, (300, 370))
            if is_mouse_down == True:
                try:
                    cslot = player_runtime.INFO['cslot'] + 1
                    ctime = timer.get_stime()
                    cround = player_runtime.INFO['round']
                    cfile_name = str(ctime) + 'a' + str(cround)
                    cimgp = 'save/slot' + str(cslot) + '.jpg'
                    cdatap = 'save/slot' + str(cslot) + '.pkl'
                    #windows下,修改文件名不能直接直接覆盖原文件
                    if os.path.exists(cimgp):
                        os.remove(cimgp)
                    res = saver.save(player_runtime.INFO, cdatap)
                    os.rename('save/tmp.jpg', cimgp)
                    shutil.copyfile(cimgp, 'save/tmp.jpg')
                    if res == True:
                        print('save ok')
                        # 更改存档纪录文件
                        f = open('save/tsd.pkl', 'rb')
                        pd = pickle.load(f)
                        f.close()
                        if cslot == 1:
                            pd['slot1'] = cfile_name
                        elif cslot == 2:
                            pd['slot2'] = cfile_name
                        elif cslot == 3:
                            pd['slot3'] = cfile_name

                        ff = open('save/tsd.pkl', 'wb')
                        pickle.dump(pd, ff)
                        ff.close()
                        print('update sd ok')
                        # 更新后刷新图片
                        sdata = saver.read_sd()
                        player_runtime.SDATA = sdata
                        #返回正常界面
                        player_runtime.INFO['checksover'] = False
                except Exception as err:
                    print('save error')
                    print(err)

        elif x >= 460 and x <= 540 and y >= 370 and y <= 410:
            loader.screen.blit(loader.SELECT_SAVE, (460, 370))
            if is_mouse_down == True:
                #取消的话就返回到界面浏览状态就行了
                player_runtime.INFO['checksover'] = False

    else:
        #执行选择存档逻辑
        if a1:
            loader.screen.blit(loader.SAVE_BOARD_LOTA, (0, 80))
            if is_mouse_down == True:
                player_runtime.INFO['checksover'] = True
                player_runtime.INFO['cslot'] = 0
        else:
            loader.screen.blit(loader.SAVE_BOARD_LOT, (0, 80))

        if a2:
            loader.screen.blit(loader.SAVE_BOARD_LOTA, (0, 280))
            if is_mouse_down == True:
                player_runtime.INFO['checksover'] = True
                player_runtime.INFO['cslot'] = 1
        else:
            loader.screen.blit(loader.SAVE_BOARD_LOT, (0, 280))

        if a3:
            loader.screen.blit(loader.SAVE_BOARD_LOTA, (0, 480))
            if is_mouse_down == True:
                player_runtime.INFO['checksover'] = True
                player_runtime.INFO['cslot'] = 2
        else:
            loader.screen.blit(loader.SAVE_BOARD_LOT, (0, 480))

        #绘制存档截图
        s = 90
        for sd in player_runtime.SDATA:
            try:
                img_url = sd['img_path']
                if not img_url == '':
                    img_res = sd['img_res']
                    cr = sd['round']
                    cd = sd['date']
                    loader.screen.blit(img_res, (15, s))
                    crt = loader.BAIKE_FONT.render(cr, True, color_rgb.WHITE,
                                                   None)
                    cdt = loader.BAIKE_FONT.render(cd, True, color_rgb.WHITE,
                                                   None)

                    loader.screen.blit(crt, (450, s + 35))
                    loader.screen.blit(cdt, (450, s + 90))
            except Exception as err:
                print('save read err')
                print(err)
            s = s + 200