Ejemplo n.º 1
0
def getDiscriminatorModel(netDPath='', ngpu=1):
	netD = Discriminator(ngpu).to(device)
	netD.apply(weights_init)
	if netDPath != '':
	    netD.load_state_dict(torch.load(netDPath))
	print(netD)
	return netD
def one_im_discrim(discrim_path, im_path):
    discriminator = Discriminator(3, 64)
    discriminator.load_state_dict(
        torch.load(discrim_path, map_location=torch.device('cpu')))
    discriminator.eval()

    tensor = transforms.ToTensor()
    im = torchImage.open(im_path)

    result = discriminator(tensor(Image.open(im_path))).view(-1)
    print(result.data.item())
Ejemplo n.º 3
0
def main(args):
    os.makedirs('models', exist_ok=True)
    os.makedirs('outputs', exist_ok=True)

    # -------------- dataset ----------------------------
    g_train_loader = IAMDataLoader(args.batch_size,
                                   args.T,
                                   args.data_scale,
                                   chars=args.chars,
                                   points_per_char=args.points_per_char)
    print('number of batches:', g_train_loader.num_batches)
    args.c_dimension = len(g_train_loader.chars) + 1

    args.U = g_train_loader.max_U

    # -------------- pretrain generator ----------------------------
    generator = Generator(num_gaussians=args.M,
                          mode=args.mode,
                          c_dimension=args.c_dimension,
                          K=args.K,
                          U=args.U,
                          batch_size=args.batch_size,
                          T=args.T,
                          bias=args.b,
                          sample_random=args.sample_random,
                          learning_rate=args.g_learning_rate).to(device)
    generator = generator.train()

    if args.g_path and os.path.exists(args.g_path):
        print('Start loading generator: %s' % (args.g_path))
        generator.load_state_dict(torch.load(args.g_path))
    else:
        print('Start pre-training generator:')
        pre_g(generator, g_train_loader, num_epochs=40, mode=args.mode)

    # -------------- pretrain discriminator ----------------------------
    if args.batch_size > 16:  # Do not set batch_size too large
        generator.batch_size = 16
        args.batch_size = 16
    discriminator = Discriminator(learning_rate=args.d_learning_rate,
                                  weight_decay=args.d_weight_decay).to(device)
    discriminator = discriminator.train()

    if args.d_path and os.path.exists(args.d_path):
        print('Start loading discriminator: %s' % (args.d_path))
        discriminator.load_state_dict(torch.load(args.d_path))
    else:
        print('Start pre-training discriminator:')
        pre_d(discriminator, generator, g_train_loader, num_steps=200)

    generator.set_learning_rate(args.ad_g_learning_rate)

    print('Start training discriminator:')
    ad_train(args, generator, discriminator, g_train_loader, num_steps=100)
def run_model(model_path, discrim_path):
    model = Deblurrer()
    model.load_state_dict(
        torch.load(model_path, map_location=torch.device('cpu')))
    model.eval()

    discriminator = Discriminator(3, 64)
    discriminator.load_state_dict(
        torch.load(discrim_path, map_location=torch.device('cpu')))
    discriminator.eval()

    dataset = LFWC(["../data/train/faces_blurred"], "../data/train/faces")
    #dataset = FakeData(size=1000, image_size=(3, 128, 128), transform=transforms.ToTensor())
    data_loader = torch.utils.data.DataLoader(dataset,
                                              batch_size=1,
                                              shuffle=True)
    for data in data_loader:
        blurred_img = Variable(data['blurred'])
        nonblurred = Variable(data['nonblurred'])

        # Should be near zero
        discrim_output_blurred = discriminator(blurred_img).view(
            -1).data.item()
        # Should be naer one
        discrim_output_nonblurred = discriminator(nonblurred).view(
            -1).data.item()

        #im = Image.open(image_path)
        #transform = transforms.ToTensor()
        transformback = transforms.ToPILImage()
        plt.imshow(transformback(blurred_img[0]))
        plt.title('Blurred, Discrim value: ' + str(discrim_output_blurred))
        plt.show()
        plt.imshow(transformback(nonblurred[0]))
        plt.title('Non Blurred, Discrim value: ' +
                  str(discrim_output_nonblurred))
        plt.show()

        out = model(blurred_img)
        discrim_output_model = discriminator(out).view(-1).data.item()
        #print(out.shape)
        outIm = transformback(out[0])

        plt.imshow(outIm)
        plt.title('Model out, Discrim value: ' + str(discrim_output_model))
        plt.show()
Ejemplo n.º 5
0
def main():
    num_epochs = 80
    num_output_imgs = 1
    discrim_save_path = './discrim_omacir'
    gen_save_path = './gen_omacir'
    discriminator = Discriminator()
    to_load = False

    generator = Generator()
    if to_load:
        print('loading saved GAN state...')
        discriminator.load_state_dict(torch.load(discrim_save_path))
        generator.load_state_dict(torch.load(gen_save_path))

    #test(generator)
    #return
    fidmodel = torch.hub.load('pytorch/vision:v0.6.0', 'inception_v3', pretrained=True)
    fidmodel = fidmodel.float()
    fidmodel.eval()
    train_gan(discriminator, generator, num_epochs, gen_save_path, discrim_save_path, fidmodel, is_omacir=True)
Ejemplo n.º 6
0
def load(filename, use_gpu):
    """
    Loads the trained models of the discriminator and the generator from a file.
    :param filename: The file name to load the networks from.    
    :param use_gpu: Determines whether to load the model to GPU or CPU.
    """

    checkpoint = torch.load(filename,
                            map_location=lambda storage, loc: storage.cuda()
                            if use_gpu else storage)

    D = Discriminator(*checkpoint['D_init_params'])
    D.load_state_dict(checkpoint['D_state'])

    G = Generator(*checkpoint['G_init_params'])
    G.load_state_dict(checkpoint['G_state'])

    print('Checkpoint loaded.')

    return D, G
Ejemplo n.º 7
0
def load_cyclegan_alignment(name, device):
    """
    Load trained models for entity alignment with a cycle GAN architecture.
    :param name: the name of the model directory
    :param device: the current torch device, used for transferring the saved models (which were possibly trained on a
    different device) to the correct device
    :return: the generator and discriminator models (subclasses of torch.nn.Module) and the training configurations
    for entity alignment with a cycle gan architecture
    """
    path = Path(MODEL_PATH) / name
    with open(path / "config.json", "r") as file:
        config = json.load(file)
    # Load Generator B->A
    with open(path / "generator_a_config.json", "r") as file:
        generator_a_config = json.load(file)
    generator_a = Generator(generator_a_config, device)
    generator_a.load_state_dict(
        load(path / "generator_a.pt", map_location=device))
    generator_a.to(device)
    # Load Generator A->B
    with open(path / "generator_b_config.json", "r") as file:
        generator_b_config = json.load(file)
    generator_b = Generator(generator_b_config, device)
    generator_b.load_state_dict(
        load(path / "generator_b.pt", map_location=device))
    generator_b.to(device)
    # Load Discriminator A
    with open(path / "discriminator_a_config.json", "r") as file:
        discriminator_a_config = json.load(file)
    discriminator_a = Discriminator(discriminator_a_config, device)
    discriminator_a.load_state_dict(
        load(path / "discriminator_a.pt", map_location=device))
    discriminator_a.to(device)
    # Load Discriminator B
    with open(path / "discriminator_b_config.json", "r") as file:
        discriminator_b_config = json.load(file)
    discriminator_b = Discriminator(discriminator_b_config, device)
    discriminator_b.load_state_dict(
        load(path / "discriminator_b.pt", map_location=device))
    discriminator_b.to(device)
    return generator_a, generator_b, discriminator_a, discriminator_b, config
Ejemplo n.º 8
0
class BigGAN():
    """Big GAN"""
    def __init__(self, device, dataloader, num_classes, configs):
        self.device = device
        self.dataloader = dataloader
        self.num_classes = num_classes

        # model settings & hyperparams
        # self.total_steps = configs.total_steps
        self.epochs = configs.epochs
        self.d_iters = configs.d_iters
        self.g_iters = configs.g_iters
        self.batch_size = configs.batch_size
        self.imsize = configs.imsize
        self.nz = configs.nz
        self.ngf = configs.ngf
        self.ndf = configs.ndf
        self.g_lr = configs.g_lr
        self.d_lr = configs.d_lr
        self.beta1 = configs.beta1
        self.beta2 = configs.beta2

        # instance noise
        self.inst_noise_sigma = configs.inst_noise_sigma
        self.inst_noise_sigma_iters = configs.inst_noise_sigma_iters

        # model logging and saving
        self.log_step = configs.log_step
        self.save_epoch = configs.save_epoch
        self.model_path = configs.model_path
        self.sample_path = configs.sample_path

        # pretrained
        self.pretrained_model = configs.pretrained_model

        # building
        self.build_model()

        # archive of all losses
        self.ave_d_losses = []
        self.ave_d_losses_real = []
        self.ave_d_losses_fake = []
        self.ave_g_losses = []

        if self.pretrained_model:
            self.load_pretrained()

    def build_model(self):
        """Initiate Generator and Discriminator"""
        self.G = Generator(self.nz, self.ngf, self.num_classes).to(self.device)
        self.D = Discriminator(self.ndf, self.num_classes).to(self.device)

        self.g_optimizer = optim.Adam(
            filter(lambda p: p.requires_grad, self.G.parameters()), self.g_lr,
            [self.beta1, self.beta2])
        self.d_optimizer = optim.Adam(
            filter(lambda p: p.requires_grad, self.D.parameters()), self.d_lr,
            [self.beta1, self.beta2])

        print("Generator Parameters: ", parameters(self.G))
        print(self.G)
        print("Discriminator Parameters: ", parameters(self.D))
        print(self.D)
        print("Number of classes: ", self.num_classes)

    def load_pretrained(self):
        """Loading pretrained model"""
        checkpoint = torch.load(
            os.path.join(self.model_path,
                         "{}_biggan.pth".format(self.pretrained_model)))

        # load models
        self.G.load_state_dict(checkpoint["g_state_dict"])
        self.D.load_state_dict(checkpoint["d_state_dict"])

        # load optimizers
        self.g_optimizer.load_state_dict(checkpoint["g_optimizer"])
        self.d_optimizer.load_state_dict(checkpoint["d_optimizer"])

        # load losses
        self.ave_d_losses = checkpoint["ave_d_losses"]
        self.ave_d_losses_real = checkpoint["ave_d_losses_real"]
        self.ave_d_losses_fake = checkpoint["ave_d_losses_fake"]
        self.ave_g_losses = checkpoint["ave_g_losses"]

        print("Loading pretrained models (epoch: {})..!".format(
            self.pretrained_model))

    def reset_grad(self):
        """Reset gradients"""
        self.g_optimizer.zero_grad()
        self.d_optimizer.zero_grad()

    def train(self):
        """Train model"""
        step_per_epoch = len(self.dataloader)
        epochs = self.epochs
        total_steps = epochs * step_per_epoch

        # fixed z and labels for sampling generator images
        fixed_z = tensor2var(torch.randn(self.batch_size, self.nz),
                             device=self.device)
        fixed_labels = tensor2var(torch.from_numpy(
            np.tile(np.arange(self.num_classes), self.batch_size)).long(),
                                  device=self.device)

        print("Initiating Training")
        print("Epochs: {}, Total Steps: {}, Steps/Epoch: {}".format(
            epochs, total_steps, step_per_epoch))

        if self.pretrained_model:
            start_epoch = self.pretrained_model
        else:
            start_epoch = 0

        self.D.train()
        self.G.train()

        # Instance noise - make random noise mean (0) and std for injecting
        inst_noise_mean = torch.full(
            (self.batch_size, 3, self.imsize, self.imsize), 0).to(self.device)
        inst_noise_std = torch.full(
            (self.batch_size, 3, self.imsize, self.imsize),
            self.inst_noise_sigma).to(self.device)

        # total time
        start_time = time.time()
        for epoch in range(start_epoch, epochs):
            # local losses
            d_losses = []
            d_losses_real = []
            d_losses_fake = []
            g_losses = []

            data_iter = iter(self.dataloader)
            for step in range(step_per_epoch):
                # Instance noise std is linearly annealed from self.inst_noise_sigma to 0 thru self.inst_noise_sigma_iters
                inst_noise_sigma_curr = 0 if step > self.inst_noise_sigma_iters else (
                    1 -
                    step / self.inst_noise_sigma_iters) * self.inst_noise_sigma
                inst_noise_std.fill_(inst_noise_sigma_curr)

                # get real images
                real_images, real_labels = next(data_iter)
                real_images = real_images.to(self.device)
                real_labels = real_labels.to(self.device)

                # ================== TRAIN DISCRIMINATOR ================== #

                for _ in range(self.d_iters):
                    self.reset_grad()

                    # TRAIN REAL

                    # creating instance noise
                    inst_noise = torch.normal(mean=inst_noise_mean,
                                              std=inst_noise_std).to(
                                                  self.device)
                    # adding noise to real images
                    d_real = self.D(real_images + inst_noise, real_labels)
                    d_loss_real = loss_hinge_dis_real(d_real)
                    d_loss_real.backward()

                    # delete loss
                    if (step + 1) % self.log_step != 0:
                        del d_real, d_loss_real

                    # TRAIN FAKE

                    # create fake images using latent vector
                    z = tensor2var(torch.randn(real_images.size(0), self.nz),
                                   device=self.device)
                    fake_images = self.G(z, real_labels)

                    # creating instance noise
                    inst_noise = torch.normal(mean=inst_noise_mean,
                                              std=inst_noise_std).to(
                                                  self.device)
                    # adding noise to fake images
                    # detach fake_images tensor from graph
                    d_fake = self.D(fake_images.detach() + inst_noise,
                                    real_labels)
                    d_loss_fake = loss_hinge_dis_fake(d_fake)
                    d_loss_fake.backward()

                    # delete loss, output
                    del fake_images
                    if (step + 1) % self.log_step != 0:
                        del d_fake, d_loss_fake

                # optimize D
                self.d_optimizer.step()

                # ================== TRAIN GENERATOR ================== #

                for _ in range(self.g_iters):
                    self.reset_grad()

                    # create new latent vector
                    z = tensor2var(torch.randn(real_images.size(0), self.nz),
                                   device=self.device)

                    # generate fake images
                    inst_noise = torch.normal(mean=inst_noise_mean,
                                              std=inst_noise_std).to(
                                                  self.device)
                    fake_images = self.G(z, real_labels)
                    g_fake = self.D(fake_images + inst_noise, real_labels)

                    # compute hinge loss for G
                    g_loss = loss_hinge_gen(g_fake)
                    g_loss.backward()

                    del fake_images
                    if (step + 1) % self.log_step != 0:
                        del g_fake, g_loss

                # optimize G
                self.g_optimizer.step()

                # logging step progression
                if (step + 1) % self.log_step == 0:
                    d_loss = d_loss_real + d_loss_fake

                    # logging losses
                    d_losses.append(d_loss.item())
                    d_losses_real.append(d_loss_real.item())
                    d_losses_fake.append(d_loss_fake.item())
                    g_losses.append(g_loss.item())

                    # print out
                    elapsed = time.time() - start_time
                    elapsed = str(datetime.timedelta(seconds=elapsed))
                    print(
                        "Elapsed [{}], Epoch: [{}/{}], Step [{}/{}], g_loss: {:.4f}, d_loss: {:.4f},"
                        " d_loss_real: {:.4f}, d_loss_fake: {:.4f}".format(
                            elapsed, (epoch + 1), epochs, (step + 1),
                            step_per_epoch, g_loss, d_loss, d_loss_real,
                            d_loss_fake))

                    del d_real, d_loss_real, d_fake, d_loss_fake, g_fake, g_loss

            # logging average losses over epoch
            self.ave_d_losses.append(mean(d_losses))
            self.ave_d_losses_real.append(mean(d_losses_real))
            self.ave_d_losses_fake.append(mean(d_losses_fake))
            self.ave_g_losses.append(mean(g_losses))

            # epoch update
            print(
                "Elapsed [{}], Epoch: [{}/{}], ave_g_loss: {:.4f}, ave_d_loss: {:.4f},"
                " ave_d_loss_real: {:.4f}, ave_d_loss_fake: {:.4f},".format(
                    elapsed, epoch + 1, epochs, self.ave_g_losses[epoch],
                    self.ave_d_losses[epoch], self.ave_d_losses_real[epoch],
                    self.ave_d_losses_fake[epoch]))

            # sample images every epoch
            fake_images = self.G(fixed_z, fixed_labels)
            fake_images = denorm(fake_images.data)
            save_image(
                fake_images,
                os.path.join(self.sample_path,
                             "Epoch {}.png".format(epoch + 1)))

            # save model
            if (epoch + 1) % self.save_epoch == 0:
                torch.save(
                    {
                        "g_state_dict": self.G.state_dict(),
                        "d_state_dict": self.D.state_dict(),
                        "g_optimizer": self.g_optimizer.state_dict(),
                        "d_optimizer": self.d_optimizer.state_dict(),
                        "ave_d_losses": self.ave_d_losses,
                        "ave_d_losses_real": self.ave_d_losses_real,
                        "ave_d_losses_fake": self.ave_d_losses_fake,
                        "ave_g_losses": self.ave_g_losses
                    },
                    os.path.join(self.model_path,
                                 "{}_biggan.pth".format(epoch + 1)))

                print("Saving models (epoch {})..!".format(epoch + 1))

    def plot(self):
        plt.plot(self.ave_d_losses)
        plt.plot(self.ave_d_losses_real)
        plt.plot(self.ave_d_losses_fake)
        plt.plot(self.ave_g_losses)
        plt.legend(["d loss", "d real", "d fake", "g loss"], loc="upper left")
        plt.show()
def _main():
    print_gpu_details()
    device = torch.device("cuda:0" if (torch.cuda.is_available()) else "cpu")
    train_root = args.train_path

    image_size = 256
    cropped_image_size = 256
    print("set image folder")
    train_set = dset.ImageFolder(root=train_root,
                                 transform=transforms.Compose([
                                     transforms.Resize(image_size),
                                     transforms.CenterCrop(cropped_image_size),
                                     transforms.ToTensor()
                                 ]))

    normalizer_clf = transforms.Compose([
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ])
    normalizer_discriminator = transforms.Compose([
        transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
    ])
    print('set data loader')
    train_loader = torch.utils.data.DataLoader(train_set, batch_size=args.batch_size, shuffle=True, num_workers=4, pin_memory=True, drop_last=True)

    # Network creation
    classifier = torch.load(args.classifier_path)
    classifier.eval()
    generator = Generator(gen_type=args.gen_type)
    discriminator = Discriminator(args.discriminator_norm, dis_type=args.gen_type)
    # init weights
    if args.generator_path is not None:
        generator.load_state_dict(torch.load(args.generator_path))
    else:
        generator.init_weights()
    if args.discriminator_path is not None:
        discriminator.load_state_dict(torch.load(args.discriminator_path))
    else:
        discriminator.init_weights()

    classifier.to(device)
    generator.to(device)
    discriminator.to(device)

    # losses + optimizers
    criterion_discriminator, criterion_generator = get_wgan_losses_fn()
    criterion_features = nn.L1Loss()
    criterion_diversity_n = nn.L1Loss()
    criterion_diversity_d = nn.L1Loss()
    generator_optimizer = optim.Adam(generator.parameters(), lr=args.lr, betas=(0.5, 0.999))
    discriminator_optimizer = optim.Adam(discriminator.parameters(), lr=args.lr, betas=(0.5, 0.999))

    num_of_epochs = args.epochs

    starting_time = time.time()
    iterations = 0
    # creating dirs for keeping models checkpoint, temp created images, and loss summary
    outputs_dir = os.path.join('wgan-gp_models', args.model_name)
    if not os.path.isdir(outputs_dir):
        os.makedirs(outputs_dir, exist_ok=True)
    temp_results_dir = os.path.join(outputs_dir, 'temp_results')
    if not os.path.isdir(temp_results_dir):
        os.mkdir(temp_results_dir)
    models_dir = os.path.join(outputs_dir, 'models_checkpoint')
    if not os.path.isdir(models_dir):
        os.mkdir(models_dir)
    writer = tensorboardX.SummaryWriter(os.path.join(outputs_dir, 'summaries'))

    z = torch.randn(args.batch_size, 128, 1, 1).to(device)  # a fixed noise for sampling
    z2 = torch.randn(args.batch_size, 128, 1, 1).to(device)  # a fixed noise for diversity sampling
    fixed_features = 0
    fixed_masks = 0
    fixed_features_diversity = 0
    first_iter = True
    print("Starting Training Loop...")
    for epoch in range(num_of_epochs):
        for data in train_loader:
            train_type = random.choices([1, 2], [args.train1_prob, 1-args.train1_prob]) # choose train type
            iterations += 1
            if iterations % 30 == 1:
                print('epoch:', epoch, ', iter', iterations, 'start, time =', time.time() - starting_time, 'seconds')
                starting_time = time.time()
            images, _ = data
            images = images.to(device)  # change to gpu tensor
            images_discriminator = normalizer_discriminator(images)
            images_clf = normalizer_clf(images)
            _, features = classifier(images_clf)
            if first_iter: # save batch of images to keep track of the model process
                first_iter = False
                fixed_features = [torch.clone(features[x]) for x in range(len(features))]
                fixed_masks = [torch.ones(features[x].shape, device=device) for x in range(len(features))]
                fixed_features_diversity = [torch.clone(features[x]) for x in range(len(features))]
                for i in range(len(features)):
                    for j in range(fixed_features_diversity[i].shape[0]):
                        fixed_features_diversity[i][j] = fixed_features_diversity[i][j % 8]
                grid = vutils.make_grid(images_discriminator, padding=2, normalize=True, nrow=8)
                vutils.save_image(grid, os.path.join(temp_results_dir, 'original_images.jpg'))
                orig_images_diversity = torch.clone(images_discriminator)
                for i in range(orig_images_diversity.shape[0]):
                    orig_images_diversity[i] = orig_images_diversity[i % 8]
                grid = vutils.make_grid(orig_images_diversity, padding=2, normalize=True, nrow=8)
                vutils.save_image(grid, os.path.join(temp_results_dir, 'original_images_diversity.jpg'))
            # Select a features layer to train on
            features_to_train = random.randint(1, len(features) - 2) if args.fixed_layer is None else args.fixed_layer
            # Set masks
            masks = [features[i].clone() for i in range(len(features))]
            setMasksPart1(masks, device, features_to_train) if train_type == 1 else setMasksPart2(masks, device, features_to_train)
            discriminator_loss_dict = train_discriminator(generator, discriminator, criterion_discriminator, discriminator_optimizer, images_discriminator, features, masks)
            for k, v in discriminator_loss_dict.items():
                writer.add_scalar('D/%s' % k, v.data.cpu().numpy(), global_step=iterations)
                if iterations % 30 == 1:
                    print('{}: {:.6f}'.format(k, v))
            if iterations % args.discriminator_steps == 1:
                generator_loss_dict = train_generator(generator, discriminator, criterion_generator, generator_optimizer, images.shape[0], features,
                                                      criterion_features, features_to_train, classifier, normalizer_clf, criterion_diversity_n,
                                                      criterion_diversity_d, masks, train_type)

                for k, v in generator_loss_dict.items():
                    writer.add_scalar('G/%s' % k, v.data.cpu().numpy(), global_step=iterations//5 + 1)
                    if iterations % 30 == 1:
                        print('{}: {:.6f}'.format(k, v))

            # Save generator and discriminator weights every 1000 iterations
            if iterations % 1000 == 1:
                torch.save(generator.state_dict(), models_dir + '/' + args.model_name + 'G')
                torch.save(discriminator.state_dict(), models_dir + '/' + args.model_name + 'D')
            # Save temp results
            if args.keep_temp_results:
                if iterations < 10000 and iterations % 1000 == 1 or iterations % 2000 == 1:
                    # regular sampling (batch of different images)
                    first_features = True
                    fake_images = None
                    fake_images_diversity = None
                    for i in range(1, 5):
                        one_layer_mask = isolate_layer(fixed_masks, i, device)
                        if first_features:
                            first_features = False
                            fake_images = sample(generator, z, fixed_features, one_layer_mask)
                            fake_images_diversity = sample(generator, z, fixed_features_diversity, one_layer_mask)
                        else:
                            tmp_fake_images = sample(generator, z, fixed_features, one_layer_mask)
                            fake_images = torch.vstack((fake_images, tmp_fake_images))
                            tmp_fake_images = sample(generator, z2, fixed_features_diversity, one_layer_mask)
                            fake_images_diversity = torch.vstack((fake_images_diversity, tmp_fake_images))
                    grid = vutils.make_grid(fake_images, padding=2, normalize=True, nrow=8)
                    vutils.save_image(grid, os.path.join(temp_results_dir, 'res_iter_{}.jpg'.format(iterations // 1000)))
                    # diversity sampling (8 different images each with few different noises)
                    grid = vutils.make_grid(fake_images_diversity, padding=2, normalize=True, nrow=8)
                    vutils.save_image(grid, os.path.join(temp_results_dir, 'div_iter_{}.jpg'.format(iterations // 1000)))

                if iterations % 20000 == 1:
                    torch.save(generator.state_dict(), models_dir + '/' + args.model_name + 'G_' + str(iterations // 15000))
                    torch.save(discriminator.state_dict(), models_dir + '/' + args.model_name + 'D_' + str(iterations // 15000))
Ejemplo n.º 10
0
class CycleGAN(AlignmentModel):
    """This class implements the alignment model for GAN networks with two generators and two discriminators
    (cycle GAN). For description of the implemented functions, refer to the alignment model."""
    def __init__(self,
                 device,
                 config,
                 generator_a=None,
                 generator_b=None,
                 discriminator_a=None,
                 discriminator_b=None):
        """Initialize two new generators and two discriminators from the config or use pre-trained ones and create Adam
        optimizers for all models."""
        super().__init__(device, config)
        self.epoch_losses = [0., 0., 0., 0.]

        if generator_a is None:
            generator_a_conf = dict(
                dim_1=config['dim_b'],
                dim_2=config['dim_a'],
                layer_number=config['generator_layers'],
                layer_expansion=config['generator_expansion'],
                initialize_generator=config['initialize_generator'],
                norm=config['gen_norm'],
                batch_norm=config['gen_batch_norm'],
                activation=config['gen_activation'],
                dropout=config['gen_dropout'])
            self.generator_a = Generator(generator_a_conf, device)
            self.generator_a.to(device)
        else:
            self.generator_a = generator_a
        if 'optimizer' in config:
            self.optimizer_g_a = OPTIMIZERS[config['optimizer']](
                self.generator_a.parameters(), config['learning_rate'])
        elif 'optimizer_default' in config:
            if config['optimizer_default'] == 'sgd':
                self.optimizer_g_a = OPTIMIZERS[config['optimizer_default']](
                    self.generator_a.parameters(), config['learning_rate'])
            else:
                self.optimizer_g_a = OPTIMIZERS[config['optimizer_default']](
                    self.generator_a.parameters())
        else:
            self.optimizer_g_a = torch.optim.Adam(
                self.generator_a.parameters(), config['learning_rate'])

        if generator_b is None:
            generator_b_conf = dict(
                dim_1=config['dim_a'],
                dim_2=config['dim_b'],
                layer_number=config['generator_layers'],
                layer_expansion=config['generator_expansion'],
                initialize_generator=config['initialize_generator'],
                norm=config['gen_norm'],
                batch_norm=config['gen_batch_norm'],
                activation=config['gen_activation'],
                dropout=config['gen_dropout'])
            self.generator_b = Generator(generator_b_conf, device)
            self.generator_b.to(device)
        else:
            self.generator_b = generator_b
        if 'optimizer' in config:
            self.optimizer_g_b = OPTIMIZERS[config['optimizer']](
                self.generator_b.parameters(), config['learning_rate'])
        elif 'optimizer_default' in config:
            if config['optimizer_default'] == 'sgd':
                self.optimizer_g_b = OPTIMIZERS[config['optimizer_default']](
                    self.generator_b.parameters(), config['learning_rate'])
            else:
                self.optimizer_g_b = OPTIMIZERS[config['optimizer_default']](
                    self.generator_b.parameters())
        else:
            self.optimizer_g_b = torch.optim.Adam(
                self.generator_b.parameters(), config['learning_rate'])

        if discriminator_a is None:
            discriminator_a_conf = dict(
                dim=config['dim_a'],
                layer_number=config['discriminator_layers'],
                layer_expansion=config['discriminator_expansion'],
                batch_norm=config['disc_batch_norm'],
                activation=config['disc_activation'],
                dropout=config['disc_dropout'])
            self.discriminator_a = Discriminator(discriminator_a_conf, device)
            self.discriminator_a.to(device)
        else:
            self.discriminator_a = discriminator_a
        if 'optimizer' in config:
            self.optimizer_d_a = OPTIMIZERS[config['optimizer']](
                self.discriminator_a.parameters(), config['learning_rate'])
        elif 'optimizer_default' in config:
            if config['optimizer_default'] == 'sgd':
                self.optimizer_d_a = OPTIMIZERS[config['optimizer_default']](
                    self.discriminator_a.parameters(), config['learning_rate'])
            else:
                self.optimizer_d_a = OPTIMIZERS[config['optimizer_default']](
                    self.discriminator_a.parameters())
        else:
            self.optimizer_d_a = torch.optim.Adam(
                self.discriminator_a.parameters(), config['learning_rate'])

        if discriminator_b is None:
            discriminator_b_conf = dict(
                dim=config['dim_b'],
                layer_number=config['discriminator_layers'],
                layer_expansion=config['discriminator_expansion'],
                batch_norm=config['disc_batch_norm'],
                activation=config['disc_activation'],
                dropout=config['disc_dropout'])
            self.discriminator_b = Discriminator(discriminator_b_conf, device)
            self.discriminator_b.to(device)
        else:
            self.discriminator_b = discriminator_b
        if 'optimizer' in config:
            self.optimizer_d_b = OPTIMIZERS[config['optimizer']](
                self.discriminator_b.parameters(), config['learning_rate'])
        elif 'optimizer_default' in config:
            if config['optimizer_default'] == 'sgd':
                self.optimizer_d_b = OPTIMIZERS[config['optimizer_default']](
                    self.discriminator_b.parameters(), config['learning_rate'])
            else:
                self.optimizer_d_b = OPTIMIZERS[config['optimizer_default']](
                    self.discriminator_b.parameters())
        else:
            self.optimizer_d_b = torch.optim.Adam(
                self.discriminator_b.parameters(), config['learning_rate'])

    def train(self):
        self.generator_a.train()
        self.generator_b.train()
        self.discriminator_a.train()
        self.discriminator_b.train()

    def eval(self):
        self.generator_a.eval()
        self.generator_b.eval()
        self.discriminator_a.eval()
        self.discriminator_b.eval()

    def zero_grad(self):
        self.optimizer_g_a.zero_grad()
        self.optimizer_g_b.zero_grad()
        self.optimizer_d_a.zero_grad()
        self.optimizer_d_b.zero_grad()

    def optimize_all(self):
        self.optimizer_g_a.step()
        self.optimizer_g_b.step()
        self.optimizer_d_a.step()
        self.optimizer_d_b.step()

    def optimize_generator(self):
        """Do the optimization step only for generators (e.g. when training generators and discriminators separately or
        in turns)."""
        self.optimizer_g_a.step()
        self.optimizer_g_b.step()

    def optimize_discriminator(self):
        """Do the optimization step only for discriminators (e.g. when training generators and discriminators separately
        or in turns)."""
        self.optimizer_d_a.step()
        self.optimizer_d_b.step()

    def change_lr(self, factor):
        self.current_lr = self.current_lr * factor
        for param_group in self.optimizer_g_a.param_groups:
            param_group['lr'] = self.current_lr
        for param_group in self.optimizer_g_b.param_groups:
            param_group['lr'] = self.current_lr

    def update_losses_batch(self, *losses):
        loss_g_a, loss_g_b, loss_d_a, loss_d_b = losses
        self.epoch_losses[0] += loss_g_a
        self.epoch_losses[1] += loss_g_b
        self.epoch_losses[2] += loss_d_a
        self.epoch_losses[3] += loss_d_b

    def complete_epoch(self, epoch_metrics):
        self.metrics.append(epoch_metrics + [sum(self.epoch_losses)])
        self.losses.append(self.epoch_losses)
        self.epoch_losses = [0., 0., 0., 0.]

    def print_epoch_info(self):
        print(
            f"{len(self.metrics)} ### {self.losses[-1][0]:.2f} - {self.losses[-1][1]:.2f} "
            f"- {self.losses[-1][2]:.2f} - {self.losses[-1][3]:.2f} ### {self.metrics[-1]}"
        )

    def copy_model(self):
        self.model_copy = deepcopy(self.generator_a.state_dict()), deepcopy(self.generator_b.state_dict()),\
                          deepcopy(self.discriminator_a.state_dict()), deepcopy(self.discriminator_b.state_dict())

    def restore_model(self):
        self.generator_a.load_state_dict(self.model_copy[0])
        self.generator_b.load_state_dict(self.model_copy[1])
        self.discriminator_a.load_state_dict(self.model_copy[2])
        self.discriminator_b.load_state_dict(self.model_copy[3])

    def export_model(self, test_results, description=None):
        if description is None:
            description = f"CycleGAN_{self.config['evaluation']}_{self.config['subset']}"
        export_cyclegan_alignment(description, self.config, self.generator_a,
                                  self.generator_b, self.discriminator_a,
                                  self.discriminator_b, self.metrics)
        save_alignment_test_results(test_results, description)
        print(f"Saved model to directory {description}.")

    @classmethod
    def load_model(cls, name, device):
        generator_a, generator_b, discriminator_a, discriminator_b, config = load_cyclegan_alignment(
            name, device)
        model = cls(device, config, generator_a, generator_b, discriminator_a,
                    discriminator_b)
        return model
Ejemplo n.º 11
0
class GAN:
    def __init__(self, device, args):
        self.device = device
        self.args = args
        self.batch_size = args.batch_size
        self.generator_checkpoint_path = os.path.join(args.checkpoint_path, 'generator.pth')
        self.discriminator_checkpoint_path = os.path.join(args.checkpoint_path, 'discriminator.pth')
        if not os.path.isdir(args.checkpoint_path):
            os.mkdir(args.checkpoint_path)
        self.generator = Generator(args).to(self.device)
        self.discriminator = Discriminator(args).to(self.device)
        self.sequence_loss = SequenceLoss()
        self.reinforce_loss = ReinforceLoss()
        self.generator_optimizer = optim.Adam(self.generator.parameters(), lr=args.generator_lr)
        self.discriminator_optimizer = optim.Adam(self.discriminator.parameters(), lr=args.discriminator_lr)
        self.evaluator = Evaluator('val', self.device, args)
        self.cider = Cider(args)
        generator_dataset = CaptionDataset('train', args)
        self.generator_loader = DataLoader(generator_dataset, batch_size=self.batch_size, shuffle=True, num_workers=4)
        discriminator_dataset = DiscCaption('train', args)
        self.discriminator_loader = DataLoader(discriminator_dataset, batch_size=self.batch_size, shuffle=True, num_workers=4)

    def train(self):
        if self.args.load_generator:
            self.generator.load_state_dict(torch.load(self.generator_checkpoint_path))
        else:
            self._pretrain_generator()
        if self.args.load_discriminator:
            self.discriminator.load_state_dict(torch.load(self.discriminator_checkpoint_path))
        else:
            self._pretrain_discriminator()
        self._train_gan()

    def _pretrain_generator(self):
        iter = 0
        for epoch in range(self.args.pretrain_generator_epochs):
            self.generator.train()
            for data in self.generator_loader:
                for name, item in data.items():
                    data[name] = item.to(self.device)
                self.generator.zero_grad()
                probs = self.generator(data['fc_feats'], data['att_feats'], data['att_masks'], data['labels'])
                loss = self.sequence_loss(probs, data['labels'])
                loss.backward()
                self.generator_optimizer.step()
                print('iter {}, epoch {}, generator loss {:.3f}'.format(iter, epoch, loss.item()))
                iter += 1
            self.evaluator.evaluate_generator(self.generator)
            torch.save(self.generator.state_dict(), self.generator_checkpoint_path)

    def _pretrain_discriminator(self):
        iter = 0
        for epoch in range(self.args.pretrain_discriminator_epochs):
            self.discriminator.train()
            for data in self.discriminator_loader:
                loss = self._train_discriminator(data)
                print('iter {}, epoch {}, discriminator loss {:.3f}'.format(iter, epoch, loss))
                iter += 1
            self.evaluator.evaluate_discriminator(generator=self.generator, discriminator=self.discriminator)
            torch.save(self.discriminator.state_dict(), self.discriminator_checkpoint_path)
            
    def _train_gan(self):
        generator_iter = iter(self.generator_loader)
        discriminator_iter = iter(self.discriminator_loader)
        for i in range(self.args.train_gan_iters):
            print('iter {}'.format(i))
            for j in range(1):
                try:
                    data = next(generator_iter)
                except StopIteration:
                    generator_iter = iter(self.generator_loader)
                    data = next(generator_iter)
                result = self._train_generator(data)
                print('generator loss {:.3f}, fake prob {:.3f}, cider score {:.3f}'.format(result['loss'], result['fake_prob'], result['cider_score']))
            for j in range(1):
                try:
                    data = next(discriminator_iter)
                except StopIteration:
                    discriminator_iter = iter(self.discriminator_loader)
                    data = next(discriminator_iter)
                loss = self._train_discriminator(data)
                print('discriminator loss {:.3f}'.format(loss))
            if i != 0 and i % 10000 == 0:
                self.evaluator.evaluate_generator(self.generator)
                torch.save(self.generator.state_dict(), self.generator_checkpoint_path)
                self.evaluator.evaluate_discriminator(generator=self.generator, discriminator=self.discriminator)
                torch.save(self.discriminator.state_dict(), self.discriminator_checkpoint_path)

    def _train_generator(self, data):
        self.generator.train()
        for name, item in data.items():
            data[name] = item.to(self.device)
        self.generator.zero_grad()

        probs = self.generator(data['fc_feats'], data['att_feats'], data['att_masks'], data['labels'])
        loss1 = self.sequence_loss(probs, data['labels'])

        seqs, probs = self.generator.sample(data['fc_feats'], data['att_feats'], data['att_masks'])
        greedy_seqs = self.generator.greedy_decode(data['fc_feats'], data['att_feats'], data['att_masks'])
        reward, fake_prob, score = self._get_reward(data, seqs)
        baseline, _, _ = self._get_reward(data, greedy_seqs)
        loss2 = self.reinforce_loss(reward, baseline, probs, seqs)

        loss = loss1 + loss2
        loss.backward()
        self.generator_optimizer.step()
        result = {
            'loss': loss1.item(),
            'fake_prob': fake_prob,
            'cider_score': score
        }
        return result

    def _train_discriminator(self, data):
        self.discriminator.train()
        for name, item in data.items():
            data[name] = item.to(self.device)
        self.discriminator.zero_grad()

        real_probs = self.discriminator(data['fc_feats'], data['att_feats'], data['att_masks'], data['labels'])
        wrong_probs = self.discriminator(data['fc_feats'], data['att_feats'], data['att_masks'], data['wrong_labels'])

        # generate fake data
        with torch.no_grad():
            fake_seqs, _ = self.generator.sample(data['fc_feats'], data['att_feats'], data['att_masks'])
        fake_probs = self.discriminator(data['fc_feats'], data['att_feats'], data['att_masks'], fake_seqs)

        loss = -(0.5 * torch.log(real_probs + 1e-10) + 0.25 * torch.log(1 - wrong_probs + 1e-10) + 0.25 * torch.log(1 - fake_probs + 1e-10)).mean()
        loss.backward()
        self.discriminator_optimizer.step()
        return loss.item()

    def _get_reward(self, data, seqs):
        probs = self.discriminator(data['fc_feats'], data['att_feats'], data['att_masks'], seqs)
        scores = self.cider.get_scores(seqs.cpu().numpy(), data['images'].cpu().numpy())
        reward = probs + torch.tensor(scores, dtype=torch.float, device=self.device)
        fake_prob = probs.mean().item()
        score = scores.mean()
        return reward, fake_prob, score
Ejemplo n.º 12
0
checkpoint = torch.load("./save_temp/checkpoint.th")
targeted_model = torch.nn.DataParallel(resnet_model.__dict__['resnet32']())
targeted_model.cuda()
targeted_model.load_state_dict(checkpoint['state_dict'])
targeted_model.eval()

# load the generator of adversarial gan
pretrained_generator_path = './models/lp_pretrained/netG_rl_epoch_20.pth'
pretrained_G = Generator().to(device)
pretrained_G.load_state_dict(torch.load(pretrained_generator_path))
pretrained_G.eval()

# load the discriminator of adversarial gan
pretrained_disciminator_path = './models/lp_pretrained/netDisc_rl_epoch_20.pth'
pretrained_Disc = Discriminator().to(device)
pretrained_Disc.load_state_dict(torch.load(pretrained_disciminator_path))
pretrained_Disc.eval()

# load the Pixel Valuation network
pretrained_pvrl_path = './models/lp_pretrained/netPv_rl_epoch_20.pth'
pretrained_PV = PVRL().to(device)
pretrained_PV.load_state_dict(torch.load(pretrained_pvrl_path))
pretrained_PV.eval()

# test adversarial examples in CIFAR10 training dataset
cifar_dataset = torchvision.datasets.CIFAR10('./data',
                                             train=True,
                                             transform=transforms.ToTensor(),
                                             download=True)
train_dataloader = DataLoader(cifar_dataset,
                              batch_size=batch_size,
Ejemplo n.º 13
0
class GAIL(PPO):
    def __init__(
        self,
        state_dimension: Tuple,
        action_space: int,
        save_path: Path,
        hyp: HyperparametersGAIL,
        policy_params: namedtuple,
        discriminator_params: DiscrimParams,
        param_plot_num: int,
        ppo_type: str = "clip",
        adv_type: str = "monte_carlo",
        max_plot_size: int = 10000,
        policy_burn_in: int = 0,
        verbose: bool = False,
    ):
        self.discrim_net_save = save_path / "GAIL_discrim.pth"

        self.discriminator = Discriminator(
            state_dimension,
            action_space,
            discriminator_params,
        ).to(device)

        self.discrim_optim = torch.optim.Adam(self.discriminator.parameters(),
                                              lr=hyp.discrim_lr)
        gail_plots = [("discrim_loss", np.float64)]

        super(GAIL, self).__init__(
            state_dimension,
            action_space,
            save_path,
            hyp,
            policy_params,
            param_plot_num,
            ppo_type,
            advantage_type=adv_type,
            neural_net_save=f"GAIL-{adv_type}",
            max_plot_size=max_plot_size,
            discrim_params=discriminator_params,
            policy_burn_in=policy_burn_in,
            verbose=verbose,
            additional_plots=gail_plots,
        )

        self.discrim_loss = torch.nn.NLLLoss()

    def update(self, buffer: GAILExperienceBuffer, ep_num: int):
        # Update discriminator
        state_actions = buffer.state_actions.to(device)
        num_learner_samples = buffer.get_length()
        expert_samples_per_epoch = int(
            (state_actions.size()[0] - num_learner_samples) /
            self.hyp.num_discrim_epochs)
        for epoch in range(self.hyp.num_discrim_epochs):
            step_state_actions = torch.cat(
                (
                    state_actions[:num_learner_samples],
                    state_actions[
                        num_learner_samples +
                        epoch * expert_samples_per_epoch:num_learner_samples +
                        (epoch + 1) * expert_samples_per_epoch],
                ),
                dim=0,
            )
            discrim_logprobs = self.discriminator.logprobs(
                step_state_actions).to(device)
            loss = self.discrim_loss(
                input=discrim_logprobs,
                target=buffer.discrim_labels.type(torch.long),
            )
            plotted_loss = loss.detach().cpu().numpy()
            self.plotter.record_data({"discrim_loss": plotted_loss})

            if self.verbose:
                print(
                    f"Learner labels {buffer.discrim_labels[:num_learner_samples].mean()}: "
                    f"\t{torch.exp(discrim_logprobs[:num_learner_samples]).t()[1].mean()}"
                )
                print(
                    f"Expert labels {buffer.discrim_labels[num_learner_samples:].mean()}: "
                    f"\t\t{torch.exp(discrim_logprobs[num_learner_samples:]).t()[1].mean()}"
                )

            self.discrim_optim.zero_grad()
            loss.backward()
            self.discrim_optim.step()
            self.record_nn_params()

        # Update policy
        buffer.rewards = list(
            np.squeeze(
                self.discriminator.logprob_expert(
                    state_actions[:num_learner_samples]).float().detach().cpu(
                    ).numpy()))
        if self.verbose:
            print(
                "----------------------------------------------------------------------"
            )
        super(GAIL, self).update(buffer, ep_num)

    def record_nn_params(self):
        """Gets randomly sampled actor NN parameters from 1st layer."""
        names, x_params, y_params = self.plotter.get_param_plot_nums()
        sampled_params = {}

        for name, x_param, y_param in zip(names, x_params, y_params):
            network_to_sample = (self.discriminator
                                 if name[:7] == "discrim" else self.policy)
            sampled_params[name] = (
                network_to_sample.state_dict()[name].cpu().numpy()[x_param,
                                                                   y_param])
        self.plotter.record_data(sampled_params)

    def _save_network(self):
        super(GAIL, self)._save_network()
        torch.save(self.discriminator.state_dict(), f"{self.discrim_net_save}")

    def _load_network(self):
        super(GAIL, self)._load_network()
        print(
            f"Loading discriminator network saved at: {self.discrim_net_save}")
        net = torch.load(self.discrim_net_save, map_location=device)
        self.discriminator.load_state_dict(net)
Ejemplo n.º 14
0
def main(args):
    use_cuda = (len(args.gpuid) >= 1)
    print("{0} GPU(s) are available".format(cuda.device_count()))

    # Load dataset
    splits = ['train', 'valid']
    if data.has_binary_files(args.data, splits):
        dataset = data.load_dataset(args.data, splits, args.src_lang,
                                    args.trg_lang, args.fixed_max_len)
    else:
        dataset = data.load_raw_text_dataset(args.data, splits, args.src_lang,
                                             args.trg_lang, args.fixed_max_len)
    if args.src_lang is None or args.trg_lang is None:
        # record inferred languages in args, so that it's saved in checkpoints
        args.src_lang, args.trg_lang = dataset.src, dataset.dst

    print('| [{}] dictionary: {} types'.format(dataset.src,
                                               len(dataset.src_dict)))
    print('| [{}] dictionary: {} types'.format(dataset.dst,
                                               len(dataset.dst_dict)))

    for split in splits:
        print('| {} {} {} examples'.format(args.data, split,
                                           len(dataset.splits[split])))

    g_logging_meters = OrderedDict()
    g_logging_meters['train_loss'] = AverageMeter()
    g_logging_meters['valid_loss'] = AverageMeter()
    g_logging_meters['train_acc'] = AverageMeter()
    g_logging_meters['valid_acc'] = AverageMeter()
    g_logging_meters['bsz'] = AverageMeter()  # sentences per batch

    d_logging_meters = OrderedDict()
    d_logging_meters['train_loss'] = AverageMeter()
    d_logging_meters['valid_loss'] = AverageMeter()
    d_logging_meters['train_acc'] = AverageMeter()
    d_logging_meters['valid_acc'] = AverageMeter()
    d_logging_meters['bsz'] = AverageMeter()  # sentences per batch

    # Set model parameters
    args.encoder_embed_dim = 1000
    args.encoder_layers = 2  # 4
    args.encoder_dropout_out = 0
    args.decoder_embed_dim = 1000
    args.decoder_layers = 2  # 4
    args.decoder_out_embed_dim = 1000
    args.decoder_dropout_out = 0
    args.bidirectional = False

    generator = LSTMModel(args,
                          dataset.src_dict,
                          dataset.dst_dict,
                          use_cuda=use_cuda)
    print("Generator loaded successfully!")
    discriminator = Discriminator(args,
                                  dataset.src_dict,
                                  dataset.dst_dict,
                                  use_cuda=use_cuda)
    print("Discriminator loaded successfully!")

    g_model_path = 'checkpoints/zhenwarm/generator.pt'
    assert os.path.exists(g_model_path)
    # generator = LSTMModel(args, dataset.src_dict, dataset.dst_dict, use_cuda=use_cuda)
    model_dict = generator.state_dict()
    model = torch.load(g_model_path)
    pretrained_dict = model.state_dict()
    # 1. filter out unnecessary keys
    pretrained_dict = {
        k: v
        for k, v in pretrained_dict.items() if k in model_dict
    }
    # 2. overwrite entries in the existing state dict
    model_dict.update(pretrained_dict)
    # 3. load the new state dict
    generator.load_state_dict(model_dict)
    print("pre-trained Generator loaded successfully!")
    #
    # Load discriminator model
    d_model_path = 'checkpoints/zhenwarm/discri.pt'
    assert os.path.exists(d_model_path)
    # generator = LSTMModel(args, dataset.src_dict, dataset.dst_dict, use_cuda=use_cuda)
    d_model_dict = discriminator.state_dict()
    d_model = torch.load(d_model_path)
    d_pretrained_dict = d_model.state_dict()
    # 1. filter out unnecessary keys
    d_pretrained_dict = {
        k: v
        for k, v in d_pretrained_dict.items() if k in d_model_dict
    }
    # 2. overwrite entries in the existing state dict
    d_model_dict.update(d_pretrained_dict)
    # 3. load the new state dict
    discriminator.load_state_dict(d_model_dict)
    print("pre-trained Discriminator loaded successfully!")

    if use_cuda:
        if torch.cuda.device_count() > 1:
            discriminator = torch.nn.DataParallel(discriminator).cuda()
            generator = torch.nn.DataParallel(generator).cuda()
        else:
            generator.cuda()
            discriminator.cuda()
    else:
        discriminator.cpu()
        generator.cpu()

    # adversarial training checkpoints saving path
    if not os.path.exists('checkpoints/myzhencli5'):
        os.makedirs('checkpoints/myzhencli5')
    checkpoints_path = 'checkpoints/myzhencli5/'

    # define loss function
    g_criterion = torch.nn.NLLLoss(ignore_index=dataset.dst_dict.pad(),
                                   reduction='sum')
    d_criterion = torch.nn.BCELoss()
    pg_criterion = PGLoss(ignore_index=dataset.dst_dict.pad(),
                          size_average=True,
                          reduce=True)

    # fix discriminator word embedding (as Wu et al. do)
    for p in discriminator.embed_src_tokens.parameters():
        p.requires_grad = False
    for p in discriminator.embed_trg_tokens.parameters():
        p.requires_grad = False

    # define optimizer
    g_optimizer = eval("torch.optim." + args.g_optimizer)(filter(
        lambda x: x.requires_grad, generator.parameters()),
                                                          args.g_learning_rate)

    d_optimizer = eval("torch.optim." + args.d_optimizer)(
        filter(lambda x: x.requires_grad, discriminator.parameters()),
        args.d_learning_rate,
        momentum=args.momentum,
        nesterov=True)

    # start joint training
    best_dev_loss = math.inf
    num_update = 0
    # main training loop
    for epoch_i in range(1, args.epochs + 1):
        logging.info("At {0}-th epoch.".format(epoch_i))

        seed = args.seed + epoch_i
        torch.manual_seed(seed)

        max_positions_train = (args.fixed_max_len, args.fixed_max_len)

        # Initialize dataloader, starting at batch_offset
        trainloader = dataset.train_dataloader(
            'train',
            max_tokens=args.max_tokens,
            max_sentences=args.joint_batch_size,
            max_positions=max_positions_train,
            # seed=seed,
            epoch=epoch_i,
            sample_without_replacement=args.sample_without_replacement,
            sort_by_source_size=(epoch_i <= args.curriculum),
            shard_id=args.distributed_rank,
            num_shards=args.distributed_world_size,
        )

        # reset meters
        for key, val in g_logging_meters.items():
            if val is not None:
                val.reset()
        for key, val in d_logging_meters.items():
            if val is not None:
                val.reset()

        for i, sample in enumerate(trainloader):

            # set training mode
            generator.train()
            discriminator.train()
            update_learning_rate(num_update, 8e4, args.g_learning_rate,
                                 args.lr_shrink, g_optimizer)

            if use_cuda:
                # wrap input tensors in cuda tensors
                sample = utils.make_variable(sample, cuda=cuda)

            ## part I: use gradient policy method to train the generator

            # use policy gradient training when random.random() > 50%
            if random.random() >= 0.5:

                print("Policy Gradient Training")

                sys_out_batch = generator(sample)  # 64 X 50 X 6632

                out_batch = sys_out_batch.contiguous().view(
                    -1, sys_out_batch.size(-1))  # (64 * 50) X 6632

                _, prediction = out_batch.topk(1)
                prediction = prediction.squeeze(1)  # 64*50 = 3200
                prediction = torch.reshape(
                    prediction,
                    sample['net_input']['src_tokens'].shape)  # 64 X 50

                with torch.no_grad():
                    reward = discriminator(sample['net_input']['src_tokens'],
                                           prediction)  # 64 X 1

                train_trg_batch = sample['target']  # 64 x 50

                pg_loss = pg_criterion(sys_out_batch, train_trg_batch, reward,
                                       use_cuda)
                sample_size = sample['target'].size(
                    0) if args.sentence_avg else sample['ntokens']  # 64
                logging_loss = pg_loss / math.log(2)
                g_logging_meters['train_loss'].update(logging_loss.item(),
                                                      sample_size)
                logging.debug(
                    f"G policy gradient loss at batch {i}: {pg_loss.item():.3f}, lr={g_optimizer.param_groups[0]['lr']}"
                )
                g_optimizer.zero_grad()
                pg_loss.backward()
                torch.nn.utils.clip_grad_norm_(generator.parameters(),
                                               args.clip_norm)
                g_optimizer.step()

            else:
                # MLE training
                print("MLE Training")

                sys_out_batch = generator(sample)

                out_batch = sys_out_batch.contiguous().view(
                    -1, sys_out_batch.size(-1))  # (64 X 50) X 6632

                train_trg_batch = sample['target'].view(-1)  # 64*50 = 3200

                loss = g_criterion(out_batch, train_trg_batch)

                sample_size = sample['target'].size(
                    0) if args.sentence_avg else sample['ntokens']
                nsentences = sample['target'].size(0)
                logging_loss = loss.data / sample_size / math.log(2)
                g_logging_meters['bsz'].update(nsentences)
                g_logging_meters['train_loss'].update(logging_loss,
                                                      sample_size)
                logging.debug(
                    f"G MLE loss at batch {i}: {g_logging_meters['train_loss'].avg:.3f}, lr={g_optimizer.param_groups[0]['lr']}"
                )
                g_optimizer.zero_grad()
                loss.backward()
                # all-reduce grads and rescale by grad_denom
                for p in generator.parameters():
                    if p.requires_grad:
                        p.grad.data.div_(sample_size)
                torch.nn.utils.clip_grad_norm_(generator.parameters(),
                                               args.clip_norm)
                g_optimizer.step()

            num_update += 1

            # part II: train the discriminator
            if num_update % 5 == 0:
                bsz = sample['target'].size(0)  # batch_size = 64

                src_sentence = sample['net_input'][
                    'src_tokens']  # 64 x max-len i.e 64 X 50

                # now train with machine translation output i.e generator output
                true_sentence = sample['target'].view(-1)  # 64*50 = 3200

                true_labels = Variable(
                    torch.ones(
                        sample['target'].size(0)).float())  # 64 length vector

                with torch.no_grad():
                    sys_out_batch = generator(sample)  # 64 X 50 X 6632

                out_batch = sys_out_batch.contiguous().view(
                    -1, sys_out_batch.size(-1))  # (64 X 50) X 6632

                _, prediction = out_batch.topk(1)
                prediction = prediction.squeeze(1)  # 64 * 50 = 6632

                fake_labels = Variable(
                    torch.zeros(
                        sample['target'].size(0)).float())  # 64 length vector

                fake_sentence = torch.reshape(prediction,
                                              src_sentence.shape)  # 64 X 50
                true_sentence = torch.reshape(true_sentence,
                                              src_sentence.shape)
                if use_cuda:
                    fake_labels = fake_labels.cuda()
                    true_labels = true_labels.cuda()

                # fake_disc_out = discriminator(src_sentence, fake_sentence)  # 64 X 1
                # true_disc_out = discriminator(src_sentence, true_sentence)
                #
                # fake_d_loss = d_criterion(fake_disc_out.squeeze(1), fake_labels)
                # true_d_loss = d_criterion(true_disc_out.squeeze(1), true_labels)
                #
                # fake_acc = torch.sum(torch.round(fake_disc_out).squeeze(1) == fake_labels).float() / len(fake_labels)
                # true_acc = torch.sum(torch.round(true_disc_out).squeeze(1) == true_labels).float() / len(true_labels)
                # acc = (fake_acc + true_acc) / 2
                #
                # d_loss = fake_d_loss + true_d_loss
                if random.random() > 0.5:
                    fake_disc_out = discriminator(src_sentence, fake_sentence)
                    fake_d_loss = d_criterion(fake_disc_out.squeeze(1),
                                              fake_labels)
                    fake_acc = torch.sum(
                        torch.round(fake_disc_out).squeeze(1) ==
                        fake_labels).float() / len(fake_labels)
                    d_loss = fake_d_loss
                    acc = fake_acc
                else:
                    true_disc_out = discriminator(src_sentence, true_sentence)
                    true_d_loss = d_criterion(true_disc_out.squeeze(1),
                                              true_labels)
                    true_acc = torch.sum(
                        torch.round(true_disc_out).squeeze(1) ==
                        true_labels).float() / len(true_labels)
                    d_loss = true_d_loss
                    acc = true_acc

                d_logging_meters['train_acc'].update(acc)
                d_logging_meters['train_loss'].update(d_loss)
                logging.debug(
                    f"D training loss {d_logging_meters['train_loss'].avg:.3f}, acc {d_logging_meters['train_acc'].avg:.3f} at batch {i}"
                )
                d_optimizer.zero_grad()
                d_loss.backward()
                d_optimizer.step()

            if num_update % 10000 == 0:

                # validation
                # set validation mode
                generator.eval()
                discriminator.eval()
                # Initialize dataloader
                max_positions_valid = (args.fixed_max_len, args.fixed_max_len)
                valloader = dataset.eval_dataloader(
                    'valid',
                    max_tokens=args.max_tokens,
                    max_sentences=args.joint_batch_size,
                    max_positions=max_positions_valid,
                    skip_invalid_size_inputs_valid_test=True,
                    descending=
                    True,  # largest batch first to warm the caching allocator
                    shard_id=args.distributed_rank,
                    num_shards=args.distributed_world_size,
                )

                # reset meters
                for key, val in g_logging_meters.items():
                    if val is not None:
                        val.reset()
                for key, val in d_logging_meters.items():
                    if val is not None:
                        val.reset()

                for i, sample in enumerate(valloader):

                    with torch.no_grad():
                        if use_cuda:
                            # wrap input tensors in cuda tensors
                            sample = utils.make_variable(sample, cuda=cuda)

                        # generator validation
                        sys_out_batch = generator(sample)
                        out_batch = sys_out_batch.contiguous().view(
                            -1, sys_out_batch.size(-1))  # (64 X 50) X 6632
                        dev_trg_batch = sample['target'].view(
                            -1)  # 64*50 = 3200

                        loss = g_criterion(out_batch, dev_trg_batch)
                        sample_size = sample['target'].size(
                            0) if args.sentence_avg else sample['ntokens']
                        loss = loss / sample_size / math.log(2)
                        g_logging_meters['valid_loss'].update(
                            loss, sample_size)
                        logging.debug(
                            f"G dev loss at batch {i}: {g_logging_meters['valid_loss'].avg:.3f}"
                        )

                        # discriminator validation
                        bsz = sample['target'].size(0)
                        src_sentence = sample['net_input']['src_tokens']
                        # train with half human-translation and half machine translation

                        true_sentence = sample['target']
                        true_labels = Variable(
                            torch.ones(sample['target'].size(0)).float())

                        with torch.no_grad():
                            sys_out_batch = generator(sample)

                        out_batch = sys_out_batch.contiguous().view(
                            -1, sys_out_batch.size(-1))  # (64 X 50) X 6632

                        _, prediction = out_batch.topk(1)
                        prediction = prediction.squeeze(1)  # 64 * 50 = 6632

                        fake_labels = Variable(
                            torch.zeros(sample['target'].size(0)).float())

                        fake_sentence = torch.reshape(
                            prediction, src_sentence.shape)  # 64 X 50
                        true_sentence = torch.reshape(true_sentence,
                                                      src_sentence.shape)
                        if use_cuda:
                            fake_labels = fake_labels.cuda()
                            true_labels = true_labels.cuda()

                        fake_disc_out = discriminator(src_sentence,
                                                      fake_sentence)  # 64 X 1
                        true_disc_out = discriminator(src_sentence,
                                                      true_sentence)

                        fake_d_loss = d_criterion(fake_disc_out.squeeze(1),
                                                  fake_labels)
                        true_d_loss = d_criterion(true_disc_out.squeeze(1),
                                                  true_labels)
                        d_loss = fake_d_loss + true_d_loss
                        fake_acc = torch.sum(
                            torch.round(fake_disc_out).squeeze(1) ==
                            fake_labels).float() / len(fake_labels)
                        true_acc = torch.sum(
                            torch.round(true_disc_out).squeeze(1) ==
                            true_labels).float() / len(true_labels)
                        acc = (fake_acc + true_acc) / 2
                        d_logging_meters['valid_acc'].update(acc)
                        d_logging_meters['valid_loss'].update(d_loss)
                        logging.debug(
                            f"D dev loss {d_logging_meters['valid_loss'].avg:.3f}, acc {d_logging_meters['valid_acc'].avg:.3f} at batch {i}"
                        )

                # torch.save(discriminator,
                #            open(checkpoints_path + f"numupdate_{num_update/10000}k.discri_{d_logging_meters['valid_loss'].avg:.3f}.pt",'wb'), pickle_module=dill)

                # if d_logging_meters['valid_loss'].avg < best_dev_loss:
                #     best_dev_loss = d_logging_meters['valid_loss'].avg
                #     torch.save(discriminator, open(checkpoints_path + "best_dmodel.pt", 'wb'), pickle_module=dill)

                torch.save(
                    generator,
                    open(
                        checkpoints_path +
                        f"numupdate_{num_update/10000}k.joint_{g_logging_meters['valid_loss'].avg:.3f}.pt",
                        'wb'),
                    pickle_module=dill)
        discriminator, _, _ = train_dis(optims_nll, discriminator, bsize,
                                        embed_dim, recom_length - 1,
                                        trainSample, validSample, testSample)
        print("Testing")
        print("Discriminator evaluation!")
        eval_acc_dis, eval_map_dis = evaluate_discriminator(
            discriminator, 101, bsize, recom_length - 1, validSample,
            testSample, device, 'test')

        #Adversarial training
        weight = torch.FloatTensor(2).fill_(1)
        print('\n--------------------------------------------')
        print("Adversarial Training")
        print('--------------------------------------------')
        generator.load_state_dict(torch.load(pretrained_gen))
        discriminator.load_state_dict(torch.load(pretrained_dis))
        agent.load_state_dict(torch.load(pretrained_agent))

        trainSample, validSample, testSample = sampleSplit(
            trainindex, validindex, testindex, Seqlist, numlabel,
            recom_length - 1, 'adv')  #No eos

        _ = evaluate_agent(agent, 101, bsize, recom_length - 1, validSample,
                           testSample, device, 'test')
        _ = pgtrain(optims_adv,
                    optims_nll,
                    generator,
                    agent,
                    discriminator,
                    bsize,
                    embed_dim,
Ejemplo n.º 16
0
class GAN_CLS(object):
    def __init__(self, args, data_loader, SUPERVISED=True):
        """
		args : Arguments
		data_loader = An instance of class DataLoader for loading our dataset in batches
		"""

        self.data_loader = data_loader
        self.num_epochs = args.num_epochs
        self.batch_size = args.batch_size

        self.log_step = args.log_step
        self.sample_step = args.sample_step

        self.log_dir = args.log_dir
        self.checkpoint_dir = args.checkpoint_dir
        self.sample_dir = args.sample_dir
        self.final_model = args.final_model
        self.model_save_step = args.model_save_step

        #self.dataset = args.dataset
        #self.model_name = args.model_name

        self.img_size = args.img_size
        self.z_dim = args.z_dim
        self.text_embed_dim = args.text_embed_dim
        self.text_reduced_dim = args.text_reduced_dim
        self.learning_rate = args.learning_rate
        self.beta1 = args.beta1
        self.beta2 = args.beta2
        self.l1_coeff = args.l1_coeff
        self.resume_epoch = args.resume_epoch
        self.resume_idx = args.resume_idx
        self.SUPERVISED = SUPERVISED

        # Logger setting
        log_name = datetime.datetime.now().strftime('%Y-%m-%d') + '.log'
        self.logger = logging.getLogger('__name__')
        self.logger.setLevel(logging.INFO)
        self.formatter = logging.Formatter(
            '%(asctime)s:%(levelname)s:%(message)s')
        self.file_handler = logging.FileHandler(
            os.path.join(self.log_dir, log_name))
        self.file_handler.setFormatter(self.formatter)
        self.logger.addHandler(self.file_handler)

        self.build_model()

    def smooth_label(self, tensor, offset):
        return tensor + offset

    def dump_imgs(images_Array, name):
        with open('{}.pickle'.format(name), 'wb') as file:
            dump(images_Array, file)

    def build_model(self):
        """ A function of defining following instances :

		-----  Generator
		-----  Discriminator
		-----  Optimizer for Generator
		-----  Optimizer for Discriminator
		-----  Defining Loss functions

		"""

        # ---------------------------------------------------------------------#
        #						1. Network Initialization					   #
        # ---------------------------------------------------------------------#
        self.gen = Generator(batch_size=self.batch_size,
                             img_size=self.img_size,
                             z_dim=self.z_dim,
                             text_embed_dim=self.text_embed_dim,
                             text_reduced_dim=self.text_reduced_dim)

        self.disc = Discriminator(batch_size=self.batch_size,
                                  img_size=self.img_size,
                                  text_embed_dim=self.text_embed_dim,
                                  text_reduced_dim=self.text_reduced_dim)

        self.gen_optim = optim.Adam(self.gen.parameters(),
                                    lr=self.learning_rate,
                                    betas=(self.beta1, self.beta2))

        self.disc_optim = optim.Adam(self.disc.parameters(),
                                     lr=self.learning_rate,
                                     betas=(self.beta1, self.beta2))

        self.cls_gan_optim = optim.Adam(itertools.chain(
            self.gen.parameters(), self.disc.parameters()),
                                        lr=self.learning_rate,
                                        betas=(self.beta1, self.beta2))

        print('-------------  Generator Model Info  ---------------')
        self.print_network(self.gen, 'G')
        print('------------------------------------------------')

        print('-------------  Discriminator Model Info  ---------------')
        self.print_network(self.disc, 'D')
        print('------------------------------------------------')

        self.criterion = nn.BCELoss().cuda()
        # self.CE_loss = nn.CrossEntropyLoss().cuda()
        # self.MSE_loss = nn.MSELoss().cuda()
        self.gen.train()
        self.disc.train()

    def print_network(self, model, name):
        """ A function for printing total number of model parameters """
        num_params = 0
        for p in model.parameters():
            num_params += p.numel()

        print(model)
        print(name)
        print("Total number of parameters: {}".format(num_params))

    def load_checkpoints(self, resume_epoch, idx):
        """Restore the trained generator and discriminator."""
        print('Loading the trained models from epoch {} and iteration {}...'.
              format(resume_epoch, idx))
        G_path = os.path.join(self.checkpoint_dir,
                              '{}-{}-G.ckpt'.format(resume_epoch, idx))
        D_path = os.path.join(self.checkpoint_dir,
                              '{}-{}-D.ckpt'.format(resume_epoch, idx))
        self.gen.load_state_dict(
            torch.load(G_path, map_location=lambda storage, loc: storage))
        self.disc.load_state_dict(
            torch.load(D_path, map_location=lambda storage, loc: storage))

    def train_model(self):

        data_loader = self.data_loader

        start_epoch = 0
        if self.resume_epoch >= 0:
            start_epoch = self.resume_epoch
            self.load_checkpoints(self.resume_epoch, self.resume_idx)

        print('---------------  Model Training Started  ---------------')
        start_time = time.time()

        for epoch in range(start_epoch, self.num_epochs):
            print("Epoch: {}".format(epoch + 1))
            for idx, batch in enumerate(data_loader):
                print("Index: {}".format(idx + 1), end="\t")
                true_imgs = batch['true_imgs']
                true_embed = batch['true_embds']
                false_imgs = batch['false_imgs']

                real_labels = torch.ones(true_imgs.size(0))
                fake_labels = torch.zeros(true_imgs.size(0))

                smooth_real_labels = torch.FloatTensor(
                    self.smooth_label(real_labels.numpy(), -0.1))

                true_imgs = Variable(true_imgs.float()).cuda()
                true_embed = Variable(true_embed.float()).cuda()
                false_imgs = Variable(false_imgs.float()).cuda()

                real_labels = Variable(real_labels).cuda()
                smooth_real_labels = Variable(smooth_real_labels).cuda()
                fake_labels = Variable(fake_labels).cuda()

                # ---------------------------------------------------------------#
                # 					  2. Training the generator                  #
                # ---------------------------------------------------------------#
                self.gen.zero_grad()
                z = Variable(torch.randn(true_imgs.size(0), self.z_dim)).cuda()
                fake_imgs = self.gen.forward(true_embed, z)
                fake_out, fake_logit = self.disc.forward(fake_imgs, true_embed)
                fake_out = Variable(fake_out.data, requires_grad=True).cuda()

                true_out, true_logit = self.disc.forward(true_imgs, true_embed)
                true_out = Variable(true_out.data, requires_grad=True).cuda()

                g_sf = self.criterion(fake_out, real_labels)
                #g_img = self.l1_coeff * nn.L1Loss()(fake_imgs, true_imgs)
                gen_loss = g_sf

                gen_loss.backward()
                self.gen_optim.step()

                # ---------------------------------------------------------------#
                # 					3. Training the discriminator				 #
                # ---------------------------------------------------------------#
                self.disc.zero_grad()
                false_out, false_logit = self.disc.forward(
                    false_imgs, true_embed)
                false_out = Variable(false_out.data, requires_grad=True)

                sr = self.criterion(true_out, smooth_real_labels)
                sw = self.criterion(true_out, fake_labels)
                sf = self.criterion(false_out, smooth_real_labels)

                disc_loss = torch.log(sr) + (torch.log(1 - sw) +
                                             torch.log(1 - sf)) / 2

                disc_loss.backward()
                self.disc_optim.step()

                self.cls_gan_optim.step()

                # Logging
                loss = {}
                loss['G_loss'] = gen_loss.item()
                loss['D_loss'] = disc_loss.item()

                # ---------------------------------------------------------------#
                # 					4. Logging INFO into log_dir				 #
                # ---------------------------------------------------------------#
                log = ""
                if (idx + 1) % self.log_step == 0:
                    end_time = time.time() - start_time
                    end_time = datetime.timedelta(seconds=end_time)
                    log = "Elapsed [{}], Epoch [{}/{}], Idx [{}]".format(
                        end_time, epoch + 1, self.num_epochs, idx)

                for net, loss_value in loss.items():
                    log += "{}: {:.4f}".format(net, loss_value)
                    self.logger.info(log)
                    print(log)
                """
				# ---------------------------------------------------------------#
				# 					5. Saving generated images					 #
				# ---------------------------------------------------------------#
				if (idx + 1) % self.sample_step == 0:
					concat_imgs = torch.cat((true_imgs, fake_imgs), 0)  # ??????????
					concat_imgs = (concat_imgs + 1) / 2
					# out.clamp_(0, 1)
					 
					save_path = os.path.join(self.sample_dir, '{}-{}-images.jpg'.format(epoch, idx + 1))
					# concat_imgs.cpu().detach().numpy()
					self.dump_imgs(concat_imgs.cpu().numpy(), save_path)
					
					#save_image(concat_imgs.data.cpu(), self.sample_dir, nrow=1, padding=0)
					print ('Saved real and fake images into {}...'.format(self.sample_dir))
				"""

                # ---------------------------------------------------------------#
                # 				6. Saving the checkpoints & final model			 #
                # ---------------------------------------------------------------#
                if (idx + 1) % self.model_save_step == 0:
                    G_path = os.path.join(
                        self.checkpoint_dir,
                        '{}-{}-G.ckpt'.format(epoch, idx + 1))
                    D_path = os.path.join(
                        self.checkpoint_dir,
                        '{}-{}-D.ckpt'.format(epoch, idx + 1))
                    torch.save(self.gen.state_dict(), G_path)
                    torch.save(self.disc.state_dict(), D_path)
                    print('Saved model checkpoints into {}...\n'.format(
                        self.checkpoint_dir))

        print('---------------  Model Training Completed  ---------------')
        # Saving final model into final_model directory
        G_path = os.path.join(self.final_model, '{}-G.pth'.format('final'))
        D_path = os.path.join(self.final_model, '{}-D.pth'.format('final'))
        torch.save(self.gen.state_dict(), G_path)
        torch.save(self.disc.state_dict(), D_path)
        print('Saved final model into {}...'.format(self.final_model))
Ejemplo n.º 17
0
def main(args):
    # log hyperparameter
    print(args)

    # select device
    args.cuda = not args.no_cuda and torch.cuda.is_available()
    device = torch.device("cuda: 0" if args.cuda else "cpu")

    # set random seed
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)

    # data loader
    transform = transforms.Compose([
        utils.Normalize(),
        utils.ToTensor()
    ])
    train_dataset = TVDataset(
        root=args.root,
        sub_size=args.block_size,
        volume_list=args.volume_train_list,
        max_k=args.training_step,
        train=True,
        transform=transform
    )
    test_dataset = TVDataset(
        root=args.root,
        sub_size=args.block_size,
        volume_list=args.volume_test_list,
        max_k=args.training_step,
        train=False,
        transform=transform
    )

    kwargs = {"num_workers": 4, "pin_memory": True} if args.cuda else {}
    train_loader = DataLoader(train_dataset, batch_size=args.batch_size,
                              shuffle=True, **kwargs)
    test_loader = DataLoader(test_dataset, batch_size=args.batch_size,
                             shuffle=False, **kwargs)

    # model
    def generator_weights_init(m):
        if isinstance(m, nn.Conv3d):
            nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
            if m.bias is not None:
                nn.init.zeros_(m.bias)

    def discriminator_weights_init(m):
        if isinstance(m, nn.Conv3d):
            nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='leaky_relu')
            if m.bias is not None:
                nn.init.zeros_(m.bias)

    g_model = Generator(args.upsample_mode, args.forward, args.backward, args.gen_sn, args.residual)
    g_model.apply(generator_weights_init)
    if args.data_parallel and torch.cuda.device_count() > 1:
        g_model = nn.DataParallel(g_model)
    g_model.to(device)

    if args.gan_loss != "none":
        d_model = Discriminator(args.dis_sn)
        d_model.apply(discriminator_weights_init)
        # if args.dis_sn:
        #     d_model = add_sn(d_model)
        if args.data_parallel and torch.cuda.device_count() > 1:
            d_model = nn.DataParallel(d_model)
        d_model.to(device)

    mse_loss = nn.MSELoss()
    adversarial_loss = nn.MSELoss()
    train_losses, test_losses = [], []
    d_losses, g_losses = [], []

    # optimizer
    g_optimizer = optim.Adam(g_model.parameters(), lr=args.lr,
                             betas=(args.beta1, args.beta2))
    if args.gan_loss != "none":
        d_optimizer = optim.Adam(d_model.parameters(), lr=args.d_lr,
                                 betas=(args.beta1, args.beta2))

    Tensor = torch.cuda.FloatTensor if args.cuda else torch.FloatTensor

    # load checkpoint
    if args.resume:
        if os.path.isfile(args.resume):
            print("=> loading checkpoint {}".format(args.resume))
            checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint["epoch"]
            g_model.load_state_dict(checkpoint["g_model_state_dict"])
            # g_optimizer.load_state_dict(checkpoint["g_optimizer_state_dict"])
            if args.gan_loss != "none":
                d_model.load_state_dict(checkpoint["d_model_state_dict"])
                # d_optimizer.load_state_dict(checkpoint["d_optimizer_state_dict"])
                d_losses = checkpoint["d_losses"]
                g_losses = checkpoint["g_losses"]
            train_losses = checkpoint["train_losses"]
            test_losses = checkpoint["test_losses"]
            print("=> load chekcpoint {} (epoch {})"
                  .format(args.resume, checkpoint["epoch"]))

    # main loop
    for epoch in tqdm(range(args.start_epoch, args.epochs)):
        # training..
        g_model.train()
        if args.gan_loss != "none":
            d_model.train()
        train_loss = 0.
        volume_loss_part = np.zeros(args.training_step)
        for i, sample in enumerate(train_loader):
            params = list(g_model.named_parameters())
            # pdb.set_trace()
            # params[0][1].register_hook(lambda g: print("{}.grad: {}".format(params[0][0], g)))
            # adversarial ground truths
            real_label = Variable(Tensor(sample["v_i"].shape[0], sample["v_i"].shape[1], 1, 1, 1, 1).fill_(1.0), requires_grad=False)
            fake_label = Variable(Tensor(sample["v_i"].shape[0], sample["v_i"].shape[1], 1, 1, 1, 1).fill_(0.0), requires_grad=False)

            v_f = sample["v_f"].to(device)
            v_b = sample["v_b"].to(device)
            v_i = sample["v_i"].to(device)
            g_optimizer.zero_grad()
            fake_volumes = g_model(v_f, v_b, args.training_step, args.wo_ori_volume, args.norm)

            # adversarial loss
            # update discriminator
            if args.gan_loss != "none":
                avg_d_loss = 0.
                avg_d_loss_real = 0.
                avg_d_loss_fake = 0.
                for k in range(args.n_d):
                    d_optimizer.zero_grad()
                    decisions = d_model(v_i)
                    d_loss_real = adversarial_loss(decisions, real_label)
                    fake_decisions = d_model(fake_volumes.detach())

                    d_loss_fake = adversarial_loss(fake_decisions, fake_label)
                    d_loss = d_loss_real + d_loss_fake
                    d_loss.backward()
                    avg_d_loss += d_loss.item() / args.n_d
                    avg_d_loss_real += d_loss_real / args.n_d
                    avg_d_loss_fake += d_loss_fake / args.n_d

                    d_optimizer.step()

            # update generator
            if args.gan_loss != "none":
                avg_g_loss = 0.
            avg_loss = 0.
            for k in range(args.n_g):
                loss = 0.
                g_optimizer.zero_grad()

                # adversarial loss
                if args.gan_loss != "none":
                    fake_decisions = d_model(fake_volumes)
                    g_loss = args.gan_loss_weight * adversarial_loss(fake_decisions, real_label)
                    loss += g_loss
                    avg_g_loss += g_loss.item() / args.n_g

                # volume loss
                if args.volume_loss:
                    volume_loss = args.volume_loss_weight * mse_loss(v_i, fake_volumes)
                    for j in range(v_i.shape[1]):
                        volume_loss_part[j] += mse_loss(v_i[:, j, :], fake_volumes[:, j, :]) / args.n_g / args.log_every
                    loss += volume_loss

                # feature loss
                if args.feature_loss:
                    feat_real = d_model.extract_features(v_i)
                    feat_fake = d_model.extract_features(fake_volumes)
                    for m in range(len(feat_real)):
                        loss += args.feature_loss_weight / len(feat_real) * mse_loss(feat_real[m], feat_fake[m])

                avg_loss += loss / args.n_g
                loss.backward()
                g_optimizer.step()

            train_loss += avg_loss

            # log training status
            subEpoch = (i + 1) // args.log_every
            if (i+1) % args.log_every == 0:
                print("Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}".format(
                    epoch, (i+1) * args.batch_size, len(train_loader.dataset), 100. * (i+1) / len(train_loader),
                    avg_loss
                ))
                print("Volume Loss: ")
                for j in range(volume_loss_part.shape[0]):
                    print("\tintermediate {}: {:.6f}".format(
                        j+1, volume_loss_part[j]
                    ))

                if args.gan_loss != "none":
                    print("DLossReal: {:.6f} DLossFake: {:.6f} DLoss: {:.6f}, GLoss: {:.6f}".format(
                        avg_d_loss_real, avg_d_loss_fake, avg_d_loss, avg_g_loss
                    ))
                    d_losses.append(avg_d_loss)
                    g_losses.append(avg_g_loss)
                # train_losses.append(avg_loss)
                train_losses.append(train_loss.item() / args.log_every)
                print("====> SubEpoch: {} Average loss: {:.6f} Time {}".format(
                    subEpoch, train_loss.item() / args.log_every, time.asctime(time.localtime(time.time()))
                ))
                train_loss = 0.
                volume_loss_part = np.zeros(args.training_step)

            # testing...
            if (i + 1) % args.test_every == 0:
                g_model.eval()
                if args.gan_loss != "none":
                    d_model.eval()
                test_loss = 0.
                with torch.no_grad():
                    for i, sample in enumerate(test_loader):
                        v_f = sample["v_f"].to(device)
                        v_b = sample["v_b"].to(device)
                        v_i = sample["v_i"].to(device)
                        fake_volumes = g_model(v_f, v_b, args.training_step, args.wo_ori_volume, args.norm)
                        test_loss += args.volume_loss_weight * mse_loss(v_i, fake_volumes).item()

                test_losses.append(test_loss * args.batch_size / len(test_loader.dataset))
                print("====> SubEpoch: {} Test set loss {:4f} Time {}".format(
                    subEpoch, test_losses[-1], time.asctime(time.localtime(time.time()))
                ))

            # saving...
            if (i+1) % args.check_every == 0:
                print("=> saving checkpoint at epoch {}".format(epoch))
                if args.gan_loss != "none":
                    torch.save({"epoch": epoch + 1,
                                "g_model_state_dict": g_model.state_dict(),
                                "g_optimizer_state_dict":  g_optimizer.state_dict(),
                                "d_model_state_dict": d_model.state_dict(),
                                "d_optimizer_state_dict": d_optimizer.state_dict(),
                                "d_losses": d_losses,
                                "g_losses": g_losses,
                                "train_losses": train_losses,
                                "test_losses": test_losses},
                               os.path.join(args.save_dir, "model_" + str(epoch) + "_" + str(subEpoch) + "_" + "pth.tar")
                               )
                else:
                    torch.save({"epoch": epoch + 1,
                                "g_model_state_dict": g_model.state_dict(),
                                "g_optimizer_state_dict": g_optimizer.state_dict(),
                                "train_losses": train_losses,
                                "test_losses": test_losses},
                               os.path.join(args.save_dir, "model_" + str(epoch) + "_" + str(subEpoch) + "_" + "pth.tar")
                               )
                torch.save(g_model.state_dict(),
                           os.path.join(args.save_dir, "model_" + str(epoch) + "_" + str(subEpoch) + ".pth"))

        num_subEpoch = len(train_loader) // args.log_every
        print("====> Epoch: {} Average loss: {:.6f} Time {}".format(
            epoch, np.array(train_losses[-num_subEpoch:]).mean(), time.asctime(time.localtime(time.time()))
        ))
Ejemplo n.º 18
0
# PRETRAIN DISCRIMINATOR
# random choose 1000 batch to train discriminator

print('\nStarting Discriminator Training...')
dis_optimizer = optim.Adagrad(dis.parameters())

optimizer = optim.Adam(gen.parameters(), lr=1e-4)

#optimizer = optim.Adagrad(dis.parameters())

#train_discriminator(dis, dis_optimizer, train_gen_batch, train_tar_batch, BATCH_SIZE, 3)
#torch.save(dis.state_dict(), './model/pretrain_discriminator/model.pt')

pretrained_dis_path = './model/pretrain_discriminator/model.pt'
dis.load_state_dict(torch.load(pretrained_dis_path))
dis.to(device)

#h = dis.init_hidden(10)
#dis(train_src_batch[0].transpose(0, 1).to(device), h)

# ADVERSARIAL TRAINING
print('\nStarting Adversarial Training...')

ADV_TRAIN_EPOCHS = 0
for epoch in range(ADV_TRAIN_EPOCHS):
    print('\n--------\nEPOCH %d\n--------' % (epoch + 1))
    # TRAIN GENERATOR
    print('\nAdversarial Training Generator : ', end='')
    sys.stdout.flush()
    gen.train()
Ejemplo n.º 19
0
import sys
import os
from unet import UNet
from discriminator import Discriminator
from data_loader import Dataset
from predictor import Predictor

yml_path = sys.argv[1]
with open(yml_path) as f:
    config = yaml.load(f)

if config['use_gpu']:
    torch.set_default_tensor_type('torch.cuda.FloatTensor')

discriminator = Discriminator(**(config['discriminator_params']))
unet = UNet(**(config['unet_params']))

dl = Dataset(**(config['dataset_params'])) \
    .flow_from_directory(**(config['test_dataloader_params']))

unet_path = os.path.join(config['fit_params']['logdir'],
                         'unet_%d.pth' % config['test_epoch'])
unet.load_state_dict(torch.load(unet_path))

discriminator_path = os.path.join(
    config['fit_params']['logdir'],
    'discriminator_%d.pth' % config['test_epoch'])
discriminator.load_state_dict(torch.load(discriminator_path))

p = Predictor(unet, discriminator)
p(dl, os.path.join(config['fit_params']['logdir'], 'predicted'))
Ejemplo n.º 20
0
from discriminator import Discriminator
import torchvision.datasets as dset
from torchvision import transforms
import torch.utils.data


if __name__ == "__main__":
    saved_state = torch.load("C:\\Users\\ankit\\Workspaces\\CS7150\\FinalProject\\models\\trained_model_Mon_05_45.pth")
    dis = Discriminator(ngpu=1, num_channels=3, num_features=64)
    dis.load_state_dict(saved_state['discriminator'])

    dis.eval()

    dataset = dset.ImageFolder(root="C:\\Users\\ankit\\Workspaces\\CS7150\\data\\imagenet",
                                    transform=transforms.Compose([
                                        transforms.ToTensor(),
                                        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
                                    ]))
    # Create the dataloader
    dataloader = torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=True, num_workers=2)
    images = next(iter(dataloader))
    out = dis(images[0])

    print()


Ejemplo n.º 21
0
def main(args):
    use_cuda = (len(args.gpuid) >= 1)
    print("{0} GPU(s) are available".format(cuda.device_count()))

    print("======printing args========")
    print(args)
    print("=================================")

    # Load dataset
    splits = ['train', 'valid']
    if data.has_binary_files(args.data, splits):
        print("Loading bin dataset")
        dataset = data.load_dataset(args.data, splits, args.src_lang,
                                    args.trg_lang, args.fixed_max_len)
        #args.data, splits, args.src_lang, args.trg_lang)
    else:
        print(f"Loading raw text dataset {args.data}")
        dataset = data.load_raw_text_dataset(args.data, splits, args.src_lang,
                                             args.trg_lang, args.fixed_max_len)
        #args.data, splits, args.src_lang, args.trg_lang)
    if args.src_lang is None or args.trg_lang is None:
        # record inferred languages in args, so that it's saved in checkpoints
        args.src_lang, args.trg_lang = dataset.src, dataset.dst
    print('| [{}] dictionary: {} types'.format(dataset.src,
                                               len(dataset.src_dict)))
    print('| [{}] dictionary: {} types'.format(dataset.dst,
                                               len(dataset.dst_dict)))
    for split in splits:
        print('| {} {} {} examples'.format(args.data, split,
                                           len(dataset.splits[split])))

    g_logging_meters = OrderedDict()
    g_logging_meters['train_loss'] = AverageMeter()
    g_logging_meters['valid_loss'] = AverageMeter()
    g_logging_meters['train_acc'] = AverageMeter()
    g_logging_meters['valid_acc'] = AverageMeter()
    g_logging_meters['bsz'] = AverageMeter()  # sentences per batch

    d_logging_meters = OrderedDict()
    d_logging_meters['train_loss'] = AverageMeter()
    d_logging_meters['valid_loss'] = AverageMeter()
    d_logging_meters['train_acc'] = AverageMeter()
    d_logging_meters['valid_acc'] = AverageMeter()
    d_logging_meters['bsz'] = AverageMeter()  # sentences per batch

    # Set model parameters
    args.encoder_embed_dim = 1000
    args.encoder_layers = 4
    args.encoder_dropout_out = 0
    args.decoder_embed_dim = 1000
    args.decoder_layers = 4
    args.decoder_out_embed_dim = 1000
    args.decoder_dropout_out = 0
    args.bidirectional = False

    # try to load generator model
    g_model_path = 'checkpoints/generator/best_gmodel.pt'
    if not os.path.exists(g_model_path):
        print("Start training generator!")
        train_g(args, dataset)
    assert os.path.exists(g_model_path)
    generator = LSTMModel(args,
                          dataset.src_dict,
                          dataset.dst_dict,
                          use_cuda=use_cuda)
    model_dict = generator.state_dict()
    pretrained_dict = torch.load(g_model_path)
    #print(f"First dict: {pretrained_dict}")
    # 1. filter out unnecessary keys
    pretrained_dict = {
        k: v
        for k, v in pretrained_dict.items() if k in model_dict
    }
    #print(f"Second dict: {pretrained_dict}")
    # 2. overwrite entries in the existing state dict
    model_dict.update(pretrained_dict)
    #print(f"model dict: {model_dict}")
    # 3. load the new state dict
    generator.load_state_dict(model_dict)

    print("Generator has successfully loaded!")

    # try to load discriminator model
    d_model_path = 'checkpoints/discriminator/best_dmodel.pt'
    if not os.path.exists(d_model_path):
        print("Start training discriminator!")
        train_d(args, dataset)
    assert os.path.exists(d_model_path)
    discriminator = Discriminator(args,
                                  dataset.src_dict,
                                  dataset.dst_dict,
                                  use_cuda=use_cuda)
    model_dict = discriminator.state_dict()
    pretrained_dict = torch.load(d_model_path)
    # 1. filter out unnecessary keys
    pretrained_dict = {
        k: v
        for k, v in pretrained_dict.items() if k in model_dict
    }
    # 2. overwrite entries in the existing state dict
    model_dict.update(pretrained_dict)
    # 3. load the new state dict
    discriminator.load_state_dict(model_dict)

    print("Discriminator has successfully loaded!")

    #return
    print("starting main training loop")

    torch.autograd.set_detect_anomaly(True)

    if use_cuda:
        if torch.cuda.device_count() > 1:
            discriminator = torch.nn.DataParallel(discriminator).cuda()
            generator = torch.nn.DataParallel(generator).cuda()
        else:
            generator.cuda()
            discriminator.cuda()
    else:
        discriminator.cpu()
        generator.cpu()

    # adversarial training checkpoints saving path
    if not os.path.exists('checkpoints/joint'):
        os.makedirs('checkpoints/joint')
    checkpoints_path = 'checkpoints/joint/'

    # define loss function
    g_criterion = torch.nn.NLLLoss(size_average=False,
                                   ignore_index=dataset.dst_dict.pad(),
                                   reduce=True)
    d_criterion = torch.nn.BCEWithLogitsLoss()
    pg_criterion = PGLoss(ignore_index=dataset.dst_dict.pad(),
                          size_average=True,
                          reduce=True)

    # fix discriminator word embedding (as Wu et al. do)
    for p in discriminator.embed_src_tokens.parameters():
        p.requires_grad = False
    for p in discriminator.embed_trg_tokens.parameters():
        p.requires_grad = False

    # define optimizer
    g_optimizer = eval("torch.optim." + args.g_optimizer)(filter(
        lambda x: x.requires_grad, generator.parameters()),
                                                          args.g_learning_rate)

    d_optimizer = eval("torch.optim." + args.d_optimizer)(
        filter(lambda x: x.requires_grad, discriminator.parameters()),
        args.d_learning_rate,
        momentum=args.momentum,
        nesterov=True)

    # start joint training
    best_dev_loss = math.inf
    num_update = 0
    # main training loop
    for epoch_i in range(1, args.epochs + 1):
        logging.info("At {0}-th epoch.".format(epoch_i))

        # seed = args.seed + epoch_i
        # torch.manual_seed(seed)

        max_positions_train = (args.fixed_max_len, args.fixed_max_len)

        # Initialize dataloader, starting at batch_offset
        itr = dataset.train_dataloader(
            'train',
            max_tokens=args.max_tokens,
            max_sentences=args.joint_batch_size,
            max_positions=max_positions_train,
            # seed=seed,
            epoch=epoch_i,
            sample_without_replacement=args.sample_without_replacement,
            sort_by_source_size=(epoch_i <= args.curriculum),
            shard_id=args.distributed_rank,
            num_shards=args.distributed_world_size,
        )

        # reset meters
        for key, val in g_logging_meters.items():
            if val is not None:
                val.reset()
        for key, val in d_logging_meters.items():
            if val is not None:
                val.reset()

        # set training mode
        generator.train()
        discriminator.train()
        update_learning_rate(num_update, 8e4, args.g_learning_rate,
                             args.lr_shrink, g_optimizer)

        for i, sample in enumerate(itr):
            if use_cuda:
                # wrap input tensors in cuda tensors
                sample = utils.make_variable(sample, cuda=cuda)

            ## part I: use gradient policy method to train the generator

            # use policy gradient training when rand > 50%
            rand = random.random()
            if rand >= 0.5:
                # policy gradient training
                generator.decoder.is_testing = True
                sys_out_batch, prediction, _ = generator(sample)
                generator.decoder.is_testing = False
                with torch.no_grad():
                    n_i = sample['net_input']['src_tokens']
                    #print(f"net input:\n{n_i}, pred: \n{prediction}")
                    reward = discriminator(
                        sample['net_input']['src_tokens'],
                        prediction)  # dataset.dst_dict.pad())
                train_trg_batch = sample['target']
                #print(f"sys_out_batch: {sys_out_batch.shape}:\n{sys_out_batch}")
                pg_loss = pg_criterion(sys_out_batch, train_trg_batch, reward,
                                       use_cuda)
                # logging.debug("G policy gradient loss at batch {0}: {1:.3f}, lr={2}".format(i, pg_loss.item(), g_optimizer.param_groups[0]['lr']))
                g_optimizer.zero_grad()
                pg_loss.backward()
                torch.nn.utils.clip_grad_norm(generator.parameters(),
                                              args.clip_norm)
                g_optimizer.step()

                # oracle valid
                _, _, loss = generator(sample)
                sample_size = sample['target'].size(
                    0) if args.sentence_avg else sample['ntokens']
                logging_loss = loss.data / sample_size / math.log(2)
                g_logging_meters['train_loss'].update(logging_loss,
                                                      sample_size)
                logging.debug(
                    "G MLE loss at batch {0}: {1:.3f}, lr={2}".format(
                        i, g_logging_meters['train_loss'].avg,
                        g_optimizer.param_groups[0]['lr']))
            else:
                # MLE training
                #print(f"printing sample: \n{sample}")
                _, _, loss = generator(sample)
                sample_size = sample['target'].size(
                    0) if args.sentence_avg else sample['ntokens']
                nsentences = sample['target'].size(0)
                logging_loss = loss.data / sample_size / math.log(2)
                g_logging_meters['bsz'].update(nsentences)
                g_logging_meters['train_loss'].update(logging_loss,
                                                      sample_size)
                logging.debug(
                    "G MLE loss at batch {0}: {1:.3f}, lr={2}".format(
                        i, g_logging_meters['train_loss'].avg,
                        g_optimizer.param_groups[0]['lr']))
                g_optimizer.zero_grad()
                loss.backward()
                # all-reduce grads and rescale by grad_denom
                for p in generator.parameters():
                    if p.requires_grad:
                        p.grad.data.div_(sample_size)
                torch.nn.utils.clip_grad_norm(generator.parameters(),
                                              args.clip_norm)
                g_optimizer.step()
            num_update += 1

            # part II: train the discriminator
            bsz = sample['target'].size(0)
            src_sentence = sample['net_input']['src_tokens']
            # train with half human-translation and half machine translation

            true_sentence = sample['target']
            true_labels = Variable(
                torch.ones(sample['target'].size(0)).float())

            with torch.no_grad():
                generator.decoder.is_testing = True
                _, prediction, _ = generator(sample)
                generator.decoder.is_testing = False
            fake_sentence = prediction
            fake_labels = Variable(
                torch.zeros(sample['target'].size(0)).float())

            trg_sentence = torch.cat([true_sentence, fake_sentence], dim=0)
            labels = torch.cat([true_labels, fake_labels], dim=0)

            indices = np.random.permutation(2 * bsz)
            trg_sentence = trg_sentence[indices][:bsz]
            labels = labels[indices][:bsz]

            if use_cuda:
                labels = labels.cuda()

            disc_out = discriminator(src_sentence,
                                     trg_sentence)  #, dataset.dst_dict.pad())
            #print(f"disc out: {disc_out.shape}, labels: {labels.shape}")
            #print(f"labels: {labels}")
            d_loss = d_criterion(disc_out, labels.long())
            acc = torch.sum(torch.Sigmoid()
                            (disc_out).round() == labels).float() / len(labels)
            d_logging_meters['train_acc'].update(acc)
            d_logging_meters['train_loss'].update(d_loss)
            # logging.debug("D training loss {0:.3f}, acc {1:.3f} at batch {2}: ".format(d_logging_meters['train_loss'].avg,
            #                                                                            d_logging_meters['train_acc'].avg,
            #                                                                            i))
            d_optimizer.zero_grad()
            d_loss.backward()
            d_optimizer.step()

        # validation
        # set validation mode
        generator.eval()
        discriminator.eval()
        # Initialize dataloader
        max_positions_valid = (args.fixed_max_len, args.fixed_max_len)
        itr = dataset.eval_dataloader(
            'valid',
            max_tokens=args.max_tokens,
            max_sentences=args.joint_batch_size,
            max_positions=max_positions_valid,
            skip_invalid_size_inputs_valid_test=True,
            descending=True,  # largest batch first to warm the caching allocator
            shard_id=args.distributed_rank,
            num_shards=args.distributed_world_size,
        )

        # reset meters
        for key, val in g_logging_meters.items():
            if val is not None:
                val.reset()
        for key, val in d_logging_meters.items():
            if val is not None:
                val.reset()

        for i, sample in enumerate(itr):
            with torch.no_grad():
                if use_cuda:
                    sample['id'] = sample['id'].cuda()
                    sample['net_input']['src_tokens'] = sample['net_input'][
                        'src_tokens'].cuda()
                    sample['net_input']['src_lengths'] = sample['net_input'][
                        'src_lengths'].cuda()
                    sample['net_input']['prev_output_tokens'] = sample[
                        'net_input']['prev_output_tokens'].cuda()
                    sample['target'] = sample['target'].cuda()

                # generator validation
                _, _, loss = generator(sample)
                sample_size = sample['target'].size(
                    0) if args.sentence_avg else sample['ntokens']
                loss = loss / sample_size / math.log(2)
                g_logging_meters['valid_loss'].update(loss, sample_size)
                logging.debug("G dev loss at batch {0}: {1:.3f}".format(
                    i, g_logging_meters['valid_loss'].avg))

                # discriminator validation
                bsz = sample['target'].size(0)
                src_sentence = sample['net_input']['src_tokens']
                # train with half human-translation and half machine translation

                true_sentence = sample['target']
                true_labels = Variable(
                    torch.ones(sample['target'].size(0)).float())

                with torch.no_grad():
                    generator.decoder.is_testing = True
                    _, prediction, _ = generator(sample)
                    generator.decoder.is_testing = False
                fake_sentence = prediction
                fake_labels = Variable(
                    torch.zeros(sample['target'].size(0)).float())

                trg_sentence = torch.cat([true_sentence, fake_sentence], dim=0)
                labels = torch.cat([true_labels, fake_labels], dim=0)

                indices = np.random.permutation(2 * bsz)
                trg_sentence = trg_sentence[indices][:bsz]
                labels = labels[indices][:bsz]

                if use_cuda:
                    labels = labels.cuda()

                disc_out = discriminator(src_sentence, trg_sentence,
                                         dataset.dst_dict.pad())
                d_loss = d_criterion(disc_out, labels)
                acc = torch.sum(torch.Sigmoid()(disc_out).round() ==
                                labels).float() / len(labels)
                d_logging_meters['valid_acc'].update(acc)
                d_logging_meters['valid_loss'].update(d_loss)
                # logging.debug("D dev loss {0:.3f}, acc {1:.3f} at batch {2}".format(d_logging_meters['valid_loss'].avg,
                #                                                                     d_logging_meters['valid_acc'].avg, i))

        torch.save(generator,
                   open(
                       checkpoints_path + "joint_{0:.3f}.epoch_{1}.pt".format(
                           g_logging_meters['valid_loss'].avg, epoch_i), 'wb'),
                   pickle_module=dill)

        if g_logging_meters['valid_loss'].avg < best_dev_loss:
            best_dev_loss = g_logging_meters['valid_loss'].avg
            torch.save(generator,
                       open(checkpoints_path + "best_gmodel.pt", 'wb'),
                       pickle_module=dill)
Ejemplo n.º 22
0
if_use_wgan_gp = True
if_vis = True

logger = Logger('./logs')
batch_size = 40
lr = 0.0003

discriminator = Discriminator()
generator = Generator()
one = torch.FloatTensor([1])
mone = one * -1
LAMBDA = 1

try:
    discriminator.load_state_dict(torch.load("./discriminator.pkl"))
    generator.load_state_dict(torch.load("./generator.pkl"))
    print('Load learner previous point: Successed')
except Exception as e:
    print('Load learner previous point: Failed')

if if_vis:
    TMUX = 'TMUX 1'
    port = 8097
    from visdom import Visdom
    viz = Visdom(port=port)
    win = None
    win_dic = {}
    recorder = {
        'plot': {},
        'line': {},
Ejemplo n.º 23
0
  :param tags:
  :return: img's tensor and file path.
  '''
    # g_noise = Variable(torch.FloatTensor(1, 128)).to(device).data.normal_(.0, 1)
    # g_tag = Variable(torch.FloatTensor([utils.get_one_hot(tags)])).to(device)
    g_noise, g_tag = utils.fake_generator(1, 128, device)

    img = G(torch.cat([g_noise, g_tag], dim=1))
    label_p, tag_p = D(img)
    print(label_p)
    print(tag_p)
    vutils.save_image(img.data.view(1, 3, 128, 128),
                      os.path.join(tmp_path, '{}.png'.format(file_name)))
    print('Saved file in {}'.format(
        os.path.join(tmp_path, '{}.png'.format(file_name))))
    return img.data.view(1, 3, 128,
                         128), os.path.join(tmp_path,
                                            '{}.png'.format(file_name))


if __name__ == '__main__':
    G = Generator().to(device)
    checkpoint, _ = load_checkpoint(model_dump_path)
    G.load_state_dict(checkpoint['G'])
    # print(G)

    D = Discriminator().to(device)
    D.load_state_dict(checkpoint['D'])
    # print(D)
    img, _ = generate(G, 'test', ['glasses'], D)
Ejemplo n.º 24
0
class SVM_Classifier:
    def __init__(self, batch_size, image_size=64):
        self.image_size = image_size
        self.device = torch.device("cuda:0" if (
            torch.cuda.is_available()) else "cpu")
        self.save_filename = f'model_{datetime.datetime.now().strftime("%a_%H_%M")}.sav'

        transform = transforms.Compose([
            # transforms.Resize(self.image_size),
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
        ])

        self.trainset = torchvision.datasets.CIFAR10(root='./data',
                                                     train=True,
                                                     download=True,
                                                     transform=transform)
        self.trainloader = data.DataLoader(self.trainset,
                                           batch_size=batch_size,
                                           shuffle=True,
                                           num_workers=2)

        self.testset = torchvision.datasets.CIFAR10(root='./data',
                                                    train=False,
                                                    download=True,
                                                    transform=transform)
        self.testloader = data.DataLoader(self.testset,
                                          batch_size=batch_size,
                                          shuffle=False,
                                          num_workers=2)

        saved_state = torch.load(
            "C:\\Users\\ankit\\Workspaces\\CS7150\\FinalProject\\models\\imagenet\\trained_model_Tue_17_06.pth"
        )
        self.discriminator = Discriminator(ngpu=1,
                                           num_channels=3,
                                           num_features=64,
                                           data_generation_mode=1,
                                           input_size=image_size)
        self.discriminator.load_state_dict(saved_state['discriminator'])
        self.discriminator.eval()  # change the mode of the network.

    def plot_training_data(self):
        # Plot some training images
        real_batch = next(iter(self.trainloader))
        real_batch = real_batch[0][0:8]
        plt.figure(figsize=(8, 8))
        plt.axis("off")
        plt.title("Training Images")
        plt.imshow(
            np.transpose(
                vutils.make_grid(real_batch[0].to(self.device)[:64],
                                 padding=2,
                                 normalize=True).cpu(), (1, 2, 0)))
        plt.show()

    def train(self):
        train_data, train_labels = next(iter(self.trainloader))
        modified_train_data = self.discriminator(train_data)
        l2_svm = svm.LinearSVC(verbose=2, max_iter=2000)

        modified_train_data_ndarray = modified_train_data.detach().numpy()
        train_labels_ndarray = train_labels.detach().numpy()
        self.l2_svm = l2_svm.fit(modified_train_data_ndarray,
                                 train_labels_ndarray)

        # save model
        with open(self.save_filename, 'wb') as file:
            pickle.dump(self.l2_svm, file)

    def train_test_SGD_Classifier(self):
        est = make_pipeline(StandardScaler(), SGDClassifier(max_iter=200))
        training_data = self.discriminator(next(iter(self.trainloader))[0])
        training_data = training_data.detach().numpy()
        est.steps[0][1].fit(training_data)

        self.est = est

        for i, data in enumerate(self.trainloader):
            train_data, train_labels = data
            modified_train_data = self.discriminator(train_data)

            modified_train_data_ndarray = modified_train_data.detach().numpy()
            train_labels_ndarray = train_labels.detach().numpy()
            modified_train_data_ndarray = est.steps[0][1].transform(
                modified_train_data_ndarray)

            est.steps[1][1].partial_fit(
                modified_train_data_ndarray,
                train_labels_ndarray,
                classes=np.unique(train_labels_ndarray))
            print(f'Batch: {i}')

        with open(self.save_filename, 'wb') as file:
            pickle.dump(est.steps[1][1], file)

    def test(self):
        l2_svm = self.est.steps[1][1]
        accuracy = []

        for i, data in enumerate(self.testloader):
            test_data, test_labels = data
            modified_test_data = self.discriminator(test_data)

            modified_test_data_ndarray = modified_test_data.detach().numpy()
            test_labels_ndarray = test_labels.detach().numpy()
            modified_test_data_ndarray = self.est.steps[0][1].transform(
                modified_test_data_ndarray)

            predictions = l2_svm.predict(modified_test_data_ndarray)

            accuracy.append(
                metrics.accuracy_score(test_labels_ndarray, predictions))

        print(f'Accuracy: {np.mean(accuracy)}')
Ejemplo n.º 25
0
def main(options):
    # 1. Make sure the options are valid argparse CLI options indeed
    assert isinstance(options, argparse.Namespace)

    # 2. Set up the logger
    logging.basicConfig(level=str(options.loglevel).upper())

    # 3. Make sure the output dir `outf` exists
    _check_out_dir(options)

    # 4. Set the random state
    _set_random_state(options)

    # 5. Configure CUDA and Cudnn, set the global `device` for PyTorch
    device = _configure_cuda(options)

    # 6. Prepare the datasets
    data_loader = _prepare_dataset(options)

    # 7. Set the parameters
    ngpu = int(options.ngpu)  # num of GPUs
    nz = int(options.nz)  # size of latent vector
    ngf = int(options.ngf)  # depth of feature maps through G
    ndf = int(options.ndf)  # depth of feature maps through D
    nc = int(options.nc
             )  # num of channels of the input images, 3 indicates color images

    # 8. Initialize (or load checkpoints for) the Generator model
    netG = Generator(ngpu, nz, ngf, nc).to(device)
    netG.apply(weights_init)
    if options.netG != '':
        logging.info(
            f'Found checkpoint of Generator at {options.netG}, loading from the saved model.\n'
        )
        netG.load_state_dict(torch.load(options.netG))
    logging.info(f'Showing the Generator model: \n {netG}\n')

    # 9. Initialize (or load checkpoints for) the Discriminator model
    netD = Discriminator(ngpu, ndf, nc).to(device)
    netD.apply(weights_init)
    if options.netD != '':
        logging.info(
            f'Found checkpoint of Discriminator at {options.netG}, loading from the saved model.\n'
        )
        netD.load_state_dict(torch.load(options.netD))
    logging.info(f'Showing the Discriminator model: \n {netD}\n')

    # ========================
    # === Training Process ===
    # ========================

    # Lists to keep track of progress
    img_list = []
    G_losses = []
    D_losses = []
    stats = []
    iters = 0

    # Set the loss function to Binary Cross Entropy between the target and the output
    # See https://pytorch.org/docs/stable/nn.html#torch.nn.BCELoss
    criterion = nn.BCELoss()

    fixed_noise = torch.randn(options.batchSize, nz, 1, 1, device=device)
    real_label = 1
    fake_label = 0

    # setup optimizer
    optimizerD = optim.Adam(netD.parameters(),
                            lr=options.lr,
                            betas=(options.beta1, 0.999))
    optimizerG = optim.Adam(netG.parameters(),
                            lr=options.lr,
                            betas=(options.beta1, 0.999))

    print("\nStarting Training Loop...\n")
    for epoch in range(options.niter):
        for i, data in enumerate(data_loader, 0):
            ############################
            # (1) Update D network: maximize log(D(x)) + log(1 - D(G(z)))
            ###########################
            # train with real
            netD.zero_grad()
            real_cpu = data[0].to(device)
            batch_size = real_cpu.size(0)
            label = torch.full((batch_size, ), real_label, device=device)

            output = netD(real_cpu)
            errD_real = criterion(output, label)
            errD_real.backward()
            D_x = output.mean().item()

            # train with fake
            noise = torch.randn(batch_size, nz, 1, 1, device=device)
            fake = netG(noise)
            label.fill_(fake_label)
            output = netD(fake.detach())
            errD_fake = criterion(output, label)
            errD_fake.backward()
            D_G_z1 = output.mean().item()
            errD = errD_real + errD_fake
            optimizerD.step()

            ############################
            # (2) Update G network: maximize log(D(G(z)))
            ###########################
            netG.zero_grad()
            label.fill_(real_label)  # fake labels are real for generator cost
            output = netD(fake)
            errG = criterion(output, label)
            errG.backward()
            D_G_z2 = output.mean().item()
            optimizerG.step()

            training_status = f"[{epoch}/{options.niter}][{i}/{len(data_loader)}] Loss_D: {errD.item():.4f} Loss_G: " \
                f"{errG.item():.4f} D(x): {D_x:.4f} D(G(z)): {D_G_z1:.4f}/{D_G_z2:.4f}"
            print(training_status)

            if i % int(options.notefreq) == 0:
                vutils.save_image(
                    real_cpu,
                    f"{options.outf}/real_samples_epoch_{epoch:{0}{3}}_{i}.png",
                    normalize=True)
                fake = netG(fixed_noise)
                vutils.save_image(
                    fake.detach(),
                    f"{options.outf}/fake_samples_epoch_{epoch:{0}{3}}_{i}.png",
                    normalize=True)

                # Save Losses statistics for post-mortem
                G_losses.append(errG.item())
                D_losses.append(errD.item())
                stats.append(training_status)

                # Check how the generator is doing by saving G's output on fixed_noise
                if (iters % 500 == 0) or ((epoch == options.niter - 1) and
                                          (i == len(data_loader) - 1)):
                    with torch.no_grad():
                        fake = netG(fixed_noise).detach().cpu()
                    img_list.append(
                        vutils.make_grid(fake, padding=2, normalize=True))
                iters += 1

        # do checkpointing
        torch.save(netG.state_dict(), f"{options.outf}/netG_epoch_{epoch}.pth")
        torch.save(netG.state_dict(), f"{options.outf}/netD_epoch_{epoch}.pth")

    # save training stats
    _save_stats(statistic=G_losses, save_name='G_losses', options=options)
    _save_stats(statistic=D_losses, save_name='D_losses', options=options)
    _save_stats(statistic=stats, save_name='Training_stats', options=options)
Ejemplo n.º 26
0
def main(pretrain_dataset, rl_dataset, args):
    ##############################################################################
    # Setup
    ##############################################################################
    # set random seeds
    random.seed(const.SEED)
    np.random.seed(const.SEED)

    # load datasets
    pt_train_loader, pt_valid_loader = SplitDataLoader(
        pretrain_dataset, batch_size=const.BATCH_SIZE, drop_last=True).split()

    # Define Networks
    generator = Generator(const.VOCAB_SIZE, const.GEN_EMBED_DIM,
                          const.GEN_HIDDEN_DIM, device, args.cuda)
    discriminator = Discriminator(const.VOCAB_SIZE, const.DSCR_EMBED_DIM,
                                  const.DSCR_FILTER_LENGTHS,
                                  const.DSCR_NUM_FILTERS,
                                  const.DSCR_NUM_CLASSES, const.DSCR_DROPOUT)

    # if torch.cuda.device_count() > 1:
    # print("Using", torch.cuda.device_count(), "GPUs.")
    # generator = nn.DataParallel(generator)
    # discriminator = nn.DataParallel(discriminator)
    generator.to(device)
    discriminator.to(device)

    # set CUDA
    if args.cuda and torch.cuda.is_available():
        generator = generator.cuda()
        discriminator = discriminator.cuda()
    ##############################################################################

    ##############################################################################
    # Pre-Training
    ##############################################################################
    # Pretrain and save Generator using MLE, Load the Pretrained generator and display training stats
    # if it already exists.
    print('#' * 80)
    print('Generator Pretraining')
    print('#' * 80)
    if (not (args.force_pretrain
             or args.force_pretrain_gen)) and op.exists(GEN_MODEL_CACHE):
        print('Loading Pretrained Generator ...')
        checkpoint = torch.load(GEN_MODEL_CACHE)
        generator.load_state_dict(checkpoint['state_dict'])
        print('::INFO:: DateTime - %s.' % checkpoint['datetime'])
        print('::INFO:: Model was trained for %d epochs.' %
              checkpoint['epochs'])
        print('::INFO:: Final Training Loss - %.5f' % checkpoint['train_loss'])
        print('::INFO:: Final Validation Loss - %.5f' %
              checkpoint['valid_loss'])
    else:
        try:
            print('Pretraining Generator with MLE ...')
            GeneratorPretrainer(generator, pt_train_loader, pt_valid_loader,
                                PT_CACHE_DIR, device, args).train()
        except KeyboardInterrupt:
            print('Stopped Generator Pretraining Early.')

    # Pretrain Discriminator on real data and data from the pretrained generator. If a pretrained Discriminator
    # already exists, load it and display its stats
    print('#' * 80)
    print('Discriminator Pretraining')
    print('#' * 80)
    if (not (args.force_pretrain
             or args.force_pretrain_dscr)) and op.exists(DSCR_MODEL_CACHE):
        print("Loading Pretrained Discriminator ...")
        checkpoint = torch.load(DSCR_MODEL_CACHE)
        discriminator.load_state_dict(checkpoint['state_dict'])
        print('::INFO:: DateTime - %s.' % checkpoint['datetime'])
        print('::INFO:: Model was trained on %d data generations.' %
              checkpoint['data_gens'])
        print('::INFO:: Model was trained for %d epochs per data generation.' %
              checkpoint['epochs_per_gen'])
        print('::INFO:: Final Loss - %.5f' % checkpoint['loss'])
    else:
        print('Pretraining Discriminator ...')
        try:
            DiscriminatorPretrainer(discriminator, rl_dataset, PT_CACHE_DIR,
                                    TEMP_DATA_DIR, device,
                                    args).train(generator)
        except KeyboardInterrupt:
            print('Stopped Discriminator Pretraining Early.')
    ##############################################################################

    ##############################################################################
    # Adversarial Training
    ##############################################################################
    print('#' * 80)
    print('Adversarial Training')
    print('#' * 80)
    AdversarialRLTrainer(generator, discriminator, rl_dataset, TEMP_DATA_DIR,
                         pt_valid_loader, device, args).train()
Ejemplo n.º 27
0
utils.print_network(D)
print('-----------------------------------------------')

if opt.load_pretrained:
    model_name = os.path.join(opt.save_folder + opt.pretrained_sr)
    if os.path.exists(model_name):
        #model= torch.load(model_name, map_location=lambda storage, loc: storage)
        model.load_state_dict(
            torch.load(model_name, map_location=lambda storage, loc: storage))
        print('Pre-trained SR model is loaded.')

if opt.load_pretrained_D:
    D_name = os.path.join(opt.save_folder + opt.pretrained_D)
    if os.path.exists(D_name):
        #model= torch.load(model_name, map_location=lambda storage, loc: storage)
        D.load_state_dict(
            torch.load(D_name, map_location=lambda storage, loc: storage))
        print('Pre-trained Discriminator model is loaded.')

if cuda:
    model = model.cuda(gpus_list[0])
    D = D.cuda(gpus_list[0])
    feature_extractor = feature_extractor.cuda(gpus_list[0])
    MSE_loss = MSE_loss.cuda(gpus_list[0])
    BCE_loss = BCE_loss.cuda(gpus_list[0])

optimizer = optim.Adam(model.parameters(),
                       lr=opt.lr,
                       betas=(0.9, 0.999),
                       eps=1e-8)
D_optimizer = optim.Adam(D.parameters(),
                         lr=opt.lr,
Ejemplo n.º 28
0
g_net = Generator().cuda()
g_opt = optim.RMSprop(g_net.parameters(),
                      args.learning_rate_d,
                      weight_decay=args.rmsprop_decay)
g_losses = np.empty(0)

print("Initializing discriminator model and optimizer.")
d_net = Discriminator().cuda()
d_opt = optim.RMSprop(d_net.parameters(),
                      args.learning_rate_d,
                      weight_decay=args.rmsprop_decay)
d_losses = np.empty(0)

if args.retrain:
    g_net.load_state_dict(torch.load('../data/generator_state'))
    d_net.load_state_dict(torch.load('../data/discriminator_state'))

print("Beginning training..")
loader = ETL(args.batch_size, args.image_size, args.path)

for iteration in range(args.iterations):

    # Train discriminator
    for _ in range(args.k_discriminator):
        d_opt.zero_grad()

        d_examples, d_targets = loader.next_batch()
        d_noise = torch.Tensor(args.batch_size, 1, args.image_size,
                               args.image_size).uniform_(-1., 1.)
        d_noise = Variable(d_noise).cuda()
        d_samples = g_net(d_noise, d_examples).detach()
Ejemplo n.º 29
0
class GAIL:
    def __init__(self,
                 exp_dir,
                 exp_thresh,
                 state_dim,
                 action_dim,
                 learn_rate,
                 betas,
                 _device,
                 _gamma,
                 load_weights=False):
        """
            exp_dir : directory containing the expert episodes
         exp_thresh : parameter to control number of episodes to load 
                      as expert based on returns (lower means more episodes)
          state_dim : dimesnion of state 
         action_dim : dimesnion of action
         learn_rate : learning rate for optimizer 
            _device : GPU or cpu
            _gamma  : discount factor
     _load_weights  : load weights from directory
        """

        # storing runtime device
        self.device = _device

        # discount factor
        self.gamma = _gamma

        # Expert trajectory
        self.expert = ExpertTrajectories(exp_dir, exp_thresh, gamma=self.gamma)

        # Defining the actor and its optimizer
        self.actor = ActorNetwork(state_dim).to(self.device)
        self.optim_actor = torch.optim.Adam(self.actor.parameters(),
                                            lr=learn_rate,
                                            betas=betas)

        # Defining the discriminator and its optimizer
        self.disc = Discriminator(state_dim, action_dim).to(self.device)
        self.optim_disc = torch.optim.Adam(self.disc.parameters(),
                                           lr=learn_rate,
                                           betas=betas)

        if not load_weights:
            self.actor.apply(init_weights)
            self.disc.apply(init_weights)
        else:
            self.load()

        # Loss function crtiterion
        self.criterion = torch.nn.BCELoss()

    def get_action(self, state):
        """
            obtain action for a given state using actor network 
        """
        state = torch.tensor(state, dtype=torch.float,
                             device=self.device).view(1, -1)
        return self.actor(state).cpu().data.numpy().flatten()

    def update(self, n_iter, batch_size=100):
        """
            train discriminator and actor for mini-batch
        """
        # memory to store
        disc_losses = np.zeros(n_iter, dtype=np.float)
        act_losses = np.zeros(n_iter, dtype=np.float)

        for i in range(n_iter):

            # Get expert state and actions batch
            exp_states, exp_actions = self.expert.sample(batch_size)
            exp_states = torch.FloatTensor(exp_states).to(self.device)
            exp_actions = torch.FloatTensor(exp_actions).to(self.device)

            # Get state, and actions using actor
            states, _ = self.expert.sample(batch_size)
            states = torch.FloatTensor(states).to(self.device)
            actions = self.actor(states)
            '''
                train the discriminator
            '''
            self.optim_disc.zero_grad()

            # label tensors
            exp_labels = torch.full((batch_size, 1), 1, device=self.device)
            policy_labels = torch.full((batch_size, 1), 0, device=self.device)

            # with expert transitions
            prob_exp = self.disc(exp_states, exp_actions)
            exp_loss = self.criterion(prob_exp, exp_labels)

            # with policy actor transitions
            prob_policy = self.disc(states, actions.detach())
            policy_loss = self.criterion(prob_policy, policy_labels)

            # use backprop
            disc_loss = exp_loss + policy_loss
            disc_losses[i] = disc_loss.mean().item()

            disc_loss.backward()
            self.optim_disc.step()
            '''
                train the actor
            '''
            self.optim_actor.zero_grad()
            loss_actor = -self.disc(states, actions)
            act_losses[i] = loss_actor.mean().detach().item()

            loss_actor.mean().backward()
            self.optim_actor.step()

        print("Finished training minibatch")

        return act_losses, disc_losses

    def save(
            self,
            directory='/home/aman/Programming/RL-Project/Deterministic-GAIL/weights',
            name='GAIL'):
        torch.save(self.actor.state_dict(),
                   '{}/{}_actor.pth'.format(directory, name))
        torch.save(self.disc.state_dict(),
                   '{}/{}_discriminator.pth'.format(directory, name))

    def load(
            self,
            directory='/home/aman/Programming/RL-Project/Deterministic-GAIL/weights',
            name='GAIL'):
        print(os.getcwd())
        self.actor.load_state_dict(
            torch.load('{}/{}_actor.pth'.format(directory, name)))
        self.disc.load_state_dict(
            torch.load('{}/{}_discriminator.pth'.format(directory, name)))

    def set_mode(self, mode="train"):

        if mode == "train":
            self.actor.train()
            self.disc.train()
        else:
            self.actor.eval()
            self.disc.eval()
Ejemplo n.º 30
0
class trainer(object):
    def __init__(self, cfg):
        self.cfg = cfg
        self.OldLabel_generator = U_Net(in_ch=cfg.DATASET.N_CLASS,
                                        out_ch=cfg.DATASET.N_CLASS,
                                        side='out')
        self.Image_generator = U_Net(in_ch=3,
                                     out_ch=cfg.DATASET.N_CLASS,
                                     side='in')
        self.discriminator = Discriminator(cfg.DATASET.N_CLASS + 3,
                                           cfg.DATASET.IMGSIZE,
                                           patch=True)

        self.criterion_G = GeneratorLoss(cfg.LOSS.LOSS_WEIGHT[0],
                                         cfg.LOSS.LOSS_WEIGHT[1],
                                         cfg.LOSS.LOSS_WEIGHT[2],
                                         ignore_index=cfg.LOSS.IGNORE_INDEX)
        self.criterion_D = DiscriminatorLoss()

        train_dataset = BaseDataset(cfg, split='train')
        valid_dataset = BaseDataset(cfg, split='val')
        self.train_dataloader = data.DataLoader(
            train_dataset,
            batch_size=cfg.DATASET.BATCHSIZE,
            num_workers=8,
            shuffle=True,
            drop_last=True)
        self.valid_dataloader = data.DataLoader(
            valid_dataset,
            batch_size=cfg.DATASET.BATCHSIZE,
            num_workers=8,
            shuffle=True,
            drop_last=True)

        self.ckpt_outdir = os.path.join(cfg.TRAIN.OUTDIR, 'checkpoints')
        if not os.path.isdir(self.ckpt_outdir):
            os.mkdir(self.ckpt_outdir)
        self.val_outdir = os.path.join(cfg.TRAIN.OUTDIR, 'val')
        if not os.path.isdir(self.val_outdir):
            os.mkdir(self.val_outdir)
        self.start_epoch = cfg.TRAIN.RESUME
        self.n_epoch = cfg.TRAIN.N_EPOCH

        self.optimizer_G = torch.optim.Adam(
            [{
                'params': self.OldLabel_generator.parameters()
            }, {
                'params': self.Image_generator.parameters()
            }],
            lr=cfg.OPTIMIZER.G_LR,
            betas=(cfg.OPTIMIZER.BETA1, cfg.OPTIMIZER.BETA2),
            # betas=(cfg.OPTIMIZER.BETA1, cfg.OPTIMIZER.BETA2),
            weight_decay=cfg.OPTIMIZER.WEIGHT_DECAY)

        self.optimizer_D = torch.optim.Adam(
            [{
                'params': self.discriminator.parameters(),
                'initial_lr': cfg.OPTIMIZER.D_LR
            }],
            lr=cfg.OPTIMIZER.D_LR,
            betas=(cfg.OPTIMIZER.BETA1, cfg.OPTIMIZER.BETA2),
            # betas=(cfg.OPTIMIZER.BETA1, cfg.OPTIMIZER.BETA2),
            weight_decay=cfg.OPTIMIZER.WEIGHT_DECAY)

        iter_per_epoch = len(train_dataset) // cfg.DATASET.BATCHSIZE
        lambda_poly = lambda iters: pow(
            (1.0 - iters / (cfg.TRAIN.N_EPOCH * iter_per_epoch)), 0.9)
        self.scheduler_G = torch.optim.lr_scheduler.LambdaLR(
            self.optimizer_G,
            lr_lambda=lambda_poly,
        )
        # last_epoch=(self.start_epoch+1)*iter_per_epoch)
        self.scheduler_D = torch.optim.lr_scheduler.LambdaLR(
            self.optimizer_D,
            lr_lambda=lambda_poly,
        )
        # last_epoch=(self.start_epoch+1)*iter_per_epoch)

        self.logger = logger(cfg.TRAIN.OUTDIR, name='train')
        self.running_metrics = runningScore(n_classes=cfg.DATASET.N_CLASS)

        if self.start_epoch >= 0:
            self.OldLabel_generator.load_state_dict(
                torch.load(
                    os.path.join(cfg.TRAIN.OUTDIR, 'checkpoints',
                                 '{}epoch.pth'.format(
                                     self.start_epoch)))['model_G_N'])
            self.Image_generator.load_state_dict(
                torch.load(
                    os.path.join(cfg.TRAIN.OUTDIR, 'checkpoints',
                                 '{}epoch.pth'.format(
                                     self.start_epoch)))['model_G_I'])
            self.discriminator.load_state_dict(
                torch.load(
                    os.path.join(cfg.TRAIN.OUTDIR, 'checkpoints',
                                 '{}epoch.pth'.format(
                                     self.start_epoch)))['model_D'])
            self.optimizer_G.load_state_dict(
                torch.load(
                    os.path.join(cfg.TRAIN.OUTDIR, 'checkpoints',
                                 '{}epoch.pth'.format(
                                     self.start_epoch)))['optimizer_G'])
            self.optimizer_D.load_state_dict(
                torch.load(
                    os.path.join(cfg.TRAIN.OUTDIR, 'checkpoints',
                                 '{}epoch.pth'.format(
                                     self.start_epoch)))['optimizer_D'])

            log = "Using the {}th checkpoint".format(self.start_epoch)
            self.logger.info(log)
        self.Image_generator = self.Image_generator.cuda()
        self.OldLabel_generator = self.OldLabel_generator.cuda()
        self.discriminator = self.discriminator.cuda()
        self.criterion_G = self.criterion_G.cuda()
        self.criterion_D = self.criterion_D.cuda()

    def train(self):
        all_train_iter_total_loss = []
        all_train_iter_corr_loss = []
        all_train_iter_recover_loss = []
        all_train_iter_change_loss = []
        all_train_iter_gan_loss_gen = []
        all_train_iter_gan_loss_dis = []
        all_val_epo_iou = []
        all_val_epo_acc = []
        iter_num = [0]
        epoch_num = []
        num_batches = len(self.train_dataloader)

        for epoch_i in range(self.start_epoch + 1, self.n_epoch):
            iter_total_loss = AverageTracker()
            iter_corr_loss = AverageTracker()
            iter_recover_loss = AverageTracker()
            iter_change_loss = AverageTracker()
            iter_gan_loss_gen = AverageTracker()
            iter_gan_loss_dis = AverageTracker()
            batch_time = AverageTracker()
            tic = time.time()

            # train
            self.OldLabel_generator.train()
            self.Image_generator.train()
            self.discriminator.train()
            for i, meta in enumerate(self.train_dataloader):

                image, old_label, new_label = meta[0].cuda(), meta[1].cuda(
                ), meta[2].cuda()
                recover_pred, feats = self.OldLabel_generator(
                    label2onehot(old_label, self.cfg.DATASET.N_CLASS))
                corr_pred = self.Image_generator(image, feats)

                # -------------------
                # Train Discriminator
                # -------------------
                self.discriminator.set_requires_grad(True)
                self.optimizer_D.zero_grad()

                fake_sample = torch.cat((image, corr_pred), 1).detach()
                real_sample = torch.cat(
                    (image, label2onehot(new_label, cfg.DATASET.N_CLASS)), 1)

                score_fake_d = self.discriminator(fake_sample)
                score_real = self.discriminator(real_sample)

                gan_loss_dis = self.criterion_D(pred_score=score_fake_d,
                                                real_score=score_real)
                gan_loss_dis.backward()
                self.optimizer_D.step()
                self.scheduler_D.step()

                # ---------------
                # Train Generator
                # ---------------
                self.discriminator.set_requires_grad(False)
                self.optimizer_G.zero_grad()

                score_fake = self.discriminator(
                    torch.cat((image, corr_pred), 1))

                total_loss, corr_loss, recover_loss, change_loss, gan_loss_gen = self.criterion_G(
                    corr_pred, recover_pred, score_fake, old_label, new_label)

                total_loss.backward()
                self.optimizer_G.step()
                self.scheduler_G.step()

                iter_total_loss.update(total_loss.item())
                iter_corr_loss.update(corr_loss.item())
                iter_recover_loss.update(recover_loss.item())
                iter_change_loss.update(change_loss.item())
                iter_gan_loss_gen.update(gan_loss_gen.item())
                iter_gan_loss_dis.update(gan_loss_dis.item())
                batch_time.update(time.time() - tic)
                tic = time.time()

                log = '{}: Epoch: [{}][{}/{}], Time: {:.2f}, ' \
                      'Total Loss: {:.6f}, Corr Loss: {:.6f}, Recover Loss: {:.6f}, Change Loss: {:.6f}, GAN_G Loss: {:.6f}, GAN_D Loss: {:.6f}'.format(
                    datetime.now(), epoch_i, i, num_batches, batch_time.avg,
                    total_loss.item(), corr_loss.item(), recover_loss.item(), change_loss.item(), gan_loss_gen.item(), gan_loss_dis.item())
                print(log)

                if (i + 1) % 10 == 0:
                    all_train_iter_total_loss.append(iter_total_loss.avg)
                    all_train_iter_corr_loss.append(iter_corr_loss.avg)
                    all_train_iter_recover_loss.append(iter_recover_loss.avg)
                    all_train_iter_change_loss.append(iter_change_loss.avg)
                    all_train_iter_gan_loss_gen.append(iter_gan_loss_gen.avg)
                    all_train_iter_gan_loss_dis.append(iter_gan_loss_dis.avg)
                    iter_total_loss.reset()
                    iter_corr_loss.reset()
                    iter_recover_loss.reset()
                    iter_change_loss.reset()
                    iter_gan_loss_gen.reset()
                    iter_gan_loss_dis.reset()

                    vis.line(X=np.column_stack(
                        np.repeat(np.expand_dims(iter_num, 0), 6, axis=0)),
                             Y=np.column_stack((all_train_iter_total_loss,
                                                all_train_iter_corr_loss,
                                                all_train_iter_recover_loss,
                                                all_train_iter_change_loss,
                                                all_train_iter_gan_loss_gen,
                                                all_train_iter_gan_loss_dis)),
                             opts={
                                 'legend': [
                                     'total_loss', 'corr_loss', 'recover_loss',
                                     'change_loss', 'gan_loss_gen',
                                     'gan_loss_dis'
                                 ],
                                 'linecolor':
                                 np.array([[255, 0, 0], [0, 255, 0],
                                           [0, 0, 255], [255, 255, 0],
                                           [0, 255, 255], [255, 0, 255]]),
                                 'title':
                                 'Train loss of generator and discriminator'
                             },
                             win='Train loss of generator and discriminator')
                    iter_num.append(iter_num[-1] + 1)

            # eval
            self.OldLabel_generator.eval()
            self.Image_generator.eval()
            self.discriminator.eval()
            with torch.no_grad():
                for j, meta in enumerate(self.valid_dataloader):
                    image, old_label, new_label = meta[0].cuda(), meta[1].cuda(
                    ), meta[2].cuda()
                    recover_pred, feats = self.OldLabel_generator(
                        label2onehot(old_label, self.cfg.DATASET.N_CLASS))
                    corr_pred = self.Image_generator(image, feats)
                    preds = np.argmax(corr_pred.cpu().detach().numpy().copy(),
                                      axis=1)
                    target = new_label.cpu().detach().numpy().copy()
                    self.running_metrics.update(target, preds)

                    if j == 0:
                        color_map1 = gen_color_map(preds[0, :]).astype(
                            np.uint8)
                        color_map2 = gen_color_map(preds[1, :]).astype(
                            np.uint8)
                        color_map = cv2.hconcat([color_map1, color_map2])
                        cv2.imwrite(
                            os.path.join(
                                self.val_outdir, '{}epoch*{}*{}.png'.format(
                                    epoch_i, meta[3][0], meta[3][1])),
                            color_map)

            score = self.running_metrics.get_scores()
            oa = score['Overall Acc: \t']
            precision = score['Precision: \t'][1]
            recall = score['Recall: \t'][1]
            iou = score['Class IoU: \t'][1]
            miou = score['Mean IoU: \t']
            self.running_metrics.reset()

            epoch_num.append(epoch_i)
            all_val_epo_acc.append(oa)
            all_val_epo_iou.append(miou)
            vis.line(X=np.column_stack(
                np.repeat(np.expand_dims(epoch_num, 0), 2, axis=0)),
                     Y=np.column_stack((all_val_epo_acc, all_val_epo_iou)),
                     opts={
                         'legend':
                         ['val epoch Overall Acc', 'val epoch Mean IoU'],
                         'linecolor': np.array([[255, 0, 0], [0, 255, 0]]),
                         'title': 'Validate Accuracy and IoU'
                     },
                     win='validate Accuracy and IoU')

            log = '{}: Epoch Val: [{}], ACC: {:.2f}, Recall: {:.2f}, mIoU: {:.4f}' \
                .format(datetime.now(), epoch_i, oa, recall, miou)
            self.logger.info(log)

            state = {
                'epoch': epoch_i,
                "acc": oa,
                "recall": recall,
                "iou": miou,
                'model_G_N': self.OldLabel_generator.state_dict(),
                'model_G_I': self.Image_generator.state_dict(),
                'model_D': self.discriminator.state_dict(),
                'optimizer_G': self.optimizer_G.state_dict(),
                'optimizer_D': self.optimizer_D.state_dict()
            }
            save_path = os.path.join(self.cfg.TRAIN.OUTDIR, 'checkpoints',
                                     '{}epoch.pth'.format(epoch_i))
            torch.save(state, save_path)