Exemplo n.º 1
0
def main(arguments):
    """Main loop."""
    EPOCHS = arguments.epochs
    GPU = arguments.gpu
    GPU_NUMBER = arguments.gpu_number
    # TO_SAMPLE = arguments.sample

    DATA_PATH = 'data/horse2zebra/'

    tf.reset_default_graph() 

    if GPU == 1:
        os.environ["CUDA_VISIBLE_DEVICES"]="{}".format(GPU_NUMBER)
        os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"

        config = tf.ConfigProto(log_device_placement=True)
        config.gpu_options.per_process_gpu_memory_fraction = 0.5  # pylint: disable=no-member
        config.gpu_options.allow_growth = True  # pylint: disable=no-member
        sess = tf.Session(config=config)
    else:
        sess = tf.Session()
    
    it_a, train_A = Images(DATA_PATH + '_trainA.tfrecords', name='trainA').feed()
    it_b, train_B = Images(DATA_PATH + '_trainB.tfrecords', name='trainB').feed()
    # it_at, test_A = Images(DATA_PATH + '_testA.tfrecords', name='test_a').feed()
    # it_bt, test_B = Images(DATA_PATH + '_testB.tfrecords', name='test_b').feed()
    
    gen_a_sample = tf.placeholder(tf.float32, [None, WIDTH, HEIGHT, CHANNEL], name="fake_a_sample")
    gen_b_sample = tf.placeholder(tf.float32, [None, WIDTH, HEIGHT, CHANNEL], name="fake_b_sample")
    learning_rate = tf.placeholder(tf.float32, shape=[], name="lr")

    d_a_train_op, d_b_train_op, g_a_train_op, g_b_train_op, g1, g2 = \
    build_model(train_A, train_B, gen_a_sample, gen_b_sample, learning_rate)

    # testG1 = generator(test_A, name='g_a2b')
    # testG2 = generator(test_B,  name='g_b2a')
    # testCycleA = generator(testG1,  name='d_a')
    # testCycleB = generator(testG2, name='d_b')

    merged = tf.summary.merge_all()
    
    init = tf.global_variables_initializer()
    saver = tf.train.Saver()

    with sess:
        sess.run(init)
        writer = tf.summary.FileWriter(LOG_DIR, tf.get_default_graph())
        
        cache_a = ImageCache(50)
        cache_b = ImageCache(50)

        print('Beginning training...')
        start = time.perf_counter()
        for epoch in range(EPOCHS):
            sess.run(it_a)
            sess.run(it_b)
            if epoch < 100:
                lr = 2e-4
            else:
                lr = 2e-4 - (2e-4 * (epoch - 100) / 100)
            try:
                for step in tqdm(range(533)):  # TODO change number of steps
                    gen_a, gen_b, = sess.run([g1, g2])

                    _, _, _, _, summaries = sess.run([d_b_train_op, d_a_train_op, 
                                                      g_a_train_op, g_b_train_op, merged],
                                                     feed_dict={gen_b_sample: cache_b.fetch(gen_b),
                                                                gen_a_sample: cache_a.fetch(gen_a),
                                                                learning_rate: lr})
                    if step % 100 == 0:
                        writer.add_summary(summaries, epoch * 533 + step)

            except tf.errors.OutOfRangeError as e:
                print(e)
                print("Out of range: {}".format(step))
                pass  
           
            print("Epoch {}/{} done.".format(epoch+1, EPOCHS))

            counter = epoch + 1
                
            if np.mod(counter, SAVE_STEP) == 0:
                save_path = save_model(saver, sess, counter)
                print('Running for {:.2f} seconds, saving to {}'.format(time.perf_counter() - start, save_path))
Exemplo n.º 2
0
class Handler:
    def __init__(self, scores, builder, page_entry):
        self.page_entry = page_entry
        self.page_num = 0
        self.show_page_number()
        self.page_count = len(scores)
        self.scores = scores
        self.builder = builder
        self.cache = ImageCache(num_workers=16, max_cache_size=40)
        self.update_page()

    def cache_nearby(self, pagenum):
        filenames = []

        for i in range(pagenum, pagenum + 10):
            if i > 0 and i < len(self.scores):
                leftfile = self.scores[i][1]
                rightfile = self.scores[i][2]
                filenames.append(leftfile)
                filenames.append(rightfile)

        for i in range(pagenum - 5, pagenum):
            if i > 0 and i < len(self.scores):
                leftfile = self.scores[i][1]
                rightfile = self.scores[i][2]
                filenames.append(leftfile)
                filenames.append(rightfile)

        self.cache.preload(filenames)

    def get_image(self, pagenum):
        left_file = self.scores[self.page_num][1]
        right_file = self.scores[self.page_num][2]

        left_image = self.cache.fetch(left_file)
        right_image = self.cache.fetch(right_file)
        left_image['pixbuf'] = pil_to_pixbuf(left_image['canvas'])
        right_image['pixbuf'] = pil_to_pixbuf(right_image['canvas'])
        return left_image, right_image

    def update_page(self):
        if (len(self.scores) == 0):
            description = Gtk.TextBuffer()
            description.set_text("No similar pairs of images found")
            textview3 = self.builder.get_object("textview3")
            textview3.set_buffer(description)
            return

        score = self.scores[self.page_num][0]

        description = Gtk.TextBuffer()
        description.set_text("Distance between images: %.2f" % (score))
        textview3 = self.builder.get_object("textview3")
        textview3.set_property("justification", Gtk.Justification.CENTER)
        textview3.set_buffer(description)

        left_file = self.scores[self.page_num][1]
        right_file = self.scores[self.page_num][2]

        image2 = self.builder.get_object("image2")
        image1 = self.builder.get_object("image1")
        try:
            left_image, right_image = self.get_image(self.page_num)
            image2.set_from_pixbuf(right_image['pixbuf'])
            image1.set_from_pixbuf(left_image['pixbuf'])
        except FileNotFoundError as e:
            print(e)
            return

        self.cache_nearby(self.page_num)

        # update right
        right_description = "%s\nWidth: %d\nHeight: %d\nFilesize: %d\n" % (
            right_file, right_image["width"], right_image["height"],
            right_image["filesize"])
        if (right_image["filesize"] > left_image["filesize"]):
            right_description = right_description + "(Larger)\n"

        right_buffer = Gtk.TextBuffer()
        right_buffer.set_text(right_description)
        textview2 = self.builder.get_object("textview2")
        textview2.set_buffer(right_buffer)

        # update left
        left_description = "%s\nWidth: %d\nHeight: %d\nFilesize: %d\n" % (
            left_file, left_image["width"], left_image["height"],
            left_image["filesize"])
        if (left_image["filesize"] > right_image["filesize"]):
            left_description = left_description + "(Larger)\n"

        left_buffer = Gtk.TextBuffer()
        left_buffer.set_text(left_description)
        textview1 = self.builder.get_object("textview1")
        textview1.set_buffer(left_buffer)

    def show_page_number(self):
        new_page_str = str(self.page_num + 1)
        self.page_entry.get_buffer().set_text(new_page_str, len(new_page_str))

    def set_page_number(self, num):
        self.page_num = num
        if (self.page_num >= self.page_count):
            self.page_num = 0
        elif (self.page_num < 0):
            self.page_num = self.page_count - 1

        self.update_page()

    def cancel_deletion(self):
        window = self.builder.get_object("window1")
        dialog = DialogExample(window)
        response = dialog.run()

        result = None
        if response == Gtk.ResponseType.OK:
            result = False
        elif response == Gtk.ResponseType.CANCEL:
            result = True

        dialog.destroy()

        return result

    def onDeleteRight(self, *args):
        if (len(self.scores) == 0):
            return
        #TODO: Disable navigation until this finishes to prevent races
        if self.cancel_deletion():
            return

        right_file = self.scores[self.page_num][2]
        temp = []
        for res in self.scores:
            if res[1] != right_file and res[2] != right_file:
                temp.append(res)

        self.scores = temp
        self.page_count = len(self.scores)

        self.set_page_number(self.page_num)
        self.show_page_number()
        try:
            os.remove(right_file)
        except FileNotFoundError as e:
            print(e)
            pass

    def onDeleteLeft(self, *args):
        if (len(self.scores) == 0):
            return
        #TODO: Disable navigation until this finishes to prevent races
        if self.cancel_deletion():
            return

        left_file = self.scores[self.page_num][1]
        temp = []
        for res in self.scores:
            if res[1] != left_file and res[2] != left_file:
                temp.append(res)

        self.scores = temp
        self.page_count = len(self.scores)

        self.set_page_number(self.page_num)
        self.show_page_number()
        try:
            os.remove(left_file)
        except FileNotFoundError as e:
            print(e)
            pass

    def onDeleteWindow(self, *args):
        self.cache.quit()
        Gtk.main_quit(*args)

    def onLeftClicked(self, *args):
        self.set_page_number(self.page_num - 1)
        self.show_page_number()

    def onRightClicked(self, *args):
        self.set_page_number(self.page_num + 1)
        self.show_page_number()

    def onIgnoreClicked(self, *args):
        pass

    def onDeleteBothClicked(self, *args):
        pass

    def page_num_edited(self, *args):
        page_num_str = self.page_entry.get_text()
        try:
            self.page_num = int(page_num_str) - 1

            if (self.page_num >= self.page_count):
                self.page_num = self.page_count - 1
                self.page_entry.set_text(str(self.page_count))
            elif (self.page_num < 0):
                self.page_num = 0
                self.page_entry.set_text(str(1))
            else:
                self.update_page()
        except:
            pass