Ejemplo n.º 1
0
def run(args):
    # Get lasagne weights
    lasagne_weights_path = args.weights

    print('Loading lasagne weights')
    with open(lasagne_weights_path, "rb") as f:
        G, D, Gs = cPickle.load(f)

    # for l in lasagne.layers.get_all_layers(D.find_layer('Dscoresws')):
    #     print('%-12s %-30s %d' % (l.name, `lasagne.layers.get_output_shape(l)`, lasagne.layers.count_params(l)))

    if args.generator:
        print('Converting Generator model')
        generator = Generator()
        convert_generator(Gs, generator)

        _, model_name = os.path.split(args.weights)
        model_name = model_name.replace('.pkl', '.pth')
        output_path = os.path.join(args.generator, model_name)
        print('Saving model to {}'.format(output_path))
        torch.save(generator.state_dict(), output_path)

    if args.discriminator:
        print('Converting Discriminator model')
        discriminator = Discriminator()
        convert_discriminator(D, discriminator)

        _, model_name = os.path.split(args.weights)
        model_name = model_name.replace('.pkl', '.pth')
        output_path = os.path.join(args.discriminator, model_name)
        print('Saving model to {}'.format(output_path))
        torch.save(discriminator.state_dict(), output_path)
Ejemplo n.º 2
0
class gan(nn.Module):
    # def __init__(self, params, save_dir, g_weight_dir, d_weight_dir, d_update_freq=1, start_epoch=0, g_lr=2e-4, d_lr=2e-4, use_cuda=True):
    def __init__(self, params, args):
        super(gan, self).__init__()
        self.G = MModel(params, use_cuda=True)
        self.D = Discriminator(params, bias=True)
        self.vgg_loss = VGGPerceptualLoss()
        self.L1_loss = nn.L1Loss()
        if args.use_cuda:
            self.G = self.G.cuda()
            self.D = self.D.cuda()
            self.vgg_loss = self.vgg_loss.cuda()
            self.L1_loss = self.L1_loss.cuda()
        if args.g_weight_dir:
            self.G.load_state_dict(torch.load(args.g_weight_dir), strict=True)
        if args.d_weight_dir:
            self.D.load_state_dict(torch.load(args.d_weight_dir), strict=False)

        self.optimizer_G = torch.optim.Adam(self.G.parameters(), lr=args.g_lr)
        self.optimizer_D = torch.optim.Adam(self.D.parameters(), lr=args.d_lr)

        self.save_dir = args.save_dir
        if not os.path.exists(args.save_dir):
            os.makedirs(args.save_dir)

        self.d_update_freq = args.d_update_freq
        self.save_freq = args.save_freq
        self.writer = SummaryWriter('runs/' + args.save_dir)
        self.use_cuda = args.use_cuda
        self.start_epoch = args.start_epoch

    def G_loss(self, input, target):
        vgg = self.vgg_loss(input, target)
        L1 = self.L1_loss(input, target)
        # return vgg
        return vgg + L1

    def update_D(self, loss, epoch):
        if epoch % self.d_update_freq == 0:
            loss.backward()
            self.optimizer_D.step()

    def get_patch_weight(self, pose, size=62):
        heads = pose[:, 0, :, :]
        heads = heads.unsqueeze(1)
        heads = torch.nn.functional.interpolate(heads, size=size)
        heads = heads * 5 + torch.ones_like(heads)
        return heads

    def gan_loss(self, out, label, pose):
        # weight = self.get_patch_weight(pose)
        # return nn.BCELoss(weight=weight)(out, torch.ones_like(out) if label==1 else torch.zeros_like(out))
        return nn.BCELoss()(
            out, torch.ones_like(out) if label == 1 else torch.zeros_like(out))

    def train(self, dl, epoch):  # i -- current epoch
        cnt = 0
        loss_D_real_sum, loss_D_fake_sum, loss_D_sum, loss_G_gan_sum, loss_G_img_sum, loss_G_sum = 0, 0, 0, 0, 0, 0
        for iter, (src_img, y, src_pose, tgt_pose, src_mask_prior, x_trans,
                   src_mask_gt, tgt_face, tgt_face_box,
                   src_face_box) in enumerate(dl):
            print('epoch:', epoch, 'iter:', iter)
            self.optimizer_D.zero_grad()
            if self.use_cuda:
                src_img, y, src_pose, tgt_pose, src_mask_prior, x_trans = src_img.cuda(
                ), y.cuda(), src_pose.cuda(), tgt_pose.cuda(
                ), src_mask_prior.cuda(), x_trans.cuda()

            out = self.G(src_img, src_pose, tgt_pose, src_mask_prior, x_trans)
            gen = out[0]
            loss_D_real = self.gan_loss(self.D(y, tgt_pose), 1, tgt_pose)
            loss_D_fake = self.gan_loss(self.D(gen.detach(), tgt_pose), 0,
                                        tgt_pose)
            loss_D = loss_D_real + loss_D_fake
            self.update_D(loss_D, epoch)

            if False and epoch < 10:
                loss_G_gan = torch.zeros((1))
                loss_G_img = torch.zeros((1))
                loss_G = loss_G_gan + loss_G_img
            else:
                self.optimizer_G.zero_grad()
                loss_G_gan = self.gan_loss(self.D(gen, tgt_pose), 1, tgt_pose)
                loss_G_img = self.G_loss(gen, y)  # vgg_loss + L1_loss
                loss_G = loss_G_gan + loss_G_img
                loss_G.backward()
                self.optimizer_G.step()

            loss_D_real_sum += loss_D_real.item()
            loss_D_fake_sum += loss_D_fake.item()
            loss_D_sum += loss_D.item()
            loss_G_gan_sum += loss_G_gan.item()
            loss_G_img_sum += loss_G_img.item()
            loss_G_sum += loss_G.item()
            cnt += 1

            # if epoch % self.save_freq == 0 and iter < 3:
            #     self.writer.add_images('gen/epoch%d'%epoch, gen*0.5+0.5)
            #     self.writer.add_images('y/epoch%d'%epoch, y*0.5+0.5)
            #     self.writer.add_images('src_mask/epoch%d'%epoch, out[2].view((out[2].size(0)*out[2].size(1), 1, out[2].size(2), out[2].size(3))))
            #     self.writer.add_images('warped/epoch%d'%epoch, out[3].view((out[3].size(0)*11, 3, out[3].size(2), out[3].size(3)))*0.5+0.5)

        self.writer.add_scalar('loss_D_real', loss_D_real_sum / cnt, epoch)
        self.writer.add_scalar('loss_D_fake', loss_D_fake_sum / cnt, epoch)
        self.writer.add_scalar('loss_D', loss_D_sum / cnt, epoch)
        self.writer.add_scalar('loss_G_gan', loss_G_gan_sum / cnt, epoch)
        self.writer.add_scalar('loss_G_img', loss_G_img_sum / cnt, epoch)
        self.writer.add_scalar('loss_G', loss_G_sum / cnt, epoch)
        self.writer.add_scalars('DG', {
            'D': loss_D / cnt,
            'G': loss_G / cnt
        }, epoch)
        if epoch % self.save_freq == 0:
            torch.save(self.G.state_dict(),
                       os.path.join(self.save_dir, 'g_epoch_%d.pth' % epoch))
            torch.save(self.D.state_dict(),
                       os.path.join(self.save_dir, 'd_epoch_%d.pth' % epoch))

    def test(self, test_dl, epoch):
        self.G.eval()
        for iter, (src_img, y, src_pose, tgt_pose, src_mask_prior, x_trans,
                   src_mask_gt, tgt_face, tgt_face_box,
                   src_face_box) in enumerate(test_dl):
            print('test', 'epoch:', epoch, 'iter:', iter)
            if self.use_cuda:
                src_img, y, src_pose, tgt_pose, src_mask_prior, x_trans = src_img.cuda(
                ), y.cuda(), src_pose.cuda(), tgt_pose.cuda(
                ), src_mask_prior.cuda(), x_trans.cuda()
            with torch.no_grad():
                out = self.G(src_img, src_pose, tgt_pose, src_mask_prior,
                             x_trans)
            gen = out[0]
            if iter == 0:
                self.writer.add_images('test_gen/epoch%d' % epoch,
                                       gen * 0.5 + 0.5)
                self.writer.add_images('test_y/epoch%d' % epoch, y * 0.5 + 0.5)
                self.writer.add_images('test_src/epoch%d' % epoch,
                                       src_img * 0.5 + 0.5)
                self.writer.add_images(
                    'test_src_mask/epoch%d' % epoch, out[2].view(
                        (out[2].size(0) * out[2].size(1), 1, out[2].size(2),
                         out[2].size(3))))
class AdvGAN_Pretrain:
    def __init__(self,
                 device,
                 model,
                 model_num_labels,
                 box_min,
                 box_max):
        self.device = device
        self.model_num_labels = model_num_labels
        self.model = model
        self.box_min = box_min
        self.box_max = box_max

        self.netG = Generator().to(device)
        self.netDisc = Discriminator().to(device)

        # initialize all weights
        self.netG.apply(weights_init)
        self.netDisc.apply(weights_init)

        # initialize optimizers
        self.optimizer_G = torch.optim.Adam(self.netG.parameters(),
                                            lr=1e-3)
        self.optimizer_D = torch.optim.Adam(self.netDisc.parameters(),
                                            lr=1e-3)

        if not os.path.exists(models_path):
            os.makedirs(models_path)

    def train_batch(self, x, labels):
        # optimize D
        for i in range(1):
            perturbation = self.netG(x)

            # add a clipping trick
            adv_images = torch.clamp(perturbation, -0.3, 0.3) + x
            adv_images = torch.clamp(adv_images, self.box_min, self.box_max)

            self.optimizer_D.zero_grad()
            pred_real = self.netDisc(x)
            loss_D_real = F.mse_loss(pred_real, torch.ones_like(pred_real, device=self.device))
            loss_D_real.backward()

            pred_fake = self.netDisc(adv_images.detach())
            loss_D_fake = F.mse_loss(pred_fake, torch.zeros_like(pred_fake, device=self.device))
            loss_D_fake.backward()
            loss_D_GAN = loss_D_fake + loss_D_real
            self.optimizer_D.step()

        # optimize G
        for i in range(1):
            self.optimizer_G.zero_grad()

            # cal G's loss in GAN
            pred_fake = self.netDisc(adv_images)
            loss_G_fake = F.mse_loss(pred_fake, torch.ones_like(pred_fake, device=self.device))
            loss_G_fake.backward(retain_graph=True)

            # calculate perturbation norm
            loss_perturb = torch.mean(torch.norm(perturbation.view(perturbation.shape[0], -1), 2, dim=1))
            # loss_perturb = torch.max(loss_perturb - C, torch.zeros(1, device=self.device))

            # cal adv loss
            logits_model = self.model(adv_images)
            probs_model = F.softmax(logits_model, dim=1)
            onehot_labels = torch.eye(self.model_num_labels, device=self.device)[labels]

            # C&W loss function
            real = torch.sum(onehot_labels * probs_model, dim=1)
            other, _ = torch.max((1 - onehot_labels) * probs_model - onehot_labels * 10000, dim=1)
            zeros = torch.zeros_like(other)
            loss_adv_arr = torch.max(real - other, zeros)
            print(loss_adv_arr)
            print(loss_adv_arr.shape)
            loss_adv = torch.sum(loss_adv)

            # maximize cross_entropy loss
            # loss_adv = -F.mse_loss(logits_model, onehot_labels)
            # loss_adv = - F.cross_entropy(logits_model, labels)

            adv_lambda = 10
            pert_lambda = 1
            loss_G = adv_lambda * loss_adv + pert_lambda * loss_perturb
            loss_G.backward()
            self.optimizer_G.step()

        return loss_D_GAN.item(), loss_G_fake.item(), loss_perturb.item(), loss_adv.item()

    def train(self, train_dataloader, epochs):
        writer = SummaryWriter(log_dir="visualization/orig_advgan/", comment='Original AdvGAN stats')
        
        for epoch in range(1, epochs+1):

            if epoch == 50:
                self.optimizer_G = torch.optim.Adam(self.netG.parameters(),
                                                    lr=1e-4)
                self.optimizer_D = torch.optim.Adam(self.netDisc.parameters(),
                                                    lr=1e-4)
                
            if epoch == 80:
                self.optimizer_G = torch.optim.Adam(self.netG.parameters(),
                                                    lr=1e-5)
                self.optimizer_D = torch.optim.Adam(self.netDisc.parameters(),
                                                    lr=1e-5)
                
            loss_D_sum = 0
            loss_G_fake_sum = 0
            loss_perturb_sum = 0
            loss_adv_sum = 0
            for i, data in enumerate(train_dataloader, start=0):
                images, labels = data
                images, labels = images.to(self.device), labels.to(self.device)

                loss_D_batch, loss_G_fake_batch, loss_perturb_batch, loss_adv_batch = \
                    self.train_batch(images, labels)
                loss_D_sum += loss_D_batch
                loss_G_fake_sum += loss_G_fake_batch
                loss_perturb_sum += loss_perturb_batch
                loss_adv_sum += loss_adv_batch

            # print statistics
            num_batch = len(train_dataloader)
            writer.add_scalar('discriminator_loss', loss_D_sum/num_batch, epoch)
            writer.add_scalar('generator_loss', loss_G_fake_sum/num_batch, epoch)
            writer.add_scalar('perturbation_loss', loss_perturb_sum/num_batch, epoch)
            writer.add_scalar('adversarial_loss', loss_adv_sum/num_batch, epoch)
            print("epoch %d:\nloss_D: %.5f, loss_G_fake: %.5f,\
             \nloss_perturb: %.5f, loss_adv: %.5f\n" %
                  (epoch, loss_D_sum/num_batch, loss_G_fake_sum/num_batch,
                   loss_perturb_sum/num_batch, loss_adv_sum/num_batch))

            # save generator
            if epoch%20==0:
                netG_file_name = models_path + 'netG_original_epoch_' + str(epoch) + '.pth'
                torch.save(self.netG.state_dict(), netG_file_name)
                netDisc_file_name = models_path + 'netDisc_original_epoch_' + str(epoch) + '.pth'
                torch.save(self.netDisc.state_dict(), netDisc_file_name)
        
        writer.close()
Ejemplo n.º 4
0
     accumulate, generator.generate_7, inference_generator.generate_7,
     discriminator.discriminate_7, generator_optimizer,
     discriminator_optimizer, dataroot, device)
# torch.save(generator.state_dict(), f'./split_generator')
# torch.save(discriminator.state_dict(), f'./split_discriminator')
# torch.save(inference_generator.state_dict(), f'./split_inference_generator')

# generator.load_state_dict(torch.load('./generator'))
# discriminator.load_state_dict(torch.load('./discriminator'))

# interpolation_step(
#     48, 400, test_noise,
#     generator, inference_generator, discriminator, accumulate,
#     generator.generate_8, inference_generator.generate_8,
#     discriminator.discriminate_8,
#     generator_optimizer, discriminator_optimizer,
#     dataroot, device
# )
# step(
#     48, 400, test_noise,
#     generator, inference_generator, discriminator, accumulate,
#     generator.generate_9, inference_generator.generate_9,
#     discriminator.discriminate_9,
#     generator_optimizer, discriminator_optimizer,
#     dataroot, device
# )
torch.save(generator.state_dict(), f'./split_90_generator_final')
torch.save(discriminator.state_dict(), f'./split_90_discriminator_final')
torch.save(inference_generator.state_dict(),
           f'./split_90_inference_generator_final')
Ejemplo n.º 5
0
def main(args):

    #transformer
    transform = transforms.Compose([
        transforms.Resize(64),
        transforms.ToTensor(),
        transforms.Normalize(mean=[.5, .5, .5], std=[.5, .5, .5])
    ])

    #dateset
    anime = AnimeData(args.tags, args.imgs, transform=transform)
    dataloder = DataLoader(anime, batch_size=args.bs, shuffle=True)

    #model
    gen = Generator(args.noise, momentum=args.momentum)
    dis = Discriminator(momentum=args.momentum)

    #criterion
    criterion = nn.BCELoss()

    if torch.cuda.is_available():
        gen = gen.cuda()
        dis = dis.cuda()
        criterion = criterion.cuda()

    #optimizer
    optimizer_gen = optim.Adam(gen.parameters(),
                               lr=args.lr_g,
                               betas=args.betas)
    optimizer_dis = optim.Adam(dis.parameters(),
                               lr=args.lr_d,
                               betas=args.betas)

    loss_history_d = []
    loss_history_g = []
    out_history_true_d = []
    out_history_fake_d = []
    out_history_fake_g = []

    for epoch in range(args.epochs):
        print('----------------start epoch %d ---------------' % epoch)
        step = 0
        for data in dataloder:
            step += 1
            start = time.time()
            img = Variable(data)
            noise = Variable(torch.randn(img.shape[0], args.noise))
            labels_true = Variable(torch.ones(img.shape[0], 1))
            labels_fake = Variable(torch.zeros(img.shape[0], 1))
            if args.label_smoothing:
                labels_true = labels_true - torch.rand(img.shape[0], 1) * 0.1
                labels_fake = labels_fake + torch.rand(img.shape[0], 1) * 0.1

            #train on GPU
            if torch.cuda.is_available():
                img = img.cuda()
                noise = noise.cuda()
                labels_true = labels_true.cuda()
                labels_fake = labels_fake.cuda()

            #train D
            out_true_d = dis(img)
            out_fake_d = dis(gen(noise))
            out_history_true_d.append(torch.mean(out_true_d).item())
            out_history_fake_d.append(torch.mean(out_fake_d).item())
            #d_loss_ture = -torch.mean(labels_true * torch.log(out_true_d) + (1. - labels_true) * torch.log(1. - out_true_d))
            loss_true_d = criterion(out_true_d, labels_true)
            loss_fake_d = criterion(out_fake_d, labels_fake)
            loss_d = loss_true_d + loss_fake_d
            optimizer_dis.zero_grad()
            loss_d.backward()

            if args.check:
                print('>>>>>>>>>>check_d_grad<<<<<<<<<<')
                try:
                    check_grad(dis, 'conv2.weight')
                except ValueError as e:
                    print(e)
                    show(loss_history_d, loss_history_g, out_history_true_d,
                         out_history_fake_d, out_history_fake_g)
                    torch.save(dis.state_dict(),
                               os.path.join(os.getcwd(), args.d, 'bad.pth'))
                    torch.save(gen.state_dict(),
                               os.path.join(os.getcwd(), args.g, 'bad.pth'))
                    return
            loss_history_d.append(loss_d.item())
            optimizer_dis.step()

            #train G
            noise = Variable(torch.randn(img.shape[0], args.noise))
            if torch.cuda.is_available():
                noise = noise.cuda()
            out_fake_g = dis(gen(noise))
            labels_fake = 1. - labels_fake
            out_history_fake_g.append(torch.mean(out_fake_g).item())
            loss_g = criterion(out_fake_g, labels_fake)
            optimizer_gen.zero_grad()
            loss_g.backward()

            if args.check:
                print('>>>>>>>>>>check_g_grad<<<<<<<<<<')
                try:
                    check_grad(gen, 'convTrans.weight')
                except ValueError as e:
                    print(e)
                    show(loss_history_d, loss_history_g, out_history_true_d,
                         out_history_fake_d, out_history_fake_g)
                    torch.save(dis.state_dict(),
                               os.path.join(os.getcwd(), args.d, 'bad.pth'))
                    torch.save(gen.state_dict(),
                               os.path.join(os.getcwd(), args.g, 'bad.pth'))
                    return
            loss_history_g.append(loss_g.item())
            optimizer_gen.step()
            end = time.time()
            print(
                'epoch: %d  step: %d  d_true: %.2f  d_fake: %.2f  g_fake: %.2f time: %.2f'
                % (epoch, step, out_history_true_d[-1], out_history_fake_d[-1],
                   out_history_fake_g[-1], end - start))

        #save model
        torch.save(dis.state_dict(),
                   os.path.join(os.getcwd(), args.d, '{}.pth'.format(epoch)))
        torch.save(gen.state_dict(),
                   os.path.join(os.getcwd(), args.g, '{}.pth'.format(epoch)))
Ejemplo n.º 6
0
def main(args):
    use_cuda = (len(args.gpuid) >= 1)
    print("{0} GPU(s) are available".format(cuda.device_count()))

    # Load dataset
    splits = ['train', 'valid']
    if data.has_binary_files(args.data, splits):
        dataset = data.load_dataset(args.data, splits, args.src_lang,
                                    args.trg_lang, args.fixed_max_len)
    else:
        dataset = data.load_raw_text_dataset(args.data, splits, args.src_lang,
                                             args.trg_lang, args.fixed_max_len)
    if args.src_lang is None or args.trg_lang is None:
        # record inferred languages in args, so that it's saved in checkpoints
        args.src_lang, args.trg_lang = dataset.src, dataset.dst

    print('| [{}] dictionary: {} types'.format(dataset.src,
                                               len(dataset.src_dict)))
    print('| [{}] dictionary: {} types'.format(dataset.dst,
                                               len(dataset.dst_dict)))

    for split in splits:
        print('| {} {} {} examples'.format(args.data, split,
                                           len(dataset.splits[split])))

    g_logging_meters = OrderedDict()
    g_logging_meters['train_loss'] = AverageMeter()
    g_logging_meters['valid_loss'] = AverageMeter()
    g_logging_meters['train_acc'] = AverageMeter()
    g_logging_meters['valid_acc'] = AverageMeter()
    g_logging_meters['bsz'] = AverageMeter()  # sentences per batch

    d_logging_meters = OrderedDict()
    d_logging_meters['train_loss'] = AverageMeter()
    d_logging_meters['valid_loss'] = AverageMeter()
    d_logging_meters['train_acc'] = AverageMeter()
    d_logging_meters['valid_acc'] = AverageMeter()
    d_logging_meters['bsz'] = AverageMeter()  # sentences per batch

    # Set model parameters
    args.encoder_embed_dim = 1000
    args.encoder_layers = 2  # 4
    args.encoder_dropout_out = 0
    args.decoder_embed_dim = 1000
    args.decoder_layers = 2  # 4
    args.decoder_out_embed_dim = 1000
    args.decoder_dropout_out = 0
    args.bidirectional = False

    generator = LSTMModel(args,
                          dataset.src_dict,
                          dataset.dst_dict,
                          use_cuda=use_cuda)
    print("Generator loaded successfully!")
    discriminator = Discriminator(args,
                                  dataset.src_dict,
                                  dataset.dst_dict,
                                  use_cuda=use_cuda)
    print("Discriminator loaded successfully!")

    g_model_path = 'checkpoints/zhenwarm/generator.pt'
    assert os.path.exists(g_model_path)
    # generator = LSTMModel(args, dataset.src_dict, dataset.dst_dict, use_cuda=use_cuda)
    model_dict = generator.state_dict()
    model = torch.load(g_model_path)
    pretrained_dict = model.state_dict()
    # 1. filter out unnecessary keys
    pretrained_dict = {
        k: v
        for k, v in pretrained_dict.items() if k in model_dict
    }
    # 2. overwrite entries in the existing state dict
    model_dict.update(pretrained_dict)
    # 3. load the new state dict
    generator.load_state_dict(model_dict)
    print("pre-trained Generator loaded successfully!")
    #
    # Load discriminator model
    d_model_path = 'checkpoints/zhenwarm/discri.pt'
    assert os.path.exists(d_model_path)
    # generator = LSTMModel(args, dataset.src_dict, dataset.dst_dict, use_cuda=use_cuda)
    d_model_dict = discriminator.state_dict()
    d_model = torch.load(d_model_path)
    d_pretrained_dict = d_model.state_dict()
    # 1. filter out unnecessary keys
    d_pretrained_dict = {
        k: v
        for k, v in d_pretrained_dict.items() if k in d_model_dict
    }
    # 2. overwrite entries in the existing state dict
    d_model_dict.update(d_pretrained_dict)
    # 3. load the new state dict
    discriminator.load_state_dict(d_model_dict)
    print("pre-trained Discriminator loaded successfully!")

    if use_cuda:
        if torch.cuda.device_count() > 1:
            discriminator = torch.nn.DataParallel(discriminator).cuda()
            generator = torch.nn.DataParallel(generator).cuda()
        else:
            generator.cuda()
            discriminator.cuda()
    else:
        discriminator.cpu()
        generator.cpu()

    # adversarial training checkpoints saving path
    if not os.path.exists('checkpoints/myzhencli5'):
        os.makedirs('checkpoints/myzhencli5')
    checkpoints_path = 'checkpoints/myzhencli5/'

    # define loss function
    g_criterion = torch.nn.NLLLoss(ignore_index=dataset.dst_dict.pad(),
                                   reduction='sum')
    d_criterion = torch.nn.BCELoss()
    pg_criterion = PGLoss(ignore_index=dataset.dst_dict.pad(),
                          size_average=True,
                          reduce=True)

    # fix discriminator word embedding (as Wu et al. do)
    for p in discriminator.embed_src_tokens.parameters():
        p.requires_grad = False
    for p in discriminator.embed_trg_tokens.parameters():
        p.requires_grad = False

    # define optimizer
    g_optimizer = eval("torch.optim." + args.g_optimizer)(filter(
        lambda x: x.requires_grad, generator.parameters()),
                                                          args.g_learning_rate)

    d_optimizer = eval("torch.optim." + args.d_optimizer)(
        filter(lambda x: x.requires_grad, discriminator.parameters()),
        args.d_learning_rate,
        momentum=args.momentum,
        nesterov=True)

    # start joint training
    best_dev_loss = math.inf
    num_update = 0
    # main training loop
    for epoch_i in range(1, args.epochs + 1):
        logging.info("At {0}-th epoch.".format(epoch_i))

        seed = args.seed + epoch_i
        torch.manual_seed(seed)

        max_positions_train = (args.fixed_max_len, args.fixed_max_len)

        # Initialize dataloader, starting at batch_offset
        trainloader = dataset.train_dataloader(
            'train',
            max_tokens=args.max_tokens,
            max_sentences=args.joint_batch_size,
            max_positions=max_positions_train,
            # seed=seed,
            epoch=epoch_i,
            sample_without_replacement=args.sample_without_replacement,
            sort_by_source_size=(epoch_i <= args.curriculum),
            shard_id=args.distributed_rank,
            num_shards=args.distributed_world_size,
        )

        # reset meters
        for key, val in g_logging_meters.items():
            if val is not None:
                val.reset()
        for key, val in d_logging_meters.items():
            if val is not None:
                val.reset()

        for i, sample in enumerate(trainloader):

            # set training mode
            generator.train()
            discriminator.train()
            update_learning_rate(num_update, 8e4, args.g_learning_rate,
                                 args.lr_shrink, g_optimizer)

            if use_cuda:
                # wrap input tensors in cuda tensors
                sample = utils.make_variable(sample, cuda=cuda)

            ## part I: use gradient policy method to train the generator

            # use policy gradient training when random.random() > 50%
            if random.random() >= 0.5:

                print("Policy Gradient Training")

                sys_out_batch = generator(sample)  # 64 X 50 X 6632

                out_batch = sys_out_batch.contiguous().view(
                    -1, sys_out_batch.size(-1))  # (64 * 50) X 6632

                _, prediction = out_batch.topk(1)
                prediction = prediction.squeeze(1)  # 64*50 = 3200
                prediction = torch.reshape(
                    prediction,
                    sample['net_input']['src_tokens'].shape)  # 64 X 50

                with torch.no_grad():
                    reward = discriminator(sample['net_input']['src_tokens'],
                                           prediction)  # 64 X 1

                train_trg_batch = sample['target']  # 64 x 50

                pg_loss = pg_criterion(sys_out_batch, train_trg_batch, reward,
                                       use_cuda)
                sample_size = sample['target'].size(
                    0) if args.sentence_avg else sample['ntokens']  # 64
                logging_loss = pg_loss / math.log(2)
                g_logging_meters['train_loss'].update(logging_loss.item(),
                                                      sample_size)
                logging.debug(
                    f"G policy gradient loss at batch {i}: {pg_loss.item():.3f}, lr={g_optimizer.param_groups[0]['lr']}"
                )
                g_optimizer.zero_grad()
                pg_loss.backward()
                torch.nn.utils.clip_grad_norm_(generator.parameters(),
                                               args.clip_norm)
                g_optimizer.step()

            else:
                # MLE training
                print("MLE Training")

                sys_out_batch = generator(sample)

                out_batch = sys_out_batch.contiguous().view(
                    -1, sys_out_batch.size(-1))  # (64 X 50) X 6632

                train_trg_batch = sample['target'].view(-1)  # 64*50 = 3200

                loss = g_criterion(out_batch, train_trg_batch)

                sample_size = sample['target'].size(
                    0) if args.sentence_avg else sample['ntokens']
                nsentences = sample['target'].size(0)
                logging_loss = loss.data / sample_size / math.log(2)
                g_logging_meters['bsz'].update(nsentences)
                g_logging_meters['train_loss'].update(logging_loss,
                                                      sample_size)
                logging.debug(
                    f"G MLE loss at batch {i}: {g_logging_meters['train_loss'].avg:.3f}, lr={g_optimizer.param_groups[0]['lr']}"
                )
                g_optimizer.zero_grad()
                loss.backward()
                # all-reduce grads and rescale by grad_denom
                for p in generator.parameters():
                    if p.requires_grad:
                        p.grad.data.div_(sample_size)
                torch.nn.utils.clip_grad_norm_(generator.parameters(),
                                               args.clip_norm)
                g_optimizer.step()

            num_update += 1

            # part II: train the discriminator
            if num_update % 5 == 0:
                bsz = sample['target'].size(0)  # batch_size = 64

                src_sentence = sample['net_input'][
                    'src_tokens']  # 64 x max-len i.e 64 X 50

                # now train with machine translation output i.e generator output
                true_sentence = sample['target'].view(-1)  # 64*50 = 3200

                true_labels = Variable(
                    torch.ones(
                        sample['target'].size(0)).float())  # 64 length vector

                with torch.no_grad():
                    sys_out_batch = generator(sample)  # 64 X 50 X 6632

                out_batch = sys_out_batch.contiguous().view(
                    -1, sys_out_batch.size(-1))  # (64 X 50) X 6632

                _, prediction = out_batch.topk(1)
                prediction = prediction.squeeze(1)  # 64 * 50 = 6632

                fake_labels = Variable(
                    torch.zeros(
                        sample['target'].size(0)).float())  # 64 length vector

                fake_sentence = torch.reshape(prediction,
                                              src_sentence.shape)  # 64 X 50
                true_sentence = torch.reshape(true_sentence,
                                              src_sentence.shape)
                if use_cuda:
                    fake_labels = fake_labels.cuda()
                    true_labels = true_labels.cuda()

                # fake_disc_out = discriminator(src_sentence, fake_sentence)  # 64 X 1
                # true_disc_out = discriminator(src_sentence, true_sentence)
                #
                # fake_d_loss = d_criterion(fake_disc_out.squeeze(1), fake_labels)
                # true_d_loss = d_criterion(true_disc_out.squeeze(1), true_labels)
                #
                # fake_acc = torch.sum(torch.round(fake_disc_out).squeeze(1) == fake_labels).float() / len(fake_labels)
                # true_acc = torch.sum(torch.round(true_disc_out).squeeze(1) == true_labels).float() / len(true_labels)
                # acc = (fake_acc + true_acc) / 2
                #
                # d_loss = fake_d_loss + true_d_loss
                if random.random() > 0.5:
                    fake_disc_out = discriminator(src_sentence, fake_sentence)
                    fake_d_loss = d_criterion(fake_disc_out.squeeze(1),
                                              fake_labels)
                    fake_acc = torch.sum(
                        torch.round(fake_disc_out).squeeze(1) ==
                        fake_labels).float() / len(fake_labels)
                    d_loss = fake_d_loss
                    acc = fake_acc
                else:
                    true_disc_out = discriminator(src_sentence, true_sentence)
                    true_d_loss = d_criterion(true_disc_out.squeeze(1),
                                              true_labels)
                    true_acc = torch.sum(
                        torch.round(true_disc_out).squeeze(1) ==
                        true_labels).float() / len(true_labels)
                    d_loss = true_d_loss
                    acc = true_acc

                d_logging_meters['train_acc'].update(acc)
                d_logging_meters['train_loss'].update(d_loss)
                logging.debug(
                    f"D training loss {d_logging_meters['train_loss'].avg:.3f}, acc {d_logging_meters['train_acc'].avg:.3f} at batch {i}"
                )
                d_optimizer.zero_grad()
                d_loss.backward()
                d_optimizer.step()

            if num_update % 10000 == 0:

                # validation
                # set validation mode
                generator.eval()
                discriminator.eval()
                # Initialize dataloader
                max_positions_valid = (args.fixed_max_len, args.fixed_max_len)
                valloader = dataset.eval_dataloader(
                    'valid',
                    max_tokens=args.max_tokens,
                    max_sentences=args.joint_batch_size,
                    max_positions=max_positions_valid,
                    skip_invalid_size_inputs_valid_test=True,
                    descending=
                    True,  # largest batch first to warm the caching allocator
                    shard_id=args.distributed_rank,
                    num_shards=args.distributed_world_size,
                )

                # reset meters
                for key, val in g_logging_meters.items():
                    if val is not None:
                        val.reset()
                for key, val in d_logging_meters.items():
                    if val is not None:
                        val.reset()

                for i, sample in enumerate(valloader):

                    with torch.no_grad():
                        if use_cuda:
                            # wrap input tensors in cuda tensors
                            sample = utils.make_variable(sample, cuda=cuda)

                        # generator validation
                        sys_out_batch = generator(sample)
                        out_batch = sys_out_batch.contiguous().view(
                            -1, sys_out_batch.size(-1))  # (64 X 50) X 6632
                        dev_trg_batch = sample['target'].view(
                            -1)  # 64*50 = 3200

                        loss = g_criterion(out_batch, dev_trg_batch)
                        sample_size = sample['target'].size(
                            0) if args.sentence_avg else sample['ntokens']
                        loss = loss / sample_size / math.log(2)
                        g_logging_meters['valid_loss'].update(
                            loss, sample_size)
                        logging.debug(
                            f"G dev loss at batch {i}: {g_logging_meters['valid_loss'].avg:.3f}"
                        )

                        # discriminator validation
                        bsz = sample['target'].size(0)
                        src_sentence = sample['net_input']['src_tokens']
                        # train with half human-translation and half machine translation

                        true_sentence = sample['target']
                        true_labels = Variable(
                            torch.ones(sample['target'].size(0)).float())

                        with torch.no_grad():
                            sys_out_batch = generator(sample)

                        out_batch = sys_out_batch.contiguous().view(
                            -1, sys_out_batch.size(-1))  # (64 X 50) X 6632

                        _, prediction = out_batch.topk(1)
                        prediction = prediction.squeeze(1)  # 64 * 50 = 6632

                        fake_labels = Variable(
                            torch.zeros(sample['target'].size(0)).float())

                        fake_sentence = torch.reshape(
                            prediction, src_sentence.shape)  # 64 X 50
                        true_sentence = torch.reshape(true_sentence,
                                                      src_sentence.shape)
                        if use_cuda:
                            fake_labels = fake_labels.cuda()
                            true_labels = true_labels.cuda()

                        fake_disc_out = discriminator(src_sentence,
                                                      fake_sentence)  # 64 X 1
                        true_disc_out = discriminator(src_sentence,
                                                      true_sentence)

                        fake_d_loss = d_criterion(fake_disc_out.squeeze(1),
                                                  fake_labels)
                        true_d_loss = d_criterion(true_disc_out.squeeze(1),
                                                  true_labels)
                        d_loss = fake_d_loss + true_d_loss
                        fake_acc = torch.sum(
                            torch.round(fake_disc_out).squeeze(1) ==
                            fake_labels).float() / len(fake_labels)
                        true_acc = torch.sum(
                            torch.round(true_disc_out).squeeze(1) ==
                            true_labels).float() / len(true_labels)
                        acc = (fake_acc + true_acc) / 2
                        d_logging_meters['valid_acc'].update(acc)
                        d_logging_meters['valid_loss'].update(d_loss)
                        logging.debug(
                            f"D dev loss {d_logging_meters['valid_loss'].avg:.3f}, acc {d_logging_meters['valid_acc'].avg:.3f} at batch {i}"
                        )

                # torch.save(discriminator,
                #            open(checkpoints_path + f"numupdate_{num_update/10000}k.discri_{d_logging_meters['valid_loss'].avg:.3f}.pt",'wb'), pickle_module=dill)

                # if d_logging_meters['valid_loss'].avg < best_dev_loss:
                #     best_dev_loss = d_logging_meters['valid_loss'].avg
                #     torch.save(discriminator, open(checkpoints_path + "best_dmodel.pt", 'wb'), pickle_module=dill)

                torch.save(
                    generator,
                    open(
                        checkpoints_path +
                        f"numupdate_{num_update/10000}k.joint_{g_logging_meters['valid_loss'].avg:.3f}.pt",
                        'wb'),
                    pickle_module=dill)
Ejemplo n.º 7
0
class GAN_CLS(object):
    def __init__(self, args, data_loader, SUPERVISED=True):
        """
		args : Arguments
		data_loader = An instance of class DataLoader for loading our dataset in batches
		"""

        self.data_loader = data_loader
        self.num_epochs = args.num_epochs
        self.batch_size = args.batch_size

        self.log_step = args.log_step
        self.sample_step = args.sample_step

        self.log_dir = args.log_dir
        self.checkpoint_dir = args.checkpoint_dir
        self.sample_dir = args.sample_dir
        self.final_model = args.final_model
        self.model_save_step = args.model_save_step

        #self.dataset = args.dataset
        #self.model_name = args.model_name

        self.img_size = args.img_size
        self.z_dim = args.z_dim
        self.text_embed_dim = args.text_embed_dim
        self.text_reduced_dim = args.text_reduced_dim
        self.learning_rate = args.learning_rate
        self.beta1 = args.beta1
        self.beta2 = args.beta2
        self.l1_coeff = args.l1_coeff
        self.resume_epoch = args.resume_epoch
        self.resume_idx = args.resume_idx
        self.SUPERVISED = SUPERVISED

        # Logger setting
        log_name = datetime.datetime.now().strftime('%Y-%m-%d') + '.log'
        self.logger = logging.getLogger('__name__')
        self.logger.setLevel(logging.INFO)
        self.formatter = logging.Formatter(
            '%(asctime)s:%(levelname)s:%(message)s')
        self.file_handler = logging.FileHandler(
            os.path.join(self.log_dir, log_name))
        self.file_handler.setFormatter(self.formatter)
        self.logger.addHandler(self.file_handler)

        self.build_model()

    def smooth_label(self, tensor, offset):
        return tensor + offset

    def dump_imgs(images_Array, name):
        with open('{}.pickle'.format(name), 'wb') as file:
            dump(images_Array, file)

    def build_model(self):
        """ A function of defining following instances :

		-----  Generator
		-----  Discriminator
		-----  Optimizer for Generator
		-----  Optimizer for Discriminator
		-----  Defining Loss functions

		"""

        # ---------------------------------------------------------------------#
        #						1. Network Initialization					   #
        # ---------------------------------------------------------------------#
        self.gen = Generator(batch_size=self.batch_size,
                             img_size=self.img_size,
                             z_dim=self.z_dim,
                             text_embed_dim=self.text_embed_dim,
                             text_reduced_dim=self.text_reduced_dim)

        self.disc = Discriminator(batch_size=self.batch_size,
                                  img_size=self.img_size,
                                  text_embed_dim=self.text_embed_dim,
                                  text_reduced_dim=self.text_reduced_dim)

        self.gen_optim = optim.Adam(self.gen.parameters(),
                                    lr=self.learning_rate,
                                    betas=(self.beta1, self.beta2))

        self.disc_optim = optim.Adam(self.disc.parameters(),
                                     lr=self.learning_rate,
                                     betas=(self.beta1, self.beta2))

        self.cls_gan_optim = optim.Adam(itertools.chain(
            self.gen.parameters(), self.disc.parameters()),
                                        lr=self.learning_rate,
                                        betas=(self.beta1, self.beta2))

        print('-------------  Generator Model Info  ---------------')
        self.print_network(self.gen, 'G')
        print('------------------------------------------------')

        print('-------------  Discriminator Model Info  ---------------')
        self.print_network(self.disc, 'D')
        print('------------------------------------------------')

        self.criterion = nn.BCELoss().cuda()
        # self.CE_loss = nn.CrossEntropyLoss().cuda()
        # self.MSE_loss = nn.MSELoss().cuda()
        self.gen.train()
        self.disc.train()

    def print_network(self, model, name):
        """ A function for printing total number of model parameters """
        num_params = 0
        for p in model.parameters():
            num_params += p.numel()

        print(model)
        print(name)
        print("Total number of parameters: {}".format(num_params))

    def load_checkpoints(self, resume_epoch, idx):
        """Restore the trained generator and discriminator."""
        print('Loading the trained models from epoch {} and iteration {}...'.
              format(resume_epoch, idx))
        G_path = os.path.join(self.checkpoint_dir,
                              '{}-{}-G.ckpt'.format(resume_epoch, idx))
        D_path = os.path.join(self.checkpoint_dir,
                              '{}-{}-D.ckpt'.format(resume_epoch, idx))
        self.gen.load_state_dict(
            torch.load(G_path, map_location=lambda storage, loc: storage))
        self.disc.load_state_dict(
            torch.load(D_path, map_location=lambda storage, loc: storage))

    def train_model(self):

        data_loader = self.data_loader

        start_epoch = 0
        if self.resume_epoch >= 0:
            start_epoch = self.resume_epoch
            self.load_checkpoints(self.resume_epoch, self.resume_idx)

        print('---------------  Model Training Started  ---------------')
        start_time = time.time()

        for epoch in range(start_epoch, self.num_epochs):
            print("Epoch: {}".format(epoch + 1))
            for idx, batch in enumerate(data_loader):
                print("Index: {}".format(idx + 1), end="\t")
                true_imgs = batch['true_imgs']
                true_embed = batch['true_embds']
                false_imgs = batch['false_imgs']

                real_labels = torch.ones(true_imgs.size(0))
                fake_labels = torch.zeros(true_imgs.size(0))

                smooth_real_labels = torch.FloatTensor(
                    self.smooth_label(real_labels.numpy(), -0.1))

                true_imgs = Variable(true_imgs.float()).cuda()
                true_embed = Variable(true_embed.float()).cuda()
                false_imgs = Variable(false_imgs.float()).cuda()

                real_labels = Variable(real_labels).cuda()
                smooth_real_labels = Variable(smooth_real_labels).cuda()
                fake_labels = Variable(fake_labels).cuda()

                # ---------------------------------------------------------------#
                # 					  2. Training the generator                  #
                # ---------------------------------------------------------------#
                self.gen.zero_grad()
                z = Variable(torch.randn(true_imgs.size(0), self.z_dim)).cuda()
                fake_imgs = self.gen.forward(true_embed, z)
                fake_out, fake_logit = self.disc.forward(fake_imgs, true_embed)
                fake_out = Variable(fake_out.data, requires_grad=True).cuda()

                true_out, true_logit = self.disc.forward(true_imgs, true_embed)
                true_out = Variable(true_out.data, requires_grad=True).cuda()

                g_sf = self.criterion(fake_out, real_labels)
                #g_img = self.l1_coeff * nn.L1Loss()(fake_imgs, true_imgs)
                gen_loss = g_sf

                gen_loss.backward()
                self.gen_optim.step()

                # ---------------------------------------------------------------#
                # 					3. Training the discriminator				 #
                # ---------------------------------------------------------------#
                self.disc.zero_grad()
                false_out, false_logit = self.disc.forward(
                    false_imgs, true_embed)
                false_out = Variable(false_out.data, requires_grad=True)

                sr = self.criterion(true_out, smooth_real_labels)
                sw = self.criterion(true_out, fake_labels)
                sf = self.criterion(false_out, smooth_real_labels)

                disc_loss = torch.log(sr) + (torch.log(1 - sw) +
                                             torch.log(1 - sf)) / 2

                disc_loss.backward()
                self.disc_optim.step()

                self.cls_gan_optim.step()

                # Logging
                loss = {}
                loss['G_loss'] = gen_loss.item()
                loss['D_loss'] = disc_loss.item()

                # ---------------------------------------------------------------#
                # 					4. Logging INFO into log_dir				 #
                # ---------------------------------------------------------------#
                log = ""
                if (idx + 1) % self.log_step == 0:
                    end_time = time.time() - start_time
                    end_time = datetime.timedelta(seconds=end_time)
                    log = "Elapsed [{}], Epoch [{}/{}], Idx [{}]".format(
                        end_time, epoch + 1, self.num_epochs, idx)

                for net, loss_value in loss.items():
                    log += "{}: {:.4f}".format(net, loss_value)
                    self.logger.info(log)
                    print(log)
                """
				# ---------------------------------------------------------------#
				# 					5. Saving generated images					 #
				# ---------------------------------------------------------------#
				if (idx + 1) % self.sample_step == 0:
					concat_imgs = torch.cat((true_imgs, fake_imgs), 0)  # ??????????
					concat_imgs = (concat_imgs + 1) / 2
					# out.clamp_(0, 1)
					 
					save_path = os.path.join(self.sample_dir, '{}-{}-images.jpg'.format(epoch, idx + 1))
					# concat_imgs.cpu().detach().numpy()
					self.dump_imgs(concat_imgs.cpu().numpy(), save_path)
					
					#save_image(concat_imgs.data.cpu(), self.sample_dir, nrow=1, padding=0)
					print ('Saved real and fake images into {}...'.format(self.sample_dir))
				"""

                # ---------------------------------------------------------------#
                # 				6. Saving the checkpoints & final model			 #
                # ---------------------------------------------------------------#
                if (idx + 1) % self.model_save_step == 0:
                    G_path = os.path.join(
                        self.checkpoint_dir,
                        '{}-{}-G.ckpt'.format(epoch, idx + 1))
                    D_path = os.path.join(
                        self.checkpoint_dir,
                        '{}-{}-D.ckpt'.format(epoch, idx + 1))
                    torch.save(self.gen.state_dict(), G_path)
                    torch.save(self.disc.state_dict(), D_path)
                    print('Saved model checkpoints into {}...\n'.format(
                        self.checkpoint_dir))

        print('---------------  Model Training Completed  ---------------')
        # Saving final model into final_model directory
        G_path = os.path.join(self.final_model, '{}-G.pth'.format('final'))
        D_path = os.path.join(self.final_model, '{}-D.pth'.format('final'))
        torch.save(self.gen.state_dict(), G_path)
        torch.save(self.disc.state_dict(), D_path)
        print('Saved final model into {}...'.format(self.final_model))
Ejemplo n.º 8
0
class BigGAN():
    """Big GAN"""
    def __init__(self, device, dataloader, num_classes, configs):
        self.device = device
        self.dataloader = dataloader
        self.num_classes = num_classes

        # model settings & hyperparams
        # self.total_steps = configs.total_steps
        self.epochs = configs.epochs
        self.d_iters = configs.d_iters
        self.g_iters = configs.g_iters
        self.batch_size = configs.batch_size
        self.imsize = configs.imsize
        self.nz = configs.nz
        self.ngf = configs.ngf
        self.ndf = configs.ndf
        self.g_lr = configs.g_lr
        self.d_lr = configs.d_lr
        self.beta1 = configs.beta1
        self.beta2 = configs.beta2

        # instance noise
        self.inst_noise_sigma = configs.inst_noise_sigma
        self.inst_noise_sigma_iters = configs.inst_noise_sigma_iters

        # model logging and saving
        self.log_step = configs.log_step
        self.save_epoch = configs.save_epoch
        self.model_path = configs.model_path
        self.sample_path = configs.sample_path

        # pretrained
        self.pretrained_model = configs.pretrained_model

        # building
        self.build_model()

        # archive of all losses
        self.ave_d_losses = []
        self.ave_d_losses_real = []
        self.ave_d_losses_fake = []
        self.ave_g_losses = []

        if self.pretrained_model:
            self.load_pretrained()

    def build_model(self):
        """Initiate Generator and Discriminator"""
        self.G = Generator(self.nz, self.ngf, self.num_classes).to(self.device)
        self.D = Discriminator(self.ndf, self.num_classes).to(self.device)

        self.g_optimizer = optim.Adam(
            filter(lambda p: p.requires_grad, self.G.parameters()), self.g_lr,
            [self.beta1, self.beta2])
        self.d_optimizer = optim.Adam(
            filter(lambda p: p.requires_grad, self.D.parameters()), self.d_lr,
            [self.beta1, self.beta2])

        print("Generator Parameters: ", parameters(self.G))
        print(self.G)
        print("Discriminator Parameters: ", parameters(self.D))
        print(self.D)
        print("Number of classes: ", self.num_classes)

    def load_pretrained(self):
        """Loading pretrained model"""
        checkpoint = torch.load(
            os.path.join(self.model_path,
                         "{}_biggan.pth".format(self.pretrained_model)))

        # load models
        self.G.load_state_dict(checkpoint["g_state_dict"])
        self.D.load_state_dict(checkpoint["d_state_dict"])

        # load optimizers
        self.g_optimizer.load_state_dict(checkpoint["g_optimizer"])
        self.d_optimizer.load_state_dict(checkpoint["d_optimizer"])

        # load losses
        self.ave_d_losses = checkpoint["ave_d_losses"]
        self.ave_d_losses_real = checkpoint["ave_d_losses_real"]
        self.ave_d_losses_fake = checkpoint["ave_d_losses_fake"]
        self.ave_g_losses = checkpoint["ave_g_losses"]

        print("Loading pretrained models (epoch: {})..!".format(
            self.pretrained_model))

    def reset_grad(self):
        """Reset gradients"""
        self.g_optimizer.zero_grad()
        self.d_optimizer.zero_grad()

    def train(self):
        """Train model"""
        step_per_epoch = len(self.dataloader)
        epochs = self.epochs
        total_steps = epochs * step_per_epoch

        # fixed z and labels for sampling generator images
        fixed_z = tensor2var(torch.randn(self.batch_size, self.nz),
                             device=self.device)
        fixed_labels = tensor2var(torch.from_numpy(
            np.tile(np.arange(self.num_classes), self.batch_size)).long(),
                                  device=self.device)

        print("Initiating Training")
        print("Epochs: {}, Total Steps: {}, Steps/Epoch: {}".format(
            epochs, total_steps, step_per_epoch))

        if self.pretrained_model:
            start_epoch = self.pretrained_model
        else:
            start_epoch = 0

        self.D.train()
        self.G.train()

        # Instance noise - make random noise mean (0) and std for injecting
        inst_noise_mean = torch.full(
            (self.batch_size, 3, self.imsize, self.imsize), 0).to(self.device)
        inst_noise_std = torch.full(
            (self.batch_size, 3, self.imsize, self.imsize),
            self.inst_noise_sigma).to(self.device)

        # total time
        start_time = time.time()
        for epoch in range(start_epoch, epochs):
            # local losses
            d_losses = []
            d_losses_real = []
            d_losses_fake = []
            g_losses = []

            data_iter = iter(self.dataloader)
            for step in range(step_per_epoch):
                # Instance noise std is linearly annealed from self.inst_noise_sigma to 0 thru self.inst_noise_sigma_iters
                inst_noise_sigma_curr = 0 if step > self.inst_noise_sigma_iters else (
                    1 -
                    step / self.inst_noise_sigma_iters) * self.inst_noise_sigma
                inst_noise_std.fill_(inst_noise_sigma_curr)

                # get real images
                real_images, real_labels = next(data_iter)
                real_images = real_images.to(self.device)
                real_labels = real_labels.to(self.device)

                # ================== TRAIN DISCRIMINATOR ================== #

                for _ in range(self.d_iters):
                    self.reset_grad()

                    # TRAIN REAL

                    # creating instance noise
                    inst_noise = torch.normal(mean=inst_noise_mean,
                                              std=inst_noise_std).to(
                                                  self.device)
                    # adding noise to real images
                    d_real = self.D(real_images + inst_noise, real_labels)
                    d_loss_real = loss_hinge_dis_real(d_real)
                    d_loss_real.backward()

                    # delete loss
                    if (step + 1) % self.log_step != 0:
                        del d_real, d_loss_real

                    # TRAIN FAKE

                    # create fake images using latent vector
                    z = tensor2var(torch.randn(real_images.size(0), self.nz),
                                   device=self.device)
                    fake_images = self.G(z, real_labels)

                    # creating instance noise
                    inst_noise = torch.normal(mean=inst_noise_mean,
                                              std=inst_noise_std).to(
                                                  self.device)
                    # adding noise to fake images
                    # detach fake_images tensor from graph
                    d_fake = self.D(fake_images.detach() + inst_noise,
                                    real_labels)
                    d_loss_fake = loss_hinge_dis_fake(d_fake)
                    d_loss_fake.backward()

                    # delete loss, output
                    del fake_images
                    if (step + 1) % self.log_step != 0:
                        del d_fake, d_loss_fake

                # optimize D
                self.d_optimizer.step()

                # ================== TRAIN GENERATOR ================== #

                for _ in range(self.g_iters):
                    self.reset_grad()

                    # create new latent vector
                    z = tensor2var(torch.randn(real_images.size(0), self.nz),
                                   device=self.device)

                    # generate fake images
                    inst_noise = torch.normal(mean=inst_noise_mean,
                                              std=inst_noise_std).to(
                                                  self.device)
                    fake_images = self.G(z, real_labels)
                    g_fake = self.D(fake_images + inst_noise, real_labels)

                    # compute hinge loss for G
                    g_loss = loss_hinge_gen(g_fake)
                    g_loss.backward()

                    del fake_images
                    if (step + 1) % self.log_step != 0:
                        del g_fake, g_loss

                # optimize G
                self.g_optimizer.step()

                # logging step progression
                if (step + 1) % self.log_step == 0:
                    d_loss = d_loss_real + d_loss_fake

                    # logging losses
                    d_losses.append(d_loss.item())
                    d_losses_real.append(d_loss_real.item())
                    d_losses_fake.append(d_loss_fake.item())
                    g_losses.append(g_loss.item())

                    # print out
                    elapsed = time.time() - start_time
                    elapsed = str(datetime.timedelta(seconds=elapsed))
                    print(
                        "Elapsed [{}], Epoch: [{}/{}], Step [{}/{}], g_loss: {:.4f}, d_loss: {:.4f},"
                        " d_loss_real: {:.4f}, d_loss_fake: {:.4f}".format(
                            elapsed, (epoch + 1), epochs, (step + 1),
                            step_per_epoch, g_loss, d_loss, d_loss_real,
                            d_loss_fake))

                    del d_real, d_loss_real, d_fake, d_loss_fake, g_fake, g_loss

            # logging average losses over epoch
            self.ave_d_losses.append(mean(d_losses))
            self.ave_d_losses_real.append(mean(d_losses_real))
            self.ave_d_losses_fake.append(mean(d_losses_fake))
            self.ave_g_losses.append(mean(g_losses))

            # epoch update
            print(
                "Elapsed [{}], Epoch: [{}/{}], ave_g_loss: {:.4f}, ave_d_loss: {:.4f},"
                " ave_d_loss_real: {:.4f}, ave_d_loss_fake: {:.4f},".format(
                    elapsed, epoch + 1, epochs, self.ave_g_losses[epoch],
                    self.ave_d_losses[epoch], self.ave_d_losses_real[epoch],
                    self.ave_d_losses_fake[epoch]))

            # sample images every epoch
            fake_images = self.G(fixed_z, fixed_labels)
            fake_images = denorm(fake_images.data)
            save_image(
                fake_images,
                os.path.join(self.sample_path,
                             "Epoch {}.png".format(epoch + 1)))

            # save model
            if (epoch + 1) % self.save_epoch == 0:
                torch.save(
                    {
                        "g_state_dict": self.G.state_dict(),
                        "d_state_dict": self.D.state_dict(),
                        "g_optimizer": self.g_optimizer.state_dict(),
                        "d_optimizer": self.d_optimizer.state_dict(),
                        "ave_d_losses": self.ave_d_losses,
                        "ave_d_losses_real": self.ave_d_losses_real,
                        "ave_d_losses_fake": self.ave_d_losses_fake,
                        "ave_g_losses": self.ave_g_losses
                    },
                    os.path.join(self.model_path,
                                 "{}_biggan.pth".format(epoch + 1)))

                print("Saving models (epoch {})..!".format(epoch + 1))

    def plot(self):
        plt.plot(self.ave_d_losses)
        plt.plot(self.ave_d_losses_real)
        plt.plot(self.ave_d_losses_fake)
        plt.plot(self.ave_g_losses)
        plt.legend(["d loss", "d real", "d fake", "g loss"], loc="upper left")
        plt.show()
Ejemplo n.º 9
0
def main(args):
    # log hyperparameter
    print(args)

    # select device
    args.cuda = not args.no_cuda and torch.cuda.is_available()
    device = torch.device("cuda: 0" if args.cuda else "cpu")

    # set random seed
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)

    # data loader
    transform = transforms.Compose([
        utils.Normalize(),
        utils.ToTensor()
    ])
    train_dataset = TVDataset(
        root=args.root,
        sub_size=args.block_size,
        volume_list=args.volume_train_list,
        max_k=args.training_step,
        train=True,
        transform=transform
    )
    test_dataset = TVDataset(
        root=args.root,
        sub_size=args.block_size,
        volume_list=args.volume_test_list,
        max_k=args.training_step,
        train=False,
        transform=transform
    )

    kwargs = {"num_workers": 4, "pin_memory": True} if args.cuda else {}
    train_loader = DataLoader(train_dataset, batch_size=args.batch_size,
                              shuffle=True, **kwargs)
    test_loader = DataLoader(test_dataset, batch_size=args.batch_size,
                             shuffle=False, **kwargs)

    # model
    def generator_weights_init(m):
        if isinstance(m, nn.Conv3d):
            nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
            if m.bias is not None:
                nn.init.zeros_(m.bias)

    def discriminator_weights_init(m):
        if isinstance(m, nn.Conv3d):
            nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='leaky_relu')
            if m.bias is not None:
                nn.init.zeros_(m.bias)

    g_model = Generator(args.upsample_mode, args.forward, args.backward, args.gen_sn, args.residual)
    g_model.apply(generator_weights_init)
    if args.data_parallel and torch.cuda.device_count() > 1:
        g_model = nn.DataParallel(g_model)
    g_model.to(device)

    if args.gan_loss != "none":
        d_model = Discriminator(args.dis_sn)
        d_model.apply(discriminator_weights_init)
        # if args.dis_sn:
        #     d_model = add_sn(d_model)
        if args.data_parallel and torch.cuda.device_count() > 1:
            d_model = nn.DataParallel(d_model)
        d_model.to(device)

    mse_loss = nn.MSELoss()
    adversarial_loss = nn.MSELoss()
    train_losses, test_losses = [], []
    d_losses, g_losses = [], []

    # optimizer
    g_optimizer = optim.Adam(g_model.parameters(), lr=args.lr,
                             betas=(args.beta1, args.beta2))
    if args.gan_loss != "none":
        d_optimizer = optim.Adam(d_model.parameters(), lr=args.d_lr,
                                 betas=(args.beta1, args.beta2))

    Tensor = torch.cuda.FloatTensor if args.cuda else torch.FloatTensor

    # load checkpoint
    if args.resume:
        if os.path.isfile(args.resume):
            print("=> loading checkpoint {}".format(args.resume))
            checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint["epoch"]
            g_model.load_state_dict(checkpoint["g_model_state_dict"])
            # g_optimizer.load_state_dict(checkpoint["g_optimizer_state_dict"])
            if args.gan_loss != "none":
                d_model.load_state_dict(checkpoint["d_model_state_dict"])
                # d_optimizer.load_state_dict(checkpoint["d_optimizer_state_dict"])
                d_losses = checkpoint["d_losses"]
                g_losses = checkpoint["g_losses"]
            train_losses = checkpoint["train_losses"]
            test_losses = checkpoint["test_losses"]
            print("=> load chekcpoint {} (epoch {})"
                  .format(args.resume, checkpoint["epoch"]))

    # main loop
    for epoch in tqdm(range(args.start_epoch, args.epochs)):
        # training..
        g_model.train()
        if args.gan_loss != "none":
            d_model.train()
        train_loss = 0.
        volume_loss_part = np.zeros(args.training_step)
        for i, sample in enumerate(train_loader):
            params = list(g_model.named_parameters())
            # pdb.set_trace()
            # params[0][1].register_hook(lambda g: print("{}.grad: {}".format(params[0][0], g)))
            # adversarial ground truths
            real_label = Variable(Tensor(sample["v_i"].shape[0], sample["v_i"].shape[1], 1, 1, 1, 1).fill_(1.0), requires_grad=False)
            fake_label = Variable(Tensor(sample["v_i"].shape[0], sample["v_i"].shape[1], 1, 1, 1, 1).fill_(0.0), requires_grad=False)

            v_f = sample["v_f"].to(device)
            v_b = sample["v_b"].to(device)
            v_i = sample["v_i"].to(device)
            g_optimizer.zero_grad()
            fake_volumes = g_model(v_f, v_b, args.training_step, args.wo_ori_volume, args.norm)

            # adversarial loss
            # update discriminator
            if args.gan_loss != "none":
                avg_d_loss = 0.
                avg_d_loss_real = 0.
                avg_d_loss_fake = 0.
                for k in range(args.n_d):
                    d_optimizer.zero_grad()
                    decisions = d_model(v_i)
                    d_loss_real = adversarial_loss(decisions, real_label)
                    fake_decisions = d_model(fake_volumes.detach())

                    d_loss_fake = adversarial_loss(fake_decisions, fake_label)
                    d_loss = d_loss_real + d_loss_fake
                    d_loss.backward()
                    avg_d_loss += d_loss.item() / args.n_d
                    avg_d_loss_real += d_loss_real / args.n_d
                    avg_d_loss_fake += d_loss_fake / args.n_d

                    d_optimizer.step()

            # update generator
            if args.gan_loss != "none":
                avg_g_loss = 0.
            avg_loss = 0.
            for k in range(args.n_g):
                loss = 0.
                g_optimizer.zero_grad()

                # adversarial loss
                if args.gan_loss != "none":
                    fake_decisions = d_model(fake_volumes)
                    g_loss = args.gan_loss_weight * adversarial_loss(fake_decisions, real_label)
                    loss += g_loss
                    avg_g_loss += g_loss.item() / args.n_g

                # volume loss
                if args.volume_loss:
                    volume_loss = args.volume_loss_weight * mse_loss(v_i, fake_volumes)
                    for j in range(v_i.shape[1]):
                        volume_loss_part[j] += mse_loss(v_i[:, j, :], fake_volumes[:, j, :]) / args.n_g / args.log_every
                    loss += volume_loss

                # feature loss
                if args.feature_loss:
                    feat_real = d_model.extract_features(v_i)
                    feat_fake = d_model.extract_features(fake_volumes)
                    for m in range(len(feat_real)):
                        loss += args.feature_loss_weight / len(feat_real) * mse_loss(feat_real[m], feat_fake[m])

                avg_loss += loss / args.n_g
                loss.backward()
                g_optimizer.step()

            train_loss += avg_loss

            # log training status
            subEpoch = (i + 1) // args.log_every
            if (i+1) % args.log_every == 0:
                print("Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}".format(
                    epoch, (i+1) * args.batch_size, len(train_loader.dataset), 100. * (i+1) / len(train_loader),
                    avg_loss
                ))
                print("Volume Loss: ")
                for j in range(volume_loss_part.shape[0]):
                    print("\tintermediate {}: {:.6f}".format(
                        j+1, volume_loss_part[j]
                    ))

                if args.gan_loss != "none":
                    print("DLossReal: {:.6f} DLossFake: {:.6f} DLoss: {:.6f}, GLoss: {:.6f}".format(
                        avg_d_loss_real, avg_d_loss_fake, avg_d_loss, avg_g_loss
                    ))
                    d_losses.append(avg_d_loss)
                    g_losses.append(avg_g_loss)
                # train_losses.append(avg_loss)
                train_losses.append(train_loss.item() / args.log_every)
                print("====> SubEpoch: {} Average loss: {:.6f} Time {}".format(
                    subEpoch, train_loss.item() / args.log_every, time.asctime(time.localtime(time.time()))
                ))
                train_loss = 0.
                volume_loss_part = np.zeros(args.training_step)

            # testing...
            if (i + 1) % args.test_every == 0:
                g_model.eval()
                if args.gan_loss != "none":
                    d_model.eval()
                test_loss = 0.
                with torch.no_grad():
                    for i, sample in enumerate(test_loader):
                        v_f = sample["v_f"].to(device)
                        v_b = sample["v_b"].to(device)
                        v_i = sample["v_i"].to(device)
                        fake_volumes = g_model(v_f, v_b, args.training_step, args.wo_ori_volume, args.norm)
                        test_loss += args.volume_loss_weight * mse_loss(v_i, fake_volumes).item()

                test_losses.append(test_loss * args.batch_size / len(test_loader.dataset))
                print("====> SubEpoch: {} Test set loss {:4f} Time {}".format(
                    subEpoch, test_losses[-1], time.asctime(time.localtime(time.time()))
                ))

            # saving...
            if (i+1) % args.check_every == 0:
                print("=> saving checkpoint at epoch {}".format(epoch))
                if args.gan_loss != "none":
                    torch.save({"epoch": epoch + 1,
                                "g_model_state_dict": g_model.state_dict(),
                                "g_optimizer_state_dict":  g_optimizer.state_dict(),
                                "d_model_state_dict": d_model.state_dict(),
                                "d_optimizer_state_dict": d_optimizer.state_dict(),
                                "d_losses": d_losses,
                                "g_losses": g_losses,
                                "train_losses": train_losses,
                                "test_losses": test_losses},
                               os.path.join(args.save_dir, "model_" + str(epoch) + "_" + str(subEpoch) + "_" + "pth.tar")
                               )
                else:
                    torch.save({"epoch": epoch + 1,
                                "g_model_state_dict": g_model.state_dict(),
                                "g_optimizer_state_dict": g_optimizer.state_dict(),
                                "train_losses": train_losses,
                                "test_losses": test_losses},
                               os.path.join(args.save_dir, "model_" + str(epoch) + "_" + str(subEpoch) + "_" + "pth.tar")
                               )
                torch.save(g_model.state_dict(),
                           os.path.join(args.save_dir, "model_" + str(epoch) + "_" + str(subEpoch) + ".pth"))

        num_subEpoch = len(train_loader) // args.log_every
        print("====> Epoch: {} Average loss: {:.6f} Time {}".format(
            epoch, np.array(train_losses[-num_subEpoch:]).mean(), time.asctime(time.localtime(time.time()))
        ))
Ejemplo n.º 10
0
        np.save(os.path.join(save_dir, 'd_losses.npy'), d_losses)
        np.save(os.path.join(save_dir, 'g_losses.npy'), g_losses)
        np.save(os.path.join(save_dir, 'fake_scores.npy'), fake_scores)
        np.save(os.path.join(save_dir, 'real_scores.npy'), real_scores)
    
        plt.figure()
        pylab.xlim(0, num_epochs + 1)
        plt.plot(range(1, num_epochs + 1), d_losses, label='d loss')
        plt.plot(range(1, num_epochs + 1), g_losses, label='g loss')    
        plt.legend()
        plt.savefig(os.path.join(save_dir, 'loss.pdf'))
        plt.close()

        plt.figure()
        pylab.xlim(0, num_epochs + 1)
        pylab.ylim(0, 1)
        plt.plot(range(1, num_epochs + 1), fake_scores, label='fake score')
        plt.plot(range(1, num_epochs + 1), real_scores, label='real score')    
        plt.legend()
        plt.savefig(os.path.join(save_dir, 'accuracy.pdf'))
        plt.close()

        # Save model at checkpoints
        if (epoch + 1) % 50 == 0:
            torch.save(G.state_dict(), os.path.join(save_dir, 'G--{}.ckpt'.format(epoch+1)))
            torch.save(D.state_dict(), os.path.join(save_dir, 'D--{}.ckpt'.format(epoch+1)))

    # Save the model checkpoints 
    torch.save(G.state_dict(), 'G.ckpt')
    torch.save(D.state_dict(), 'D.ckpt')
    
Ejemplo n.º 11
0
def main():
    random.seed(SEED)
    np.random.seed(SEED)
    calc_bleu([1, 10, 12])
    exit()
    # Build up dataset
    s_train, s_test = load_from_big_file('../data/train_data_obama.txt')
    # idx_to_word: List of id to word
    # word_to_idx: Dictionary mapping word to id
    idx_to_word, word_to_idx = fetch_vocab(s_train, s_train, s_test)
    # TODO: 1. Prepare data for attention model
    # input_seq, target_seq = prepare_data(DATA_GERMAN, DATA_ENGLISH, word_to_idx)

    global VOCAB_SIZE
    VOCAB_SIZE = len(idx_to_word)

    save_vocab(CHECKPOINT_PATH + 'metadata.data', idx_to_word, word_to_idx,
               VOCAB_SIZE, g_emb_dim, g_hidden_dim, g_sequence_len)

    print('VOCAB SIZE:', VOCAB_SIZE)
    # Define Networks
    generator = Generator(VOCAB_SIZE, g_emb_dim, g_hidden_dim, g_sequence_len,
                          BATCH_SIZE, opt.cuda)
    discriminator = Discriminator(d_num_class, VOCAB_SIZE, d_emb_dim,
                                  d_filter_sizes, d_num_filters, d_dropout)
    target_lstm = TargetLSTM(VOCAB_SIZE, g_emb_dim, g_hidden_dim, opt.cuda)
    if opt.cuda:
        generator = generator.cuda()
        discriminator = discriminator.cuda()
        target_lstm = target_lstm.cuda()
    # Generate toy data using target lstm
    print('Generating data ...')
    generate_real_data('../data/train_data_obama.txt', BATCH_SIZE,
                       GENERATED_NUM, idx_to_word, word_to_idx, POSITIVE_FILE,
                       TEST_FILE)
    # Create Test data iterator for testing
    test_iter = GenDataIter(TEST_FILE, BATCH_SIZE)
    # generate_samples(target_lstm, BATCH_SIZE, GENERATED_NUM, POSITIVE_FILE, idx_to_word)

    # Load data from file
    gen_data_iter = GenDataIter(POSITIVE_FILE, BATCH_SIZE)

    # Pretrain Generator using MLE
    # gen_criterion = nn.NLLLoss(size_average=False)
    gen_criterion = nn.CrossEntropyLoss()
    gen_optimizer = optim.Adam(generator.parameters())
    if opt.cuda:
        gen_criterion = gen_criterion.cuda()
    print('Pretrain with MLE ...')
    for epoch in range(PRE_EPOCH_NUM):
        loss = train_epoch(generator, gen_data_iter, gen_criterion,
                           gen_optimizer)
        print('Epoch [%d] Model Loss: %f' % (epoch, loss))
        print('Training Output')
        test_predict(generator, test_iter, idx_to_word, train_mode=True)

        sys.stdout.flush()
        # TODO: 2. Flags to ensure dimension of model input is handled
        # generate_samples(generator, BATCH_SIZE, GENERATED_NUM, EVAL_FILE)
        """
        eval_iter = GenDataIter(EVAL_FILE, BATCH_SIZE)
        print('Iterator Done')
        loss = eval_epoch(target_lstm, eval_iter, gen_criterion)
        print('Epoch [%d] True Loss: %f' % (epoch, loss))
        """
    print('OUTPUT AFTER PRE-TRAINING')
    test_predict(generator, test_iter, idx_to_word, train_mode=True)

    # Pretrain Discriminator
    dis_criterion = nn.NLLLoss(size_average=False)
    dis_optimizer = optim.Adam(discriminator.parameters())
    if opt.cuda:
        dis_criterion = dis_criterion.cuda()
    print('Pretrain Discriminator ...')
    for epoch in range(3):
        generate_samples(generator, BATCH_SIZE, GENERATED_NUM, NEGATIVE_FILE)
        dis_data_iter = DisDataIter(POSITIVE_FILE, NEGATIVE_FILE, BATCH_SIZE)
        for _ in range(3):
            loss = train_epoch(discriminator, dis_data_iter, dis_criterion,
                               dis_optimizer)
            print('Epoch [%d], loss: %f' % (epoch, loss))
            sys.stdout.flush()
    # Adversarial Training
    rollout = Rollout(generator, 0.8)
    print('#####################################################')
    print('Start Adversarial Training...\n')
    gen_gan_loss = GANLoss()

    gen_gan_optm = optim.Adam(generator.parameters())
    if opt.cuda:
        gen_gan_loss = gen_gan_loss.cuda()
    gen_criterion = nn.NLLLoss(size_average=False)
    if opt.cuda:
        gen_criterion = gen_criterion.cuda()
    dis_criterion = nn.NLLLoss(size_average=False)
    dis_optimizer = optim.Adam(discriminator.parameters())
    if opt.cuda:
        dis_criterion = dis_criterion.cuda()
    real_iter = GenDataIter(POSITIVE_FILE, BATCH_SIZE)
    for total_batch in range(TOTAL_BATCH):
        ## Train the generator for one step
        for it in range(1):
            if real_iter.idx >= real_iter.data_num:
                real_iter.reset()
            inputs = real_iter.next()[0]
            inputs = inputs.cuda()
            samples = generator.sample(BATCH_SIZE, g_sequence_len, inputs)
            samples = samples.cpu()
            rewards = rollout.get_reward(samples, 16, discriminator)
            rewards = Variable(torch.Tensor(rewards))
            if opt.cuda:
                rewards = torch.exp(rewards.cuda()).contiguous().view((-1, ))
            prob = generator.forward(inputs)
            mini_batch = prob.shape[0]
            prob = torch.reshape(
                prob,
                (prob.shape[0] * prob.shape[1], -1))  #prob.view(-1, g_emb_dim)
            targets = copy.deepcopy(inputs).contiguous().view((-1, ))
            loss = gen_gan_loss(prob, targets, rewards)
            gen_gan_optm.zero_grad()
            loss.backward()
            gen_gan_optm.step()
            """
            samples = generator.sample(BATCH_SIZE, g_sequence_len)
            # construct the input to the genrator, add zeros before samples and delete the last column
            zeros = torch.zeros((BATCH_SIZE, 1)).type(torch.LongTensor)
            if samples.is_cuda:
                zeros = zeros.cuda()
            inputs = Variable(torch.cat([zeros, samples.data], dim = 1)[:, :-1].contiguous())
            targets = Variable(samples.data).contiguous().view((-1,))
            print('', inputs.shape, targets.shape)
            print(inputs, targets)
            # calculate the reward
            rewards = rollout.get_reward(samples, 16, discriminator)
            rewards = Variable(torch.Tensor(rewards))
            if opt.cuda:
                rewards = torch.exp(rewards.cuda()).contiguous().view((-1,))
            prob = generator.forward(inputs)
            mini_batch = prob.shape[0]
            prob = torch.reshape(prob, (prob.shape[0] * prob.shape[1], -1)) #prob.view(-1, g_emb_dim)
            loss = gen_gan_loss(prob, targets, rewards)
            gen_gan_optm.zero_grad()
            loss.backward()
            gen_gan_optm.step()
            """
        print('Batch [%d] True Loss: %f' % (total_batch, loss))

        if total_batch % 1 == 0 or total_batch == TOTAL_BATCH - 1:
            # generate_samples(generator, BATCH_SIZE, GENERATED_NUM, EVAL_FILE)
            # eval_iter = GenDataIter(EVAL_FILE, BATCH_SIZE)
            # loss = eval_epoch(target_lstm, eval_iter, gen_criterion)
            if len(prob.shape) > 2:
                prob = torch.reshape(prob, (prob.shape[0] * prob.shape[1], -1))
            predictions = torch.max(prob, dim=1)[1]
            predictions = predictions.view(mini_batch, -1)
            for each_sen in list(predictions):
                print('Train Output:',
                      generate_sentence_from_id(idx_to_word, each_sen))

            test_predict(generator, test_iter, idx_to_word, train_mode=True)
            torch.save(generator.state_dict(),
                       CHECKPOINT_PATH + 'generator.model')
            torch.save(discriminator.state_dict(),
                       CHECKPOINT_PATH + 'discriminator.model')
        rollout.update_params()

        for _ in range(4):
            generate_samples(generator, BATCH_SIZE, GENERATED_NUM,
                             NEGATIVE_FILE)
            dis_data_iter = DisDataIter(POSITIVE_FILE, NEGATIVE_FILE,
                                        BATCH_SIZE)
            for _ in range(2):
                loss = train_epoch(discriminator, dis_data_iter, dis_criterion,
                                   dis_optimizer)
                        'deblur_average_mse_loss'] += mse_loss.data.item()
                    loss_values[
                        'deblur_total_average_loss'] += total_loss.data.item()

                # ===================log========================
                loss_values = {
                    k: v / num_epochs
                    for k, v in loss_values.items()
                }
                losses_per_epoch.append(loss_values)

                print('epoch [{}/{}], {}'.format(epoch + 1, num_epochs,
                                                 loss_values))

                #print('epoch [{}/{}], Deblurrer Total Average Loss: {:.4f}, ' +
                #                'Discrim Average Loss: {:.4f}, '
                #.format(epoch + 1, num_epochs, total_loss.data, discrim_total_error.data))
        except KeyboardInterrupt:
            torch.save(model.state_dict(), 'semantic_model_interrupt.pth')
            torch.save(discriminator.state_dict(), 'discrim_interrupt.pth')
            f = open("losses.txt", "w")
            f.write(str(losses_per_epoch))
            f.close()
            sys.exit()
        break

    torch.save(model.state_dict(), 'semanticmodel.pth')
    torch.save(discriminator.state_dict(), 'discrim.pth')
    f = open("losses.txt", "w")
    f.write(str(losses_per_epoch))
    f.close()
Ejemplo n.º 13
0
class GAN:
    def __init__(self, device, args):
        self.device = device
        self.args = args
        self.batch_size = args.batch_size
        self.generator_checkpoint_path = os.path.join(args.checkpoint_path, 'generator.pth')
        self.discriminator_checkpoint_path = os.path.join(args.checkpoint_path, 'discriminator.pth')
        if not os.path.isdir(args.checkpoint_path):
            os.mkdir(args.checkpoint_path)
        self.generator = Generator(args).to(self.device)
        self.discriminator = Discriminator(args).to(self.device)
        self.sequence_loss = SequenceLoss()
        self.reinforce_loss = ReinforceLoss()
        self.generator_optimizer = optim.Adam(self.generator.parameters(), lr=args.generator_lr)
        self.discriminator_optimizer = optim.Adam(self.discriminator.parameters(), lr=args.discriminator_lr)
        self.evaluator = Evaluator('val', self.device, args)
        self.cider = Cider(args)
        generator_dataset = CaptionDataset('train', args)
        self.generator_loader = DataLoader(generator_dataset, batch_size=self.batch_size, shuffle=True, num_workers=4)
        discriminator_dataset = DiscCaption('train', args)
        self.discriminator_loader = DataLoader(discriminator_dataset, batch_size=self.batch_size, shuffle=True, num_workers=4)

    def train(self):
        if self.args.load_generator:
            self.generator.load_state_dict(torch.load(self.generator_checkpoint_path))
        else:
            self._pretrain_generator()
        if self.args.load_discriminator:
            self.discriminator.load_state_dict(torch.load(self.discriminator_checkpoint_path))
        else:
            self._pretrain_discriminator()
        self._train_gan()

    def _pretrain_generator(self):
        iter = 0
        for epoch in range(self.args.pretrain_generator_epochs):
            self.generator.train()
            for data in self.generator_loader:
                for name, item in data.items():
                    data[name] = item.to(self.device)
                self.generator.zero_grad()
                probs = self.generator(data['fc_feats'], data['att_feats'], data['att_masks'], data['labels'])
                loss = self.sequence_loss(probs, data['labels'])
                loss.backward()
                self.generator_optimizer.step()
                print('iter {}, epoch {}, generator loss {:.3f}'.format(iter, epoch, loss.item()))
                iter += 1
            self.evaluator.evaluate_generator(self.generator)
            torch.save(self.generator.state_dict(), self.generator_checkpoint_path)

    def _pretrain_discriminator(self):
        iter = 0
        for epoch in range(self.args.pretrain_discriminator_epochs):
            self.discriminator.train()
            for data in self.discriminator_loader:
                loss = self._train_discriminator(data)
                print('iter {}, epoch {}, discriminator loss {:.3f}'.format(iter, epoch, loss))
                iter += 1
            self.evaluator.evaluate_discriminator(generator=self.generator, discriminator=self.discriminator)
            torch.save(self.discriminator.state_dict(), self.discriminator_checkpoint_path)
            
    def _train_gan(self):
        generator_iter = iter(self.generator_loader)
        discriminator_iter = iter(self.discriminator_loader)
        for i in range(self.args.train_gan_iters):
            print('iter {}'.format(i))
            for j in range(1):
                try:
                    data = next(generator_iter)
                except StopIteration:
                    generator_iter = iter(self.generator_loader)
                    data = next(generator_iter)
                result = self._train_generator(data)
                print('generator loss {:.3f}, fake prob {:.3f}, cider score {:.3f}'.format(result['loss'], result['fake_prob'], result['cider_score']))
            for j in range(1):
                try:
                    data = next(discriminator_iter)
                except StopIteration:
                    discriminator_iter = iter(self.discriminator_loader)
                    data = next(discriminator_iter)
                loss = self._train_discriminator(data)
                print('discriminator loss {:.3f}'.format(loss))
            if i != 0 and i % 10000 == 0:
                self.evaluator.evaluate_generator(self.generator)
                torch.save(self.generator.state_dict(), self.generator_checkpoint_path)
                self.evaluator.evaluate_discriminator(generator=self.generator, discriminator=self.discriminator)
                torch.save(self.discriminator.state_dict(), self.discriminator_checkpoint_path)

    def _train_generator(self, data):
        self.generator.train()
        for name, item in data.items():
            data[name] = item.to(self.device)
        self.generator.zero_grad()

        probs = self.generator(data['fc_feats'], data['att_feats'], data['att_masks'], data['labels'])
        loss1 = self.sequence_loss(probs, data['labels'])

        seqs, probs = self.generator.sample(data['fc_feats'], data['att_feats'], data['att_masks'])
        greedy_seqs = self.generator.greedy_decode(data['fc_feats'], data['att_feats'], data['att_masks'])
        reward, fake_prob, score = self._get_reward(data, seqs)
        baseline, _, _ = self._get_reward(data, greedy_seqs)
        loss2 = self.reinforce_loss(reward, baseline, probs, seqs)

        loss = loss1 + loss2
        loss.backward()
        self.generator_optimizer.step()
        result = {
            'loss': loss1.item(),
            'fake_prob': fake_prob,
            'cider_score': score
        }
        return result

    def _train_discriminator(self, data):
        self.discriminator.train()
        for name, item in data.items():
            data[name] = item.to(self.device)
        self.discriminator.zero_grad()

        real_probs = self.discriminator(data['fc_feats'], data['att_feats'], data['att_masks'], data['labels'])
        wrong_probs = self.discriminator(data['fc_feats'], data['att_feats'], data['att_masks'], data['wrong_labels'])

        # generate fake data
        with torch.no_grad():
            fake_seqs, _ = self.generator.sample(data['fc_feats'], data['att_feats'], data['att_masks'])
        fake_probs = self.discriminator(data['fc_feats'], data['att_feats'], data['att_masks'], fake_seqs)

        loss = -(0.5 * torch.log(real_probs + 1e-10) + 0.25 * torch.log(1 - wrong_probs + 1e-10) + 0.25 * torch.log(1 - fake_probs + 1e-10)).mean()
        loss.backward()
        self.discriminator_optimizer.step()
        return loss.item()

    def _get_reward(self, data, seqs):
        probs = self.discriminator(data['fc_feats'], data['att_feats'], data['att_masks'], seqs)
        scores = self.cider.get_scores(seqs.cpu().numpy(), data['images'].cpu().numpy())
        reward = probs + torch.tensor(scores, dtype=torch.float, device=self.device)
        fake_prob = probs.mean().item()
        score = scores.mean()
        return reward, fake_prob, score
Ejemplo n.º 14
0
def run():
    print('loop')
    # torch.backends.cudnn.enabled = False
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    # device = torch.device("cpu")
    # Assuming that we are on a CUDA machine, this should print a CUDA device:
    print(device)

    Dx = Discriminator().to(device)
    Gx = UNet(3, 3).to(device)

    Dy = Discriminator().to(device)
    Gy = UNet(3, 3).to(device)

    ld = False
    if ld:
        try:
            Gx.load_state_dict(torch.load('./genx'))
            Dx.load_state_dict(torch.load('./fcnx'))
            Gy.load_state_dict(torch.load('./geny'))
            Dy.load_state_dict(torch.load('./fcny'))
            print('net loaded')
        except Exception as e:
            print(e)

    dataset = 'ukiyoe2photo'
    # A 562
    image_path_A = './datasets/' + dataset + '/trainA/*.jpg'
    image_path_B = './datasets/' + dataset + '/trainB/*.jpg'

    plt.ion()

    train_image_paths_A = glob.glob(image_path_A)
    train_image_paths_B = glob.glob(image_path_B)
    print(len(train_image_paths_A), len(train_image_paths_B))

    b_size = 8

    train_dataset_A = CustomDataset(train_image_paths_A, train=True)
    train_loader_A = torch.utils.data.DataLoader(train_dataset_A,
                                                 batch_size=b_size,
                                                 shuffle=True,
                                                 num_workers=4,
                                                 pin_memory=False,
                                                 drop_last=True)

    train_dataset_B = CustomDataset(train_image_paths_B, True, 562, train=True)
    train_loader_B = torch.utils.data.DataLoader(train_dataset_B,
                                                 batch_size=b_size,
                                                 shuffle=True,
                                                 num_workers=4,
                                                 pin_memory=False,
                                                 drop_last=True)

    Gx.train()
    Dx.train()

    Gy.train()
    Dy.train()

    criterion = nn.BCEWithLogitsLoss().to(device)
    # criterion2 = nn.SmoothL1Loss().to(device)
    criterion2 = nn.L1Loss().to(device)

    g_lr = 2e-4
    d_lr = 2e-4
    optimizer_x = optim.Adam(Gx.parameters(), lr=g_lr, betas=(0.5, 0.999))
    optimizer_x_d = optim.Adam(Dx.parameters(), lr=d_lr, betas=(0.5, 0.999))

    optimizer_y = optim.Adam(Gy.parameters(), lr=g_lr, betas=(0.5, 0.999))
    optimizer_y_d = optim.Adam(Dy.parameters(), lr=d_lr, betas=(0.5, 0.999))

    # cp = cropper().to(device)

    _zero = torch.from_numpy(np.zeros((b_size, 1))).float().to(device)
    _zero.requires_grad = False

    _one = torch.from_numpy(np.ones((b_size, 1))).float().to(device)
    _one.requires_grad = False

    for epoch in trange(100, desc='epoch'):
        # loop = tqdm(zip(train_loader_A, train_loader_B), desc='iteration')
        loop = zip(tqdm(train_loader_A, desc='iteration'), train_loader_B)
        batch_idx = 0
        for data_A, data_B in loop:
            batch_idx += 1
            zero = _zero
            one = _one
            _data_A = data_A.to(device)
            _data_B = data_B.to(device)

            # Dy loss (A -> B)
            gen = Gy(_data_A)

            optimizer_y_d.zero_grad()

            output2_p = Dy(_data_B.detach())
            output_p = Dy(gen.detach())

            errD = (
                criterion(output2_p - torch.mean(output_p), one.detach()) +
                criterion(output_p - torch.mean(output2_p), zero.detach())) / 2
            errD.backward()
            optimizer_y_d.step()

            # Dx loss (B -> A)
            gen = Gx(_data_B)

            optimizer_x_d.zero_grad()

            output2_p = Dx(_data_A.detach())
            output_p = Dx(gen.detach())

            errD = (
                criterion(output2_p - torch.mean(output_p), one.detach()) +
                criterion(output_p - torch.mean(output2_p), zero.detach())) / 2
            errD.backward()
            optimizer_x_d.step()

            # Gy loss (A -> B)
            optimizer_y.zero_grad()
            gen = Gy(_data_A)
            output_p = Dy(gen)
            output2_p = Dy(_data_B.detach())
            g_loss = (
                criterion(output2_p - torch.mean(output_p), zero.detach()) +
                criterion(output_p - torch.mean(output2_p), one.detach())) / 2

            # Gy cycle loss (B -> A -> B)
            fA = Gx(_data_B)
            gen = Gy(fA.detach())
            c_loss = criterion2(gen, _data_B)

            errG = g_loss + c_loss
            errG.backward()
            optimizer_y.step()

            if batch_idx % 10 == 0:

                fig = plt.figure(1)
                fig.clf()
                plt.imshow((np.transpose(_data_B.detach().cpu().numpy()[0],
                                         (1, 2, 0)) + 1) / 2)
                fig.canvas.draw()
                fig.canvas.flush_events()

                fig = plt.figure(2)
                fig.clf()
                plt.imshow((np.transpose(fA.detach().cpu().numpy()[0],
                                         (1, 2, 0)) + 1) / 2)
                fig.canvas.draw()
                fig.canvas.flush_events()

                fig = plt.figure(3)
                fig.clf()
                plt.imshow((np.transpose(gen.detach().cpu().numpy()[0],
                                         (1, 2, 0)) + 1) / 2)
                fig.canvas.draw()
                fig.canvas.flush_events()

            # Gx loss (B -> A)
            optimizer_x.zero_grad()
            gen = Gx(_data_B)
            output_p = Dx(gen)
            output2_p = Dx(_data_A.detach())
            g_loss = (
                criterion(output2_p - torch.mean(output_p), zero.detach()) +
                criterion(output_p - torch.mean(output2_p), one.detach())) / 2

            # Gx cycle loss (A -> B -> A)
            fB = Gy(_data_A)
            gen = Gx(fB.detach())
            c_loss = criterion2(gen, _data_A)

            errG = g_loss + c_loss
            errG.backward()
            optimizer_x.step()

        torch.save(Gx.state_dict(), './genx')
        torch.save(Dx.state_dict(), './fcnx')
        torch.save(Gy.state_dict(), './geny')
        torch.save(Dy.state_dict(), './fcny')
    print('\nFinished Training')
Ejemplo n.º 15
0
class SAGAN():
    def __init__(self, dataloader, configs):

        # Data Loader
        self.dataloader = dataloader

        # model settings & hyperparams
        self.total_steps = configs.total_steps
        self.d_iters = configs.d_iters
        self.g_iters = configs.g_iters
        self.batch_size = configs.batch_size
        self.imsize = configs.imsize
        self.nz = configs.nz
        self.ngf = configs.ngf
        self.ndf = configs.ndf
        self.g_lr = configs.g_lr
        self.d_lr = configs.d_lr
        self.beta1 = configs.beta1
        self.beta2 = configs.beta2

        # instance noise
        self.inst_noise_sigma = configs.inst_noise_sigma
        self.inst_noise_sigma_iters = configs.inst_noise_sigma_iters

        # model logging and saving
        self.log_step = configs.log_step
        self.save_epoch = configs.save_epoch
        self.model_path = configs.model_path
        self.sample_path = configs.sample_path

        # pretrained
        self.pretrained_model = configs.pretrained_model

        # building
        self.build_model()

        # archive of all losses
        self.ave_d_losses = []
        self.ave_d_losses_real = []
        self.ave_d_losses_fake = []
        self.ave_d_gamma1 = []
        self.ave_d_gamma2 = []

        self.ave_g_losses = []
        self.ave_g_gamma1 = []
        self.ave_g_gamma2 = []

        if self.pretrained_model:
            self.load_pretrained()

    def build_model(self):
        # initialize Generator and Discriminator
        self.G = Generator(self.imsize, self.nz, self.ngf).cuda()
        self.D = Discriminator(self.ndf).cuda()

        # optimizers
        self.g_optimizer = optim.Adam(filter(
            lambda p: p.requires_grad, self.G.parameters()), self.g_lr, [self.beta1, self.beta2])
        self.d_optimizer = optim.Adam(filter(
            lambda p: p.requires_grad, self.D.parameters()), self.d_lr, [self.beta1, self.beta2])

        # tensorboard writer
        self.tb = SummaryWriter()

        print("Generator Parameters: ", parameters(self.G))
        print(self.G)
        print("Discriminator Parameters: ", parameters(self.D))
        print(self.D)

    def load_pretrained(self):
        """Loading pretrained model"""
        checkpoint = torch.load(
            os.path.join(self.model_path, "{}_sagan.pth".format(
                self.pretrained_model)))
        # load models
        self.G.load_state_dict(checkpoint["gen_state_dict"])
        self.D.load_state_dict(checkpoint["disc_state_dict"])

        # load optimizers
        self.g_optimizer.load_state_dict(checkpoint["gen_optimizer"])
        self.d_optimizer.load_state_dict(checkpoint["disc_optimizer"])

        # load losses
        self.ave_d_losses = checkpoint["ave_d_losses"]
        self.ave_d_losses_real = checkpoint["ave_d_losses_real"]
        self.ave_d_losses_fake = checkpoint["ave_d_losses_fake"]
        self.ave_d_gamma1 = checkpoint["ave_d_gamma1"]
        self.ave_d_gamma2 = checkpoint["ave_d_gamma2"]

        self.ave_g_losses = checkpoint["ave_g_losses"]
        self.ave_g_gamma1 = checkpoint["ave_g_gamma1"]
        self.ave_g_gamma2 = checkpoint["ave_g_gamma2"]

        print("Loading pretrained models (epoch: {})..!".format(
            self.pretrained_model))

    def reset_grad(self):
        """Reset gradients"""
        self.g_optimizer.zero_grad()
        self.d_optimizer.zero_grad()

    def train(self):
        step_per_epoch = len(self.dataloader)
        epochs = int(self.total_steps / step_per_epoch)

        # fixed z for sampling generator images
        fixed_z = tensor2var(torch.randn(self.batch_size, self.nz))

        print("Initiating Training")
        print("Epochs: {}, Total Steps: {}, Steps/Epoch: {}".
              format(epochs, self.total_steps, step_per_epoch))

        if self.pretrained_model:
            start_epoch = self.pretrained_model
        else:
            start_epoch = 0

        # train layers
        self.D.train()
        self.G.train()

        # total time
        start_time = time.time()
        for epoch in range(start_epoch, epochs):
            # local losses
            d_losses = []
            d_losses_real = []
            d_losses_fake = []
            d_gamma1 = []
            d_gamma2 = []

            g_losses = []
            g_gamma1 = []
            g_gamma2 = []

            data_iter = iter(self.dataloader)
            for step in range(step_per_epoch):
                # get real images
                real_images, _ = next(data_iter)
                real_images = tensor2var(real_images)

                # Instance noise - make random noise mean (0) and std for injecting
                inst_noise_mean = torch.full(
                    (real_images.size(0), 3, self.imsize, self.imsize), 0).cuda()
                inst_noise_std = torch.full(
                    (real_images.size(0), 3, self.imsize, self.imsize), self.inst_noise_sigma).cuda()

                # Instance noise std is linearly annealed from self.inst_noise_sigma to 0 thru self.inst_noise_sigma_iters
                inst_noise_sigma_curr = 0 if step > self.inst_noise_sigma_iters else (
                    1 - step/self.inst_noise_sigma_iters)*self.inst_noise_sigma
                inst_noise_std.fill_(inst_noise_sigma_curr)

                # ================== TRAIN DISCRIMINATOR ================== #

                for _ in range(self.d_iters):
                    self.reset_grad()

                    # TRAIN REAL
                    # creating instance noise
                    inst_noise = torch.normal(
                        mean=inst_noise_mean, std=inst_noise_std).cuda()
                    # get D output for real images + noise
                    d_real = self.D(real_images + inst_noise)
                    # compute hinge loss of D with real images
                    d_loss_real = loss_hinge_dis_real(d_real)
                    d_loss_real.backward()

                    # TRAIN FAKE
                    # generate fake images and get D output for fake images
                    z = tensor2var(torch.randn(real_images.size(0), self.nz))
                    fake_images = self.G(z)

                    # creating instance noise
                    inst_noise = torch.normal(
                        mean=inst_noise_mean, std=inst_noise_std).cuda()
                    # adding noise to fake images
                    # get D output for fake images
                    d_fake = self.D(fake_images + inst_noise)
                    # compute hinge loss of D with fake images
                    d_loss_fake = loss_hinge_dis_fake(d_fake)
                    d_loss_fake.backward()

                    d_loss = d_loss_real + d_loss_fake

                # optimize D
                self.d_optimizer.step()

                # ================== TRAIN GENERATOR ================== #

                for _ in range(self.g_iters):
                    self.reset_grad()

                    # create new latent vector
                    z = tensor2var(torch.randn(real_images.size(0), self.nz))

                    inst_noise = torch.normal(
                        mean=inst_noise_mean, std=inst_noise_std).cuda()
                    # generate fake images
                    fake_images = self.G(z)
                    g_fake = self.D(fake_images + inst_noise)

                    # compute hinge loss for G
                    g_loss = loss_hinge_gen(g_fake)
                    g_loss.backward()

                self.g_optimizer.step()

                # logging step progression
                if (step+1) % self.log_step == 0:
                    # logging losses and attention
                    d_losses.append(d_loss.item())
                    d_losses_real.append(d_loss_real.item())
                    d_losses_fake.append(d_loss_fake.item())
                    d_gamma1.append(self.D.attn1.gamma.data.item())
                    d_gamma2.append(self.D.attn2.gamma.data.item())

                    g_losses.append(g_loss.item())
                    g_gamma1.append(self.G.attn1.gamma.data.item())
                    g_gamma2.append(self.G.attn2.gamma.data.item())

                    # print out
                    elapsed = time.time() - start_time
                    elapsed = str(datetime.timedelta(seconds=elapsed))
                    print("Elapsed [{}], Epoch: [{}/{}], Step [{}/{}], g_loss: {:.4f}, d_loss: {:.4f},"
                          " d_loss_real: {:.4f}, d_loss_fake: {:.4f}".
                          format(elapsed, epoch+1, epochs, (step + 1), step_per_epoch,
                                 g_loss, d_loss, d_loss_real, d_loss_fake))

            # logging average losses over epoch
            self.ave_d_losses.append(mean(d_losses))
            self.ave_d_losses_real.append(mean(d_losses_real))
            self.ave_d_losses_fake.append(mean(d_losses_fake))
            self.ave_d_gamma1.append(mean(d_gamma1))
            self.ave_d_gamma2.append(mean(d_gamma2))

            self.ave_g_losses.append(mean(g_losses))
            self.ave_g_gamma1.append(mean(g_gamma1))
            self.ave_g_gamma2.append(mean(g_gamma2))

            # adding tensorboard logs
            self.tb.add_scalar("d loss", self.ave_d_losses[epoch], epoch)
            self.tb.add_scalar('d real', self.ave_d_losses_real[epoch], epoch)
            self.tb.add_scalar('d fake', self.ave_d_losses_fake[epoch], epoch)
            self.tb.add_scalar("g loss", self.ave_g_losses[epoch], epoch)

            self.tb.add_scalar("g gamma 1", self.ave_g_gamma1[epoch], epoch)
            self.tb.add_scalar("g gamma 2", self.ave_g_gamma2[epoch], epoch)
            self.tb.add_scalar("d gamma 1", self.ave_d_gamma1[epoch], epoch)
            self.tb.add_scalar("d gamma 2", self.ave_d_gamma2[epoch], epoch)

            # epoch update
            print("Elapsed [{}], Epoch: [{}/{}], ave_g_loss: {:.4f}, ave_d_loss: {:.4f},"
                  " ave_d_loss_real: {:.4f}, ave_d_loss_fake: {:.4f},"
                  " ave_g_gamma1: {:.4f}, ave_g_gamma2: {:.4f}, ave_d_gamma1: {:.4f}, ave_d_gamma2: {:.4f}".
                  format(elapsed, epoch+1, epochs, self.ave_g_losses[epoch], self.ave_d_losses[epoch],
                         self.ave_d_losses_real[epoch], self.ave_d_losses_fake[epoch],
                         self.ave_g_gamma1[epoch], self.ave_g_gamma2[epoch], self.ave_d_gamma1[epoch], self.ave_d_gamma2[epoch]))

            # sample images every epoch
            fake_images = self.G(fixed_z)
            fake_images = denorm(fake_images.data)
            save_image(fake_images,
                       os.path.join(self.sample_path,
                                    "Epoch {}.png".format(epoch+1)))

            # save model
            if (epoch+1) % self.save_epoch == 0:
                torch.save({
                    "gen_state_dict": self.G.state_dict(),
                    "disc_state_dict": self.D.state_dict(),
                    "gen_optimizer": self.g_optimizer.state_dict(),
                    "disc_optimizer": self.d_optimizer.state_dict(),
                    "ave_d_losses": self.ave_d_losses,
                    "ave_d_losses_real": self.ave_d_losses_real,
                    "ave_d_losses_fake": self.ave_d_losses_fake,
                    "ave_d_gamma1": self.ave_d_gamma1,
                    "ave_d_gamma2": self.ave_d_gamma2,
                    "ave_g_losses": self.ave_g_losses,
                    "ave_g_gamma1": self.ave_g_gamma1,
                    "ave_g_gamma2": self.ave_g_gamma2
                }, os.path.join(self.model_path, "{}_sagan.pth".format(epoch+1)))

                print("Saving models (epoch {})..!".format(epoch+1))

    def plot(self):
        plt.plot(self.ave_d_losses)
        plt.plot(self.ave_d_losses_real)
        plt.plot(self.ave_d_losses_fake)
        plt.plot(self.ave_g_losses)
        plt.legend(["d loss", "d real", "d fake", "g loss"], loc="upper left")
        plt.show()

    def sample(self, samples):
        z = tensor2var(torch.randn(samples, self.nz))
        images = self.G(z)
        images = denorm(images.data)
        # https://pytorch.org/docs/stable/_modules/torchvision/utils.html#save_image
        grid = make_grid(images, nrow=8, padding=2, pad_value=0,
                         normalize=False, range=None, scale_each=False)
        # Add 0.5 after unnormalizing to [0, 255] to round to nearest integer
        ndarr = grid.mul(255).add_(0.5).clamp_(0, 255).permute(
            1, 2, 0).to('cpu', torch.uint8).numpy()
        im = Image.fromarray(ndarr)
        plt.imshow(im)
        plt.show()
def _main():
    print_gpu_details()
    device = torch.device("cuda:0" if (torch.cuda.is_available()) else "cpu")
    train_root = args.train_path

    image_size = 256
    cropped_image_size = 256
    print("set image folder")
    train_set = dset.ImageFolder(root=train_root,
                                 transform=transforms.Compose([
                                     transforms.Resize(image_size),
                                     transforms.CenterCrop(cropped_image_size),
                                     transforms.ToTensor()
                                 ]))

    normalizer_clf = transforms.Compose([
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ])
    normalizer_discriminator = transforms.Compose([
        transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
    ])
    print('set data loader')
    train_loader = torch.utils.data.DataLoader(train_set, batch_size=args.batch_size, shuffle=True, num_workers=4, pin_memory=True, drop_last=True)

    # Network creation
    classifier = torch.load(args.classifier_path)
    classifier.eval()
    generator = Generator(gen_type=args.gen_type)
    discriminator = Discriminator(args.discriminator_norm, dis_type=args.gen_type)
    # init weights
    if args.generator_path is not None:
        generator.load_state_dict(torch.load(args.generator_path))
    else:
        generator.init_weights()
    if args.discriminator_path is not None:
        discriminator.load_state_dict(torch.load(args.discriminator_path))
    else:
        discriminator.init_weights()

    classifier.to(device)
    generator.to(device)
    discriminator.to(device)

    # losses + optimizers
    criterion_discriminator, criterion_generator = get_wgan_losses_fn()
    criterion_features = nn.L1Loss()
    criterion_diversity_n = nn.L1Loss()
    criterion_diversity_d = nn.L1Loss()
    generator_optimizer = optim.Adam(generator.parameters(), lr=args.lr, betas=(0.5, 0.999))
    discriminator_optimizer = optim.Adam(discriminator.parameters(), lr=args.lr, betas=(0.5, 0.999))

    num_of_epochs = args.epochs

    starting_time = time.time()
    iterations = 0
    # creating dirs for keeping models checkpoint, temp created images, and loss summary
    outputs_dir = os.path.join('wgan-gp_models', args.model_name)
    if not os.path.isdir(outputs_dir):
        os.makedirs(outputs_dir, exist_ok=True)
    temp_results_dir = os.path.join(outputs_dir, 'temp_results')
    if not os.path.isdir(temp_results_dir):
        os.mkdir(temp_results_dir)
    models_dir = os.path.join(outputs_dir, 'models_checkpoint')
    if not os.path.isdir(models_dir):
        os.mkdir(models_dir)
    writer = tensorboardX.SummaryWriter(os.path.join(outputs_dir, 'summaries'))

    z = torch.randn(args.batch_size, 128, 1, 1).to(device)  # a fixed noise for sampling
    z2 = torch.randn(args.batch_size, 128, 1, 1).to(device)  # a fixed noise for diversity sampling
    fixed_features = 0
    fixed_masks = 0
    fixed_features_diversity = 0
    first_iter = True
    print("Starting Training Loop...")
    for epoch in range(num_of_epochs):
        for data in train_loader:
            train_type = random.choices([1, 2], [args.train1_prob, 1-args.train1_prob]) # choose train type
            iterations += 1
            if iterations % 30 == 1:
                print('epoch:', epoch, ', iter', iterations, 'start, time =', time.time() - starting_time, 'seconds')
                starting_time = time.time()
            images, _ = data
            images = images.to(device)  # change to gpu tensor
            images_discriminator = normalizer_discriminator(images)
            images_clf = normalizer_clf(images)
            _, features = classifier(images_clf)
            if first_iter: # save batch of images to keep track of the model process
                first_iter = False
                fixed_features = [torch.clone(features[x]) for x in range(len(features))]
                fixed_masks = [torch.ones(features[x].shape, device=device) for x in range(len(features))]
                fixed_features_diversity = [torch.clone(features[x]) for x in range(len(features))]
                for i in range(len(features)):
                    for j in range(fixed_features_diversity[i].shape[0]):
                        fixed_features_diversity[i][j] = fixed_features_diversity[i][j % 8]
                grid = vutils.make_grid(images_discriminator, padding=2, normalize=True, nrow=8)
                vutils.save_image(grid, os.path.join(temp_results_dir, 'original_images.jpg'))
                orig_images_diversity = torch.clone(images_discriminator)
                for i in range(orig_images_diversity.shape[0]):
                    orig_images_diversity[i] = orig_images_diversity[i % 8]
                grid = vutils.make_grid(orig_images_diversity, padding=2, normalize=True, nrow=8)
                vutils.save_image(grid, os.path.join(temp_results_dir, 'original_images_diversity.jpg'))
            # Select a features layer to train on
            features_to_train = random.randint(1, len(features) - 2) if args.fixed_layer is None else args.fixed_layer
            # Set masks
            masks = [features[i].clone() for i in range(len(features))]
            setMasksPart1(masks, device, features_to_train) if train_type == 1 else setMasksPart2(masks, device, features_to_train)
            discriminator_loss_dict = train_discriminator(generator, discriminator, criterion_discriminator, discriminator_optimizer, images_discriminator, features, masks)
            for k, v in discriminator_loss_dict.items():
                writer.add_scalar('D/%s' % k, v.data.cpu().numpy(), global_step=iterations)
                if iterations % 30 == 1:
                    print('{}: {:.6f}'.format(k, v))
            if iterations % args.discriminator_steps == 1:
                generator_loss_dict = train_generator(generator, discriminator, criterion_generator, generator_optimizer, images.shape[0], features,
                                                      criterion_features, features_to_train, classifier, normalizer_clf, criterion_diversity_n,
                                                      criterion_diversity_d, masks, train_type)

                for k, v in generator_loss_dict.items():
                    writer.add_scalar('G/%s' % k, v.data.cpu().numpy(), global_step=iterations//5 + 1)
                    if iterations % 30 == 1:
                        print('{}: {:.6f}'.format(k, v))

            # Save generator and discriminator weights every 1000 iterations
            if iterations % 1000 == 1:
                torch.save(generator.state_dict(), models_dir + '/' + args.model_name + 'G')
                torch.save(discriminator.state_dict(), models_dir + '/' + args.model_name + 'D')
            # Save temp results
            if args.keep_temp_results:
                if iterations < 10000 and iterations % 1000 == 1 or iterations % 2000 == 1:
                    # regular sampling (batch of different images)
                    first_features = True
                    fake_images = None
                    fake_images_diversity = None
                    for i in range(1, 5):
                        one_layer_mask = isolate_layer(fixed_masks, i, device)
                        if first_features:
                            first_features = False
                            fake_images = sample(generator, z, fixed_features, one_layer_mask)
                            fake_images_diversity = sample(generator, z, fixed_features_diversity, one_layer_mask)
                        else:
                            tmp_fake_images = sample(generator, z, fixed_features, one_layer_mask)
                            fake_images = torch.vstack((fake_images, tmp_fake_images))
                            tmp_fake_images = sample(generator, z2, fixed_features_diversity, one_layer_mask)
                            fake_images_diversity = torch.vstack((fake_images_diversity, tmp_fake_images))
                    grid = vutils.make_grid(fake_images, padding=2, normalize=True, nrow=8)
                    vutils.save_image(grid, os.path.join(temp_results_dir, 'res_iter_{}.jpg'.format(iterations // 1000)))
                    # diversity sampling (8 different images each with few different noises)
                    grid = vutils.make_grid(fake_images_diversity, padding=2, normalize=True, nrow=8)
                    vutils.save_image(grid, os.path.join(temp_results_dir, 'div_iter_{}.jpg'.format(iterations // 1000)))

                if iterations % 20000 == 1:
                    torch.save(generator.state_dict(), models_dir + '/' + args.model_name + 'G_' + str(iterations // 15000))
                    torch.save(discriminator.state_dict(), models_dir + '/' + args.model_name + 'D_' + str(iterations // 15000))
Ejemplo n.º 17
0
def train(config):
    genAB = UNet(3, 3, bilinear=config.model.bilinear_upsample).cuda()
    init_weights(genAB, 'normal')
    genBA = UNet(3, 3, bilinear=config.model.bilinear_upsample).cuda()
    init_weights(genBA, 'normal')
    discrA = Discriminator(3).cuda()
    init_weights(discrA, 'normal')
    discrB = Discriminator(3).cuda()
    init_weights(discrB, 'normal')

    writer = SummaryWriter(config.name)
    data_train, data_test = datasets_by_name(config.dataset.name,
                                             config.dataset)
    train_dataloader = DataLoader(data_train,
                                  batch_size=config.bs,
                                  shuffle=True,
                                  num_workers=config.num_workers)
    test_dataloader = DataLoader(data_test,
                                 batch_size=config.bs,
                                 shuffle=True,
                                 num_workers=config.num_workers)

    idt_loss = nn.L1Loss()
    cycle_consistency = nn.L1Loss()
    l2_loss = nn.MSELoss()
    discriminator_loss = nn.BCELoss()
    lambda_idt, lambda_C, lambda_D = config.loss.lambda_idt, config.loss.lambda_C, config.loss.lambda_D

    optG = torch.optim.Adam(itertools.chain(genAB.parameters(),
                                            genBA.parameters()),
                            lr=config.train.lr,
                            betas=(config.train.beta1, 0.999))
    optD = torch.optim.Adam(itertools.chain(discrA.parameters(),
                                            discrB.parameters()),
                            lr=config.train.lr,
                            betas=(config.train.beta1, 0.999))

    genAB, genBA, discrA, discrB, optG, optD, start_epoch = load_if_exsists(
        config, genAB, genBA, discrA, discrB, optG, optD)

    for epoch in range(start_epoch, config.train.epochs):
        set_train([genAB, genBA, discrA, discrB])
        set_requires_grad([genAB, genBA, discrA, discrB], True)
        for i, (batch_A, batch_B) in enumerate(tqdm(train_dataloader)):
            batch_A, batch_B = batch_A.cuda(), batch_B.cuda()
            optG.zero_grad()
            loss_G, loss_D = 0, 0
            fake_B = genAB(batch_A)
            cycle_A = genBA(fake_B)
            fake_A = genBA(batch_B)
            cycle_B = genAB(fake_A)
            if lambda_idt > 0:
                loss_G += idt_loss(fake_B, batch_B) * lambda_idt
                loss_G += idt_loss(fake_A, batch_A) * lambda_idt
            if lambda_C > 0:
                loss_G += cycle_consistency(cycle_A, batch_A) * lambda_C
                loss_G += cycle_consistency(cycle_B, batch_B) * lambda_C
            if lambda_D > 0:
                set_requires_grad([discrA, discrB], False)
                discr_feedbackA = discrA(fake_A)
                discr_feedbackB = discrB(fake_B)
                loss_G += discriminator_loss(
                    discr_feedbackA,
                    torch.ones_like(discr_feedbackA)) * lambda_D
                loss_G += discriminator_loss(
                    discr_feedbackB,
                    torch.ones_like(discr_feedbackB)) * lambda_D
            loss_G.backward()
            torch.nn.utils.clip_grad_norm_(
                itertools.chain(genAB.parameters(), genBA.parameters()), 15)
            optG.step()
            if lambda_D > 0:
                set_requires_grad([discrA, discrB], True)
                loss_D_fake, loss_D_true = 0, 0
                optD.zero_grad()
                logits = discrA(fake_A.detach())
                loss_D_fake += discriminator_loss(logits,
                                                  torch.zeros_like(logits))

                logits = discrB(fake_B.detach())
                loss_D_fake += discriminator_loss(logits,
                                                  torch.zeros_like(logits))
                loss_D_fake.backward()
                torch.nn.utils.clip_grad_norm_(
                    itertools.chain(discrA.parameters(), discrB.parameters()),
                    15)
                optD.step()

                optD.zero_grad()
                logits = discrA(batch_A)
                loss_D_true += discriminator_loss(logits,
                                                  torch.ones_like(logits))
                logits = discrB(batch_B)
                loss_D_true += discriminator_loss(logits,
                                                  torch.ones_like(logits))
                loss_D_true.backward()
                torch.nn.utils.clip_grad_norm_(
                    itertools.chain(discrA.parameters(), discrB.parameters()),
                    15)
                optD.step()
                loss_D = loss_D_fake + loss_D_true
            if (i % config.train.verbose_period == 0):
                writer.add_scalar('train/loss_G', loss_G.item(),
                                  len(train_dataloader) * epoch + i)
                writer.add_scalar('train/pixel_error_A',
                                  l2_loss(fake_A, batch_A).mean().item(),
                                  len(train_dataloader) * epoch + i)
                writer.add_scalar('train/pixel_error_B',
                                  l2_loss(fake_B, batch_B).mean().item(),
                                  len(train_dataloader) * epoch + i)
                if lambda_D > 0:
                    writer.add_scalar('train/loss_D', loss_D.item(),
                                      len(train_dataloader) * epoch + i)
                    writer.add_scalar('train/mean_D_A',
                                      discr_feedbackA.mean().item(),
                                      len(train_dataloader) * epoch + i)
                    writer.add_scalar('train/mean_D_B',
                                      discr_feedbackB.mean().item(),
                                      len(train_dataloader) * epoch + i)
                for batch_i in range(fake_A.shape[0]):
                    concat = (torch.cat([fake_A[batch_i], batch_B[batch_i]],
                                        dim=-1) + 1.) / 2.
                    writer.add_image('train/fake_A_' + str(batch_i), concat,
                                     len(train_dataloader) * epoch + i)
                for batch_i in range(fake_B.shape[0]):
                    concat = (torch.cat([fake_B[batch_i], batch_A[batch_i]],
                                        dim=-1) + 1.) / 2.
                    writer.add_image('train/fake_B_' + str(batch_i), concat,
                                     len(train_dataloader) * epoch + i)
        if not config.validate:
            continue
        set_eval([genAB, genBA, discrA, discrB])
        set_requires_grad([genAB, genBA, discrA, discrB], False)
        loss_G, loss_D, discr_feedbackA_mean, discr_feedbackB_mean = 0, 0, 0, 0
        pixel_error_A, pixel_error_B = 0, 0
        for i, (batch_A, batch_B) in enumerate(tqdm(test_dataloader)):
            batch_A, batch_B = batch_A.cuda(), batch_B.cuda()
            fake_B = genAB(batch_A)
            cycle_A = genBA(fake_B)
            fake_A = genBA(batch_B)
            cycle_B = genAB(fake_A)
            pixel_error_A += l2_loss(fake_A, batch_A).mean()
            pixel_error_B += l2_loss(fake_B, batch_B).mean()
            if lambda_idt > 0:
                loss_G += idt_loss(fake_B, batch_B) * lambda_idt
                loss_G += idt_loss(fake_A, batch_A) * lambda_idt
            if lambda_C > 0:
                loss_G += cycle_consistency(cycle_A, batch_A) * lambda_C
                loss_G += cycle_consistency(cycle_B, batch_B) * lambda_C
            if lambda_D > 0:
                discr_feedbackA = discrA(fake_A)
                discr_feedbackB = discrB(fake_B)
                loss_G += discriminator_loss(
                    discr_feedbackA,
                    torch.ones_like(discr_feedbackA)) * lambda_D
                loss_G += discriminator_loss(
                    discr_feedbackB,
                    torch.ones_like(discr_feedbackB)) * lambda_D
                discr_feedbackA_mean += discr_feedbackA.mean()
                discr_feedbackB_mean += discr_feedbackB.mean()
            if lambda_D > 0:
                loss_D_fake, loss_D_true = 0, 0
                logits = discrA(fake_A.detach())
                loss_D_fake += discriminator_loss(logits,
                                                  torch.zeros_like(logits))
                logits = discrB(fake_B.detach())
                loss_D_fake += discriminator_loss(logits,
                                                  torch.zeros_like(logits))
                logits = discrA(batch_A)
                loss_D_true += discriminator_loss(logits,
                                                  torch.ones_like(logits))
                logits = discrB(batch_B)
                loss_D_true += discriminator_loss(logits,
                                                  torch.ones_like(logits))
                loss_D += loss_D_fake + loss_D_true
            if i == 0:
                for batch_i in range(fake_A.shape[0]):
                    concat = (torch.cat([fake_A[batch_i], batch_B[batch_i]],
                                        dim=-1) + 1.) / 2.
                    writer.add_image('val/fake_A_' + str(batch_i), concat,
                                     epoch)
                for batch_i in range(fake_B.shape[0]):
                    concat = (torch.cat([fake_B[batch_i], batch_A[batch_i]],
                                        dim=-1) + 1.) / 2.
                    writer.add_image('val/fake_B_' + str(batch_i), concat,
                                     epoch)
        loss_G /= len(test_dataloader)
        pixel_error_A /= len(test_dataloader)
        pixel_error_B /= len(test_dataloader)
        writer.add_scalar('val/loss_G', loss_G.item(), epoch)
        writer.add_scalar('val/pixel_error_A', pixel_error_A.item(), epoch)
        writer.add_scalar('val/pixel_error_B', pixel_error_B.item(), epoch)
        if lambda_D > 0:
            loss_D /= len(test_dataloader)
            discr_feedbackA_mean /= len(test_dataloader)
            discr_feedbackB_mean /= len(test_dataloader)
            writer.add_scalar('val/loss_D', loss_D.item(), epoch)
            writer.add_scalar('val/mean_D_A', discr_feedbackA_mean.item(),
                              epoch)
            writer.add_scalar('val/mean_D_B', discr_feedbackB_mean.item(),
                              epoch)
        torch.save(
            {
                'genAB': genAB.state_dict(),
                'genBA': genBA.state_dict(),
                'discrA': discrA.state_dict(),
                'discrB': discrB.state_dict(),
                'optG': optG.state_dict(),
                'optD': optD.state_dict(),
                'epoch': epoch
            }, os.path.join(config.name, 'model.pth'))
Ejemplo n.º 18
0
            if opt.cuda:
                rewards = rewards.cuda()
            rewards = torch.exp(rewards).contiguous().view((-1, ))
            prob = generators[i].forward(inputs)
            loss = gen_gan_losses[i](prob, targets, rewards)
            gen_gan_optm[i].zero_grad()
            loss.backward()
            gen_gan_optm[i].step()

    if total_batch % 10 == 0 or total_batch == TOTAL_BATCH - 1:
        for generator in generators:
            print('Saving generator {} with bleu_4: {}'.format(
                generator.name,
                bleu_4(TEXT, corpus, generator, g_sequence_len, count=100)))
            torch.save(
                generator.state_dict(), CHECKPOINT_PATH +
                'generator_seqgan_{}.gen'.format(generator.name))
        torch.save(discriminator.state_dict(),
                   CHECKPOINT_PATH + 'discriminator_seqgan.dis')
    for rollout in rollouts:
        rollout.update_params()

    for _ in range(1):
        loss, acc = train_discriminator(discriminator, generators,
                                        real_data_iterator, dis_criterion,
                                        dis_optimizer)
        print('Epoch [%d], loss: %f, accuracy: %f' % (total_batch, loss, acc))

# if __name__ == '__main__':
# main()
Ejemplo n.º 19
0
def main():
    # # -------------------- Data --------------------
    num_workers = 8  # number of subprocesses to use for data loading
    batch_size = 64  # how many samples per batch to load
    transform = transforms.ToTensor()  # convert data to torch.FloatTensor
    train_data = datasets.MNIST(root='../data',
                                train=True,
                                download=True,
                                transform=transform)
    train_loader = torch.utils.data.DataLoader(train_data,
                                               batch_size=batch_size,
                                               num_workers=num_workers)

    # # Obtain one batch of training images
    # dataiter = iter(train_loader)
    # images, labels = dataiter.next()
    # images = images.numpy()
    # # Get one image from the batch for visualization
    # img = np.squeeze(images[0])
    # fig = plt.figure(figsize=(3, 3))
    # ax = fig.add_subplot(111)
    # ax.imshow(img, cmap='gray')
    # plt.show()

    # # -------------------- Discriminator and Generator --------------------
    # Discriminator hyperparams
    input_size = 784  # Size of input image to discriminator (28*28)
    d_output_size = 1  # Size of discriminator output (real or fake)
    d_hidden_size = 32  # Size of last hidden layer in the discriminator
    # Generator hyperparams
    z_size = 100  # Size of latent vector to give to generator
    g_output_size = 784  # Size of discriminator output (generated image)
    g_hidden_size = 32  # Size of first hidden layer in the generator
    # Instantiate discriminator and generator
    D = Discriminator(input_size, d_hidden_size, d_output_size)
    G = Generator(z_size, g_hidden_size, g_output_size)

    # # -------------------- Optimizers and Criterion --------------------
    # Training hyperparams
    num_epochs = 100
    print_every = 400
    lr = 0.002

    # Create optimizers for the discriminator and generator, respectively
    d_optimizer = optim.Adam(D.parameters(), lr)
    g_optimizer = optim.Adam(G.parameters(), lr)
    losses = []  # keep track of generated "fake" samples

    criterion = nn.BCEWithLogitsLoss()

    # -------------------- Training --------------------
    D.train()
    G.train()

    # Get some fixed data for sampling. These are images that are held
    # constant throughout training, and allow us to inspect the model's performance
    sample_size = 16
    fixed_z = np.random.uniform(-1, 1, size=(sample_size, z_size))
    fixed_z = torch.from_numpy(fixed_z).float()
    samples = []  # keep track of loss

    for epoch in range(num_epochs):
        for batch_i, (real_images, _) in enumerate(train_loader):
            batch_size = real_images.size(0)

            # Important rescaling step
            real_images = real_images * 2 - 1  # rescale input images from [0,1) to [-1, 1)

            # Generate fake images, used for both discriminator and generator
            z = np.random.uniform(-1, 1, size=(batch_size, z_size))
            z = torch.from_numpy(z).float()
            fake_images = G(z)

            real_labels = torch.ones(batch_size)
            fake_labels = torch.zeros(batch_size)

            # ============================================
            #            TRAIN THE DISCRIMINATOR
            # ============================================

            d_optimizer.zero_grad()

            # 1. Train with real images

            # Compute the discriminator losses on real images
            D_real = D(real_images)
            d_real_loss = real_loss(criterion,
                                    D_real,
                                    real_labels,
                                    smooth=True)

            # 2. Train with fake images

            # Compute the discriminator losses on fake images
            # -------------------------------------------------------
            # ATTENTION:
            # *.detach(), thus, generator is fixed when we optimize
            # the discriminator
            # -------------------------------------------------------
            D_fake = D(fake_images.detach())
            d_fake_loss = fake_loss(criterion, D_fake, fake_labels)

            # 3. Add up loss and perform backprop
            d_loss = (d_real_loss + d_fake_loss) * 0.5
            d_loss.backward()
            d_optimizer.step()

            # =========================================
            #            TRAIN THE GENERATOR
            # =========================================

            g_optimizer.zero_grad()

            # Make the discriminator fixed when optimizing the generator
            set_model_gradient(D, False)

            # 1. Train with fake images and flipped labels

            # Compute the discriminator losses on fake images using flipped labels!
            G_D_fake = D(fake_images)
            g_loss = real_loss(criterion, G_D_fake,
                               real_labels)  # use real loss to flip labels

            # 2. Perform backprop
            g_loss.backward()
            g_optimizer.step()

            # Make the discriminator require_grad=True after optimizing the generator
            set_model_gradient(D, True)

            # =========================================
            #           Print some loss stats
            # =========================================
            if batch_i % print_every == 0:
                print(
                    'Epoch [{:5d}/{:5d}] | d_loss: {:6.4f} | g_loss: {:6.4f}'.
                    format(epoch + 1, num_epochs, d_loss.item(),
                           g_loss.item()))

        # AFTER EACH EPOCH
        losses.append((d_loss.item(), g_loss.item()))

        # generate and save sample, fake images
        G.eval()  # eval mode for generating samples
        samples_z = G(fixed_z)
        samples.append(samples_z)
        view_samples(-1, samples, "last_sample.png")
        G.train()  # back to train mode

    # Save models and training generator samples
    torch.save(G.state_dict(), "G.pth")
    torch.save(D.state_dict(), "D.pth")
    with open('train_samples.pkl', 'wb') as f:
        pkl.dump(samples, f)

    # Plot the loss curve
    fig, ax = plt.subplots()
    losses = np.array(losses)
    plt.plot(losses.T[0], label='Discriminator')
    plt.plot(losses.T[1], label='Generator')
    plt.title("Training Losses")
    plt.legend()
    plt.savefig("loss.png")
    plt.show()
Ejemplo n.º 20
0
for epoch in range(num_epochs):
    for batch_i, (real_images, _) in enumerate(train_loader):
        batch_size = real_images.size(0)

        real_images = scale(real_images)

        d_loss = train_discriminator(real_images, d_optim, batch_size, z_size)
        g_loss = train_generator(g_optim, batch_size, z_size)

        # Print some loss stats
        if batch_i % print_every == 0:
            # print discriminator and generator loss
            print('Epoch [{:5d}/{:5d}] | d_loss: {:6.4f} | g_loss: {:6.4f}'.
                  format(epoch + 1, num_epochs, d_loss.item(), g_loss.item()))

    losses.append((d_loss.item(), g_loss.item()))

    # generate and save sample, fake images
    G.eval()  # eval mode for generating samples
    samples_z = G(fixed_z)
    samples.append(samples_z)
    G.train()  # back to train mode

torch.save(D.state_dict(), './D.state')
torch.save(G.state_dict(), './G.state')

with open('train_samples.pkl', 'wb') as f:
    pkl.dump(samples, f)

generate_plot(losses)
Ejemplo n.º 21
0
def main(args):
    use_cuda = (len(args.gpuid) >= 1)
    print("{0} GPU(s) are available".format(cuda.device_count()))

    print("======printing args========")
    print(args)
    print("=================================")

    # Load dataset
    splits = ['train', 'valid']
    if data.has_binary_files(args.data, splits):
        print("Loading bin dataset")
        dataset = data.load_dataset(args.data, splits, args.src_lang,
                                    args.trg_lang, args.fixed_max_len)
        #args.data, splits, args.src_lang, args.trg_lang)
    else:
        print(f"Loading raw text dataset {args.data}")
        dataset = data.load_raw_text_dataset(args.data, splits, args.src_lang,
                                             args.trg_lang, args.fixed_max_len)
        #args.data, splits, args.src_lang, args.trg_lang)
    if args.src_lang is None or args.trg_lang is None:
        # record inferred languages in args, so that it's saved in checkpoints
        args.src_lang, args.trg_lang = dataset.src, dataset.dst
    print('| [{}] dictionary: {} types'.format(dataset.src,
                                               len(dataset.src_dict)))
    print('| [{}] dictionary: {} types'.format(dataset.dst,
                                               len(dataset.dst_dict)))
    for split in splits:
        print('| {} {} {} examples'.format(args.data, split,
                                           len(dataset.splits[split])))

    g_logging_meters = OrderedDict()
    g_logging_meters['train_loss'] = AverageMeter()
    g_logging_meters['valid_loss'] = AverageMeter()
    g_logging_meters['train_acc'] = AverageMeter()
    g_logging_meters['valid_acc'] = AverageMeter()
    g_logging_meters['bsz'] = AverageMeter()  # sentences per batch

    d_logging_meters = OrderedDict()
    d_logging_meters['train_loss'] = AverageMeter()
    d_logging_meters['valid_loss'] = AverageMeter()
    d_logging_meters['train_acc'] = AverageMeter()
    d_logging_meters['valid_acc'] = AverageMeter()
    d_logging_meters['bsz'] = AverageMeter()  # sentences per batch

    # Set model parameters
    args.encoder_embed_dim = 1000
    args.encoder_layers = 4
    args.encoder_dropout_out = 0
    args.decoder_embed_dim = 1000
    args.decoder_layers = 4
    args.decoder_out_embed_dim = 1000
    args.decoder_dropout_out = 0
    args.bidirectional = False

    # try to load generator model
    g_model_path = 'checkpoints/generator/best_gmodel.pt'
    if not os.path.exists(g_model_path):
        print("Start training generator!")
        train_g(args, dataset)
    assert os.path.exists(g_model_path)
    generator = LSTMModel(args,
                          dataset.src_dict,
                          dataset.dst_dict,
                          use_cuda=use_cuda)
    model_dict = generator.state_dict()
    pretrained_dict = torch.load(g_model_path)
    #print(f"First dict: {pretrained_dict}")
    # 1. filter out unnecessary keys
    pretrained_dict = {
        k: v
        for k, v in pretrained_dict.items() if k in model_dict
    }
    #print(f"Second dict: {pretrained_dict}")
    # 2. overwrite entries in the existing state dict
    model_dict.update(pretrained_dict)
    #print(f"model dict: {model_dict}")
    # 3. load the new state dict
    generator.load_state_dict(model_dict)

    print("Generator has successfully loaded!")

    # try to load discriminator model
    d_model_path = 'checkpoints/discriminator/best_dmodel.pt'
    if not os.path.exists(d_model_path):
        print("Start training discriminator!")
        train_d(args, dataset)
    assert os.path.exists(d_model_path)
    discriminator = Discriminator(args,
                                  dataset.src_dict,
                                  dataset.dst_dict,
                                  use_cuda=use_cuda)
    model_dict = discriminator.state_dict()
    pretrained_dict = torch.load(d_model_path)
    # 1. filter out unnecessary keys
    pretrained_dict = {
        k: v
        for k, v in pretrained_dict.items() if k in model_dict
    }
    # 2. overwrite entries in the existing state dict
    model_dict.update(pretrained_dict)
    # 3. load the new state dict
    discriminator.load_state_dict(model_dict)

    print("Discriminator has successfully loaded!")

    #return
    print("starting main training loop")

    torch.autograd.set_detect_anomaly(True)

    if use_cuda:
        if torch.cuda.device_count() > 1:
            discriminator = torch.nn.DataParallel(discriminator).cuda()
            generator = torch.nn.DataParallel(generator).cuda()
        else:
            generator.cuda()
            discriminator.cuda()
    else:
        discriminator.cpu()
        generator.cpu()

    # adversarial training checkpoints saving path
    if not os.path.exists('checkpoints/joint'):
        os.makedirs('checkpoints/joint')
    checkpoints_path = 'checkpoints/joint/'

    # define loss function
    g_criterion = torch.nn.NLLLoss(size_average=False,
                                   ignore_index=dataset.dst_dict.pad(),
                                   reduce=True)
    d_criterion = torch.nn.BCEWithLogitsLoss()
    pg_criterion = PGLoss(ignore_index=dataset.dst_dict.pad(),
                          size_average=True,
                          reduce=True)

    # fix discriminator word embedding (as Wu et al. do)
    for p in discriminator.embed_src_tokens.parameters():
        p.requires_grad = False
    for p in discriminator.embed_trg_tokens.parameters():
        p.requires_grad = False

    # define optimizer
    g_optimizer = eval("torch.optim." + args.g_optimizer)(filter(
        lambda x: x.requires_grad, generator.parameters()),
                                                          args.g_learning_rate)

    d_optimizer = eval("torch.optim." + args.d_optimizer)(
        filter(lambda x: x.requires_grad, discriminator.parameters()),
        args.d_learning_rate,
        momentum=args.momentum,
        nesterov=True)

    # start joint training
    best_dev_loss = math.inf
    num_update = 0
    # main training loop
    for epoch_i in range(1, args.epochs + 1):
        logging.info("At {0}-th epoch.".format(epoch_i))

        # seed = args.seed + epoch_i
        # torch.manual_seed(seed)

        max_positions_train = (args.fixed_max_len, args.fixed_max_len)

        # Initialize dataloader, starting at batch_offset
        itr = dataset.train_dataloader(
            'train',
            max_tokens=args.max_tokens,
            max_sentences=args.joint_batch_size,
            max_positions=max_positions_train,
            # seed=seed,
            epoch=epoch_i,
            sample_without_replacement=args.sample_without_replacement,
            sort_by_source_size=(epoch_i <= args.curriculum),
            shard_id=args.distributed_rank,
            num_shards=args.distributed_world_size,
        )

        # reset meters
        for key, val in g_logging_meters.items():
            if val is not None:
                val.reset()
        for key, val in d_logging_meters.items():
            if val is not None:
                val.reset()

        # set training mode
        generator.train()
        discriminator.train()
        update_learning_rate(num_update, 8e4, args.g_learning_rate,
                             args.lr_shrink, g_optimizer)

        for i, sample in enumerate(itr):
            if use_cuda:
                # wrap input tensors in cuda tensors
                sample = utils.make_variable(sample, cuda=cuda)

            ## part I: use gradient policy method to train the generator

            # use policy gradient training when rand > 50%
            rand = random.random()
            if rand >= 0.5:
                # policy gradient training
                generator.decoder.is_testing = True
                sys_out_batch, prediction, _ = generator(sample)
                generator.decoder.is_testing = False
                with torch.no_grad():
                    n_i = sample['net_input']['src_tokens']
                    #print(f"net input:\n{n_i}, pred: \n{prediction}")
                    reward = discriminator(
                        sample['net_input']['src_tokens'],
                        prediction)  # dataset.dst_dict.pad())
                train_trg_batch = sample['target']
                #print(f"sys_out_batch: {sys_out_batch.shape}:\n{sys_out_batch}")
                pg_loss = pg_criterion(sys_out_batch, train_trg_batch, reward,
                                       use_cuda)
                # logging.debug("G policy gradient loss at batch {0}: {1:.3f}, lr={2}".format(i, pg_loss.item(), g_optimizer.param_groups[0]['lr']))
                g_optimizer.zero_grad()
                pg_loss.backward()
                torch.nn.utils.clip_grad_norm(generator.parameters(),
                                              args.clip_norm)
                g_optimizer.step()

                # oracle valid
                _, _, loss = generator(sample)
                sample_size = sample['target'].size(
                    0) if args.sentence_avg else sample['ntokens']
                logging_loss = loss.data / sample_size / math.log(2)
                g_logging_meters['train_loss'].update(logging_loss,
                                                      sample_size)
                logging.debug(
                    "G MLE loss at batch {0}: {1:.3f}, lr={2}".format(
                        i, g_logging_meters['train_loss'].avg,
                        g_optimizer.param_groups[0]['lr']))
            else:
                # MLE training
                #print(f"printing sample: \n{sample}")
                _, _, loss = generator(sample)
                sample_size = sample['target'].size(
                    0) if args.sentence_avg else sample['ntokens']
                nsentences = sample['target'].size(0)
                logging_loss = loss.data / sample_size / math.log(2)
                g_logging_meters['bsz'].update(nsentences)
                g_logging_meters['train_loss'].update(logging_loss,
                                                      sample_size)
                logging.debug(
                    "G MLE loss at batch {0}: {1:.3f}, lr={2}".format(
                        i, g_logging_meters['train_loss'].avg,
                        g_optimizer.param_groups[0]['lr']))
                g_optimizer.zero_grad()
                loss.backward()
                # all-reduce grads and rescale by grad_denom
                for p in generator.parameters():
                    if p.requires_grad:
                        p.grad.data.div_(sample_size)
                torch.nn.utils.clip_grad_norm(generator.parameters(),
                                              args.clip_norm)
                g_optimizer.step()
            num_update += 1

            # part II: train the discriminator
            bsz = sample['target'].size(0)
            src_sentence = sample['net_input']['src_tokens']
            # train with half human-translation and half machine translation

            true_sentence = sample['target']
            true_labels = Variable(
                torch.ones(sample['target'].size(0)).float())

            with torch.no_grad():
                generator.decoder.is_testing = True
                _, prediction, _ = generator(sample)
                generator.decoder.is_testing = False
            fake_sentence = prediction
            fake_labels = Variable(
                torch.zeros(sample['target'].size(0)).float())

            trg_sentence = torch.cat([true_sentence, fake_sentence], dim=0)
            labels = torch.cat([true_labels, fake_labels], dim=0)

            indices = np.random.permutation(2 * bsz)
            trg_sentence = trg_sentence[indices][:bsz]
            labels = labels[indices][:bsz]

            if use_cuda:
                labels = labels.cuda()

            disc_out = discriminator(src_sentence,
                                     trg_sentence)  #, dataset.dst_dict.pad())
            #print(f"disc out: {disc_out.shape}, labels: {labels.shape}")
            #print(f"labels: {labels}")
            d_loss = d_criterion(disc_out, labels.long())
            acc = torch.sum(torch.Sigmoid()
                            (disc_out).round() == labels).float() / len(labels)
            d_logging_meters['train_acc'].update(acc)
            d_logging_meters['train_loss'].update(d_loss)
            # logging.debug("D training loss {0:.3f}, acc {1:.3f} at batch {2}: ".format(d_logging_meters['train_loss'].avg,
            #                                                                            d_logging_meters['train_acc'].avg,
            #                                                                            i))
            d_optimizer.zero_grad()
            d_loss.backward()
            d_optimizer.step()

        # validation
        # set validation mode
        generator.eval()
        discriminator.eval()
        # Initialize dataloader
        max_positions_valid = (args.fixed_max_len, args.fixed_max_len)
        itr = dataset.eval_dataloader(
            'valid',
            max_tokens=args.max_tokens,
            max_sentences=args.joint_batch_size,
            max_positions=max_positions_valid,
            skip_invalid_size_inputs_valid_test=True,
            descending=True,  # largest batch first to warm the caching allocator
            shard_id=args.distributed_rank,
            num_shards=args.distributed_world_size,
        )

        # reset meters
        for key, val in g_logging_meters.items():
            if val is not None:
                val.reset()
        for key, val in d_logging_meters.items():
            if val is not None:
                val.reset()

        for i, sample in enumerate(itr):
            with torch.no_grad():
                if use_cuda:
                    sample['id'] = sample['id'].cuda()
                    sample['net_input']['src_tokens'] = sample['net_input'][
                        'src_tokens'].cuda()
                    sample['net_input']['src_lengths'] = sample['net_input'][
                        'src_lengths'].cuda()
                    sample['net_input']['prev_output_tokens'] = sample[
                        'net_input']['prev_output_tokens'].cuda()
                    sample['target'] = sample['target'].cuda()

                # generator validation
                _, _, loss = generator(sample)
                sample_size = sample['target'].size(
                    0) if args.sentence_avg else sample['ntokens']
                loss = loss / sample_size / math.log(2)
                g_logging_meters['valid_loss'].update(loss, sample_size)
                logging.debug("G dev loss at batch {0}: {1:.3f}".format(
                    i, g_logging_meters['valid_loss'].avg))

                # discriminator validation
                bsz = sample['target'].size(0)
                src_sentence = sample['net_input']['src_tokens']
                # train with half human-translation and half machine translation

                true_sentence = sample['target']
                true_labels = Variable(
                    torch.ones(sample['target'].size(0)).float())

                with torch.no_grad():
                    generator.decoder.is_testing = True
                    _, prediction, _ = generator(sample)
                    generator.decoder.is_testing = False
                fake_sentence = prediction
                fake_labels = Variable(
                    torch.zeros(sample['target'].size(0)).float())

                trg_sentence = torch.cat([true_sentence, fake_sentence], dim=0)
                labels = torch.cat([true_labels, fake_labels], dim=0)

                indices = np.random.permutation(2 * bsz)
                trg_sentence = trg_sentence[indices][:bsz]
                labels = labels[indices][:bsz]

                if use_cuda:
                    labels = labels.cuda()

                disc_out = discriminator(src_sentence, trg_sentence,
                                         dataset.dst_dict.pad())
                d_loss = d_criterion(disc_out, labels)
                acc = torch.sum(torch.Sigmoid()(disc_out).round() ==
                                labels).float() / len(labels)
                d_logging_meters['valid_acc'].update(acc)
                d_logging_meters['valid_loss'].update(d_loss)
                # logging.debug("D dev loss {0:.3f}, acc {1:.3f} at batch {2}".format(d_logging_meters['valid_loss'].avg,
                #                                                                     d_logging_meters['valid_acc'].avg, i))

        torch.save(generator,
                   open(
                       checkpoints_path + "joint_{0:.3f}.epoch_{1}.pt".format(
                           g_logging_meters['valid_loss'].avg, epoch_i), 'wb'),
                   pickle_module=dill)

        if g_logging_meters['valid_loss'].avg < best_dev_loss:
            best_dev_loss = g_logging_meters['valid_loss'].avg
            torch.save(generator,
                       open(checkpoints_path + "best_gmodel.pt", 'wb'),
                       pickle_module=dill)
Ejemplo n.º 22
0
                   d_loss.data[0], g_loss.data[0], real_score, fake_score,
                   time.time() - start_time))

            record(
                name='loss_d',
                value=d_loss.data.cpu().numpy(),
                data_type='plot',
            )

            record(
                name='loss_g',
                value=g_loss.data.cpu().numpy(),
                data_type='plot',
            )
            log_visdom()

        counter += 1
    d_cost_avg /= num_batch
    g_cost_avg /= num_batch

    # Save weights every 3 epoch
    if (current_epoch + 1) % 3 == 0:
        print('Epoch:', current_epoch, ' train_loss->',
              (d_cost_avg, g_cost_avg))
        torch.save(generator.state_dict(), './generator.pkl')
        torch.save(discriminator.state_dict(), './discriminator.pkl')
    predict(generator, validation_sample, current_epoch, DIR_TO_SAVE)
torch.save(generator.state_dict(), './generator.pkl')
torch.save(discriminator.state_dict(), './discriminator.pkl')
print('Done')
def train(opt):

    netG_A2B = Unet2(3, 3)
    netG_B2A = Unet2(3, 3)
    netD_A = Discriminator(3)
    netD_B = Discriminator(3)

    if opt.use_cuda:
        netG_A2B = netG_A2B.cuda()
        netG_B2A = netG_B2A.cuda()
        netD_A = netD_A.cuda()
        netD_B = netD_B.cuda()

    netG_A2B_optimizer = optimizer.Adam(params=netG_A2B.parameters(),
                                        lr=opt.lr,
                                        betas=(0.5, 0.999))
    netG_B2A_optimizer = optimizer.Adam(params=netG_B2A.parameters(),
                                        lr=opt.lr,
                                        betas=(0.5, 0.999))
    netD_A_optimizer = optimizer.Adam(params=netD_A.parameters(),
                                      lr=opt.lr,
                                      betas=(0.5, 0.999))
    netD_B_optimizer = optimizer.Adam(params=netD_B.parameters(),
                                      lr=opt.lr,
                                      betas=(0.5, 0.999))

    optimizers = dict()
    optimizers['G1'] = netG_A2B_optimizer
    optimizers['G2'] = netG_B2A_optimizer
    optimizers['D1'] = netD_A_optimizer
    optimizers['D2'] = netD_B_optimizer

    # Dataset loader
    transforms_ = [
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ]

    tarindataloader = DataLoader(ImageDataset(opt.dataroot,
                                              transforms_=transforms_,
                                              unaligned=True),
                                 batch_size=opt.batchSize,
                                 shuffle=True)

    #writer
    writer = SummaryWriter(opt.log_dir)

    for epoch in range(0, opt.n_epochs):
        for ii, batch in enumerate(tarindataloader):
            # Set model input
            real_A = Variable(batch['A'])
            real_B = Variable(batch['B'])

            if opt.use_cuda:
                real_A = real_A.cuda()
                real_B = real_B.cuda()

            train_one_step(use_cuda=opt.use_cuda,
                           netG_A2B=netG_A2B,
                           netG_B2A=netG_B2A,
                           netD_A=netD_A,
                           netD_B=netD_B,
                           real_A=real_A,
                           real_B=real_B,
                           optimizers=optimizers,
                           iteration=ii,
                           writer=writer)

            print("\nEpoch: %s Batch: %s" % (epoch, ii))

    writer.export_scalars_to_json("./all_scalars.json")
    writer.close()
    torch.save(netG_A2B.state_dict(),
               os.path.join(opt.save_dir, '%s' % "netG_A2B"))
    torch.save(netG_B2A.state_dict(),
               os.path.join(opt.save_dir, '%s' % "netG_B2A"))
    torch.save(netD_A.state_dict(), os.path.join(opt.save_dir,
                                                 '%s' % "netD_A"))
    torch.save(netD_B.state_dict(), os.path.join(opt.save_dir,
                                                 '%s' % "netD_B"))
Ejemplo n.º 24
0
                x = x.to(device)
                G_recon, _ = netG(x)
                result = torch.cat((x[0], G_recon[0]), 2)
                path = os.path.join(
                    name + '_results', 'Transfer',
                    str(epoch + 1) + '_epoch_' + name + '_test_' + str(n + 1) +
                    '.png')
                plt.imsave(path,
                           (result.cpu().numpy().transpose(1, 2, 0) + 1) / 2)
                if n == 4:
                    break

            torch.save(netG.state_dict(),
                       os.path.join(name + '_results', 'generator_latest.pkl'))
            torch.save(
                netD.state_dict(),
                os.path.join(name + '_results', 'discriminator_latest.pkl'))

total_time = time.time() - start_time
train_hist['total_time'].append(total_time)

print("Avg one epoch time: %.2f, total %d epochs time: %.2f" % (torch.mean(
    torch.FloatTensor(train_hist['per_epoch_time'])), num_epochs, total_time))
print("Training finish!... save training results")

torch.save(netG.state_dict(),
           os.path.join(name + '_results', 'generator_param.pkl'))
torch.save(netD.state_dict(),
           os.path.join(name + '_results', 'discriminator_param.pkl'))
with open(os.path.join(name + '_results', 'train_hist.pkl'), 'wb') as f:
    pickle.dump(train_hist, f)
Ejemplo n.º 25
0
            output = model_d(g_out, onehotv)
            errG = criterion(output, labelv)
            optim_g.zero_grad()
            errG.backward()
            optim_g.step()
            
            d_loss += errD.data[0]
            g_loss += errG.data[0]
            if batch_idx % args.print_every == 0:
                print(
                "\t{} ({} / {}) mean D(fake) = {:.4f}, mean D(real) = {:.4f}".
                    format(epoch_idx, batch_idx, len(train_loader), fakeD_mean,
                        realD_mean))

                g_out = model_g(fixed_noise, fixed_labels).data.view(
                    SAMPLE_SIZE, 1, 28,28).cpu()
                save_image(g_out,
                    '{}/{}_{}.png'.format(
                        args.samples_dir, epoch_idx, batch_idx))


        print('Epoch {} - D loss = {:.4f}, G loss = {:.4f}'.format(epoch_idx,
            d_loss, g_loss))
        if epoch_idx % args.save_every == 0:
            torch.save({'state_dict': model_d.state_dict()},
                        '{}/model_d_epoch_{}.pth'.format(
                            args.save_dir, epoch_idx))
            torch.save({'state_dict': model_g.state_dict()},
                        '{}/model_g_epoch_{}.pth'.format(
                            args.save_dir, epoch_idx))
Ejemplo n.º 26
0
class trainer(object):
    def __init__(self, cfg):
        self.cfg = cfg
        self.OldLabel_generator = U_Net(in_ch=cfg.DATASET.N_CLASS,
                                        out_ch=cfg.DATASET.N_CLASS,
                                        side='out')
        self.Image_generator = U_Net(in_ch=3,
                                     out_ch=cfg.DATASET.N_CLASS,
                                     side='in')
        self.discriminator = Discriminator(cfg.DATASET.N_CLASS + 3,
                                           cfg.DATASET.IMGSIZE,
                                           patch=True)

        self.criterion_G = GeneratorLoss(cfg.LOSS.LOSS_WEIGHT[0],
                                         cfg.LOSS.LOSS_WEIGHT[1],
                                         cfg.LOSS.LOSS_WEIGHT[2],
                                         ignore_index=cfg.LOSS.IGNORE_INDEX)
        self.criterion_D = DiscriminatorLoss()

        train_dataset = BaseDataset(cfg, split='train')
        valid_dataset = BaseDataset(cfg, split='val')
        self.train_dataloader = data.DataLoader(
            train_dataset,
            batch_size=cfg.DATASET.BATCHSIZE,
            num_workers=8,
            shuffle=True,
            drop_last=True)
        self.valid_dataloader = data.DataLoader(
            valid_dataset,
            batch_size=cfg.DATASET.BATCHSIZE,
            num_workers=8,
            shuffle=True,
            drop_last=True)

        self.ckpt_outdir = os.path.join(cfg.TRAIN.OUTDIR, 'checkpoints')
        if not os.path.isdir(self.ckpt_outdir):
            os.mkdir(self.ckpt_outdir)
        self.val_outdir = os.path.join(cfg.TRAIN.OUTDIR, 'val')
        if not os.path.isdir(self.val_outdir):
            os.mkdir(self.val_outdir)
        self.start_epoch = cfg.TRAIN.RESUME
        self.n_epoch = cfg.TRAIN.N_EPOCH

        self.optimizer_G = torch.optim.Adam(
            [{
                'params': self.OldLabel_generator.parameters()
            }, {
                'params': self.Image_generator.parameters()
            }],
            lr=cfg.OPTIMIZER.G_LR,
            betas=(cfg.OPTIMIZER.BETA1, cfg.OPTIMIZER.BETA2),
            # betas=(cfg.OPTIMIZER.BETA1, cfg.OPTIMIZER.BETA2),
            weight_decay=cfg.OPTIMIZER.WEIGHT_DECAY)

        self.optimizer_D = torch.optim.Adam(
            [{
                'params': self.discriminator.parameters(),
                'initial_lr': cfg.OPTIMIZER.D_LR
            }],
            lr=cfg.OPTIMIZER.D_LR,
            betas=(cfg.OPTIMIZER.BETA1, cfg.OPTIMIZER.BETA2),
            # betas=(cfg.OPTIMIZER.BETA1, cfg.OPTIMIZER.BETA2),
            weight_decay=cfg.OPTIMIZER.WEIGHT_DECAY)

        iter_per_epoch = len(train_dataset) // cfg.DATASET.BATCHSIZE
        lambda_poly = lambda iters: pow(
            (1.0 - iters / (cfg.TRAIN.N_EPOCH * iter_per_epoch)), 0.9)
        self.scheduler_G = torch.optim.lr_scheduler.LambdaLR(
            self.optimizer_G,
            lr_lambda=lambda_poly,
        )
        # last_epoch=(self.start_epoch+1)*iter_per_epoch)
        self.scheduler_D = torch.optim.lr_scheduler.LambdaLR(
            self.optimizer_D,
            lr_lambda=lambda_poly,
        )
        # last_epoch=(self.start_epoch+1)*iter_per_epoch)

        self.logger = logger(cfg.TRAIN.OUTDIR, name='train')
        self.running_metrics = runningScore(n_classes=cfg.DATASET.N_CLASS)

        if self.start_epoch >= 0:
            self.OldLabel_generator.load_state_dict(
                torch.load(
                    os.path.join(cfg.TRAIN.OUTDIR, 'checkpoints',
                                 '{}epoch.pth'.format(
                                     self.start_epoch)))['model_G_N'])
            self.Image_generator.load_state_dict(
                torch.load(
                    os.path.join(cfg.TRAIN.OUTDIR, 'checkpoints',
                                 '{}epoch.pth'.format(
                                     self.start_epoch)))['model_G_I'])
            self.discriminator.load_state_dict(
                torch.load(
                    os.path.join(cfg.TRAIN.OUTDIR, 'checkpoints',
                                 '{}epoch.pth'.format(
                                     self.start_epoch)))['model_D'])
            self.optimizer_G.load_state_dict(
                torch.load(
                    os.path.join(cfg.TRAIN.OUTDIR, 'checkpoints',
                                 '{}epoch.pth'.format(
                                     self.start_epoch)))['optimizer_G'])
            self.optimizer_D.load_state_dict(
                torch.load(
                    os.path.join(cfg.TRAIN.OUTDIR, 'checkpoints',
                                 '{}epoch.pth'.format(
                                     self.start_epoch)))['optimizer_D'])

            log = "Using the {}th checkpoint".format(self.start_epoch)
            self.logger.info(log)
        self.Image_generator = self.Image_generator.cuda()
        self.OldLabel_generator = self.OldLabel_generator.cuda()
        self.discriminator = self.discriminator.cuda()
        self.criterion_G = self.criterion_G.cuda()
        self.criterion_D = self.criterion_D.cuda()

    def train(self):
        all_train_iter_total_loss = []
        all_train_iter_corr_loss = []
        all_train_iter_recover_loss = []
        all_train_iter_change_loss = []
        all_train_iter_gan_loss_gen = []
        all_train_iter_gan_loss_dis = []
        all_val_epo_iou = []
        all_val_epo_acc = []
        iter_num = [0]
        epoch_num = []
        num_batches = len(self.train_dataloader)

        for epoch_i in range(self.start_epoch + 1, self.n_epoch):
            iter_total_loss = AverageTracker()
            iter_corr_loss = AverageTracker()
            iter_recover_loss = AverageTracker()
            iter_change_loss = AverageTracker()
            iter_gan_loss_gen = AverageTracker()
            iter_gan_loss_dis = AverageTracker()
            batch_time = AverageTracker()
            tic = time.time()

            # train
            self.OldLabel_generator.train()
            self.Image_generator.train()
            self.discriminator.train()
            for i, meta in enumerate(self.train_dataloader):

                image, old_label, new_label = meta[0].cuda(), meta[1].cuda(
                ), meta[2].cuda()
                recover_pred, feats = self.OldLabel_generator(
                    label2onehot(old_label, self.cfg.DATASET.N_CLASS))
                corr_pred = self.Image_generator(image, feats)

                # -------------------
                # Train Discriminator
                # -------------------
                self.discriminator.set_requires_grad(True)
                self.optimizer_D.zero_grad()

                fake_sample = torch.cat((image, corr_pred), 1).detach()
                real_sample = torch.cat(
                    (image, label2onehot(new_label, cfg.DATASET.N_CLASS)), 1)

                score_fake_d = self.discriminator(fake_sample)
                score_real = self.discriminator(real_sample)

                gan_loss_dis = self.criterion_D(pred_score=score_fake_d,
                                                real_score=score_real)
                gan_loss_dis.backward()
                self.optimizer_D.step()
                self.scheduler_D.step()

                # ---------------
                # Train Generator
                # ---------------
                self.discriminator.set_requires_grad(False)
                self.optimizer_G.zero_grad()

                score_fake = self.discriminator(
                    torch.cat((image, corr_pred), 1))

                total_loss, corr_loss, recover_loss, change_loss, gan_loss_gen = self.criterion_G(
                    corr_pred, recover_pred, score_fake, old_label, new_label)

                total_loss.backward()
                self.optimizer_G.step()
                self.scheduler_G.step()

                iter_total_loss.update(total_loss.item())
                iter_corr_loss.update(corr_loss.item())
                iter_recover_loss.update(recover_loss.item())
                iter_change_loss.update(change_loss.item())
                iter_gan_loss_gen.update(gan_loss_gen.item())
                iter_gan_loss_dis.update(gan_loss_dis.item())
                batch_time.update(time.time() - tic)
                tic = time.time()

                log = '{}: Epoch: [{}][{}/{}], Time: {:.2f}, ' \
                      'Total Loss: {:.6f}, Corr Loss: {:.6f}, Recover Loss: {:.6f}, Change Loss: {:.6f}, GAN_G Loss: {:.6f}, GAN_D Loss: {:.6f}'.format(
                    datetime.now(), epoch_i, i, num_batches, batch_time.avg,
                    total_loss.item(), corr_loss.item(), recover_loss.item(), change_loss.item(), gan_loss_gen.item(), gan_loss_dis.item())
                print(log)

                if (i + 1) % 10 == 0:
                    all_train_iter_total_loss.append(iter_total_loss.avg)
                    all_train_iter_corr_loss.append(iter_corr_loss.avg)
                    all_train_iter_recover_loss.append(iter_recover_loss.avg)
                    all_train_iter_change_loss.append(iter_change_loss.avg)
                    all_train_iter_gan_loss_gen.append(iter_gan_loss_gen.avg)
                    all_train_iter_gan_loss_dis.append(iter_gan_loss_dis.avg)
                    iter_total_loss.reset()
                    iter_corr_loss.reset()
                    iter_recover_loss.reset()
                    iter_change_loss.reset()
                    iter_gan_loss_gen.reset()
                    iter_gan_loss_dis.reset()

                    vis.line(X=np.column_stack(
                        np.repeat(np.expand_dims(iter_num, 0), 6, axis=0)),
                             Y=np.column_stack((all_train_iter_total_loss,
                                                all_train_iter_corr_loss,
                                                all_train_iter_recover_loss,
                                                all_train_iter_change_loss,
                                                all_train_iter_gan_loss_gen,
                                                all_train_iter_gan_loss_dis)),
                             opts={
                                 'legend': [
                                     'total_loss', 'corr_loss', 'recover_loss',
                                     'change_loss', 'gan_loss_gen',
                                     'gan_loss_dis'
                                 ],
                                 'linecolor':
                                 np.array([[255, 0, 0], [0, 255, 0],
                                           [0, 0, 255], [255, 255, 0],
                                           [0, 255, 255], [255, 0, 255]]),
                                 'title':
                                 'Train loss of generator and discriminator'
                             },
                             win='Train loss of generator and discriminator')
                    iter_num.append(iter_num[-1] + 1)

            # eval
            self.OldLabel_generator.eval()
            self.Image_generator.eval()
            self.discriminator.eval()
            with torch.no_grad():
                for j, meta in enumerate(self.valid_dataloader):
                    image, old_label, new_label = meta[0].cuda(), meta[1].cuda(
                    ), meta[2].cuda()
                    recover_pred, feats = self.OldLabel_generator(
                        label2onehot(old_label, self.cfg.DATASET.N_CLASS))
                    corr_pred = self.Image_generator(image, feats)
                    preds = np.argmax(corr_pred.cpu().detach().numpy().copy(),
                                      axis=1)
                    target = new_label.cpu().detach().numpy().copy()
                    self.running_metrics.update(target, preds)

                    if j == 0:
                        color_map1 = gen_color_map(preds[0, :]).astype(
                            np.uint8)
                        color_map2 = gen_color_map(preds[1, :]).astype(
                            np.uint8)
                        color_map = cv2.hconcat([color_map1, color_map2])
                        cv2.imwrite(
                            os.path.join(
                                self.val_outdir, '{}epoch*{}*{}.png'.format(
                                    epoch_i, meta[3][0], meta[3][1])),
                            color_map)

            score = self.running_metrics.get_scores()
            oa = score['Overall Acc: \t']
            precision = score['Precision: \t'][1]
            recall = score['Recall: \t'][1]
            iou = score['Class IoU: \t'][1]
            miou = score['Mean IoU: \t']
            self.running_metrics.reset()

            epoch_num.append(epoch_i)
            all_val_epo_acc.append(oa)
            all_val_epo_iou.append(miou)
            vis.line(X=np.column_stack(
                np.repeat(np.expand_dims(epoch_num, 0), 2, axis=0)),
                     Y=np.column_stack((all_val_epo_acc, all_val_epo_iou)),
                     opts={
                         'legend':
                         ['val epoch Overall Acc', 'val epoch Mean IoU'],
                         'linecolor': np.array([[255, 0, 0], [0, 255, 0]]),
                         'title': 'Validate Accuracy and IoU'
                     },
                     win='validate Accuracy and IoU')

            log = '{}: Epoch Val: [{}], ACC: {:.2f}, Recall: {:.2f}, mIoU: {:.4f}' \
                .format(datetime.now(), epoch_i, oa, recall, miou)
            self.logger.info(log)

            state = {
                'epoch': epoch_i,
                "acc": oa,
                "recall": recall,
                "iou": miou,
                'model_G_N': self.OldLabel_generator.state_dict(),
                'model_G_I': self.Image_generator.state_dict(),
                'model_D': self.discriminator.state_dict(),
                'optimizer_G': self.optimizer_G.state_dict(),
                'optimizer_D': self.optimizer_D.state_dict()
            }
            save_path = os.path.join(self.cfg.TRAIN.OUTDIR, 'checkpoints',
                                     '{}epoch.pth'.format(epoch_i))
            torch.save(state, save_path)
Ejemplo n.º 27
0
class GAIL(PPO):
    def __init__(
        self,
        state_dimension: Tuple,
        action_space: int,
        save_path: Path,
        hyp: HyperparametersGAIL,
        policy_params: namedtuple,
        discriminator_params: DiscrimParams,
        param_plot_num: int,
        ppo_type: str = "clip",
        adv_type: str = "monte_carlo",
        max_plot_size: int = 10000,
        policy_burn_in: int = 0,
        verbose: bool = False,
    ):
        self.discrim_net_save = save_path / "GAIL_discrim.pth"

        self.discriminator = Discriminator(
            state_dimension,
            action_space,
            discriminator_params,
        ).to(device)

        self.discrim_optim = torch.optim.Adam(self.discriminator.parameters(),
                                              lr=hyp.discrim_lr)
        gail_plots = [("discrim_loss", np.float64)]

        super(GAIL, self).__init__(
            state_dimension,
            action_space,
            save_path,
            hyp,
            policy_params,
            param_plot_num,
            ppo_type,
            advantage_type=adv_type,
            neural_net_save=f"GAIL-{adv_type}",
            max_plot_size=max_plot_size,
            discrim_params=discriminator_params,
            policy_burn_in=policy_burn_in,
            verbose=verbose,
            additional_plots=gail_plots,
        )

        self.discrim_loss = torch.nn.NLLLoss()

    def update(self, buffer: GAILExperienceBuffer, ep_num: int):
        # Update discriminator
        state_actions = buffer.state_actions.to(device)
        num_learner_samples = buffer.get_length()
        expert_samples_per_epoch = int(
            (state_actions.size()[0] - num_learner_samples) /
            self.hyp.num_discrim_epochs)
        for epoch in range(self.hyp.num_discrim_epochs):
            step_state_actions = torch.cat(
                (
                    state_actions[:num_learner_samples],
                    state_actions[
                        num_learner_samples +
                        epoch * expert_samples_per_epoch:num_learner_samples +
                        (epoch + 1) * expert_samples_per_epoch],
                ),
                dim=0,
            )
            discrim_logprobs = self.discriminator.logprobs(
                step_state_actions).to(device)
            loss = self.discrim_loss(
                input=discrim_logprobs,
                target=buffer.discrim_labels.type(torch.long),
            )
            plotted_loss = loss.detach().cpu().numpy()
            self.plotter.record_data({"discrim_loss": plotted_loss})

            if self.verbose:
                print(
                    f"Learner labels {buffer.discrim_labels[:num_learner_samples].mean()}: "
                    f"\t{torch.exp(discrim_logprobs[:num_learner_samples]).t()[1].mean()}"
                )
                print(
                    f"Expert labels {buffer.discrim_labels[num_learner_samples:].mean()}: "
                    f"\t\t{torch.exp(discrim_logprobs[num_learner_samples:]).t()[1].mean()}"
                )

            self.discrim_optim.zero_grad()
            loss.backward()
            self.discrim_optim.step()
            self.record_nn_params()

        # Update policy
        buffer.rewards = list(
            np.squeeze(
                self.discriminator.logprob_expert(
                    state_actions[:num_learner_samples]).float().detach().cpu(
                    ).numpy()))
        if self.verbose:
            print(
                "----------------------------------------------------------------------"
            )
        super(GAIL, self).update(buffer, ep_num)

    def record_nn_params(self):
        """Gets randomly sampled actor NN parameters from 1st layer."""
        names, x_params, y_params = self.plotter.get_param_plot_nums()
        sampled_params = {}

        for name, x_param, y_param in zip(names, x_params, y_params):
            network_to_sample = (self.discriminator
                                 if name[:7] == "discrim" else self.policy)
            sampled_params[name] = (
                network_to_sample.state_dict()[name].cpu().numpy()[x_param,
                                                                   y_param])
        self.plotter.record_data(sampled_params)

    def _save_network(self):
        super(GAIL, self)._save_network()
        torch.save(self.discriminator.state_dict(), f"{self.discrim_net_save}")

    def _load_network(self):
        super(GAIL, self)._load_network()
        print(
            f"Loading discriminator network saved at: {self.discrim_net_save}")
        net = torch.load(self.discrim_net_save, map_location=device)
        self.discriminator.load_state_dict(net)
Ejemplo n.º 28
0
class CycleGAN(AlignmentModel):
    """This class implements the alignment model for GAN networks with two generators and two discriminators
    (cycle GAN). For description of the implemented functions, refer to the alignment model."""
    def __init__(self,
                 device,
                 config,
                 generator_a=None,
                 generator_b=None,
                 discriminator_a=None,
                 discriminator_b=None):
        """Initialize two new generators and two discriminators from the config or use pre-trained ones and create Adam
        optimizers for all models."""
        super().__init__(device, config)
        self.epoch_losses = [0., 0., 0., 0.]

        if generator_a is None:
            generator_a_conf = dict(
                dim_1=config['dim_b'],
                dim_2=config['dim_a'],
                layer_number=config['generator_layers'],
                layer_expansion=config['generator_expansion'],
                initialize_generator=config['initialize_generator'],
                norm=config['gen_norm'],
                batch_norm=config['gen_batch_norm'],
                activation=config['gen_activation'],
                dropout=config['gen_dropout'])
            self.generator_a = Generator(generator_a_conf, device)
            self.generator_a.to(device)
        else:
            self.generator_a = generator_a
        if 'optimizer' in config:
            self.optimizer_g_a = OPTIMIZERS[config['optimizer']](
                self.generator_a.parameters(), config['learning_rate'])
        elif 'optimizer_default' in config:
            if config['optimizer_default'] == 'sgd':
                self.optimizer_g_a = OPTIMIZERS[config['optimizer_default']](
                    self.generator_a.parameters(), config['learning_rate'])
            else:
                self.optimizer_g_a = OPTIMIZERS[config['optimizer_default']](
                    self.generator_a.parameters())
        else:
            self.optimizer_g_a = torch.optim.Adam(
                self.generator_a.parameters(), config['learning_rate'])

        if generator_b is None:
            generator_b_conf = dict(
                dim_1=config['dim_a'],
                dim_2=config['dim_b'],
                layer_number=config['generator_layers'],
                layer_expansion=config['generator_expansion'],
                initialize_generator=config['initialize_generator'],
                norm=config['gen_norm'],
                batch_norm=config['gen_batch_norm'],
                activation=config['gen_activation'],
                dropout=config['gen_dropout'])
            self.generator_b = Generator(generator_b_conf, device)
            self.generator_b.to(device)
        else:
            self.generator_b = generator_b
        if 'optimizer' in config:
            self.optimizer_g_b = OPTIMIZERS[config['optimizer']](
                self.generator_b.parameters(), config['learning_rate'])
        elif 'optimizer_default' in config:
            if config['optimizer_default'] == 'sgd':
                self.optimizer_g_b = OPTIMIZERS[config['optimizer_default']](
                    self.generator_b.parameters(), config['learning_rate'])
            else:
                self.optimizer_g_b = OPTIMIZERS[config['optimizer_default']](
                    self.generator_b.parameters())
        else:
            self.optimizer_g_b = torch.optim.Adam(
                self.generator_b.parameters(), config['learning_rate'])

        if discriminator_a is None:
            discriminator_a_conf = dict(
                dim=config['dim_a'],
                layer_number=config['discriminator_layers'],
                layer_expansion=config['discriminator_expansion'],
                batch_norm=config['disc_batch_norm'],
                activation=config['disc_activation'],
                dropout=config['disc_dropout'])
            self.discriminator_a = Discriminator(discriminator_a_conf, device)
            self.discriminator_a.to(device)
        else:
            self.discriminator_a = discriminator_a
        if 'optimizer' in config:
            self.optimizer_d_a = OPTIMIZERS[config['optimizer']](
                self.discriminator_a.parameters(), config['learning_rate'])
        elif 'optimizer_default' in config:
            if config['optimizer_default'] == 'sgd':
                self.optimizer_d_a = OPTIMIZERS[config['optimizer_default']](
                    self.discriminator_a.parameters(), config['learning_rate'])
            else:
                self.optimizer_d_a = OPTIMIZERS[config['optimizer_default']](
                    self.discriminator_a.parameters())
        else:
            self.optimizer_d_a = torch.optim.Adam(
                self.discriminator_a.parameters(), config['learning_rate'])

        if discriminator_b is None:
            discriminator_b_conf = dict(
                dim=config['dim_b'],
                layer_number=config['discriminator_layers'],
                layer_expansion=config['discriminator_expansion'],
                batch_norm=config['disc_batch_norm'],
                activation=config['disc_activation'],
                dropout=config['disc_dropout'])
            self.discriminator_b = Discriminator(discriminator_b_conf, device)
            self.discriminator_b.to(device)
        else:
            self.discriminator_b = discriminator_b
        if 'optimizer' in config:
            self.optimizer_d_b = OPTIMIZERS[config['optimizer']](
                self.discriminator_b.parameters(), config['learning_rate'])
        elif 'optimizer_default' in config:
            if config['optimizer_default'] == 'sgd':
                self.optimizer_d_b = OPTIMIZERS[config['optimizer_default']](
                    self.discriminator_b.parameters(), config['learning_rate'])
            else:
                self.optimizer_d_b = OPTIMIZERS[config['optimizer_default']](
                    self.discriminator_b.parameters())
        else:
            self.optimizer_d_b = torch.optim.Adam(
                self.discriminator_b.parameters(), config['learning_rate'])

    def train(self):
        self.generator_a.train()
        self.generator_b.train()
        self.discriminator_a.train()
        self.discriminator_b.train()

    def eval(self):
        self.generator_a.eval()
        self.generator_b.eval()
        self.discriminator_a.eval()
        self.discriminator_b.eval()

    def zero_grad(self):
        self.optimizer_g_a.zero_grad()
        self.optimizer_g_b.zero_grad()
        self.optimizer_d_a.zero_grad()
        self.optimizer_d_b.zero_grad()

    def optimize_all(self):
        self.optimizer_g_a.step()
        self.optimizer_g_b.step()
        self.optimizer_d_a.step()
        self.optimizer_d_b.step()

    def optimize_generator(self):
        """Do the optimization step only for generators (e.g. when training generators and discriminators separately or
        in turns)."""
        self.optimizer_g_a.step()
        self.optimizer_g_b.step()

    def optimize_discriminator(self):
        """Do the optimization step only for discriminators (e.g. when training generators and discriminators separately
        or in turns)."""
        self.optimizer_d_a.step()
        self.optimizer_d_b.step()

    def change_lr(self, factor):
        self.current_lr = self.current_lr * factor
        for param_group in self.optimizer_g_a.param_groups:
            param_group['lr'] = self.current_lr
        for param_group in self.optimizer_g_b.param_groups:
            param_group['lr'] = self.current_lr

    def update_losses_batch(self, *losses):
        loss_g_a, loss_g_b, loss_d_a, loss_d_b = losses
        self.epoch_losses[0] += loss_g_a
        self.epoch_losses[1] += loss_g_b
        self.epoch_losses[2] += loss_d_a
        self.epoch_losses[3] += loss_d_b

    def complete_epoch(self, epoch_metrics):
        self.metrics.append(epoch_metrics + [sum(self.epoch_losses)])
        self.losses.append(self.epoch_losses)
        self.epoch_losses = [0., 0., 0., 0.]

    def print_epoch_info(self):
        print(
            f"{len(self.metrics)} ### {self.losses[-1][0]:.2f} - {self.losses[-1][1]:.2f} "
            f"- {self.losses[-1][2]:.2f} - {self.losses[-1][3]:.2f} ### {self.metrics[-1]}"
        )

    def copy_model(self):
        self.model_copy = deepcopy(self.generator_a.state_dict()), deepcopy(self.generator_b.state_dict()),\
                          deepcopy(self.discriminator_a.state_dict()), deepcopy(self.discriminator_b.state_dict())

    def restore_model(self):
        self.generator_a.load_state_dict(self.model_copy[0])
        self.generator_b.load_state_dict(self.model_copy[1])
        self.discriminator_a.load_state_dict(self.model_copy[2])
        self.discriminator_b.load_state_dict(self.model_copy[3])

    def export_model(self, test_results, description=None):
        if description is None:
            description = f"CycleGAN_{self.config['evaluation']}_{self.config['subset']}"
        export_cyclegan_alignment(description, self.config, self.generator_a,
                                  self.generator_b, self.discriminator_a,
                                  self.discriminator_b, self.metrics)
        save_alignment_test_results(test_results, description)
        print(f"Saved model to directory {description}.")

    @classmethod
    def load_model(cls, name, device):
        generator_a, generator_b, discriminator_a, discriminator_b, config = load_cyclegan_alignment(
            name, device)
        model = cls(device, config, generator_a, generator_b, discriminator_a,
                    discriminator_b)
        return model
Ejemplo n.º 29
0
class GAIL:
    def __init__(self,
                 exp_dir,
                 exp_thresh,
                 state_dim,
                 action_dim,
                 learn_rate,
                 betas,
                 _device,
                 _gamma,
                 load_weights=False):
        """
            exp_dir : directory containing the expert episodes
         exp_thresh : parameter to control number of episodes to load 
                      as expert based on returns (lower means more episodes)
          state_dim : dimesnion of state 
         action_dim : dimesnion of action
         learn_rate : learning rate for optimizer 
            _device : GPU or cpu
            _gamma  : discount factor
     _load_weights  : load weights from directory
        """

        # storing runtime device
        self.device = _device

        # discount factor
        self.gamma = _gamma

        # Expert trajectory
        self.expert = ExpertTrajectories(exp_dir, exp_thresh, gamma=self.gamma)

        # Defining the actor and its optimizer
        self.actor = ActorNetwork(state_dim).to(self.device)
        self.optim_actor = torch.optim.Adam(self.actor.parameters(),
                                            lr=learn_rate,
                                            betas=betas)

        # Defining the discriminator and its optimizer
        self.disc = Discriminator(state_dim, action_dim).to(self.device)
        self.optim_disc = torch.optim.Adam(self.disc.parameters(),
                                           lr=learn_rate,
                                           betas=betas)

        if not load_weights:
            self.actor.apply(init_weights)
            self.disc.apply(init_weights)
        else:
            self.load()

        # Loss function crtiterion
        self.criterion = torch.nn.BCELoss()

    def get_action(self, state):
        """
            obtain action for a given state using actor network 
        """
        state = torch.tensor(state, dtype=torch.float,
                             device=self.device).view(1, -1)
        return self.actor(state).cpu().data.numpy().flatten()

    def update(self, n_iter, batch_size=100):
        """
            train discriminator and actor for mini-batch
        """
        # memory to store
        disc_losses = np.zeros(n_iter, dtype=np.float)
        act_losses = np.zeros(n_iter, dtype=np.float)

        for i in range(n_iter):

            # Get expert state and actions batch
            exp_states, exp_actions = self.expert.sample(batch_size)
            exp_states = torch.FloatTensor(exp_states).to(self.device)
            exp_actions = torch.FloatTensor(exp_actions).to(self.device)

            # Get state, and actions using actor
            states, _ = self.expert.sample(batch_size)
            states = torch.FloatTensor(states).to(self.device)
            actions = self.actor(states)
            '''
                train the discriminator
            '''
            self.optim_disc.zero_grad()

            # label tensors
            exp_labels = torch.full((batch_size, 1), 1, device=self.device)
            policy_labels = torch.full((batch_size, 1), 0, device=self.device)

            # with expert transitions
            prob_exp = self.disc(exp_states, exp_actions)
            exp_loss = self.criterion(prob_exp, exp_labels)

            # with policy actor transitions
            prob_policy = self.disc(states, actions.detach())
            policy_loss = self.criterion(prob_policy, policy_labels)

            # use backprop
            disc_loss = exp_loss + policy_loss
            disc_losses[i] = disc_loss.mean().item()

            disc_loss.backward()
            self.optim_disc.step()
            '''
                train the actor
            '''
            self.optim_actor.zero_grad()
            loss_actor = -self.disc(states, actions)
            act_losses[i] = loss_actor.mean().detach().item()

            loss_actor.mean().backward()
            self.optim_actor.step()

        print("Finished training minibatch")

        return act_losses, disc_losses

    def save(
            self,
            directory='/home/aman/Programming/RL-Project/Deterministic-GAIL/weights',
            name='GAIL'):
        torch.save(self.actor.state_dict(),
                   '{}/{}_actor.pth'.format(directory, name))
        torch.save(self.disc.state_dict(),
                   '{}/{}_discriminator.pth'.format(directory, name))

    def load(
            self,
            directory='/home/aman/Programming/RL-Project/Deterministic-GAIL/weights',
            name='GAIL'):
        print(os.getcwd())
        self.actor.load_state_dict(
            torch.load('{}/{}_actor.pth'.format(directory, name)))
        self.disc.load_state_dict(
            torch.load('{}/{}_discriminator.pth'.format(directory, name)))

    def set_mode(self, mode="train"):

        if mode == "train":
            self.actor.train()
            self.disc.train()
        else:
            self.actor.eval()
            self.disc.eval()
Ejemplo n.º 30
0
def train_gan(args):

    # prepare dataloader
    dataloader = create_data_loader(args)

    # set up device
    device = torch.device('cuda:0' if (
        torch.cuda.is_available() and args.ngpu > 0) else 'cpu')

    # Create & setup generator
    netG = Generator(args).to(device)

    # handle multiple gpus
    if (device.type == 'cuda' and args.ngpu > 1):
        netG = nn.DataParallel(netG, list(range(args.ngpu)))

    # load from checkpoint if available
    if args.netG:
        netG.load_state_dict(torch.load(args.netG))

    # initialize network with random weights
    else:
        netG.apply(weights_init)

    # Create & setup discriminator
    netD = Discriminator(args).to(device)

    # handle multiple gpus
    if (device.type == 'cuda' and args.ngpu > 1):
        netD = nn.DataParallel(netD, list(range(args.ngpu)))

    # load from checkpoint if available
    if args.netD:
        netD.load_state_dict(torch.load(args.netD))

    # initialize network with random weights
    else:
        netD.apply(weights_init)

    # setup up loss & optimizers
    criterion = nn.BCELoss()
    optimizerG = optim.Adam(netG.parameters(),
                            lr=args.lr,
                            betas=(args.beta1, 0.999))
    optimizerD = optim.Adam(netD.parameters(),
                            lr=args.lr,
                            betas=(args.beta1, 0.999))

    # For input of generator in testing
    fixed_noise = torch.randn(64, args.nz, 1, 1, device=device)

    # convention for training
    real_label = 1
    fake_label = 0

    # training data for later analysis
    img_list = []
    G_losses = []
    D_losses = []
    iters = 0

    # epochs
    num_epochs = 150

    print('Starting Training Loop....')
    # For each epoch
    for e in range(args.num_epochs):
        # for each batch in the dataloader
        for i, data in enumerate(dataloader, 0):
            ########## Training Discriminator ##########
            netD.zero_grad()

            # train with real data
            real_data = data[0].to(device)

            # make labels
            batch_size = real_data.size(0)
            labels = torch.full((batch_size, ), real_label, device=device)

            # forward pass real data through D
            real_outputD = netD(real_data).view(-1)

            # calc error on real data
            errD_real = criterion(real_outputD, labels)

            # calc grad
            errD_real.backward()
            D_x = real_outputD.mean().item()

            # train with fake data
            noise = torch.randn(batch_size, args.nz, 1, 1, device=device)
            fake_data = netG(noise)
            labels.fill_(fake_label)

            # classify fake
            fake_outputD = netD(fake_data.detach()).view(-1)

            # calc error on fake data
            errD_fake = criterion(fake_outputD, labels)

            # calc grad
            errD_fake.backward()
            D_G_z1 = fake_outputD.mean().item()

            # add all grad and update D
            errD = errD_real + errD_fake
            optimizerD.step()

            ########################################
            ########## Training Generator ##########
            netG.zero_grad()

            # since aim is fooling the netD, labels should be flipped
            labels.fill_(real_label)

            # forward pass with updated netD
            fake_outputD = netD(fake_data).view(-1)

            # calc error
            errG = criterion(fake_outputD, labels)

            # calc grad
            errG.backward()

            D_G_z2 = fake_outputD.mean().item()

            # update G
            optimizerG.step()

            ########################################

            # output training stats
            if i % 500 == 0:
                print(f'[{e+1}/{args.num_epochs}][{i+1}/{len(dataloader)}]\
  					\tLoss_D:{errD.item():.4f}\
  					\tLoss_G:{errG.item():.4f}\
  					\tD(x):{D_x:.4f}\
  					\tD(G(z)):{D_G_z1:.4f}/{D_G_z2:.4f}')

            # for later plot
            G_losses.append(errG.item())
            D_losses.append(errD.item())

            # generate fake image on fixed noise for comparison
            if ((iters % 500 == 0) or ((e == args.num_epochs - 1) and
                                       (i == len(dataloader) - 1))):
                with torch.no_grad():
                    fake = netG(fixed_noise).detach().cpu()
                    img_list.append(
                        vutils.make_grid(fake, padding=2, normalize=True))
            iters += 1

        if e % args.save_every == 0:
            # save at args.save_every epoch
            torch.save(netG.state_dict(), args.outputG)
            torch.save(netD.state_dict(), args.outputD)
            print(f'Made a New Checkpoint for {e+1}')
    # return training data for analysis
    return img_list, G_losses, D_losses