示例#1
0
def train():
    # placeholder for z
    z = tf.placeholder(tf.float32, [FLAGS.batch_size, FLAGS.z_dim], name='z')

    # get images and labels
    images, labels = ac_gan.inputs(batch_size=FLAGS.batch_size)

    # logits
    [
        source_logits_real, class_logits_real, source_logits_fake,
        class_logits_fake, generated_images
    ] = ac_gan.inference(images, labels, z)

    # loss
    d_loss, g_loss = ac_gan.loss(labels, source_logits_real, class_logits_real,
                                 source_logits_fake, class_logits_fake,
                                 generated_images)

    # train the model
    train_d_op, train_g_op = ac_gan.train(d_loss, g_loss)

    sess = tf.Session()
    with sess.as_default():
        init = tf.global_variables_initializer()
        sess.run(init)
        tf.train.start_queue_runners(sess=sess)

        saver = tf.train.Saver()

        training_steps = FLAGS.train_steps

        for step in range(training_steps):

            random_z = np.random.uniform(-1,
                                         1,
                                         size=(FLAGS.batch_size,
                                               FLAGS.z_dim)).astype(np.float32)

            sess.run(train_d_op, feed_dict={z: random_z})
            sess.run(train_g_op, feed_dict={z: random_z})
            sess.run(train_g_op, feed_dict={z: random_z})

            discrimnator_loss, generator_loss = sess.run(
                [d_loss, g_loss], feed_dict={z: random_z})

            time_str = datetime.datetime.now().isoformat()
            print("{}: step {}, d_loss {:g}, g_loss {:g}".format(
                time_str, step, discrimnator_loss, generator_loss))

            if step % 10 == 0:
                test_images = sess.run(generated_images,
                                       feed_dict={z: random_z})

                image_path = os.path.join(FLAGS.sample_dir,
                                          "sampled_images_%d.jpg" % step)

                utils.grid_plot(test_images, [8, 8], image_path)

            if step % 100 == 0:
                saver.save(sess, os.path.join(FLAGS.log_dir, "model.ckp"))
def train_generator():
    for start in range(0, len(file_index[:data_split_index]), BATCH_SIZE):
        end = min(start + BATCH_SIZE, data_split_index)
        x_batch, y_batch = list(), list()

        for index in range(start, end):
            # continue
            image_path = ascii2str(filenames[file_index[index]])
            person_keypoints = keypoints[index]
            temp_keypoints = person_keypoints

            image = cv2.imread(image_path)

            temp_keypoints = temp_keypoints[np.where(
                (person_keypoints[:, 2] != -1) *
                (person_keypoints[:, 0] != 0) * (person_keypoints[:, 1] != 0))]
            x_start = max(min(temp_keypoints[:, 1]) - 100, 0)
            x_end = min(max(temp_keypoints[:, 1]) + 100, image.shape[0])
            y_start = max(min(temp_keypoints[:, 0]) - 100, 0)
            y_end = min(max(temp_keypoints[:, 0]) + 100, image.shape[1])

            x_start, x_end, y_start, y_end = int(x_start), int(x_end), int(
                y_start), int(y_end)

            person_keypoints[:, 0] = person_keypoints[:, 0] - y_start
            person_keypoints[:, 1] = person_keypoints[:, 1] - x_start

            crop_image = image[x_start:x_end, y_start:y_end, :]
            height, width = crop_image.shape[0], crop_image.shape[1]
            person_keypoints[:, 0] = person_keypoints[:, 0] * 46.0 / width
            person_keypoints[:, 1] = person_keypoints[:, 1] * 46.0 / height

            crop_image = cv2.resize(crop_image, (IMAGE_WIDTH, IMAGE_HEIGHT))
            heatmap_image = cv2.resize(crop_image, (46, 46))
            crop_heatmap = create_heatmap(heatmap_image, person_keypoints)
            x_batch.append(crop_image)
            y_batch.append(crop_heatmap)
            if GRID_PLOT:
                grid_plot(crop_heatmap)
        x_batch = np.array(x_batch) / 255.0 - 0.5
        y_batch = np.array(y_batch)
        yield x_batch, y_batch
示例#3
0
    def sample(self):
        # build model
        model = self.model
        model.build_model()

        with tf.Session(config=self.config) as sess:
            # load trained parameters
            print('loading testing model..')
            saver = tf.train.Saver()
            self.load_latest(saver, sess)

            batch_size = self.batch_size if self.batch_size <= 32 else 32
            print('start sampling..!')
            for i in range(64):
                z = np.random.normal(size=[batch_size, model.latent_dim])
                feed_dict = {model.z: z}
                sample = sess.run(model.reconst, feed_dict)

                output_name = os.path.join(self.sample_save_path,
                                           '{:03d}.jpg'.format(i))
                print('Saving to {}'.format(output_name))
                grid_plot(sample, 8, output_name)
示例#4
0
    def sample(self):
        # build model
        print("SAMPLE!!!!!!!!!!!!")
        model = self.model
        model.build_model()

        with tf.Session(config=self.config) as sess:
            # load trained parameters
            print('loading testing model..')
            saver = tf.train.Saver()
            self.load_latest(saver, sess)

            batch_size = self.batch_size if self.batch_size <= 32 else 32
            print('start sampling..!')
            for i in range(64):
                z = self.generate_z(self.batch_size, self.z_dim)
                feed_dict = {model.z: z}

                sample = sess.run(model.fake_images, feed_dict)

                output_name = os.path.join(self.sample_save_path,
                                           '{:03d}.jpg'.format(i))
                print('Saving to {}'.format(output_name))
                grid_plot(sample, 8, output_name)
def main(argv):
    train_log_save_dir = os.path.join('model', 'logs', 'train')
    test_log_save_dir = os.path.join('model', 'logs', 'test')
    os.system('mkdir -p {}'.format(train_log_save_dir))
    os.system('mkdir -p {}'.format(test_log_save_dir))

    model = CPM(stages=3, joints=16)
    model.build_model()
    model.build_loss(decay_rate=decay_rate,
                     decay_steps=decay_steps,
                     lr=init_lr)
    merged_summary = tf.summary.merge_all()

    with tf.Session() as sess:
        train_writer = tf.summary.FileWriter(train_log_save_dir, sess.graph)

        saver = tf.train.Saver(max_to_keep=None)
        init_op = tf.global_variables_initializer()
        sess.run(init_op)

        num_iterations = int(data_split_index / BATCH_SIZE) + 1

        # path = 'model/weights/model.ckpt-1'
        # saver.restore(sess, path)
        # begin = int(path.split('-')[1]) + 1
        begin = 0

        print('Starting training...')
        for epoch in range(begin, EPOCHS):
            loss = 0
            count = 0
            for start in range(0, len(file_index[:data_split_index]),
                               BATCH_SIZE):
                end = min(start + BATCH_SIZE, data_split_index)
                x_batch, y_batch = list(), list()
                count += 1

                for index in range(start, end):
                    image_path = ascii2str(filenames[file_index[index]])
                    person_keypoints = keypoints[index]
                    temp_keypoints = person_keypoints

                    image = cv2.imread(image_path)

                    temp_keypoints = temp_keypoints[np.where(
                        (person_keypoints[:, 2] != -1) *
                        (person_keypoints[:, 0] != 0) *
                        (person_keypoints[:, 1] != 0))]
                    try:
                        x_start = max(min(temp_keypoints[:, 1]) - 100, 0)
                        x_end = min(
                            max(temp_keypoints[:, 1]) + 100, image.shape[0])
                        y_start = max(min(temp_keypoints[:, 0]) - 100, 0)
                        y_end = min(
                            max(temp_keypoints[:, 0]) + 100, image.shape[1])
                    except ValueError:
                        continue

                    x_start, x_end, y_start, y_end = int(x_start), int(
                        x_end), int(y_start), int(y_end)

                    person_keypoints[:, 0] = person_keypoints[:, 0] - y_start
                    person_keypoints[:, 1] = person_keypoints[:, 1] - x_start

                    crop_image = image[x_start:x_end, y_start:y_end, :]
                    height, width = crop_image.shape[0], crop_image.shape[1]
                    person_keypoints[:,
                                     0] = person_keypoints[:, 0] * 46.0 / width
                    person_keypoints[:,
                                     1] = person_keypoints[:,
                                                           1] * 46.0 / height

                    crop_image = cv2.resize(crop_image,
                                            (IMAGE_WIDTH, IMAGE_HEIGHT))
                    heatmap_image = cv2.resize(crop_image, (46, 46))
                    crop_heatmap = create_heatmap(heatmap_image,
                                                  person_keypoints)
                    x_batch.append(crop_image)
                    y_batch.append(crop_heatmap)
                    if GRID_PLOT:
                        grid_plot(crop_heatmap)
                x_batch = np.array(x_batch) / 255.0 - 0.5
                y_batch = np.array(y_batch)

                stage_loss, total_loss, train_op, summaries, heatmaps, \
                    global_step, learning_rate = sess.run([model.stage_loss,
                                                           model.total_loss,
                                                           model.train_op,
                                                           merged_summary,
                                                           model.heatmaps,
                                                           model.global_step,
                                                           model.learning_rate],
                                                          feed_dict={model.images: x_batch,
                                                                     model.true_heatmaps: y_batch})
                train_writer.add_summary(summaries, global_step)
                loss += total_loss
                print('\tEpoch: {} Iteration: {}/{}, Total Average loss: {}'.
                      format(epoch + 1, count, num_iterations, loss / count))
            print('Epoch {} completed.'.format(epoch + 1))
            saver.save(sess=sess,
                       save_path='model/weights/model.ckpt',
                       global_step=epoch + 1)