Пример #1
0
    def train(self):
        self.define_losses()
        self.define_summaries()

        sample_z = np.random.normal(0, 1,
                                    (self.model.sample_num, self.model.z_dim))
        _, sample_embed, _, captions = self.dataset.test.next_batch_test(
            self.model.sample_num, 0, 1)
        sample_embed = np.squeeze(sample_embed, axis=0)
        print(sample_embed.shape)

        save_captions(self.cfg.SAMPLE_DIR, captions)

        counter = 1
        start_time = time.time()

        could_load, checkpoint_counter = load(self.saver, self.sess,
                                              self.cfg.CHECKPOINT_DIR)
        if could_load:
            counter = checkpoint_counter
            print(" [*] Load SUCCESS")
        else:
            print(" [!] Load failed...")

        initialize_uninitialized(self.sess)

        # Updates per epoch are given by the training data size / batch size
        updates_per_epoch = self.dataset.train.num_examples // self.model.batch_size
        epoch_start = counter // updates_per_epoch

        for epoch in range(epoch_start, self.cfg.TRAIN.EPOCH):
            cen_epoch = epoch // 100

            for idx in range(0, updates_per_epoch):
                images, wrong_images, embed, _, _ = self.dataset.train.next_batch(
                    self.model.batch_size, 4, embeddings=True, wrong_img=True)
                batch_z = np.random.normal(
                    0, 1, (self.model.batch_size, self.model.z_dim))

                feed_dict = {
                    self.learning_rate: self.lr * (0.5**cen_epoch),
                    self.model.inputs: images,
                    self.model.wrong_inputs: wrong_images,
                    self.model.embed_inputs: embed,
                    self.model.z: batch_z,
                }

                # Update D network
                _, err_d, summary_str = self.sess.run(
                    [self.D_optim, self.D_loss, self.D_merged_summ],
                    feed_dict=feed_dict)
                self.writer.add_summary(summary_str, counter)

                # Update G network
                _, err_g, summary_str = self.sess.run(
                    [self.G_optim, self.G_loss, self.G_merged_summ],
                    feed_dict=feed_dict)
                self.writer.add_summary(summary_str, counter)

                counter += 1
                print(
                    "Epoch: [%2d] [%4d/%4d] time: %4.4f, d_loss: %.8f, g_loss: %.8f"
                    % (epoch, idx, updates_per_epoch, time.time() - start_time,
                       err_d, err_g))

                if np.mod(counter, 500) == 0:
                    try:
                        samples = self.sess.run(self.model.sampler,
                                                feed_dict={
                                                    self.model.z_sample:
                                                    sample_z,
                                                    self.model.embed_sample:
                                                    sample_embed,
                                                })
                        save_images(
                            samples,
                            get_balanced_factorization(samples.shape[0]),
                            '{}train_{:02d}_{:04d}.png'.format(
                                self.cfg.SAMPLE_DIR, epoch, idx))
                    except Exception as e:
                        print("Failed to generate sample image")
                        print(type(e))
                        print(e.args)
                        print(e)

                if np.mod(counter, 500) == 0:
                    save(self.saver, self.sess, self.cfg.CHECKPOINT_DIR,
                         counter)
Пример #2
0
    def train(self):
        config = tf.ConfigProto()
        config.gpu_options.allow_growth = True

        with tf.Session(config=config) as sess:

            summary_writer = tf.summary.FileWriter(self.log_dir, sess.graph)
            start_point = 0

            if self.stage != 1:
                if self.trans:
                    could_load, _ = load(self.restore, sess, self.check_dir_read)
                    if not could_load:
                        raise RuntimeError('Could not load previous stage during transition')
                else:
                    could_load, _ = load(self.saver, sess, self.check_dir_read)
                    if not could_load:
                        raise RuntimeError('Could not load current stage')

            # variables to init
            vars_to_init = initialize_uninitialized(sess)
            sess.run(tf.variables_initializer(vars_to_init))

            sample_z = np.random.normal(0, 1, (self.sample_num, self.z_dim))
            _, sample_cond, _, captions = self.dataset.test.next_batch_test(self.sample_num, 0, 1)
            sample_cond = np.squeeze(sample_cond, axis=0)
            print('Conditionals sampler shape: {}'.format(sample_cond.shape))

            save_captions(self.sample_path, captions)
            start_time = time.time()

            for idx in range(start_point + 1, self.steps):
                if self.trans:
                    # Reduce the learning rate during the transition period and slowly increase it
                    p = idx / self.steps
                    self.lr_inp = self.lr  # * np.exp(-2 * np.square(1 - p))

                epoch_size = self.dataset.train.num_examples // self.batch_size
                epoch = idx // epoch_size

                images, wrong_images, embed, _, _ = self.dataset.train.next_batch(self.batch_size, 4,
                                                                                  wrong_img=True,
                                                                                  embeddings=True)
                batch_z = np.random.normal(0, 1, (self.batch_size, self.z_dim))
                eps = np.random.uniform(0., 1., size=(self.batch_size, 1, 1, 1))

                feed_dict = {
                    self.x: images,
                    self.learning_rate: self.lr_inp,
                    self.x_mismatch: wrong_images,
                    self.cond: embed,
                    self.z: batch_z,
                    self.epsilon: eps,
                    self.z_sample: sample_z,
                    self.cond_sample: sample_cond,
                    self.iter: idx,
                }

                _, err_d = sess.run([self.D_optim, self.D_loss], feed_dict=feed_dict)
                _, err_g = sess.run([self.G_optim, self.G_loss], feed_dict=feed_dict)

                if np.mod(idx, 20) == 0:
                    summary_str = sess.run(self.summary_op, feed_dict=feed_dict)
                    summary_writer.add_summary(summary_str, idx)

                    print("Epoch: [%2d] [%4d] time: %4.4f, d_loss: %.8f, g_loss: %.8f"
                          % (epoch, idx, time.time() - start_time, err_d, err_g))

                if np.mod(idx, 2000) == 0:
                    try:
                        samples = sess.run(self.sampler, feed_dict={
                                                    self.z_sample: sample_z,
                                                    self.cond_sample: sample_cond})
                        samples = np.clip(samples, -1., 1.)
                        if self.out_size > 256:
                            samples = samples[:4]

                        save_images(samples, get_balanced_factorization(samples.shape[0]),
                                    '{}train_{:02d}_{:04d}.png'.format(self.sample_path, epoch, idx))

                    except Exception as e:
                        print("Failed to generate sample image")
                        print(type(e))
                        print(e.args)
                        print(e)

                if np.mod(idx, 2000) == 0 or idx == self.steps - 1:
                    save(self.saver, sess, self.check_dir_write, idx)
                sys.stdout.flush()

        tf.reset_default_graph()
Пример #3
0
    def train(self):
        self.define_summaries()

        self.saver = tf.train.Saver(max_to_keep=self.cfg.TRAIN.CHECKPOINTS_TO_KEEP)

        sample_z = np.random.normal(0, 1, (self.model.sample_num, self.model.z_dim))
        _, sample_cond, _, captions = self.dataset.test.next_batch_test(self.model.sample_num, 0, 1)
        # _, sample_cond, _, captions = self.dataset.test.next_batch_test(self.model.sample_num, 1, 1)
        sample_cond = np.squeeze(sample_cond, axis=0)
        print('Conditionals sampler shape: {}'.format(sample_cond.shape))

        save_captions(self.cfg.SAMPLE_DIR, captions)

        start_time = time.time()
        tf.global_variables_initializer().run()

        could_load, checkpoint_counter = load(self.saver, self.sess, self.cfg.CHECKPOINT_DIR)
        if could_load:
            start_point = checkpoint_counter
            print(" [*] Load SUCCESS")
        else:
            start_point = 0
            print(" [!] Load failed...")
        sys.stdout.flush()

        for idx in range(start_point + 1, self.cfg.TRAIN.MAX_STEPS):
            epoch_size = self.dataset.train.num_examples // self.model.batch_size
            epoch = idx // epoch_size

            images, wrong_images, embed, _, _ = self.dataset.train.next_batch(self.model.batch_size, 1, embeddings=True,
                                                                              wrong_img=True)
            batch_z = np.random.normal(0, 1, (self.model.batch_size, self.model.z_dim))
            eps = np.random.uniform(0., 1., size=(self.model.batch_size, 1, 1, 1))
            n_critic = self.cfg.TRAIN.N_CRITIC
            kiter = (idx // n_critic) // 10000

            feed_dict = {
                self.model.learning_rate_d: self.lr_d * (0.95**kiter),
                self.model.learning_rate_g: self.lr_g * (0.95**kiter),
                self.model.x: images,
                self.model.x_mismatch: wrong_images,
                self.model.cond: embed,
                self.model.z: batch_z,
                self.model.epsilon: eps,
                self.model.z_sample: sample_z,
                self.model.cond_sample: sample_cond,
                self.model.iter: idx,
            }

            _, _, err_d = self.sess.run([self.model.D_optim, self.model.kt_optim, self.model.D_loss],
                                         feed_dict=feed_dict)

            if idx % n_critic == 0:
                _, err_g = self.sess.run([self.model.G_optim, self.model.G_loss],
                                         feed_dict=feed_dict)

            summary_period = self.cfg.TRAIN.SUMMARY_PERIOD
            if np.mod(idx, summary_period) == 0:
                summary_str = self.sess.run(self.summary_op, feed_dict=feed_dict)
                self.writer.add_summary(summary_str, idx)

            if np.mod(idx, self.cfg.TRAIN.SAMPLE_PERIOD) == 0:
                try:
                    samples = self.sess.run(self.model.sampler,
                                            feed_dict={
                                                self.model.z_sample: sample_z,
                                                self.model.cond_sample: sample_cond,
                                            })
                    save_images(samples, get_balanced_factorization(samples.shape[0]),
                                '{}train_{:02d}_{:04d}.png'.format(self.cfg.SAMPLE_DIR, epoch, idx))

                except Exception as e:
                    print("Failed to generate sample image")
                    print(type(e))
                    print(e.args)
                    print(e)

            if np.mod(idx, 500) == 2:
                save(self.saver, self.sess, self.cfg.CHECKPOINT_DIR, idx)
            sys.stdout.flush()
Пример #4
0
	def train(self):
		self.define_losses()
		self.define_summaries()

		sample_z = np.random.normal(0, 1, (self.model.sample_num, self.model.z_dim))
		_, sample_embed, _, captions = self.dataset.test.next_batch_test(self.model.sample_num, 0, 1)
		im_feats_test, sent_feats_test, labels_test = self.test_data_loader.get_batch(0,self.cfg.RETRIEVAL.SAMPLE_NUM,\
														image_aug = self.cfg.RETRIEVAL.IMAGE_AUG, phase='test')        
		sample_embed = np.squeeze(sample_embed, axis=0)
		print(sample_embed.shape)

		save_captions(self.cfg.SAMPLE_DIR, captions)

		counter = 1
		start_time = time.time()

		could_load, checkpoint_counter = load(self.stageii_saver, self.sess, self.cfg.CHECKPOINT_DIR)
		if could_load:
			counter = checkpoint_counter
			print(" [*] Load SUCCESS: Stage II networks are loaded.")
		else:
			print(" [!] Load failed for stage II networks...")

		could_load, checkpoint_counter = load(self.stagei_g_saver, self.sess, self.cfg_stage_i.CHECKPOINT_DIR)
		if could_load:
			print(" [*] Load SUCCESS: Stage I generator is loaded")
		else:
			print(" [!] WARNING!!! Failed to load the parameters for stage I generator...")

		initialize_uninitialized(self.sess)

		# Updates per epoch are given by the training data size / batch size
		updates_per_epoch = self.dataset.train.num_examples // self.model.batch_size
		epoch_start = counter // updates_per_epoch

		for epoch in range(epoch_start, self.cfg.TRAIN.EPOCH):
			cen_epoch = epoch // 100

			for idx in range(0, updates_per_epoch):
				images, wrong_images, embed, _, _ = self.dataset.train.next_batch(self.model.batch_size, 1,
																				  embeddings=True,
																				  wrong_img=True)
				batch_z = np.random.normal(0, 1, (self.model.batch_size, self.model.z_dim))

				# Retrieval data loader
				if idx % updates_per_epoch == 0:
					self.R_loader.shuffle_inds()
				
				im_feats, sent_feats, labels = self.R_loader.get_batch(idx % updates_per_epoch,\
								self.cfg.RETRIEVAL.BATCH_SIZE, image_aug = self.cfg.RETRIEVAL.IMAGE_AUG)                

				feed_dict = {
					self.learning_rate: self.lr * (0.5**cen_epoch),
					self.model.inputs: images,
					self.model.wrong_inputs: wrong_images,
					# self.model.embed_inputs: embed,
					# self.model.embed_inputs: self.txt_emb,
					self.model.z: batch_z,
					self.Retrieval.image_placeholder : im_feats, 
					self.Retrieval.sent_placeholder : sent_feats,
					self.Retrieval.label_placeholder : labels
				}

				# Update D network
				_, err_d, summary_str = self.sess.run([self.D_optim, self.D_loss, self.D_merged_summ],
													  feed_dict=feed_dict)
				self.writer.add_summary(summary_str, counter)

				# Update G network
				_, err_g, summary_str = self.sess.run([self.G_optim, self.G_loss, self.G_merged_summ],
													  feed_dict=feed_dict)
				self.writer.add_summary(summary_str, counter)
				
				# Update R network
				_, err_r, summary_str = self.sess.run([self.R_optim, self.R_loss, self.R_loss_summ],
													  feed_dict=feed_dict)
				self.writer.add_summary(summary_str, counter)                 

				counter += 1
				print("Epoch: [%2d] [%4d/%4d] time: %4.4f, d_loss: %.8f, g_loss: %.8f, r_loss: %.8f"
					  % (epoch, idx, updates_per_epoch,
						 time.time() - start_time, err_d, err_g, err_r))

				if np.mod(counter, 1000) == 0:
					try:
						# pdb.set_trace()
						self.Retrieval.eval()
						sent_emb = self.sess.run(self.Retrieval.sent_embed_tensor,
												feed_dict={
															self.Retrieval.image_placeholder_test: im_feats_test,
															self.Retrieval.sent_placeholder_test: sent_feats_test,
														  })
						self.model.eval(sent_emb)								  
						samples = self.sess.run(self.model.sampler,
												feed_dict={
															self.model.z_sample: sample_z,
															# self.model.embed_sample: sample_embed,
															self.model.embed_sample: sent_emb,
														  })
						save_images(samples, get_balanced_factorization(samples.shape[0]),
									'{}train_{:02d}_{:04d}.png'.format(self.cfg.SAMPLE_DIR, epoch, idx))
					except Exception as e:
						print("Failed to generate sample image")
						print(type(e))
						print(e.args)
						print(e)

				if np.mod(counter, 500) == 2:
					save(self.stageii_saver, self.sess, self.cfg.CHECKPOINT_DIR, counter)

			if np.mod(epoch, 50) == 0 and epoch!=0:
				self.ret_eval(epoch)