예제 #1
0
def infer(args):
    data_shape = [-1, 3, args.image_size, args.image_size]
    input = fluid.layers.data(name='input', shape=data_shape, dtype='float32')
    label_org_ = fluid.layers.data(name='label_org_',
                                   shape=[args.c_dim],
                                   dtype='float32')
    label_trg_ = fluid.layers.data(name='label_trg_',
                                   shape=[args.c_dim],
                                   dtype='float32')
    image_name = fluid.layers.data(name='image_name',
                                   shape=[args.n_samples],
                                   dtype='int32')

    model_name = 'net_G'
    if args.model_net == 'CycleGAN':
        py_reader = fluid.io.PyReader(
            feed_list=[input, image_name],
            capacity=4,  ## batch_size * 4
            iterable=True,
            use_double_buffer=True)
        from network.CycleGAN_network import CycleGAN_model
        model = CycleGAN_model()
        if args.input_style == "A":
            fake = model.network_G(input, name="GA", cfg=args)
        elif args.input_style == "B":
            fake = model.network_G(input, name="GB", cfg=args)
        else:
            raise "Input with style [%s] is not supported." % args.input_style
    elif args.model_net == 'Pix2pix':
        py_reader = fluid.io.PyReader(
            feed_list=[input, image_name],
            capacity=4,  ## batch_size * 4
            iterable=True,
            use_double_buffer=True)

        from network.Pix2pix_network import Pix2pix_model
        model = Pix2pix_model()
        fake = model.network_G(input, "generator", cfg=args)
    elif args.model_net == 'StarGAN':

        py_reader = fluid.io.PyReader(
            feed_list=[input, label_org_, label_trg_, image_name],
            capacity=32,
            iterable=True,
            use_double_buffer=True)

        from network.StarGAN_network import StarGAN_model
        model = StarGAN_model()
        fake = model.network_G(input, label_trg_, name="g_main", cfg=args)
    elif args.model_net == 'STGAN':
        from network.STGAN_network import STGAN_model

        py_reader = fluid.io.PyReader(
            feed_list=[input, label_org_, label_trg_, image_name],
            capacity=32,
            iterable=True,
            use_double_buffer=True)

        model = STGAN_model()
        fake, _ = model.network_G(input,
                                  label_org_,
                                  label_trg_,
                                  cfg=args,
                                  name='generator',
                                  is_test=True)
    elif args.model_net == 'AttGAN':
        from network.AttGAN_network import AttGAN_model

        py_reader = fluid.io.PyReader(
            feed_list=[input, label_org_, label_trg_, image_name],
            capacity=32,
            iterable=True,
            use_double_buffer=True)

        model = AttGAN_model()
        fake, _ = model.network_G(input,
                                  label_org_,
                                  label_trg_,
                                  cfg=args,
                                  name='generator',
                                  is_test=True)
    elif args.model_net == 'CGAN':
        noise = fluid.layers.data(name='noise',
                                  shape=[args.noise_size],
                                  dtype='float32')
        conditions = fluid.layers.data(name='conditions',
                                       shape=[1],
                                       dtype='float32')

        from network.CGAN_network import CGAN_model
        model = CGAN_model(args.n_samples)
        fake = model.network_G(noise, conditions, name="G")
    elif args.model_net == 'DCGAN':
        noise = fluid.layers.data(name='noise',
                                  shape=[args.noise_size],
                                  dtype='float32')

        from network.DCGAN_network import DCGAN_model
        model = DCGAN_model(args.n_samples)
        fake = model.network_G(noise, name="G")
    else:
        raise NotImplementedError("model_net {} is not support".format(
            args.model_net))

    def _compute_start_end(image_name):
        image_name_start = np.array(image_name)[0].astype('int32')
        image_name_end = image_name_start + args.n_samples - 1
        image_name_save = str(np.array(image_name)[0].astype('int32')) + '.jpg'
        print("read {}.jpg ~ {}.jpg".format(image_name_start, image_name_end))
        return image_name_save

    # prepare environment
    place = fluid.CPUPlace()
    if args.use_gpu:
        place = fluid.CUDAPlace(0)
    exe = fluid.Executor(place)
    exe.run(fluid.default_startup_program())
    for var in fluid.default_main_program().global_block().all_parameters():
        print(var.name)
    print(args.init_model + '/' + model_name)
    fluid.io.load_persistables(exe, os.path.join(args.init_model, model_name))
    print('load params done')
    if not os.path.exists(args.output):
        os.makedirs(args.output)

    attr_names = args.selected_attrs.split(',')

    if args.model_net == 'AttGAN' or args.model_net == 'STGAN':
        test_reader = celeba_reader_creator(image_dir=args.dataset_dir,
                                            list_filename=args.test_list,
                                            args=args,
                                            mode="VAL")
        reader_test = test_reader.make_reader(return_name=True)
        py_reader.decorate_batch_generator(
            reader_test,
            places=fluid.cuda_places() if args.use_gpu else fluid.cpu_places())
        for data in py_reader():
            real_img, label_org, label_trg, image_name = data[0][
                'input'], data[0]['label_org_'], data[0]['label_trg_'], data[
                    0]['image_name']
            image_name_save = _compute_start_end(image_name)
            real_img_temp = save_batch_image(np.array(real_img))
            images = [real_img_temp]
            for i in range(args.c_dim):
                label_trg_tmp = copy.deepcopy(np.array(label_trg))
                for j in range(len(label_trg_tmp)):
                    label_trg_tmp[j][i] = 1.0 - label_trg_tmp[j][i]
                    label_trg_tmp = check_attribute_conflict(
                        label_trg_tmp, attr_names[i], attr_names)
                label_org_tmp = list(
                    map(lambda x: ((x * 2) - 1) * 0.5, np.array(label_org)))
                label_trg_tmp = list(
                    map(lambda x: ((x * 2) - 1) * 0.5, label_trg_tmp))
                if args.model_net == 'AttGAN':
                    for k in range(len(label_trg_tmp)):
                        label_trg_tmp[k][i] = label_trg_tmp[k][i] * 2.0
                tensor_label_org_ = fluid.LoDTensor()
                tensor_label_trg_ = fluid.LoDTensor()
                tensor_label_org_.set(label_org_tmp, place)
                tensor_label_trg_.set(label_trg_tmp, place)
                out = exe.run(feed={
                    "input": real_img,
                    "label_org_": tensor_label_org_,
                    "label_trg_": tensor_label_trg_
                },
                              fetch_list=[fake.name])
                fake_temp = save_batch_image(out[0])
                images.append(fake_temp)
            images_concat = np.concatenate(images, 1)
            if len(np.array(label_org)) > 1:
                images_concat = np.concatenate(images_concat, 1)
            imageio.imwrite(
                os.path.join(args.output, "fake_img_" + image_name_save),
                ((images_concat + 1) * 127.5).astype(np.uint8))
    elif args.model_net == 'StarGAN':
        test_reader = celeba_reader_creator(image_dir=args.dataset_dir,
                                            list_filename=args.test_list,
                                            args=args,
                                            mode="VAL")
        reader_test = test_reader.make_reader(return_name=True)
        py_reader.decorate_batch_generator(
            reader_test,
            places=fluid.cuda_places() if args.use_gpu else fluid.cpu_places())
        for data in py_reader():
            real_img, label_org, label_trg, image_name = data[0][
                'input'], data[0]['label_org_'], data[0]['label_trg_'], data[
                    0]['image_name']
            image_name_save = _compute_start_end(image_name)
            real_img_temp = save_batch_image(np.array(real_img))
            images = [real_img_temp]
            for i in range(args.c_dim):
                label_trg_tmp = copy.deepcopy(np.array(label_org))
                for j in range(len(np.array(label_org))):
                    label_trg_tmp[j][i] = 1.0 - label_trg_tmp[j][i]
                    label_trg_tmp = check_attribute_conflict(
                        label_trg_tmp, attr_names[i], attr_names)
                tensor_label_trg_ = fluid.LoDTensor()
                tensor_label_trg_.set(label_trg_tmp, place)
                out = exe.run(feed={
                    "input": real_img,
                    "label_trg_": tensor_label_trg_
                },
                              fetch_list=[fake.name])
                fake_temp = save_batch_image(out[0])
                images.append(fake_temp)
            images_concat = np.concatenate(images, 1)
            if len(np.array(label_org)) > 1:
                images_concat = np.concatenate(images_concat, 1)
            imageio.imwrite(
                os.path.join(args.output, "fake_img_" + image_name_save),
                ((images_concat + 1) * 127.5).astype(np.uint8))

    elif args.model_net == 'Pix2pix' or args.model_net == 'CycleGAN':
        test_reader = reader_creator(image_dir=args.dataset_dir,
                                     list_filename=args.test_list,
                                     shuffle=False,
                                     batch_size=args.n_samples,
                                     mode="VAL")
        reader_test = test_reader.make_reader(args, return_name=True)
        py_reader.decorate_batch_generator(
            reader_test,
            places=fluid.cuda_places() if args.use_gpu else fluid.cpu_places())
        id2name = test_reader.id2name
        for data in py_reader():
            real_img, image_name = data[0]['input'], data[0]['image_name']
            image_name = id2name[np.array(image_name).astype('int32')[0]]
            print("read: ", image_name)
            fake_temp = exe.run(fetch_list=[fake.name],
                                feed={"input": real_img})
            fake_temp = np.squeeze(fake_temp[0]).transpose([1, 2, 0])
            input_temp = np.squeeze(np.array(real_img)[0]).transpose([1, 2, 0])

            imageio.imwrite(os.path.join(args.output, "fake_" + image_name),
                            ((fake_temp + 1) * 127.5).astype(np.uint8))

    elif args.model_net == 'CGAN':
        noise_data = np.random.uniform(low=-1.0,
                                       high=1.0,
                                       size=[args.n_samples, args.noise_size
                                             ]).astype('float32')
        label = np.random.randint(0, 9, size=[args.n_samples,
                                              1]).astype('float32')
        noise_tensor = fluid.LoDTensor()
        conditions_tensor = fluid.LoDTensor()
        noise_tensor.set(noise_data, place)
        conditions_tensor.set(label, place)
        fake_temp = exe.run(fetch_list=[fake.name],
                            feed={
                                "noise": noise_tensor,
                                "conditions": conditions_tensor
                            })[0]
        fake_image = np.reshape(fake_temp, (args.n_samples, -1))

        fig = utility.plot(fake_image)
        plt.savefig(os.path.join(args.output, 'fake_cgan.png'),
                    bbox_inches='tight')
        plt.close(fig)

    elif args.model_net == 'DCGAN':
        noise_data = np.random.uniform(low=-1.0,
                                       high=1.0,
                                       size=[args.n_samples, args.noise_size
                                             ]).astype('float32')
        noise_tensor = fluid.LoDTensor()
        noise_tensor.set(noise_data, place)
        fake_temp = exe.run(fetch_list=[fake.name],
                            feed={"noise": noise_tensor})[0]
        fake_image = np.reshape(fake_temp, (args.n_samples, -1))

        fig = utility.plot(fake_image)
        plt.savefig(os.path.join(args.output, 'fake_dcgan.png'),
                    bbox_inches='tight')
        plt.close(fig)
    else:
        raise NotImplementedError("model_net {} is not support".format(
            args.model_net))
def infer(args):
    data_shape = [-1, 3, args.image_size, args.image_size]
    input = fluid.layers.data(name='input', shape=data_shape, dtype='float32')
    label_org_ = fluid.layers.data(
        name='label_org_', shape=[args.c_dim], dtype='float32')
    label_trg_ = fluid.layers.data(
        name='label_trg_', shape=[args.c_dim], dtype='float32')

    model_name = 'net_G'
    if args.model_net == 'CycleGAN':
        from network.CycleGAN_network import CycleGAN_model
        model = CycleGAN_model()
        if args.input_style == "A":
            fake = model.network_G(input, name="GA", cfg=args)
        elif args.input_style == "B":
            fake = model.network_G(input, name="GB", cfg=args)
        else:
            raise "Input with style [%s] is not supported." % args.input_style
    elif args.model_net == 'Pix2pix':
        from network.Pix2pix_network import Pix2pix_model
        model = Pix2pix_model()
        fake = model.network_G(input, "generator", cfg=args)
    elif args.model_net == 'StarGAN':
        from network.StarGAN_network import StarGAN_model
        model = StarGAN_model()
        fake = model.network_G(input, label_trg_, name="g_main", cfg=args)
    elif args.model_net == 'STGAN':
        from network.STGAN_network import STGAN_model
        model = STGAN_model()
        fake, _ = model.network_G(
            input,
            label_org_,
            label_trg_,
            cfg=args,
            name='generator',
            is_test=True)
    elif args.model_net == 'AttGAN':
        from network.AttGAN_network import AttGAN_model
        model = AttGAN_model()
        fake, _ = model.network_G(
            input,
            label_org_,
            label_trg_,
            cfg=args,
            name='generator',
            is_test=True)
    elif args.model_net == 'CGAN':
        noise = fluid.layers.data(
            name='noise', shape=[args.noise_size], dtype='float32')
        conditions = fluid.layers.data(
            name='conditions', shape=[1], dtype='float32')

        from network.CGAN_network import CGAN_model
        model = CGAN_model()
        fake = model.network_G(noise, conditions, name="G")
    elif args.model_net == 'DCGAN':
        noise = fluid.layers.data(
            name='noise', shape=[args.noise_size], dtype='float32')

        from network.DCGAN_network import DCGAN_model
        model = DCGAN_model()
        fake = model.network_G(noise, name="G")
    else:
        raise NotImplementedError("model_net {} is not support".format(
            args.model_net))

    # prepare environment
    place = fluid.CPUPlace()
    if args.use_gpu:
        place = fluid.CUDAPlace(0)
    exe = fluid.Executor(place)
    exe.run(fluid.default_startup_program())
    for var in fluid.default_main_program().global_block().all_parameters():
        print(var.name)
    print(args.init_model + '/' + model_name)
    fluid.io.load_persistables(exe, args.init_model + "/" + model_name)
    print('load params done')
    if not os.path.exists(args.output):
        os.makedirs(args.output)

    attr_names = args.selected_attrs.split(',')

    if args.model_net == 'AttGAN' or args.model_net == 'STGAN':
        test_reader = celeba_reader_creator(
            image_dir=args.dataset_dir,
            list_filename=args.test_list,
            batch_size=args.batch_size,
            drop_last=False,
            args=args)
        reader_test = test_reader.get_test_reader(
            args, shuffle=False, return_name=True)
        for data in zip(reader_test()):
            real_img, label_org, name = data[0]
            print("read {}".format(name))
            label_trg = copy.deepcopy(label_org)
            tensor_img = fluid.LoDTensor()
            tensor_label_org = fluid.LoDTensor()
            tensor_label_trg = fluid.LoDTensor()
            tensor_label_org_ = fluid.LoDTensor()
            tensor_label_trg_ = fluid.LoDTensor()
            tensor_img.set(real_img, place)
            tensor_label_org.set(label_org, place)
            real_img_temp = save_batch_image(real_img)
            images = [real_img_temp]
            for i in range(args.c_dim):
                label_trg_tmp = copy.deepcopy(label_trg)
                for j in range(len(label_org)):
                    label_trg_tmp[j][i] = 1.0 - label_trg_tmp[j][i]
                    label_trg_tmp = check_attribute_conflict(
                        label_trg_tmp, attr_names[i], attr_names)
                label_org_ = list(map(lambda x: ((x * 2) - 1) * 0.5, label_org))
                label_trg_ = list(
                    map(lambda x: ((x * 2) - 1) * 0.5, label_trg_tmp))
                if args.model_net == 'AttGAN':
                    for k in range(len(label_org)):
                        label_trg_[k][i] = label_trg_[k][i] * 2.0
                tensor_label_org_.set(label_org_, place)
                tensor_label_trg.set(label_trg, place)
                tensor_label_trg_.set(label_trg_, place)
                out = exe.run(feed={
                    "input": tensor_img,
                    "label_org_": tensor_label_org_,
                    "label_trg_": tensor_label_trg_
                },
                              fetch_list=[fake.name])
                fake_temp = save_batch_image(out[0])
                images.append(fake_temp)
            images_concat = np.concatenate(images, 1)
            if len(label_org) > 1:
                images_concat = np.concatenate(images_concat, 1)
            imageio.imwrite(args.output + "/fake_img_" + name[0], (
                (images_concat + 1) * 127.5).astype(np.uint8))
    elif args.model_net == 'StarGAN':
        test_reader = celeba_reader_creator(
            image_dir=args.dataset_dir,
            list_filename=args.test_list,
            batch_size=args.batch_size,
            drop_last=False,
            args=args)
        reader_test = test_reader.get_test_reader(
            args, shuffle=False, return_name=True)
        for data in zip(reader_test()):
            real_img, label_org, name = data[0]
            print("read {}".format(name))
            tensor_img = fluid.LoDTensor()
            tensor_label_org = fluid.LoDTensor()
            tensor_img.set(real_img, place)
            tensor_label_org.set(label_org, place)
            real_img_temp = save_batch_image(real_img)
            images = [real_img_temp]
            for i in range(args.c_dim):
                label_trg_tmp = copy.deepcopy(label_org)
                for j in range(len(label_org)):
                    label_trg_tmp[j][i] = 1.0 - label_trg_tmp[j][i]
                    label_trg = check_attribute_conflict(
                        label_trg_tmp, attr_names[i], attr_names)
                tensor_label_trg = fluid.LoDTensor()
                tensor_label_trg.set(label_trg, place)
                out = exe.run(
                    feed={"input": tensor_img,
                          "label_trg_": tensor_label_trg},
                    fetch_list=[fake.name])
                fake_temp = save_batch_image(out[0])
                images.append(fake_temp)
            images_concat = np.concatenate(images, 1)
            if len(label_org) > 1:
                images_concat = np.concatenate(images_concat, 1)
            imageio.imwrite(args.output + "/fake_img_" + name[0], (
                (images_concat + 1) * 127.5).astype(np.uint8))

    elif args.model_net == 'Pix2pix' or args.model_net == 'CycleGAN':
        for file in glob.glob(args.dataset_dir):
            print("read {}".format(file))
            image_name = os.path.basename(file)
            image = Image.open(file).convert('RGB')
            image = image.resize((256, 256), Image.BICUBIC)
            image = np.array(image).transpose([2, 0, 1]).astype('float32')
            image = image / 255.0
            image = (image - 0.5) / 0.5
            data = image[np.newaxis, :]
            tensor = fluid.LoDTensor()
            tensor.set(data, place)

            fake_temp = exe.run(fetch_list=[fake.name], feed={"input": tensor})
            fake_temp = np.squeeze(fake_temp[0]).transpose([1, 2, 0])
            input_temp = np.squeeze(data).transpose([1, 2, 0])

            imageio.imwrite(args.output + "/fake_" + image_name, (
                (fake_temp + 1) * 127.5).astype(np.uint8))

    elif args.model_net == 'CGAN':
        noise_data = np.random.uniform(
            low=-1.0, high=1.0,
            size=[args.batch_size, args.noise_size]).astype('float32')
        label = np.random.randint(
            0, 9, size=[args.batch_size, 1]).astype('float32')
        noise_tensor = fluid.LoDTensor()
        conditions_tensor = fluid.LoDTensor()
        noise_tensor.set(noise_data, place)
        conditions_tensor.set(label, place)
        fake_temp = exe.run(
            fetch_list=[fake.name],
            feed={"noise": noise_tensor,
                  "conditions": conditions_tensor})[0]
        fake_image = np.reshape(fake_temp, (args.batch_size, -1))

        fig = utility.plot(fake_image)
        plt.savefig(args.output + '/fake_cgan.png', bbox_inches='tight')
        plt.close(fig)

    elif args.model_net == 'DCGAN':
        noise_data = np.random.uniform(
            low=-1.0, high=1.0,
            size=[args.batch_size, args.noise_size]).astype('float32')
        noise_tensor = fluid.LoDTensor()
        noise_tensor.set(noise_data, place)
        fake_temp = exe.run(fetch_list=[fake.name],
                            feed={"noise": noise_tensor})[0]
        fake_image = np.reshape(fake_temp, (args.batch_size, -1))

        fig = utility.plot(fake_image)
        plt.savefig(args.output + '/fake_dcgan.png', bbox_inches='tight')
        plt.close(fig)
    else:
        raise NotImplementedError("model_net {} is not support".format(
            args.model_net))
예제 #3
0
    def build_model(self):

        img = fluid.layers.data(name='img', shape=[784], dtype='float32')
        condition = fluid.layers.data(name='condition',
                                      shape=[1],
                                      dtype='float32')
        noise = fluid.layers.data(name='noise',
                                  shape=[self.cfg.noise_size],
                                  dtype='float32')
        label = fluid.layers.data(name='label', shape=[1], dtype='float32')

        g_trainer = GTrainer(noise, condition, self.cfg)
        d_trainer = DTrainer(img, condition, label, self.cfg)

        # prepare environment
        place = fluid.CUDAPlace(0) if self.cfg.use_gpu else fluid.CPUPlace()
        exe = fluid.Executor(place)
        exe.run(fluid.default_startup_program())

        const_n = np.random.uniform(
            low=-1.0,
            high=1.0,
            size=[self.cfg.batch_size, self.cfg.noise_size]).astype('float32')

        if self.cfg.init_model:
            utility.init_checkpoints(self.cfg, exe, g_trainer, "net_G")
            utility.init_checkpoints(self.cfg, exe, d_trainer, "net_D")

        ### memory optim
        build_strategy = fluid.BuildStrategy()
        build_strategy.enable_inplace = True

        g_trainer_program = fluid.CompiledProgram(
            g_trainer.program).with_data_parallel(
                loss_name=g_trainer.g_loss.name, build_strategy=build_strategy)
        d_trainer_program = fluid.CompiledProgram(
            d_trainer.program).with_data_parallel(
                loss_name=d_trainer.d_loss.name, build_strategy=build_strategy)

        t_time = 0
        losses = [[], []]
        for epoch_id in range(self.cfg.epoch):
            for batch_id, data in enumerate(self.train_reader()):
                if len(data) != self.cfg.batch_size:
                    continue

                noise_data = np.random.uniform(
                    low=-1.0,
                    high=1.0,
                    size=[self.cfg.batch_size,
                          self.cfg.noise_size]).astype('float32')
                real_image = np.array(list(map(lambda x: x[0], data))).reshape(
                    [-1, 784]).astype('float32')
                condition_data = np.array([x[1] for x in data
                                           ]).reshape([-1,
                                                       1]).astype('float32')
                real_label = np.ones(shape=[real_image.shape[0], 1],
                                     dtype='float32')
                fake_label = np.zeros(shape=[real_image.shape[0], 1],
                                      dtype='float32')
                s_time = time.time()

                generate_image = exe.run(g_trainer.infer_program,
                                         feed={
                                             'noise': noise_data,
                                             'condition': condition_data
                                         },
                                         fetch_list=[g_trainer.fake])

                d_real_loss = exe.run(d_trainer_program,
                                      feed={
                                          'img': real_image,
                                          'condition': condition_data,
                                          'label': real_label
                                      },
                                      fetch_list=[d_trainer.d_loss])[0]
                d_fake_loss = exe.run(d_trainer_program,
                                      feed={
                                          'img': generate_image,
                                          'condition': condition_data,
                                          'label': fake_label
                                      },
                                      fetch_list=[d_trainer.d_loss])[0]
                d_loss = d_real_loss + d_fake_loss
                losses[1].append(d_loss)

                for _ in six.moves.xrange(self.cfg.num_generator_time):
                    g_loss = exe.run(g_trainer_program,
                                     feed={
                                         'noise': noise_data,
                                         'condition': condition_data
                                     },
                                     fetch_list=[g_trainer.g_loss])[0]
                    losses[0].append(g_loss)

                batch_time = time.time() - s_time
                t_time += batch_time

                if batch_id % self.cfg.print_freq == 0:
                    image_path = os.path.join(self.cfg.output, 'images')
                    if not os.path.exists(image_path):
                        os.makedirs(image_path)
                    generate_const_image = exe.run(g_trainer.infer_program,
                                                   feed={
                                                       'noise': const_n,
                                                       'condition':
                                                       condition_data
                                                   },
                                                   fetch_list=[g_trainer.fake
                                                               ])[0]

                    generate_image_reshape = np.reshape(
                        generate_const_image, (self.cfg.batch_size, -1))
                    total_images = np.concatenate(
                        [real_image, generate_image_reshape])
                    fig = utility.plot(total_images)
                    print(
                        'Epoch ID: {} Batch ID: {} D_loss: {} G_loss: {} Batch_time_cost: {}'
                        .format(epoch_id, batch_id, d_loss[0], g_loss[0],
                                batch_time))
                    plt.title('Epoch ID={}, Batch ID={}'.format(
                        epoch_id, batch_id))
                    img_name = '{:04d}_{:04d}.png'.format(epoch_id, batch_id)
                    plt.savefig(os.path.join(image_path, img_name),
                                bbox_inches='tight')
                    plt.close(fig)

            if self.cfg.save_checkpoints:
                utility.checkpoints(epoch_id, self.cfg, exe, g_trainer,
                                    "net_G")
                utility.checkpoints(epoch_id, self.cfg, exe, d_trainer,
                                    "net_D")
예제 #4
0
파일: DCGAN.py 프로젝트: zhong110020/models
    def build_model(self):
        img = fluid.data(name='img', shape=[None, 784], dtype='float32')
        noise = fluid.data(name='noise',
                           shape=[None, self.cfg.noise_size],
                           dtype='float32')
        label = fluid.data(name='label', shape=[None, 1], dtype='float32')
        # used for continuous evaluation
        if self.cfg.enable_ce:
            fluid.default_startup_program().random_seed = 90
            random.seed(0)
            np.random.seed(0)

        g_trainer = GTrainer(noise, label, self.cfg)
        d_trainer = DTrainer(img, label, self.cfg)

        # prepare enviorment
        place = fluid.CUDAPlace(0) if self.cfg.use_gpu else fluid.CPUPlace()
        exe = fluid.Executor(place)
        exe.run(fluid.default_startup_program())

        const_n = np.random.uniform(
            low=-1.0,
            high=1.0,
            size=[self.cfg.batch_size, self.cfg.noise_size]).astype('float32')

        if self.cfg.init_model:
            utility.init_checkpoints(self.cfg, g_trainer, "net_G")
            utility.init_checkpoints(self.cfg, d_trainer, "net_D")

        ### memory optim
        build_strategy = fluid.BuildStrategy()
        build_strategy.enable_inplace = True

        g_trainer_program = fluid.CompiledProgram(
            g_trainer.program).with_data_parallel(
                loss_name=g_trainer.g_loss.name, build_strategy=build_strategy)
        d_trainer_program = fluid.CompiledProgram(
            d_trainer.program).with_data_parallel(
                loss_name=d_trainer.d_loss.name, build_strategy=build_strategy)

        if self.cfg.run_test:
            image_path = os.path.join(self.cfg.output, 'test')
            if not os.path.exists(image_path):
                os.makedirs(image_path)

        t_time = 0
        for epoch_id in range(self.cfg.epoch):
            for batch_id, data in enumerate(self.train_reader()):
                if len(data) != self.cfg.batch_size:
                    continue

                noise_data = np.random.uniform(
                    low=-1.0,
                    high=1.0,
                    size=[self.cfg.batch_size,
                          self.cfg.noise_size]).astype('float32')
                real_image = np.array(list(map(lambda x: x[0], data))).reshape(
                    [-1, 784]).astype('float32')
                real_label = np.ones(shape=[real_image.shape[0], 1],
                                     dtype='float32')
                fake_label = np.zeros(shape=[real_image.shape[0], 1],
                                      dtype='float32')
                s_time = time.time()

                generate_image = exe.run(g_trainer_program,
                                         feed={'noise': noise_data},
                                         fetch_list=[g_trainer.fake])

                d_real_loss = exe.run(d_trainer_program,
                                      feed={
                                          'img': real_image,
                                          'label': real_label
                                      },
                                      fetch_list=[d_trainer.d_loss])[0]
                d_fake_loss = exe.run(d_trainer_program,
                                      feed={
                                          'img': generate_image[0],
                                          'label': fake_label
                                      },
                                      fetch_list=[d_trainer.d_loss])[0]
                d_loss = d_real_loss + d_fake_loss

                for _ in six.moves.xrange(self.cfg.num_generator_time):
                    noise_data = np.random.uniform(
                        low=-1.0,
                        high=1.0,
                        size=[self.cfg.batch_size,
                              self.cfg.noise_size]).astype('float32')
                    g_loss = exe.run(g_trainer_program,
                                     feed={'noise': noise_data},
                                     fetch_list=[g_trainer.g_loss])[0]

                batch_time = time.time() - s_time

                if batch_id % self.cfg.print_freq == 0:
                    print(
                        'Epoch ID: {} Batch ID: {} D_loss: {} G_loss: {} Batch_time_cost: {}'
                        .format(epoch_id, batch_id, d_loss[0], g_loss[0],
                                batch_time))

                t_time += batch_time

                if self.cfg.run_test:
                    generate_const_image = exe.run(g_trainer.infer_program,
                                                   feed={'noise': const_n},
                                                   fetch_list=[g_trainer.fake
                                                               ])[0]

                    generate_image_reshape = np.reshape(
                        generate_const_image, (self.cfg.batch_size, -1))
                    total_images = np.concatenate(
                        [real_image, generate_image_reshape])
                    fig = utility.plot(total_images)

                    plt.title('Epoch ID={}, Batch ID={}'.format(
                        epoch_id, batch_id))
                    img_name = '{:04d}_{:04d}.png'.format(epoch_id, batch_id)
                    plt.savefig(os.path.join(image_path, img_name),
                                bbox_inches='tight')
                    plt.close(fig)

            if self.cfg.save_checkpoints:
                utility.checkpoints(epoch_id, self.cfg, g_trainer, "net_G")
                utility.checkpoints(epoch_id, self.cfg, d_trainer, "net_D")
        # used for continuous evaluation
        if self.cfg.enable_ce:
            device_num = fluid.core.get_cuda_device_count(
            ) if self.cfg.use_gpu else 1
            print("kpis\tdcgan_d_loss_card{}\t{}".format(
                device_num, d_loss[0]))
            print("kpis\tdcgan_g_loss_card{}\t{}".format(
                device_num, g_loss[0]))
            print("kpis\tdcgan_Batch_time_cost_card{}\t{}".format(
                device_num, batch_time))