Exemplo n.º 1
0
    def build_model(self):
        data_shape = [None, 3, self.cfg.crop_size, self.cfg.crop_size]

        input_A = fluid.data(name='input_A', shape=data_shape, dtype='float32')
        input_B = fluid.data(name='input_B', shape=data_shape, dtype='float32')
        fake_pool_A = fluid.data(
            name='fake_pool_A', shape=data_shape, dtype='float32')
        fake_pool_B = fluid.data(
            name='fake_pool_B', shape=data_shape, dtype='float32')
        # used for continuous evaluation
        if self.cfg.enable_ce:
            fluid.default_startup_program().random_seed = 90

        A_py_reader = fluid.io.PyReader(
            feed_list=[input_A],
            capacity=4,
            iterable=True,
            use_double_buffer=True)

        B_py_reader = fluid.io.PyReader(
            feed_list=[input_B],
            capacity=4,
            iterable=True,
            use_double_buffer=True)

        gen_trainer = GTrainer(input_A, input_B, self.cfg, self.batch_num)
        d_A_trainer = DATrainer(input_B, fake_pool_B, self.cfg, self.batch_num)
        d_B_trainer = DBTrainer(input_A, fake_pool_A, self.cfg, self.batch_num)

        # prepare environment
        place = fluid.CUDAPlace(0) if self.cfg.use_gpu else fluid.CPUPlace()

        A_py_reader.decorate_batch_generator(
            self.A_reader,
            places=fluid.cuda_places()
            if self.cfg.use_gpu else fluid.cpu_places())
        B_py_reader.decorate_batch_generator(
            self.B_reader,
            places=fluid.cuda_places()
            if self.cfg.use_gpu else fluid.cpu_places())

        exe = fluid.Executor(place)
        exe.run(fluid.default_startup_program())

        A_pool = utility.ImagePool()
        B_pool = utility.ImagePool()

        if self.cfg.init_model:
            utility.init_checkpoints(self.cfg, exe, gen_trainer, "net_G")
            utility.init_checkpoints(self.cfg, exe, d_A_trainer, "net_DA")
            utility.init_checkpoints(self.cfg, exe, d_B_trainer, "net_DB")

        ### memory optim
        build_strategy = fluid.BuildStrategy()
        build_strategy.enable_inplace = True

        gen_trainer_program = fluid.CompiledProgram(
            gen_trainer.program).with_data_parallel(
                loss_name=gen_trainer.g_loss.name,
                build_strategy=build_strategy)
        d_A_trainer_program = fluid.CompiledProgram(
            d_A_trainer.program).with_data_parallel(
                loss_name=d_A_trainer.d_loss_A.name,
                build_strategy=build_strategy)
        d_B_trainer_program = fluid.CompiledProgram(
            d_B_trainer.program).with_data_parallel(
                loss_name=d_B_trainer.d_loss_B.name,
                build_strategy=build_strategy)

        t_time = 0

        for epoch_id in range(self.cfg.epoch):
            batch_id = 0
            for data_A, data_B in zip(A_py_reader(), B_py_reader()):
                s_time = time.time()
                tensor_A, tensor_B = data_A[0]['input_A'], data_B[0]['input_B']
                ## optimize the g_A network
                g_A_loss, g_A_cyc_loss, g_A_idt_loss, g_B_loss, g_B_cyc_loss,\
                g_B_idt_loss, fake_A_tmp, fake_B_tmp = exe.run(
                    gen_trainer_program,
                    fetch_list=[
                        gen_trainer.G_A, gen_trainer.cyc_A_loss,
                        gen_trainer.idt_loss_A, gen_trainer.G_B,
                        gen_trainer.cyc_B_loss, gen_trainer.idt_loss_B,
                        gen_trainer.fake_A, gen_trainer.fake_B
                    ],
                    feed={"input_A": tensor_A,
                          "input_B": tensor_B})

                fake_pool_B = B_pool.pool_image(fake_B_tmp)
                fake_pool_A = A_pool.pool_image(fake_A_tmp)

                if self.cfg.enable_ce:
                    fake_pool_B = fake_B_tmp
                    fake_pool_A = fake_A_tmp

                # optimize the d_A network
                d_A_loss = exe.run(
                    d_A_trainer_program,
                    fetch_list=[d_A_trainer.d_loss_A],
                    feed={"input_B": tensor_B,
                          "fake_pool_B": fake_pool_B})[0]

                # optimize the d_B network
                d_B_loss = exe.run(
                    d_B_trainer_program,
                    fetch_list=[d_B_trainer.d_loss_B],
                    feed={"input_A": tensor_A,
                          "fake_pool_A": fake_pool_A})[0]

                batch_time = time.time() - s_time
                t_time += batch_time
                if batch_id % self.cfg.print_freq == 0:
                    print("epoch{}: batch{}: \n\
                         d_A_loss: {}; g_A_loss: {}; g_A_cyc_loss: {}; g_A_idt_loss: {}; \n\
                         d_B_loss: {}; g_B_loss: {}; g_B_cyc_loss: {}; g_B_idt_loss: {}; \n\
                         Batch_time_cost: {}".format(
                        epoch_id, batch_id, d_A_loss[0], g_A_loss[0],
                        g_A_cyc_loss[0], g_A_idt_loss[0], d_B_loss[0], g_B_loss[
                            0], g_B_cyc_loss[0], g_B_idt_loss[0], batch_time))

                sys.stdout.flush()
                batch_id += 1
                # used for continuous evaluation
                if self.cfg.enable_ce and batch_id == 10:
                    break

            if self.cfg.run_test:
                A_image_name = fluid.data(
                    name='A_image_name', shape=[None, 1], dtype='int32')
                B_image_name = fluid.data(
                    name='B_image_name', shape=[None, 1], dtype='int32')
                A_test_py_reader = fluid.io.PyReader(
                    feed_list=[input_A, A_image_name],
                    capacity=4,
                    iterable=True,
                    use_double_buffer=True)

                B_test_py_reader = fluid.io.PyReader(
                    feed_list=[input_B, B_image_name],
                    capacity=4,
                    iterable=True,
                    use_double_buffer=True)

                A_test_py_reader.decorate_batch_generator(
                    self.A_test_reader,
                    places=fluid.cuda_places()
                    if self.cfg.use_gpu else fluid.cpu_places())
                B_test_py_reader.decorate_batch_generator(
                    self.B_test_reader,
                    places=fluid.cuda_places()
                    if self.cfg.use_gpu else fluid.cpu_places())
                test_program = gen_trainer.infer_program
                utility.save_test_image(
                    epoch_id,
                    self.cfg,
                    exe,
                    place,
                    test_program,
                    gen_trainer,
                    A_test_py_reader,
                    B_test_py_reader,
                    A_id2name=self.A_id2name,
                    B_id2name=self.B_id2name)

            if self.cfg.save_checkpoints:
                utility.checkpoints(epoch_id, self.cfg, exe, gen_trainer,
                                    "net_G")
                utility.checkpoints(epoch_id, self.cfg, exe, d_A_trainer,
                                    "net_DA")
                utility.checkpoints(epoch_id, self.cfg, exe, d_B_trainer,
                                    "net_DB")

        # used for continuous evaluation
        if self.cfg.enable_ce:
            device_num = fluid.core.get_cuda_device_count(
            ) if self.cfg.use_gpu else 1
            print("kpis\tcyclegan_g_A_loss_card{}\t{}".format(device_num,
                                                              g_A_loss[0]))
            print("kpis\tcyclegan_g_A_cyc_loss_card{}\t{}".format(
                device_num, g_A_cyc_loss[0]))
            print("kpis\tcyclegan_g_A_idt_loss_card{}\t{}".format(
                device_num, g_A_idt_loss[0]))
            print("kpis\tcyclegan_d_A_loss_card{}\t{}".format(device_num,
                                                              d_A_loss[0]))
            print("kpis\tcyclegan_g_B_loss_card{}\t{}".format(device_num,
                                                              g_B_loss[0]))
            print("kpis\tcyclegan_g_B_cyc_loss_card{}\t{}".format(
                device_num, g_B_cyc_loss[0]))
            print("kpis\tcyclegan_g_B_idt_loss_card{}\t{}".format(
                device_num, g_B_idt_loss[0]))
            print("kpis\tcyclegan_d_B_loss_card{}\t{}".format(device_num,
                                                              d_B_loss[0]))
            print("kpis\tcyclegan_Batch_time_cost_card{}\t{}".format(
                device_num, batch_time))
Exemplo n.º 2
0
    def build_model(self):
        data_shape = [-1, 3, self.cfg.crop_size, self.cfg.crop_size]

        input_A = fluid.layers.data(name='input_A',
                                    shape=data_shape,
                                    dtype='float32')
        input_B = fluid.layers.data(name='input_B',
                                    shape=data_shape,
                                    dtype='float32')
        fake_pool_A = fluid.layers.data(name='fake_pool_A',
                                        shape=data_shape,
                                        dtype='float32')
        fake_pool_B = fluid.layers.data(name='fake_pool_B',
                                        shape=data_shape,
                                        dtype='float32')

        gen_trainer = GTrainer(input_A, input_B, self.cfg, self.batch_num)
        d_A_trainer = DATrainer(input_B, fake_pool_B, self.cfg, self.batch_num)
        d_B_trainer = DBTrainer(input_A, fake_pool_A, self.cfg, self.batch_num)

        # prepare environment
        place = fluid.CUDAPlace(0) if self.cfg.use_gpu else fluid.CPUPlace()
        exe = fluid.Executor(place)
        exe.run(fluid.default_startup_program())

        A_pool = utility.ImagePool()
        B_pool = utility.ImagePool()

        if self.cfg.init_model:
            utility.init_checkpoints(self.cfg, exe, gen_trainer, "net_G")
            utility.init_checkpoints(self.cfg, exe, d_A_trainer, "net_DA")
            utility.init_checkpoints(self.cfg, exe, d_B_trainer, "net_DB")

        ### memory optim
        build_strategy = fluid.BuildStrategy()
        build_strategy.enable_inplace = False
        build_strategy.memory_optimize = False

        gen_trainer_program = fluid.CompiledProgram(
            gen_trainer.program).with_data_parallel(
                loss_name=gen_trainer.g_loss.name,
                build_strategy=build_strategy)
        d_A_trainer_program = fluid.CompiledProgram(
            d_A_trainer.program).with_data_parallel(
                loss_name=d_A_trainer.d_loss_A.name,
                build_strategy=build_strategy)
        d_B_trainer_program = fluid.CompiledProgram(
            d_B_trainer.program).with_data_parallel(
                loss_name=d_B_trainer.d_loss_B.name,
                build_strategy=build_strategy)

        losses = [[], []]
        t_time = 0

        for epoch_id in range(self.cfg.epoch):
            batch_id = 0
            for i in range(self.batch_num):
                data_A = next(self.A_reader())
                data_B = next(self.B_reader())
                tensor_A = fluid.LoDTensor()
                tensor_B = fluid.LoDTensor()
                tensor_A.set(data_A, place)
                tensor_B.set(data_B, place)
                s_time = time.time()
                # optimize the g_A network
                g_A_loss, g_A_cyc_loss, g_A_idt_loss, g_B_loss, g_B_cyc_loss,\
                g_B_idt_loss, fake_A_tmp, fake_B_tmp = exe.run(
                    gen_trainer_program,
                    fetch_list=[
                        gen_trainer.G_A, gen_trainer.cyc_A_loss,
                        gen_trainer.idt_loss_A, gen_trainer.G_B,
                        gen_trainer.cyc_B_loss, gen_trainer.idt_loss_B,
                        gen_trainer.fake_A, gen_trainer.fake_B
                    ],
                    feed={"input_A": tensor_A,
                          "input_B": tensor_B})

                fake_pool_B = B_pool.pool_image(fake_B_tmp)
                fake_pool_A = A_pool.pool_image(fake_A_tmp)

                # optimize the d_A network
                d_A_loss = exe.run(d_A_trainer_program,
                                   fetch_list=[d_A_trainer.d_loss_A],
                                   feed={
                                       "input_B": tensor_B,
                                       "fake_pool_B": fake_pool_B
                                   })[0]

                # optimize the d_B network
                d_B_loss = exe.run(d_B_trainer_program,
                                   fetch_list=[d_B_trainer.d_loss_B],
                                   feed={
                                       "input_A": tensor_A,
                                       "fake_pool_A": fake_pool_A
                                   })[0]

                batch_time = time.time() - s_time
                t_time += batch_time
                if batch_id % self.cfg.print_freq == 0:
                    print("epoch{}: batch{}: \n\
                         d_A_loss: {}; g_A_loss: {}; g_A_cyc_loss: {}; g_A_idt_loss: {}; \n\
                         d_B_loss: {}; g_B_loss: {}; g_B_cyc_loss: {}; g_B_idt_loss: {}; \n\
                         Batch_time_cost: {:.2f}".format(
                        epoch_id, batch_id, d_A_loss[0], g_A_loss[0],
                        g_A_cyc_loss[0], g_A_idt_loss[0], d_B_loss[0],
                        g_B_loss[0], g_B_cyc_loss[0], g_B_idt_loss[0],
                        batch_time))

                losses[0].append(g_A_loss[0])
                losses[1].append(d_A_loss[0])
                sys.stdout.flush()
                batch_id += 1

            if self.cfg.run_test:
                test_program = gen_trainer.infer_program
                utility.save_test_image(epoch_id, self.cfg, exe, place,
                                        test_program, gen_trainer,
                                        self.A_test_reader, self.B_test_reader)

            if self.cfg.save_checkpoints:
                utility.checkpoints(epoch_id, self.cfg, exe, gen_trainer,
                                    "net_G")
                utility.checkpoints(epoch_id, self.cfg, exe, d_A_trainer,
                                    "net_DA")
                utility.checkpoints(epoch_id, self.cfg, exe, d_B_trainer,
                                    "net_DB")