コード例 #1
0
 def __init__(self, num_classes, num_domains, pretrained=True, grl=True):
     super(DGalexnet, self).__init__()
     self.num_domains = num_domains
     self.base_model = alexnet(num_classes, pretrained=pretrained)
     self.discriminator = Discriminator([4096, 1024, 1024, num_domains],
                                        grl=grl,
                                        reverse=True)
     self.feature_layers = nn.Sequential(
         *list(self.base_model.classifier.children())[:-1])
     self.fc = list(self.base_model.classifier.children())[-1]
コード例 #2
0
ファイル: dcgan.py プロジェクト: Angelowin/deeplearning
    '/media/jiming/_E/angelo/Paper/degan_indian/dcgan/s_samples30.png',
    nrow=6,
    normalize=True)

print(data_new.size())
dataset = Data.TensorDataset(data_tensor=data_new, target_tensor=train_y)
loader = torch.utils.data.DataLoader(dataset=dataset,
                                     batch_size=opt.batchSize,
                                     shuffle=True)

###############   MODEL   ####################
ndf = opt.ndf
ngf = opt.ngf
nc = 1

netD = Discriminator(nc, ndf)
netG = Generator(nc, ngf, opt.nz)
#if(opt.cuda):
netD.cuda()
netG.cuda()

###########   LOSS & OPTIMIZER   ##########
criterion = nn.BCELoss()
optimizerD = torch.optim.Adam(netD.parameters(),
                              lr=opt.lr,
                              betas=(opt.beta1, 0.999))
optimizerG = torch.optim.Adam(netG.parameters(),
                              lr=opt.lr,
                              betas=(opt.beta1, 0.999))

##########   GLOBAL VARIABLES   ###########
コード例 #3
0
# custom weights initialization called on netG and netD
def weights_init(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        m.weight.data.normal_(0.0, 0.02)
        m.bias.data.fill_(0)
    elif classname.find('BatchNorm') != -1:
        m.weight.data.normal_(1.0, 0.02)
        m.bias.data.fill_(0)


ndf = opt.ndf
ngf = opt.ngf
nc = 3

netD = Discriminator(opt.input_nc, opt.output_nc, ndf)
netG = Generator(opt.input_nc, opt.output_nc, opt.ngf)
if (opt.cuda):
    netD.cuda()
    netG.cuda()

netG.apply(weights_init)
netD.apply(weights_init)
print(netD)
print(netG)

###########   LOSS & OPTIMIZER   ##########
criterion = nn.BCELoss()
criterionL1 = nn.L1Loss()
optimizerD = torch.optim.Adam(netD.parameters(),
                              lr=opt.lr,
コード例 #4
0
                    help='hyper parameters vgg layer 3 precept loss')

opt = parser.parse_args()
print(opt)

torch.cuda.set_device(opt.cuda_num)

if not os.path.exists('output_' + opt.output_str):
    os.makedirs('output_' + opt.output_str)

test_folder = 'output_' + opt.output_str

###### Definition of variables ######
# Networks
model = model(fea_channel=opt.fea_channel)
discriminator = Discriminator()

if opt.load_model:
    load_path = opt.load_path
    model.load_state_dict(torch.load(load_path))
    print('model loaded')
    torch.cuda.empty_cache()

model = model.cuda()
discriminator = discriminator.cuda()

vgg_model = models.vgg16(pretrained=True)
vgg_model.cuda()

loss_network = utils.LossNetwork(vgg_model)
loss_network.eval()
コード例 #5
0
ファイル: train.py プロジェクト: Emmmmmaa/STIG

###########   MODEL   ###########
# custom weights initialization called on netG and netD
def weights_init(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        m.weight.data.normal_(0.0, 0.02)
        m.bias.data.fill_(0)
    elif classname.find('BatchNorm') != -1:
        m.weight.data.normal_(1.0, 0.02)
        m.bias.data.fill_(0)


generator = Generator(opt.input_nc, opt.output_nc, opt.ngf)
discriminator = Discriminator(opt.input_nc, opt.output_nc, opt.ndf)

netD1 = discriminator
netD2 = discriminator
netG1 = generator
netG2 = generator

if opt.netG != '':
    netG1.load_state_dict(torch.load(opt.netG))
    netG2.load_state_dict(torch.load(opt.netG))
    netD1.load_state_dict(torch.load(opt.netD))
    netD2.load_state_dict(torch.load(opt.netD))
if opt.cuda:
    netD1.cuda()
    netD2.cuda()
    netG1.cuda()
コード例 #6
0
ファイル: main.py プロジェクト: jungwon-choi/WGAN-pytorch
def main(args):
    #===========================================================================
    # Set the file name format
    FILE_NAME_FORMAT = "{0}_{1}_{2:d}_{3:d}_{4:d}_{5:f}{6}".format(
        args.model, args.dataset, args.epochs, args.obj_step, args.batch_size,
        args.lr, args.flag)

    # Set the results file path
    RESULT_FILE_NAME = FILE_NAME_FORMAT + '_results.pkl'
    RESULT_FILE_PATH = os.path.join(RESULTS_PATH, RESULT_FILE_NAME)
    # Set the checkpoint file path
    CHECKPOINT_FILE_NAME = FILE_NAME_FORMAT + '.ckpt'
    CHECKPOINT_FILE_PATH = os.path.join(CHECKPOINT_PATH, CHECKPOINT_FILE_NAME)
    BEST_CHECKPOINT_FILE_NAME = FILE_NAME_FORMAT + '_best.ckpt'
    BEST_CHECKPOINT_FILE_PATH = os.path.join(CHECKPOINT_PATH,
                                             BEST_CHECKPOINT_FILE_NAME)

    # Set the random seed same for reproducibility
    random.seed(190811)
    torch.manual_seed(190811)
    torch.cuda.manual_seed_all(190811)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

    # Step1 ====================================================================
    # Load dataset
    if args.dataset == 'CelebA':
        dataloader = CelebA_Dataloader()
    else:
        assert False, "Please select the proper dataset."

    train_loader = dataloader.get_train_loader(batch_size=args.batch_size,
                                               num_workers=args.num_workers)
    print('==> DataLoader ready.')

    # Step2 ====================================================================
    # Make the model
    if args.model in ['WGAN', 'DCGAN']:
        generator = Generator(BN=True)
        discriminator = Discriminator(BN=True)
    elif args.model in ['WGAN_noBN', 'DCGAN_noBN']:
        generator = Generator(BN=False)
        discriminator = Discriminator(BN=False)
    else:
        assert False, "Please select the proper model."

    # Check DataParallel available
    if torch.cuda.device_count() > 1:
        generator = nn.DataParallel(generator)
        discriminator = nn.DataParallel(discriminator)

    # Check CUDA available
    if torch.cuda.is_available():
        generator.cuda()
        discriminator.cuda()
    print('==> Model ready.')

    # Step3 ====================================================================
    # Set loss function and optimizer
    if args.model in ['DCGAN', 'DCGAN_noBN']:
        criterion = nn.BCELoss()
    else:
        criterion = None
    optimizer_G = torch.optim.RMSprop(generator.parameters(), lr=args.lr)
    optimizer_D = torch.optim.RMSprop(discriminator.parameters(), lr=args.lr)
    step_counter = StepCounter(args.obj_step)
    print('==> Criterion and optimizer ready.')

    # Step4 ====================================================================
    # Train and validate the model
    start_epoch = 0
    best_metric = float("inf")
    validate_noise = torch.randn(args.batch_size, 100, 1, 1)

    # Initialize the result lists
    train_loss_G = []
    train_loss_D = []
    train_distance = []

    if args.resume:
        assert os.path.exists(CHECKPOINT_FILE_PATH), 'No checkpoint file!'
        checkpoint = torch.load(CHECKPOINT_FILE_PATH)
        generator.load_state_dict(checkpoint['generator_state_dict'])
        discriminator.load_state_dict(checkpoint['discriminator_state_dict'])
        optimizer_G.load_state_dict(checkpoint['optimizer_G_state_dict'])
        optimizer_D.load_state_dict(checkpoint['optimizer_D_state_dict'])
        start_epoch = checkpoint['epoch']
        step_counter.current_step = checkpoint['current_step']
        train_loss_G = checkpoint['train_loss_G']
        train_loss_D = checkpoint['train_loss_D']
        train_distance = checkpoint['train_distance']
        best_metric = checkpoint['best_metric']

    # Save the training information
    result_data = {}
    result_data['model'] = args.model
    result_data['dataset'] = args.dataset
    result_data['target_epoch'] = args.epochs
    result_data['batch_size'] = args.batch_size

    # Check the directory of the file path
    if not os.path.exists(os.path.dirname(RESULT_FILE_PATH)):
        os.makedirs(os.path.dirname(RESULT_FILE_PATH))
    if not os.path.exists(os.path.dirname(CHECKPOINT_FILE_PATH)):
        os.makedirs(os.path.dirname(CHECKPOINT_FILE_PATH))

    print('==> Train ready.')

    # Validate before training (step 0)
    val(generator, validate_noise, step_counter, FILE_NAME_FORMAT)

    for epoch in range(args.epochs):
        # strat after the checkpoint epoch
        if epoch < start_epoch:
            continue
        print("\n[Epoch: {:3d}/{:3d}]".format(epoch + 1, args.epochs))
        epoch_time = time.time()
        #=======================================================================
        # train the model (+ validate the model)
        tloss_G, tloss_D, tdist = train(generator, discriminator, train_loader,
                                        criterion, optimizer_G, optimizer_D,
                                        args.clipping, args.num_critic,
                                        step_counter, validate_noise,
                                        FILE_NAME_FORMAT)
        train_loss_G.extend(tloss_G)
        train_loss_D.extend(tloss_D)
        train_distance.extend(tdist)
        #=======================================================================
        current = time.time()

        # Calculate average loss
        avg_loss_G = sum(tloss_G) / len(tloss_G)
        avg_loss_D = sum(tloss_D) / len(tloss_D)
        avg_distance = sum(tdist) / len(tdist)

        # Save the current result
        result_data['current_epoch'] = epoch
        result_data['train_loss_G'] = train_loss_G
        result_data['train_loss_D'] = train_loss_D
        result_data['train_distance'] = train_distance

        # Save result_data as pkl file
        with open(RESULT_FILE_PATH, 'wb') as pkl_file:
            pickle.dump(result_data,
                        pkl_file,
                        protocol=pickle.HIGHEST_PROTOCOL)

        # Save the best checkpoint
        # if avg_distance < best_metric:
        #     best_metric = avg_distance
        #     torch.save({
        #         'epoch': epoch+1,
        #         'generator_state_dict': generator.state_dict(),
        #         'discriminator_state_dict': discriminator.state_dict(),
        #         'optimizer_G_state_dict': optimizer_G.state_dict(),
        #         'optimizer_D_state_dict': optimizer_D.state_dict(),
        #         'current_step': step_counter.current_step,
        #         'best_metric': best_metric,
        #         }, BEST_CHECKPOINT_FILE_PATH)

        # Save the current checkpoint
        torch.save(
            {
                'epoch': epoch + 1,
                'generator_state_dict': generator.state_dict(),
                'discriminator_state_dict': discriminator.state_dict(),
                'optimizer_G_state_dict': optimizer_G.state_dict(),
                'optimizer_D_state_dict': optimizer_D.state_dict(),
                'current_step': step_counter.current_step,
                'train_loss_G': train_loss_G,
                'train_loss_D': train_loss_D,
                'train_distance': train_distance,
                'best_metric': best_metric,
            }, CHECKPOINT_FILE_PATH)

        # Print the information on the console
        print("model                : {}".format(args.model))
        print("dataset              : {}".format(args.dataset))
        print("batch_size           : {}".format(args.batch_size))
        print("current step         : {:d}".format(step_counter.current_step))
        print("current lrate        : {:f}".format(args.lr))
        print("gen/disc loss        : {:f}/{:f}".format(
            avg_loss_G, avg_loss_D))
        print("distance metric      : {:f}".format(avg_distance))
        print("epoch time           : {0:.3f} sec".format(current -
                                                          epoch_time))
        print("Current elapsed time : {0:.3f} sec".format(current - start))

        # If iteration step has been satisfied
        if step_counter.exit_signal:
            break

    print('==> Train done.')

    print(' '.join(['Results have been saved at', RESULT_FILE_PATH]))
    print(' '.join(['Checkpoints have been saved at', CHECKPOINT_FILE_PATH]))
コード例 #7
0
ファイル: main.py プロジェクト: djkim1991/DCGAN
import torch
import torch.nn as nn
from torch.optim import Adam
from torch.autograd import Variable

from model.Discriminator import Discriminator
from model.Generator import Generator

from loaders.MNISTLoader import MNIST
from util.ImageUtil import ImageUtil

# create model objects
discriminator = Discriminator()
generator = Generator()

# set data loader
dataLoader = MNIST()
train_loader, test_loader = dataLoader.train_loader, dataLoader.test_loader

# optimizer
D_optimizer = Adam(params=discriminator.parameters(), lr=0.001)
G_optimizer = Adam(params=generator.parameters(), lr=0.001)

# loss function
D_loss_function = nn.BCELoss()  # Binary Cross Entropy loss
G_loss_function = nn.BCELoss()  # Binary Cross Entropy loss

imageUtil = ImageUtil()

epoch_size = 10000
for epoch in range(epoch_size):
コード例 #8
0
 def __init__(self, num_classes, num_domains, pretrained=True, grl=True):
     super(DGresnet, self).__init__()
     self.num_domains = num_domains
     self.base_model = resnet(num_classes=num_classes, pretrained=pretrained)
     self.discriminator = Discriminator([512, 1024, 1024, num_domains], grl=grl, reverse=True)
コード例 #9
0
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        m.weight.data.normal_(0.0, 0.02)
    elif classname.find('BatchNorm') != -1:
        m.weight.data.normal_(1.0, 0.02)
        m.bias.data.fill_(0)

ndf = 64
ngf = 64
input_nc = 3
output_nc = 3
fineSize = 128
batchSize = 1
cuda = 1

D_A = Discriminator(input_nc,ndf)
D_B = Discriminator(output_nc,ndf)
G_AB = Generator(input_nc, output_nc, ngf)
G_BA = Generator(output_nc, input_nc, ngf)

G_AB.apply(weights_init)
G_BA.apply(weights_init)

D_A.apply(weights_init)
D_B.apply(weights_init)

if(cuda):
    D_A.cuda()
    D_B.cuda()
    G_AB.cuda()
    G_BA.cuda()
コード例 #10
0
###########   MODEL   ###########
# custom weights initialization called on netG and netD
def weights_init(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        m.weight.data.normal_(0.0, 0.02)
    elif classname.find('BatchNorm') != -1:
        m.weight.data.normal_(1.0, 0.02)
        m.bias.data.fill_(0)


ndf = opt.ndf
ngf = opt.ngf
nc = 3

D_A = Discriminator(opt.input_nc, ndf)
D_B = Discriminator(opt.output_nc, ndf)
G_AB = Generator(opt.input_nc, opt.output_nc, opt.ngf)
G_BA = Generator(opt.output_nc, opt.input_nc, opt.ngf)

if (opt.G_AB != ''):
    print('Warning! Loading pre-trained weights.')
    G_AB.load_state_dict(torch.load(opt.G_AB))
    G_BA.load_state_dict(torch.load(opt.G_BA))
else:
    G_AB.apply(weights_init)
    G_BA.apply(weights_init)

if (opt.cuda):
    D_A.cuda()
    D_B.cuda()
コード例 #11
0
ファイル: train.py プロジェクト: santolina/ESRGAN-pytorch
 def build_model(self):
     self.generator = ESRGAN(3, 3, 64, scale_factor=self.scale_factor).to(self.device)
     self.discriminator = Discriminator().to(self.device)
     self.load_model()
コード例 #12
0
ファイル: train.py プロジェクト: santolina/ESRGAN-pytorch
class Trainer:
    def __init__(self, config, data_loader):
        self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
        self.num_epoch = config.num_epoch
        self.epoch = config.epoch
        self.image_size = config.image_size
        self.data_loader = data_loader
        self.checkpoint_dir = config.checkpoint_dir
        self.batch_size = config.batch_size
        self.sample_dir = config.sample_dir
        self.nf = config.nf
        self.scale_factor = config.scale_factor

        if config.is_perceptual_oriented:
            self.lr = config.p_lr
            self.content_loss_factor = config.p_content_loss_factor
            self.perceptual_loss_factor = config.p_perceptual_loss_factor
            self.adversarial_loss_factor = config.p_adversarial_loss_factor
            self.decay_batch_size = config.p_decay_batch_size
        else:
            self.lr = config.g_lr
            self.content_loss_factor = config.g_content_loss_factor
            self.perceptual_loss_factor = config.g_perceptual_loss_factor
            self.adversarial_loss_factor = config.g_adversarial_loss_factor
            self.decay_batch_size = config.g_decay_batch_size

        self.build_model()
        self.optimizer_generator = Adam(self.generator.parameters(), lr=self.lr, betas=(config.b1, config.b2),
                                        weight_decay=config.weight_decay)
        self.optimizer_discriminator = Adam(self.discriminator.parameters(), lr=self.lr, betas=(config.b1, config.b2),
                                            weight_decay=config.weight_decay)

        self.lr_scheduler_generator = torch.optim.lr_scheduler.StepLR(self.optimizer_generator, self.decay_batch_size)
        self.lr_scheduler_discriminator = torch.optim.lr_scheduler.StepLR(self.optimizer_discriminator, self.decay_batch_size)

    def train(self):
        total_step = len(self.data_loader)
        adversarial_criterion = nn.BCEWithLogitsLoss().to(self.device)
        content_criterion = nn.L1Loss().to(self.device)
        perception_criterion = PerceptualLoss().to(self.device)
        self.generator.train()
        self.discriminator.train()

        for epoch in range(self.epoch, self.num_epoch):
            if not os.path.exists(os.path.join(self.sample_dir, str(epoch))):
                os.makedirs(os.path.join(self.sample_dir, str(epoch)))

            for step, image in enumerate(self.data_loader):
                low_resolution = image['lr'].to(self.device)
                high_resolution = image['hr'].to(self.device)

                real_labels = torch.ones((high_resolution.size(0), 1)).to(self.device)
                fake_labels = torch.zeros((high_resolution.size(0), 1)).to(self.device)

                ##########################
                #   training generator   #
                ##########################
                self.optimizer_generator.zero_grad()
                fake_high_resolution = self.generator(low_resolution)

                score_real = self.discriminator(high_resolution)
                score_fake = self.discriminator(fake_high_resolution)
                discriminator_rf = score_real - score_fake.mean()
                discriminator_fr = score_fake - score_real.mean()

                adversarial_loss_rf = adversarial_criterion(discriminator_rf, fake_labels)
                adversarial_loss_fr = adversarial_criterion(discriminator_fr, real_labels)
                adversarial_loss = (adversarial_loss_fr + adversarial_loss_rf) / 2

                perceptual_loss = perception_criterion(high_resolution, fake_high_resolution)
                content_loss = content_criterion(fake_high_resolution, high_resolution)

                generator_loss = adversarial_loss * self.adversarial_loss_factor + \
                                 perceptual_loss * self.perceptual_loss_factor + \
                                 content_loss * self.content_loss_factor

                generator_loss.backward()
                self.optimizer_generator.step()

                ##########################
                # training discriminator #
                ##########################

                self.optimizer_discriminator.zero_grad()

                score_real = self.discriminator(high_resolution)
                score_fake = self.discriminator(fake_high_resolution.detach())
                discriminator_rf = score_real - score_fake.mean()
                discriminator_fr = score_fake - score_real.mean()

                adversarial_loss_rf = adversarial_criterion(discriminator_rf, real_labels)
                adversarial_loss_fr = adversarial_criterion(discriminator_fr, fake_labels)
                discriminator_loss = (adversarial_loss_fr + adversarial_loss_rf) / 2

                discriminator_loss.backward()
                self.optimizer_discriminator.step()

                self.lr_scheduler_generator.step()
                self.lr_scheduler_discriminator.step()
                if step % 1000 == 0:
                    print(f"[Epoch {epoch}/{self.num_epoch}] [Batch {step}/{total_step}] "
                          f"[D loss {discriminator_loss.item():.4f}] [G loss {generator_loss.item():.4f}] "
                          f"[adversarial loss {adversarial_loss.item() * self.adversarial_loss_factor:.4f}]"
                          f"[perceptual loss {perceptual_loss.item() * self.perceptual_loss_factor:.4f}]"
                          f"[content loss {content_loss.item() * self.content_loss_factor:.4f}]"
                          f"")
                    if step % 5000 == 0:
                        result = torch.cat((high_resolution, fake_high_resolution), 2)
                        save_image(result, os.path.join(self.sample_dir, str(epoch), f"SR_{step}.png"))

            torch.save(self.generator.state_dict(), os.path.join(self.checkpoint_dir, f"generator_{epoch}.pth"))
            torch.save(self.discriminator.state_dict(), os.path.join(self.checkpoint_dir, f"discriminator_{epoch}.pth"))

    def build_model(self):
        self.generator = ESRGAN(3, 3, 64, scale_factor=self.scale_factor).to(self.device)
        self.discriminator = Discriminator().to(self.device)
        self.load_model()

    def load_model(self):
        print(f"[*] Load model from {self.checkpoint_dir}")
        if not os.path.exists(self.checkpoint_dir):
            self.makedirs = os.makedirs(self.checkpoint_dir)

        if not os.listdir(self.checkpoint_dir):
            print(f"[!] No checkpoint in {self.checkpoint_dir}")
            return

        generator = glob(os.path.join(self.checkpoint_dir, f'generator_{self.epoch - 1}.pth'))
        discriminator = glob(os.path.join(self.checkpoint_dir, f'discriminator_{self.epoch - 1}.pth'))

        if not generator:
            print(f"[!] No checkpoint in epoch {self.epoch - 1}")
            return

        self.generator.load_state_dict(torch.load(generator[0]))
        self.discriminator.load_state_dict(torch.load(discriminator[0]))
コード例 #13
0
ファイル: dcgan.py プロジェクト: zkghit/Paper-Implementations
                             transforms.ToTensor(),
                         ]),
                         download=True)

loader = torch.utils.data.DataLoader(dataset=dataset,
                                     batch_size=opt.batchSize,
                                     shuffle=True)

###############   MODEL   ####################
ndf = opt.ndf
ngf = opt.ngf
nc = 1
if (opt.dataset == 'CIFAR'):
    nc = 3

netD = Discriminator(nc, ndf)
netG = Generator(nc, ngf, opt.nz)
if (opt.cuda):
    netD.cuda()
    netG.cuda()

###########   LOSS & OPTIMIZER   ##########
criterion = nn.BCELoss()
optimizerD = torch.optim.Adam(netD.parameters(),
                              lr=opt.lr,
                              betas=(opt.beta1, 0.999))
optimizerG = torch.optim.Adam(netG.parameters(),
                              lr=opt.lr,
                              betas=(opt.beta1, 0.999))

##########   GLOBAL VARIABLES   ###########
コード例 #14
0
def train(**kwargs):
    opt._parse(kwargs)

    id_file_dir = 'ImageSets/Main/trainval_big_64.txt'
    img_dir = 'JPEGImages'
    anno_dir = 'AnnotationsBig'
    large_dataset = DatasetAugmented(opt, id_file=id_file_dir, img_dir=img_dir, anno_dir=anno_dir)
    dataloader_large = data_.DataLoader(large_dataset, \
                                        batch_size=1, \
                                        shuffle=True, \
                                        # pin_memory=True,
                                        num_workers=opt.num_workers)

    id_file_dir = 'ImageSets/Main/trainval_pcgan_generated_small.txt'
    img_dir = 'JPEGImagesPCGANGenerated'
    anno_dir = 'AnnotationsPCGANGenerated'

    small_dataset = DatasetAugmented(opt, id_file=id_file_dir, img_dir=img_dir, anno_dir=anno_dir)
    dataloader_small = data_.DataLoader(small_dataset, \
                                        batch_size=1, \
                                        shuffle=True, \
                                        # pin_memory=True,
                                        num_workers=opt.num_workers)

    small_test_dataset = SmallImageTestDataset(opt)
    dataloader_small_test = data_.DataLoader(small_test_dataset, \
                                             batch_size=1, \
                                             shuffle=True, \
                                             pin_memory=True,
                                             num_workers=opt.test_num_workers)

    print('{:d} roidb large entries'.format(len(dataloader_large)))
    print('{:d} roidb small entries'.format(len(dataloader_small)))
    print('{:d} roidb small test entries'.format(len(dataloader_small_test)))

    faster_rcnn = FasterRCNNVGG16_GAN()
    faster_rcnn_ = FasterRCNNVGG16()

    print('model construct completed')
    trainer_ = FasterRCNNTrainer(faster_rcnn_).cuda()

    netD = Discriminator()
    netD.apply(weights_init)

    faster_rcnn_.cuda()
    netD.cuda()

    lr = opt.LEARNING_RATE
    params_D = []
    for key, value in dict(netD.named_parameters()).items():
        if value.requires_grad:
            if 'bias' in key:
                params_D += [{'params': [value], 'lr': lr * 2, \
                              'weight_decay': 0}]
            else:
                params_D += [{'params': [value], 'lr': lr, 'weight_decay': opt.weight_decay}]

    optimizerD = optim.SGD(params_D, momentum=0.9)
    # optimizerG = optim.Adam(faster_rcnn.parameters(), lr=lr, betas=(0.5, 0.999))

    if not opt.gan_load_path:
        trainer_.load(opt.load_path)
        print('load pretrained faster rcnn model from %s' % opt.load_path)

        # optimizer_ = trainer_.optimizer
        state_dict_ = faster_rcnn_.state_dict()
        state_dict = faster_rcnn.state_dict()

        # for k, i in state_dict_.items():
        #     icpu = i.cpu()
        #     b = icpu.data.numpy()
        #     sz = icpu.data.numpy().shape
        #     state_dict[k] = state_dict_[k]
        state_dict.update(state_dict_)
        faster_rcnn.load_state_dict(state_dict)
        faster_rcnn.cuda()

    trainer = FasterRCNNTrainer(faster_rcnn).cuda()

    if opt.gan_load_path:
        trainer.load(opt.gan_load_path, load_optimizer=True)
        print('load pretrained generator model from %s' % opt.gan_load_path)

    if opt.disc_load_path:
        state_dict_d = torch.load(opt.disc_load_path)
        netD.load_state_dict(state_dict_d['model'])
        optimizerD.load_state_dict(state_dict_d['optimizer'])
        print('load pretrained discriminator model from %s' % opt.disc_load_path)

    real_label = 1
    fake_label = 0

    # rpn_loc_loss = []
    # rpn_cls_loss = []
    # roi_loc_loss = []
    # roi_cls_loss = []
    # total_loss = []
    test_map_list = []

    criterion = nn.BCELoss()
    iters_per_epoch = min(len(dataloader_large), len(dataloader_small))
    best_map = 0
    device = torch.device("cuda:2" if (torch.cuda.is_available()) else "cpu")

    for epoch in range(1, opt.gan_epoch + 1):
        trainer.reset_meters()

        loss_temp_G = 0
        loss_temp_D = 0
        if epoch % (opt.lr_decay_step + 1) == 0:
            adjust_learning_rate(trainer.optimizer, opt.LEARNING_RATE_DECAY_GAMMA)
            adjust_learning_rate(optimizerD, opt.LEARNING_RATE_DECAY_GAMMA)
            lr *= opt.LEARNING_RATE_DECAY_GAMMA

        data_iter_large = iter(dataloader_large)
        data_iter_small = iter(dataloader_small)
        for step in tqdm(range(iters_per_epoch)):
            #####(1) Update Perceptual branch + generator(zero mapping)
            ####     Discriminator network: maximize log(D(x))+ log(1-D(G(z)))

            ##### Train with all_real batch
            ##### Format batch
            netD.zero_grad()
            data_large = next(data_iter_large)
            img, bbox_, label_, scale_ = data_large
            scale = at.scalar(scale_)
            img, bbox, label = img.cuda().float(), bbox_.cuda(), label_.cuda()

            ##### Forward pass real batch through D
            # faster_rcnn.zero_grad()
            # trainer.optimizer.zero_grad()
            # trainer.optimizer.zero_grad()

            losses, pooled_feat, rois_label, conv1_feat = trainer.train_step_gan(img, bbox, label, scale)

            # if step < 1:
            #     custom_viz(conv1_feat.cpu().detach(), 'results-gan/features/large_orig_%s' % str(epoch))
            #     custom_viz(pooled_feat.cpu().detach(), 'results-gan/features/large_scaled_%s' % str(epoch))

            keep = rois_label != 0
            pooled_feat = pooled_feat[keep]

            real_b_size = pooled_feat.size(0)
            real_labels = torch.full((real_b_size,), real_label, device=device)

            output = netD(pooled_feat.detach()).view(-1)
            # print(output)

            ##### Calculate loss on all-real batch

            errD_real = criterion(output, real_labels)
            errD_real.backward()
            D_x = output.mean().item()

            ##### Train with all_fake batch
            # Generate batch of fake images with G
            data_small = next(data_iter_small)
            img, bbox_, label_, scale_ = data_small
            scale = at.scalar(scale_)
            img, bbox, label = img.cuda().float(), bbox_.cuda(), label_.cuda()
            trainer.optimizer.zero_grad()

            losses, fake_pooled_feat, rois_label, conv1_feat = trainer.train_step_gan_second(img, bbox, label, scale)

            # if step < 1:
            #     custom_viz(conv1_feat.cpu().detach(), 'results-gan/features/small_orig_%s' % str(epoch))
            #     custom_viz(fake_pooled_feat.cpu().detach(), 'results-gan/features/small_scaled_%s' % str(epoch))

            # select fg rois
            keep = rois_label != 0
            fake_pooled_feat = fake_pooled_feat[keep]
            # print(fake_pooled_feat)
            # print(torch.nonzero(torch.isnan(fake_pooled_feat.view(-1))))

            fake_b_size = fake_pooled_feat.size(0)
            fake_labels = torch.full((fake_b_size,), fake_label, device=device)

            # optimizerD.zero_grad()
            output = netD(fake_pooled_feat.detach()).view(-1)

            # calculate D's loss on the all_fake batch
            errD_fake = criterion(output, fake_labels)
            errD_fake.backward(retain_graph=True)
            D_G_Z1 = output.mean().item()
            # add the gradients from the all-real and all-fake batches
            errD = errD_fake + errD_real
            # Update D
            optimizerD.step()

            ################################################
            #####(2) Update G network: maximize log(D(G(z)))
            ################################################
            faster_rcnn.zero_grad()

            fake_labels.fill_(real_label)

            output = netD(fake_pooled_feat).view(-1)

            # calculate gradients for G
            errG = criterion(output, fake_labels)
            errG += losses.total_loss
            errG.backward()
            D_G_Z2 = output.mean().item()

            clip_gradient(faster_rcnn, 10.)

            trainer.optimizer.step()

            loss_temp_G += errG.item()
            loss_temp_D += errD.item()

            if step % opt.plot_every == 0:
                if step > 0:
                    loss_temp_G /= (opt.plot_every + 1)
                    loss_temp_D /= (opt.plot_every + 1)

                # losses_dict = trainer.get_meter_data()
                #
                # rpn_loc_loss.append(losses_dict['rpn_loc_loss'])
                # roi_loc_loss.append(losses_dict['roi_loc_loss'])
                # rpn_cls_loss.append(losses_dict['rpn_cls_loss'])
                # roi_cls_loss.append(losses_dict['roi_cls_loss'])
                # total_loss.append(losses_dict['total_loss'])
                #
                # save_losses('rpn_loc_loss', rpn_loc_loss, epoch)
                # save_losses('roi_loc_loss', roi_loc_loss, epoch)
                # save_losses('rpn_cls_loss', rpn_cls_loss, epoch)
                # save_losses('total_loss', total_loss, epoch)
                # save_losses('roi_cls_loss', roi_cls_loss, epoch)

                print("[epoch %2d] lossG: %.4f lossD: %.4f, lr: %.2e"
                      % (epoch, loss_temp_G, loss_temp_D, lr))
                print("\t\t\trcnn_cls: %.4f, rcnn_box %.4f"
                      % (losses.roi_cls_loss, losses.roi_loc_loss))

                print("\t\t\trpn_cls: %.4f, rpn_box %.4f"
                      % (losses.rpn_cls_loss, losses.rpn_loc_loss))

                print('\t\t\tD(x): %.4f\tD(G(z)): %.4f / %.4f'
                      % (D_x, D_G_Z1, D_G_Z2))
                loss_temp_D = 0
                loss_temp_G = 0

        eval_result = eval(dataloader_small_test, faster_rcnn, test_num=opt.test_num)
        test_map_list.append(eval_result['map'])
        save_map(test_map_list, epoch)

        lr_ = trainer.faster_rcnn.optimizer.param_groups[0]['lr']
        log_info = 'lr:{}, map:{}'.format(str(lr_),
                                                  str(eval_result['map']))
        print(log_info)

        if eval_result['map'] > best_map:
            best_map = eval_result['map']
            timestr = time.strftime('%m%d%H%M')
            trainer.save(best_map=best_map, save_path='checkpoints-pcgan-generated/gan_fasterrcnn_%s' % timestr)

            save_dict = dict()

            save_dict['model'] = netD.state_dict()

            save_dict['optimizer'] = optimizerD.state_dict()
            save_path = 'checkpoints-pcgan-generated/discriminator_%s' % timestr
            torch.save(save_dict, save_path)
コード例 #15
0
ファイル: LeakGan.py プロジェクト: Liugawa/GAN_Poem_Generate
    def __init__(self, wi_dict_path, iw_dict_path, train_data, val_data=None):
        super().__init__()

        self.vocab_size = 20
        self.emb_dim = 64
        self.hidden_dim = 64

        self.input_length = 8
        self.sequence_length = 32
        self.filter_size = [2, 3]
        self.num_filters = [100, 200]
        self.l2_reg_lambda = 0.2
        self.dropout_keep_prob = 0.75
        self.batch_size = 64
        self.generate_num = 256
        self.start_token = 0
        self.dis_embedding_dim = 64
        self.goal_size = 16

        self.save_path = 'save/model/LeakGan/LeakGan'
        self.model_path = 'save/model/LeakGan'
        self.best_path_pre = 'save/model/best-pre-gen/best-pre-gen'
        self.best_path = 'save/model/best-leak-gan/best-leak-gan'
        self.best_model_path = 'save/model/best-leak-gan'

        self.truth_file = 'save/truth.txt'
        self.generator_file = 'save/generator.txt'
        self.test_file = 'save/test_file.txt'

        self.trunc_train_file = 'save/trunc_train.txt'
        self.trunc_val_file = 'save/trunc_val.txt'
        trunc_data(train_data, self.trunc_train_file, self.input_length)
        trunc_data(val_data, self.trunc_val_file, self.input_length)

        if not os.path.isfile(wi_dict_path) or not os.path.isfile(
                iw_dict_path):
            print('Building word/index dictionaries...')
            self.sequence_length, self.vocab_size, word_index_dict, index_word_dict = text_precess(
                train_data, val_data)
            print('Vocab Size: %d' % self.vocab_size)
            print('Saving dictionaries to ' + wi_dict_path + ' ' +
                  iw_dict_path + '...')
            with open(wi_dict_path, 'wb') as f:
                pickle.dump(word_index_dict, f)
            with open(iw_dict_path, 'wb') as f:
                pickle.dump(index_word_dict, f)
        else:
            print('Loading word/index dectionaries...')
            with open(wi_dict_path, 'rb') as f:
                word_index_dict = pickle.load(f)
            with open(iw_dict_path, 'rb') as f:
                index_word_dict = pickle.load(f)
            self.vocab_size = len(word_index_dict) + 1
            print('Vocab Size: %d' % self.vocab_size)

        self.wi_dict = word_index_dict
        self.iw_dict = index_word_dict
        self.train_data = train_data
        self.val_data = val_data

        goal_out_size = sum(self.num_filters)
        self.discriminator = Discriminator(
            sequence_length=self.sequence_length,
            num_classes=2,
            vocab_size=self.vocab_size,
            dis_emb_dim=self.dis_embedding_dim,
            filter_sizes=self.filter_size,
            num_filters=self.num_filters,
            batch_size=self.batch_size,
            hidden_dim=self.hidden_dim,
            start_token=self.start_token,
            goal_out_size=goal_out_size,
            step_size=4,
            l2_reg_lambda=self.l2_reg_lambda)

        self.generator = Generator(num_classes=2,
                                   num_vocabulary=self.vocab_size,
                                   batch_size=self.batch_size,
                                   emb_dim=self.emb_dim,
                                   dis_emb_dim=self.dis_embedding_dim,
                                   goal_size=self.goal_size,
                                   hidden_dim=self.hidden_dim,
                                   sequence_length=self.sequence_length,
                                   input_length=self.input_length,
                                   filter_sizes=self.filter_size,
                                   start_token=self.start_token,
                                   num_filters=self.num_filters,
                                   goal_out_size=goal_out_size,
                                   D_model=self.discriminator,
                                   step_size=4)

        self.saver = tf.train.Saver()
        self.best_pre_saver = tf.train.Saver()
        self.best_saver = tf.train.Saver()

        self.val_bleu1 = Bleu(real_text=self.trunc_val_file, gram=1)
        self.val_bleu2 = Bleu(real_text=self.trunc_val_file, gram=2)