Пример #1
0
def main():
    G = Generator(z_dim=20)
    D = Discriminator(z_dim=20)
    E = Encoder(z_dim=20)
    G.apply(weights_init)
    D.apply(weights_init)
    E.apply(weights_init)

    train_img_list=make_datapath_list(num=200)
    mean = (0.5,)
    std = (0.5,)
    train_dataset = GAN_Img_Dataset(file_list=train_img_list, transform=ImageTransform(mean, std))

    batch_size = 64
    train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True)

    num_epochs = 1500
    G_update, D_update, E_update = train_model(G, D, E, dataloader=train_dataloader, num_epochs=num_epochs, save_model_name='Efficient_GAN')
Пример #2
0
def train(FLAGS):
    # Define the hyperparameters
    p_every = FLAGS.p_every
    s_every = FLAGS.s_every
    epochs = FLAGS.epochs
    dlr = FLAGS.dlr
    glr = FLAGS.glr
    beta1 = FLAGS.beta1
    beta2 = FLAGS.beta2
    z_size = FLAGS.zsize
    batch_size = FLAGS.batch_size
    rh = FLAGS.resize_height
    rw = FLAGS.resize_width
    d_path = FLAGS.dataset_path
    d_type = FLAGS.dataset_type

    # Preprocessing Data
    transform = transforms.Compose([
        transforms.Resize((rh, rw)),
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ])

    if FLAGS.dataset_path == None:
        if d_type == "cars":
            if not os.path.exists('./datasets/cars_train'):
                os.system('sh ./datasets/dload.sh cars')
            d_path = './datasets/cars_train/'

        elif d_type == "flowers":
            if not os.path.exists('./datasets/flowers/'):
                os.system('sh ./datasets/dload.sh flowers')
            d_path = './datasets/flowers/'

        elif d_type == "dogs":
            if not os.path.exists('./datasets/jpg'):
                os.system('sh ./datasets/dload.sh dogs')
            d_path = './datasets/jpg/'

    train_data = datasets.ImageFolder(d_path, transform=transform)
    trainloader = DataLoader(train_data, batch_size=batch_size, shuffle=True)

    # Define the D and G
    dis = Discriminator(64)
    gen = Generator()

    # Apply weight initialization
    dis.apply(init_weight)
    gen.apply(init_weight)

    # Define the loss function
    criterion = nn.BCELoss()

    # Optimizers
    d_opt = optim.Adam(dis.parameters(), lr=dlr, betas=(beta1, beta2))
    g_opt = optim.Adam(gen.parameters(), lr=glr, betas=(beta1, beta2))

    # Train loop
    device = torch.device(
        'cuda') if torch.cuda.is_available() else torch.device('cpu')

    train_losses = []
    eval_losses = []

    dis.to(device)
    gen.to(device)

    real_label = 1
    fake_label = 0

    for e in range(epochs):

        td_loss = 0
        tg_loss = 0

        for batch_i, (real_images, _) in enumerate(trainloader):

            real_images = real_images.to(device)

            batch_size = real_images.size(0)

            #### Train the Discriminator ####

            d_opt.zero_grad()

            d_real = dis(real_images)

            label = torch.full((batch_size, ), real_label, device=device)
            r_loss = criterion(d_real, label)
            r_loss.backward()

            z = torch.randn(batch_size, z_size, 1, 1, device=device)

            fake_images = gen(z)

            label.fill_(fake_label)

            d_fake = dis(fake_images.detach())

            f_loss = criterion(d_fake, label)
            f_loss.backward()

            d_loss = r_loss + f_loss

            d_opt.step()

            #### Train the Generator ####
            g_opt.zero_grad()

            label.fill_(real_label)
            d_fake2 = dis(fake_images)

            g_loss = criterion(d_fake2, label)
            g_loss.backward()

            g_opt.step()

            if batch_i % p_every == 0:
                print ('Epoch [{:5d} / {:5d}] | d_loss: {:6.4f} | g_loss: {:6.4f}'. \
                        format(e+1, epochs, d_loss, g_loss))

        train_losses.append([td_loss, tg_loss])

        if e % s_every == 0:
            d_ckpt = {
                'model_state_dict': dis.state_dict(),
                'opt_state_dict': d_opt.state_dict()
            }

            g_ckpt = {
                'model_state_dict': gen.state_dict(),
                'opt_state_dict': g_opt.state_dict()
            }

            torch.save(d_ckpt, 'd-nm-{}.pth'.format(e))
            torch.save(g_ckpt, 'g-nm-{}.pth'.format(e))

        utils.save_image(fake_images.detach(),
                         'fake_{}.png'.format(e),
                         normalize=True)

    print('[INFO] Training Completed successfully!')
Пример #3
0
class BaseExperiment:
    def __init__(self, args):
        self.args = args
        torch.manual_seed(self.args.seed)
        np.random.seed(self.args.seed)
        print('> Training arguments:')
        for arg in vars(args):
            print('>>> {}: {}'.format(arg, getattr(args, arg)))
        white_noise = dp.DatasetReader(white_noise=self.args.dataset,
                                       data_path=data_path,
                                       data_source=args.data,
                                       len_seg=self.args.len_seg)
        dataset, _ = white_noise(args.net_name)
        self.data_loader = DataLoader(dataset=dataset,
                                      batch_size=args.batch_size,
                                      shuffle=True)
        self.Generator = Generator(args)  # Generator
        self.Discriminator = Discriminator(args)  # Discriminator

    def select_optimizer(self, model):
        if self.args.optimizer == 'Adam':
            optimizer = optim.Adam(filter(lambda p: p.requires_grad,
                                          model.parameters()),
                                   lr=self.args.learning_rate,
                                   betas=(0.5, 0.9))
        elif self.args.optimizer == 'RMS':
            optimizer = optim.RMSprop(filter(lambda p: p.requires_grad,
                                             model.parameters()),
                                      lr=self.args.learning_rate)
        elif self.args.optimizer == 'SGD':
            optimizer = optim.SGD(filter(lambda p: p.requires_grad,
                                         model.parameters()),
                                  lr=self.args.learning_rate,
                                  momentum=0.9)
        elif self.args.optimizer == 'Adagrad':
            optimizer = optim.Adagrad(filter(lambda p: p.requires_grad,
                                             model.parameters()),
                                      lr=self.args.learning_rate)
        elif self.args.optimizer == 'Adadelta':
            optimizer = optim.Adadelta(filter(lambda p: p.requires_grad,
                                              model.parameters()),
                                       lr=self.args.learning_rate)
        return optimizer

    def weights_init(self, model):
        initializers = {
            'xavier_uniform_': nn.init.xavier_uniform_,
            'xavier_normal_': nn.init.xavier_normal,
            'orthogonal_': nn.init.orthogonal_,
            'kaiming_normal_': nn.init.kaiming_normal_
        }
        if isinstance(model, nn.Linear):
            initializer = initializers[self.args.initializer]
            initializer(model.weight)
            model.bias.data.fill_(0)

    def file_name(self):
        return '{}_{}_{}_{}_{}_{}'.format(
            self.args.model_name, self.args.net_name, self.args.len_seg,
            self.args.optimizer, self.args.learning_rate, self.args.num_epoch)

    def gradient_penalty(self, x_real, x_fake, batch_size, beta=0.3):
        x_real = x_real.detach()
        x_fake = x_fake.detach()
        alpha = torch.rand(batch_size, 1)
        alpha = alpha.expand_as(x_real)
        interpolates = alpha * x_real + ((1 - alpha) * x_fake)
        interpolates.requires_grad_()
        dis_interpolates = self.Discriminator(interpolates)
        gradients = autograd.grad(
            outputs=dis_interpolates,
            inputs=interpolates,
            grad_outputs=torch.ones_like(dis_interpolates),
            create_graph=True,
            retain_graph=True,
            only_inputs=True)[0]
        grad_penalty = ((gradients.norm(2, dim=1) - 1)**2).mean() * beta
        return grad_penalty

    def train(self):
        self.Generator.apply(self.weights_init)
        self.Discriminator.apply(self.weights_init)
        gen_optimizer = self.select_optimizer(self.Generator)
        dis_optimizer = self.select_optimizer(self.Generator)
        losses = {}
        criterion = nn.MSELoss()
        dis_losses, gen_losses = [0], [0]
        for epoch in range(self.args.num_epoch):
            t0 = time.time()
            for _, sample_batched in enumerate(self.data_loader):
                data_real = torch.tensor(sample_batched, dtype=torch.float32)
                batch_size = sample_batched.size(0)
                # 1. Train Discriminator: maximize log(D(x)) + log(1 - D(G(z)))
                for _ in range(5):
                    pred_real = self.Discriminator(data_real)
                    loss_real = -pred_real.mean()
                    # Generate data
                    z = torch.randn(batch_size, self.args.dim_noise)
                    data_fake = self.Generator(z).detach()
                    pred_fake = self.Discriminator(data_fake)
                    loss_fake = pred_fake.mean()
                    # Discriminator loss
                    if self.args.model_name == 'WGAN':
                        grad_penalty = self.gradient_penalty(
                            data_real, data_fake, batch_size)
                    else:
                        grad_penalty = 0
                    dis_loss = loss_real + loss_fake + grad_penalty
                    dis_optimizer.zero_grad()
                    dis_loss.backward()
                    dis_optimizer.step()
                # Train Generator: maximize log(D(G(z)))
                pred_fake = self.Discriminator(data_fake)
                gen_loss = -pred_fake.mean()
                gen_optimizer.zero_grad()
                gen_loss.backward()
                gen_optimizer.step()
                mse = criterion(data_fake, data_real)
            t1 = time.time()
            print('\033[1;31m[Epoch {:>4}]\033[0m  '
                  '\033[1;31mD(x) = {:.5f}\033[0m  '
                  '\033[1;32mD(G(z)) = {:.5f}\033[0m  '
                  '\033[1;32mMSE = {:.5f}\033[0m  '
                  'Time cost={:.2f}s'.format(epoch + 1, -loss_real, -gen_loss,
                                             mse, t1 - t0))
            dis_losses.append(dis_loss.item())
            gen_losses.append(-gen_loss.item())
            fig, ax = plt.subplots()
            ax.plot(data_real[0], label='real')
            ax.plot(data_fake[0], ls='--', lw=0.5, label='fake')
            ax.legend()
        plt.show()
Пример #4
0
class Pix2PixMain(object):
    def __init__(self):

        # -----------------------------------
        # global
        # -----------------------------------
        np.random.seed(Settings.SEED)
        torch.manual_seed(Settings.SEED)
        random.seed(Settings.SEED)
        if torch.cuda.is_available():
            torch.cuda.manual_seed(Settings.SEED)
            self.device = torch.device("cuda")
        else:
            self.device = torch.device("cpu")

        # -----------------------------------
        # model
        # -----------------------------------
        self.generator = Generator(in_c=Settings.IN_CHANNEL,
                                   out_c=Settings.OUT_CHANNEL,
                                   ngf=Settings.NGF).to(self.device)
        self.generator.apply(self.generator.weights_init)
        self.discriminator = Discriminator(
            in_c=Settings.IN_CHANNEL,
            out_c=Settings.OUT_CHANNEL,
            ndf=Settings.NDF,
            n_layers=Settings.DISCRIMINATOR_LAYER).to(self.device)
        self.discriminator.apply(self.discriminator.weights_init)
        print("model init done")

        # -----------------------------------
        # data
        # -----------------------------------
        train_transforms = transforms.Compose([
            transforms.Resize((Settings.INPUT_SIZE, Settings.INPUT_SIZE)),
            transforms.ToTensor(),
            transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
        ])

        data_prepare = get_dataloader(
            dataset_name=Settings.DATASET,
            batch_size=Settings.BATCH_SIZE,
            data_root=Settings.DATASET_ROOT,
            train_num_workers=Settings.TRAIN_NUM_WORKERS,
            transforms=train_transforms,
            val_num_workers=Settings.TEST_NUM_WORKERS)
        self.train_dataloader = data_prepare.train_dataloader
        self.test_dataloader = data_prepare.test_dataloader
        print("data init done.....")

        # -----------------------------------
        # optimizer and criterion
        # -----------------------------------
        self.optimG = optim.Adam([{
            "params": self.generator.parameters()
        }],
                                 lr=Settings.G_LR,
                                 betas=Settings.G_BETAS)
        self.optimD = optim.Adam([{
            "params": self.discriminator.parameters()
        }],
                                 lr=Settings.D_LR,
                                 betas=Settings.D_BETAS)

        self.criterion_l1loss = nn.L1Loss()
        self.criterion_BCE = nn.BCELoss()
        print("optimizer and criterion init done.....")

        # -----------------------------------
        # recorder
        # -----------------------------------
        self.recorder = {
            "errD_fake": list(),
            "errD_real": list(),
            "errG_l1loss": list(),
            "errG_bce": list(),
            "errG": list(),
            "accD": list()
        }

        output_file = time.strftime(
            "{}_{}_%Y_%m_%d_%H_%M_%S".format("pix2pix", Settings.DATASET),
            time.localtime())
        self.output_root = os.path.join(Settings.OUTPUT_ROOT, output_file)
        os.makedirs(os.path.join(self.output_root, Settings.OUTPUT_MODEL_KEY))
        os.makedirs(os.path.join(self.output_root, Settings.OUTPUT_LOG_KEY))
        os.makedirs(os.path.join(self.output_root, Settings.OUTPUT_IMAGE_KEY))
        print("recorder init done.....")

    def __call__(self):

        print_steps = max(
            1, int(len(self.train_dataloader) * Settings.PRINT_FREQUENT))
        eval_steps = max(
            1, int(len(self.train_dataloader) * Settings.EVAL_FREQUENT))
        batch_steps = max(1, int(Settings.EPOCHS * Settings.BATCH_FREQUENT))

        print("begin train.....")
        for epoch in range(1, Settings.EPOCHS + 1):
            for step, batch in enumerate(self.train_dataloader):

                # train
                self.train_module(batch)

                # print
                self.print_module(epoch, step, print_steps)

                if epoch % batch_steps == 0:
                    # val
                    self.val_module(epoch, step, eval_steps)

            # save log
            self.log_save_module()

    def train_module(self, batch):
        self.generator.train()
        self.discriminator.train()

        input_images = None
        target_images = None
        if Settings.DATASET == "edge2shoes":
            input_images = batch["edge_images"].to(self.device)
            target_images = batch["color_images"].to(self.device)
        elif Settings.DATASET == "Mogaoku":
            input_images = batch["edge_images"].to(self.device)
            target_images = batch["color_images"].to(self.device)
        else:
            KeyError("DataSet {} doesn't exit".format(Settings.DATASET))

        # 判别器迭代
        self.optimD.zero_grad()
        true_image_d_pred = self.discriminator(input_images, target_images)
        true_images_label = torch.full(true_image_d_pred.shape,
                                       Settings.REAL_LABEL,
                                       dtype=torch.float32,
                                       device=self.device)
        errD_real_bce = self.criterion_BCE(true_image_d_pred,
                                           true_images_label)
        errD_real_bce.backward()

        fake_images = self.generator(input_images)
        fake_images_d_pred = self.discriminator(input_images,
                                                fake_images.detach())
        fake_images_label = torch.full(fake_images_d_pred.shape,
                                       Settings.FAKE_LABEL,
                                       dtype=torch.float32,
                                       device=self.device)
        errD_fake_bce = self.criterion_BCE(fake_images_d_pred,
                                           fake_images_label)
        errD_fake_bce.backward()
        self.optimD.step()

        real_image_pred_true_num = ((true_image_d_pred >
                                     0.5) == true_images_label).sum().float()
        fake_image_pred_true_num = ((fake_images_d_pred >
                                     0.5) == fake_images_label).sum().float()

        accD = (real_image_pred_true_num + fake_image_pred_true_num) / \
               (true_images_label.numel() + fake_images_label.numel())

        # 生成器迭代
        self.optimG.zero_grad()
        fake_images_d_pred = self.discriminator(input_images, fake_images)
        true_images_label = torch.full(fake_images_d_pred.shape,
                                       Settings.REAL_LABEL,
                                       dtype=torch.float32,
                                       device=self.device)
        errG_bce = self.criterion_BCE(fake_images_d_pred, true_images_label)
        errG_l1loss = self.criterion_l1loss(fake_images, target_images)

        errG = errG_bce + errG_l1loss * Settings.L1_LOSS_LAMUDA
        errG.backward()
        self.optimG.step()

        # recorder
        self.recorder["errD_real"].append(errD_real_bce.item())
        self.recorder["errD_fake"].append(errD_fake_bce.item())
        self.recorder["errG_l1loss"].append(errG_l1loss.item())
        self.recorder["errG_bce"].append(errG_bce.item())
        self.recorder["errG"].append(errG.item())
        self.recorder["accD"].append(accD)

    def val_module(self, epoch, step, eval_steps):
        def apply_dropout(m):
            if type(m) == nn.Dropout:
                m.train()

        if (step + 1) % eval_steps == 0:

            output_images = None
            output_count = 0

            self.generator.eval()
            self.discriminator.eval()

            # 启用dropout
            if Settings.USING_DROPOUT_DURING_EVAL:
                self.generator.apply(apply_dropout)
                self.discriminator.apply(apply_dropout)

            for eval_step, eval_batch in enumerate(self.test_dataloader):

                input_images = eval_batch["edge_images"].to(self.device)
                target_images = eval_batch["color_images"]

                pred_images = self.generator(input_images).detach().cpu()

                output_image = torch.cat(
                    [input_images.cpu(), target_images, pred_images], dim=3)

                if output_images is None:
                    output_images = output_image
                else:
                    output_images = torch.cat([output_images, output_image],
                                              dim=0)

                if output_images.shape[0] == int(
                        len(self.test_dataloader) / 4):

                    output_images = make_grid(
                        output_images,
                        padding=2,
                        normalize=True,
                        nrow=Settings.CONSTANT_FEATURE_DIS_LEN).numpy()
                    output_images = np.array(
                        np.transpose(output_images, (1, 2, 0)) * 255,
                        dtype=np.uint8)
                    output_images = Image.fromarray(output_images)
                    output_images.save(
                        os.path.join(
                            self.output_root, Settings.OUTPUT_IMAGE_KEY,
                            "epoch_{}_step_{}_count_{}.jpg".format(
                                epoch, step, output_count)))

                    output_count += 1
                    output_images = None

            self.model_save_module(epoch, step)
            self.log_save_module()

    def print_module(self, epoch, step, print_steps):
        if (step + 1) % print_steps == 0:
            print("[{}/{}]\t [{}/{}]\t ".format(epoch, Settings.EPOCHS,
                                                step + 1,
                                                len(self.train_dataloader)),
                  end=" ")

            for key in self.recorder:
                print("[{}:{}]\t".format(key, self.recorder[key][-1]), end=" ")

            print(" ")

    def model_save_module(self, epoch, step):
        torch.save(
            self.generator.state_dict(),
            os.path.join(
                self.output_root, Settings.OUTPUT_MODEL_KEY,
                "pix2pix_generator_epoch_{}_step_{}.pth".format(epoch, step)))
        torch.save(
            self.discriminator.state_dict(),
            os.path.join(
                self.output_root, Settings.OUTPUT_MODEL_KEY,
                "pix2pix_discriminator_epoch_{}_step_{}.pth".format(
                    epoch, step)))

    def log_save_module(self):
        # 保存记录
        with open(
                os.path.join(self.output_root, Settings.OUTPUT_LOG_KEY,
                             "log.txt"), "w") as f:
            for item_ in range(len(self.recorder["accD"])):
                for key in self.recorder:
                    f.write("{}:{}\t".format(key, self.recorder[key][item_]))
                f.write("\n")

        # 保存图表
        for key in self.recorder:
            plt.figure(figsize=(10, 5))
            plt.title("{} During Training".format(key))
            plt.plot(self.recorder[key], label=key)
            plt.xlabel("iterations")
            plt.ylabel("value")
            plt.legend()
            if "acc" in key:
                plt.yticks(np.arange(0, 1, 0.5))
            plt.savefig(
                os.path.join(self.output_root, Settings.OUTPUT_LOG_KEY,
                             "{}.jpg".format(key)))

        plt.close("all")

    def learning_rate_decay_module(self, epoch):
        if epoch % Settings.LR_DECAY_EPOCHS == 0:
            for param_group in self.optimD.param_groups:
                param_group["lr"] *= 0.2
            for param_group in self.optimG.param_groups:
                param_group["lr"] *= 0.2